diff options
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r-- | modules/sd_samplers.py | 16 |
1 files changed, 13 insertions, 3 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 140b5dea..6b7979e2 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -70,13 +70,14 @@ def extended_tdqm(sequence, *args, desc=None, **kwargs): state.sampling_steps = len(sequence)
state.sampling_step = 0
- for x in tqdm.tqdm(sequence, *args, desc=state.job, **kwargs):
+ for x in tqdm.tqdm(sequence, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
if state.interrupted:
break
yield x
state.sampling_step += 1
+ shared.total_tqdm.update()
ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs)
@@ -86,7 +87,7 @@ ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdq class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)
- self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else None
+ self.orig_p_sample_ddim = self.sampler.p_sample_ddim if hasattr(self.sampler, 'p_sample_ddim') else self.sampler.p_sample_plms
self.mask = None
self.nmask = None
self.init_latent = None
@@ -112,6 +113,13 @@ class VanillaStableDiffusionSampler: return samples
def sample(self, p, x, conditioning, unconditional_conditioning):
+ for fieldname in ['p_sample_ddim', 'p_sample_plms']:
+ if hasattr(self.sampler, fieldname):
+ setattr(self.sampler, fieldname, lambda x_dec, cond, ts, *args, **kwargs: p_sample_ddim_hook(self, x_dec, cond, ts, *args, **kwargs))
+ self.mask = None
+ self.nmask = None
+ self.init_latent = None
+
samples_ddim, _ = self.sampler.sample(S=p.steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
return samples_ddim
@@ -146,13 +154,14 @@ def extended_trange(count, *args, **kwargs): state.sampling_steps = count
state.sampling_step = 0
- for x in tqdm.trange(count, *args, desc=state.job, **kwargs):
+ for x in tqdm.trange(count, *args, desc=state.job, file=shared.progress_print_out, **kwargs):
if state.interrupted:
break
yield x
state.sampling_step += 1
+ shared.total_tqdm.update()
class KDiffusionSampler:
@@ -168,6 +177,7 @@ class KDiffusionSampler: def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning):
t_enc = int(min(p.denoising_strength, 0.999) * p.steps)
sigmas = self.model_wrap.get_sigmas(p.steps)
+
noise = noise * sigmas[p.steps - t_enc - 1]
xi = x + noise
|