diff options
author | Karun <karun.ellango7@gmail.com> | 2023-03-25 09:12:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-25 09:12:55 +0000 |
commit | 63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b (patch) | |
tree | 9a7c38070d83b409895704125525dfc70cc21215 /modules/models/diffusion/uni_pc/sampler.py | |
parent | ca2b8faa83076a21dd14c974f03f88eb6da57485 (diff) | |
parent | 70615448b2ef3285dba9bb1992974cb1eaf10995 (diff) | |
download | stable-diffusion-webui-gfx803-63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b.tar.gz stable-diffusion-webui-gfx803-63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b.tar.bz2 stable-diffusion-webui-gfx803-63a2f8d8225228a52b3ca7f19d2ba1fd07a6234b.zip |
Merge branch 'master' into master
Diffstat (limited to 'modules/models/diffusion/uni_pc/sampler.py')
-rw-r--r-- | modules/models/diffusion/uni_pc/sampler.py | 100 |
1 files changed, 100 insertions, 0 deletions
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py new file mode 100644 index 00000000..a241c8a7 --- /dev/null +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -0,0 +1,100 @@ +"""SAMPLING ONLY.""" + +import torch + +from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC +from modules import shared, devices + + +class UniPCSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.before_sample = None + self.after_sample = None + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != devices.device: + attr = attr.to(devices.device) + setattr(self, name, attr) + + def set_hooks(self, before_sample, after_sample, after_update): + self.before_sample = before_sample + self.after_sample = after_sample + self.after_update = after_update + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + elif isinstance(conditioning, list): + for ctmp in conditioning: + if ctmp.shape[0] != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + # print(f'Data shape for UniPC sampling is {size}') + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + # SD 1.X is "noise", SD 2.X is "v" + model_type = "v" if self.model.parameterization == "v" else "noise" + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type=model_type, + guidance_type="classifier-free", + #condition=conditioning, + #unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update) + x = uni_pc.sample(img, steps=S, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) + + return x.to(device), None |