diff options
author | Martin Cairns <4314538+MartinCairnsSQL@users.noreply.github.com> | 2022-10-29 14:23:19 +0000 |
---|---|---|
committer | Martin Cairns <4314538+MartinCairnsSQL@users.noreply.github.com> | 2022-10-29 14:23:19 +0000 |
commit | de1dc0d279a877d5d9f512befe30a7d7e5cf3881 (patch) | |
tree | 1b74779eb96766f3970a905cec52ff4ed44bfdd4 /modules/sd_samplers.py | |
parent | 35c45df28b303a05d56a13cb56d4046f08cf8c25 (diff) | |
download | stable-diffusion-webui-gfx803-de1dc0d279a877d5d9f512befe30a7d7e5cf3881.tar.gz stable-diffusion-webui-gfx803-de1dc0d279a877d5d9f512befe30a7d7e5cf3881.tar.bz2 stable-diffusion-webui-gfx803-de1dc0d279a877d5d9f512befe30a7d7e5cf3881.zip |
Add adjust_steps_if_invalid to find next valid step for ddim uniform sampler
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r-- | modules/sd_samplers.py | 28 |
1 files changed, 15 insertions, 13 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 3670b57d..aca014e8 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,5 +1,6 @@ from collections import namedtuple
import numpy as np
+from math import floor
import torch
import tqdm
from PIL import Image
@@ -205,17 +206,22 @@ class VanillaStableDiffusionSampler: self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
+
+ def adjust_steps_if_invalid(self, p, num_steps):
+ if self.config.name == 'DDIM' and p.ddim_discretize == 'uniform':
+ valid_step = 999 / (1000 // num_steps)
+ if valid_step == floor(valid_step):
+ return int(valid_step) + 1
+
+ return num_steps
+
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
-
+ steps = self.adjust_steps_if_invalid(p, steps)
self.initialize(p)
- # existing code fails with certain step counts, like 9
- try:
- self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
- except Exception:
- self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
-
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
self.init_latent = x
@@ -239,18 +245,14 @@ class VanillaStableDiffusionSampler: self.last_latent = x
self.step = 0
- steps = steps or p.steps
+ steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model
if image_conditioning is not None:
conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
- # existing code fails with certain step counts, like 9
- try:
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
- except Exception:
- samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
+ samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
return samples_ddim
|