diff options
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r-- | modules/sd_samplers.py | 51 |
1 files changed, 34 insertions, 17 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index d478c5bc..02ffce0e 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,6 +7,7 @@ from PIL import Image import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
+from modules import prompt_parser
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -53,20 +54,6 @@ def store_latent(decoded): shared.state.current_image = sample_to_image(decoded)
-def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
- if sampler_wrapper.mask is not None:
- img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
- x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
-
- res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
-
- if sampler_wrapper.mask is not None:
- store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1])
- else:
- store_latent(res[1])
-
- return res
-
def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence)
@@ -94,10 +81,29 @@ class VanillaStableDiffusionSampler: self.nmask = None
self.init_latent = None
self.sampler_noises = None
+ self.step = 0
def number_of_needed_noises(self, p):
return 0
+ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+
+ if self.mask is not None:
+ img_orig = self.sampler.model.q_sample(self.init_latent, ts)
+ x_dec = img_orig * self.mask + self.nmask * x_dec
+
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+
+ if self.mask is not None:
+ store_latent(self.init_latent * self.mask + self.nmask * res[1])
+ else:
+ store_latent(res[1])
+
+ self.step += 1
+ return res
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
@@ -109,10 +115,11 @@ class VanillaStableDiffusionSampler: x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
- self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
+ self.sampler.p_sample_ddim = self.p_sample_ddim_hook
self.mask = p.mask
self.nmask = p.nmask
self.init_latent = p.init_latent
+ self.step = 0
samples = self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)
@@ -121,10 +128,11 @@ class VanillaStableDiffusionSampler: def sample(self, p, x, conditioning, unconditional_conditioning):
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
# existing code fails with cetin step counts, like 9
try:
@@ -142,8 +150,12 @@ class CFGDenoiser(torch.nn.Module): self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
+ cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
if shared.batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
@@ -158,6 +170,8 @@ class CFGDenoiser(torch.nn.Module): if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ self.step += 1
+
return denoised
@@ -191,7 +205,7 @@ class TorchHijack: class KDiffusionSampler:
def __init__(self, funcname, sd_model):
- self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model)
+ self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
@@ -228,6 +242,7 @@ class KDiffusionSampler: self.model_wrap_cfg.mask = p.mask
self.model_wrap_cfg.nmask = p.nmask
self.model_wrap_cfg.init_latent = p.init_latent
+ self.model_wrap.step = 0
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
@@ -241,6 +256,8 @@ class KDiffusionSampler: sigmas = self.model_wrap.get_sigmas(p.steps)
x = x * sigmas[0]
+ self.model_wrap_cfg.step = 0
+
if hasattr(k_diffusion.sampling, 'trange'):
k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(*args, **kwargs)
|