diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-09-15 12:39:30 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-09-15 12:39:30 +0000 |
commit | dc769e097c878927fcd222cd855eb794726e922b (patch) | |
tree | f093f705cdab8d2928fa680677a72e714c22fe95 /modules/sd_samplers.py | |
parent | d4dc4c1c633e27e7cb1a7208e8c39376dbee2d97 (diff) | |
parent | f2693bec08d2c2e513cb35fa24402396505a01a9 (diff) | |
download | stable-diffusion-webui-gfx803-dc769e097c878927fcd222cd855eb794726e922b.tar.gz stable-diffusion-webui-gfx803-dc769e097c878927fcd222cd855eb794726e922b.tar.bz2 stable-diffusion-webui-gfx803-dc769e097c878927fcd222cd855eb794726e922b.zip |
Merge branch 'prompt_editing'
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r-- | modules/sd_samplers.py | 44 |
1 files changed, 28 insertions, 16 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index f5e81f34..df3a6fe8 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -7,6 +7,7 @@ from PIL import Image import k_diffusion.sampling
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
+from modules import prompt_parser
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -53,20 +54,6 @@ def store_latent(decoded): shared.state.current_image = sample_to_image(decoded)
-def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs):
- if sampler_wrapper.mask is not None:
- img_orig = sampler_wrapper.sampler.model.q_sample(sampler_wrapper.init_latent, ts)
- x_dec = img_orig * sampler_wrapper.mask + sampler_wrapper.nmask * x_dec
-
- res = sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs)
-
- if sampler_wrapper.mask is not None:
- store_latent(sampler_wrapper.init_latent * sampler_wrapper.mask + sampler_wrapper.nmask * res[1])
- else:
- store_latent(res[1])
-
- return res
-
def extended_tdqm(sequence, *args, desc=None, **kwargs):
state.sampling_steps = len(sequence)
@@ -93,6 +80,25 @@ class VanillaStableDiffusionSampler: self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
+
+ def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs):
+ cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step)
+
+ if self.mask is not None:
+ img_orig = self.sampler.model.q_sample(self.init_latent, ts)
+ x_dec = img_orig * self.mask + self.nmask * x_dec
+
+ res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs)
+
+ if self.mask is not None:
+ store_latent(self.init_latent * self.mask + self.nmask * res[1])
+ else:
+ store_latent(res[1])
+
+ self.step += 1
+ return res
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
@@ -105,7 +111,7 @@ class VanillaStableDiffusionSampler: x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
- self.sampler.p_sample_ddim = lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs)
+ self.sampler.p_sample_ddim = self.p_sample_ddim_hook
self.mask = p.mask
self.nmask = p.nmask
self.init_latent = p.init_latent
@@ -117,7 +123,7 @@ class VanillaStableDiffusionSampler: def sample(self, p, x, conditioning, unconditional_conditioning):
for fieldname in ['p_sample_ddim', 'p_sample_plms']:
if hasattr(self.sampler, fieldname):
- setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
+ setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
self.mask = None
self.nmask = None
self.init_latent = None
@@ -138,8 +144,12 @@ class CFGDenoiser(torch.nn.Module): self.mask = None
self.nmask = None
self.init_latent = None
+ self.step = 0
def forward(self, x, sigma, uncond, cond, cond_scale):
+ cond = prompt_parser.reconstruct_cond_batch(cond, self.step)
+ uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+
if shared.batch_cond_uncond:
x_in = torch.cat([x] * 2)
sigma_in = torch.cat([sigma] * 2)
@@ -154,6 +164,8 @@ class CFGDenoiser(torch.nn.Module): if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
+ self.step += 1
+
return denoised
|