diff options
Diffstat (limited to 'modules/sd_samplers_compvis.py')
-rw-r--r-- | modules/sd_samplers_compvis.py | 20 |
1 files changed, 11 insertions, 9 deletions
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 86fa1c5b..946079ae 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler: return x, ts, cond, unconditional_conditioning
- def after_sample(self, x, ts, cond, uncond, res):
- if self.is_unipc:
- # unipc model_fn returns (pred_x0)
- # p_sample_ddim returns (x_prev, pred_x0)
- res = (None, res[0])
-
+ def update_step(self, last_latent):
if self.mask is not None:
- self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
+ self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
else:
- self.last_latent = res[1]
+ self.last_latent = last_latent
sd_samplers_common.store_latent(self.last_latent)
@@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler: state.sampling_step = self.step
shared.total_tqdm.update()
+ def after_sample(self, x, ts, cond, uncond, res):
+ if not self.is_unipc:
+ self.update_step(res[1])
+
return x, ts, cond, uncond, res
+ def unipc_after_update(self, x, model_x):
+ self.update_step(x)
+
def initialize(self, p):
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
if self.eta != 0.0:
@@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler: if hasattr(self.sampler, fieldname):
setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
if self.is_unipc:
- self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r))
+ self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))
self.mask = p.mask if hasattr(p, 'mask') else None
self.nmask = p.nmask if hasattr(p, 'nmask') else None
|