diff options
author | space-nuko <24979496+space-nuko@users.noreply.github.com> | 2023-02-10 11:30:20 +0000 |
---|---|---|
committer | space-nuko <24979496+space-nuko@users.noreply.github.com> | 2023-02-10 11:30:20 +0000 |
commit | 125319988984987801dc4b4ab1e5ed36e9b211c5 (patch) | |
tree | 075923068ff40724e27b4cfd4ebd13b22f0bae84 /modules/models/diffusion/uni_pc/sampler.py | |
parent | ea9bd9fc7409109adcd61b897abc2c8881161256 (diff) | |
download | stable-diffusion-webui-gfx803-125319988984987801dc4b4ab1e5ed36e9b211c5.tar.gz stable-diffusion-webui-gfx803-125319988984987801dc4b4ab1e5ed36e9b211c5.tar.bz2 stable-diffusion-webui-gfx803-125319988984987801dc4b4ab1e5ed36e9b211c5.zip |
Working UniPC (for batch size 1)
Diffstat (limited to 'modules/models/diffusion/uni_pc/sampler.py')
-rw-r--r-- | modules/models/diffusion/uni_pc/sampler.py | 85 |
1 files changed, 85 insertions, 0 deletions
diff --git a/modules/models/diffusion/uni_pc/sampler.py b/modules/models/diffusion/uni_pc/sampler.py new file mode 100644 index 00000000..7cccd8a2 --- /dev/null +++ b/modules/models/diffusion/uni_pc/sampler.py @@ -0,0 +1,85 @@ +"""SAMPLING ONLY.""" + +import torch + +from .uni_pc import NoiseScheduleVP, model_wrapper, UniPC + +class UniPCSampler(object): + def __init__(self, model, **kwargs): + super().__init__() + self.model = model + to_torch = lambda x: x.clone().detach().to(torch.float32).to(model.device) + self.before_sample = None + self.after_sample = None + self.register_buffer('alphas_cumprod', to_torch(model.alphas_cumprod)) + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def set_hooks(self, before, after): + self.before_sample = before + self.after_sample = after + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + + device = self.model.betas.device + if x_T is None: + img = torch.randn(size, device=device) + else: + img = x_T + + ns = NoiseScheduleVP('discrete', alphas_cumprod=self.alphas_cumprod) + + model_fn = model_wrapper( + lambda x, t, c: self.model.apply_model(x, t, c), + ns, + model_type="noise", + guidance_type="classifier-free", + #condition=conditioning, + #unconditional_condition=unconditional_conditioning, + guidance_scale=unconditional_guidance_scale, + ) + + uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample) + x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True) + + return x.to(device), None |