aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_samplers.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-11-27 09:54:45 +0000
committerGitHub <noreply@github.com>2022-11-27 09:54:45 +0000
commit6df49457183fedec635d30dd0c611844b0932c85 (patch)
tree7cd692e17d8ff273916c52d30c77370f1871de54 /modules/sd_samplers.py
parent45fd785436068f3b1c09fb7bc575118b6059fc7b (diff)
parentb48b7999c86fd6d7f006f76adf5a484175782c37 (diff)
downloadstable-diffusion-webui-gfx803-6df49457183fedec635d30dd0c611844b0932c85.tar.gz
stable-diffusion-webui-gfx803-6df49457183fedec635d30dd0c611844b0932c85.tar.bz2
stable-diffusion-webui-gfx803-6df49457183fedec635d30dd0c611844b0932c85.zip
Merge branch 'master' into DPM++SDE
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r--modules/sd_samplers.py14
1 files changed, 8 insertions, 6 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py
index 80e91d62..43ce34eb 100644
--- a/modules/sd_samplers.py
+++ b/modules/sd_samplers.py
@@ -129,7 +129,8 @@ class InterruptedException(BaseException):
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
- self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
+ self.is_plms = hasattr(self.sampler, 'p_sample_plms')
+ self.orig_p_sample_ddim = self.sampler.p_sample_plms if self.is_plms else self.sampler.p_sample_ddim
self.mask = None
self.nmask = None
self.init_latent = None
@@ -220,7 +221,6 @@ class VanillaStableDiffusionSampler:
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
-
def adjust_steps_if_invalid(self, p, num_steps):
if (self.config.name == 'DDIM' and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS'):
valid_step = 999 / (1000 // num_steps)
@@ -229,7 +229,6 @@ class VanillaStableDiffusionSampler:
return num_steps
-
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = setup_img2img_steps(p, steps)
steps = self.adjust_steps_if_invalid(p, steps)
@@ -262,9 +261,10 @@ class VanillaStableDiffusionSampler:
steps = self.adjust_steps_if_invalid(p, steps or p.steps)
# Wrap the conditioning models with additional image conditioning for inpainting model
+ # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape
if image_conditioning is not None:
- conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]}
- unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]}
+ conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]}
+ unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]}
samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0])
@@ -352,7 +352,9 @@ class TorchHijack:
class KDiffusionSampler:
def __init__(self, funcname, sd_model):
- self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
+ denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
+
+ self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.extra_params = sampler_extra_params.get(funcname, [])