diff options
author | space-nuko <24979496+space-nuko@users.noreply.github.com> | 2023-02-10 12:47:08 +0000 |
---|---|---|
committer | space-nuko <24979496+space-nuko@users.noreply.github.com> | 2023-02-10 12:47:08 +0000 |
commit | 21880eb9e57b884635a07d2360831b4186afddf4 (patch) | |
tree | ecc0969bb4e36b1addb157464b6dae86faefe583 /modules/models/diffusion/uni_pc/sampler.py | |
parent | 125319988984987801dc4b4ab1e5ed36e9b211c5 (diff) | |
download | stable-diffusion-webui-gfx803-21880eb9e57b884635a07d2360831b4186afddf4.tar.gz stable-diffusion-webui-gfx803-21880eb9e57b884635a07d2360831b4186afddf4.tar.bz2 stable-diffusion-webui-gfx803-21880eb9e57b884635a07d2360831b4186afddf4.zip |
Fix logspam and live previews
Diffstat (limited to 'modules/models/diffusion/uni_pc/sampler.py')
-rw-r--r-- | modules/models/diffusion/uni_pc/sampler.py | 20 |
1 files changed, 15 insertions, 5 deletions
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py index 7cccd8a2..219e9862 100644 --- a/modules/models/diffusion/uni_pc/sampler.py +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -19,9 +19,10 @@ class UniPCSampler(object): attr = attr.to(torch.device("cuda")) setattr(self, name, attr) - def set_hooks(self, before, after): - self.before_sample = before - self.after_sample = after + def set_hooks(self, before_sample, after_sample, after_update): + self.before_sample = before_sample + self.after_sample = after_sample + self.after_update = after_update @torch.no_grad() def sample(self, @@ -50,9 +51,17 @@ class UniPCSampler(object): ): if conditioning is not None: if isinstance(conditioning, dict): - cbs = conditioning[list(conditioning.keys())[0]].shape[0] + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] if cbs != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: if conditioning.shape[0] != batch_size: print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") @@ -60,6 +69,7 @@ class UniPCSampler(object): # sampling C, H, W = shape size = (batch_size, C, H, W) + print(f'Data shape for UniPC sampling is {size}, eta {eta}') device = self.model.betas.device if x_T is None: @@ -79,7 +89,7 @@ class UniPCSampler(object): guidance_scale=unconditional_guidance_scale, ) - uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample) + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) return x.to(device), None |