From a243bc7859b7ab92a28d28c11b0ed5525fa0d6ba Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 6 Sep 2022 02:09:01 +0300 Subject: added progressbar added an option to disable progressbar added interrupt support to DDIM/PLMS --- modules/sd_samplers.py | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) (limited to 'modules/sd_samplers.py') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 6f028f5f..896e8b3f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,10 +1,12 @@ from collections import namedtuple + +import ldm.models.diffusion.ddim import torch import tqdm import k_diffusion.sampling -from ldm.models.diffusion.ddim import DDIMSampler -from ldm.models.diffusion.plms import PLMSSampler +import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -29,8 +31,8 @@ samplers_data_k_diffusion = [ samplers = [ *samplers_data_k_diffusion, - SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(DDIMSampler, model), []), - SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(PLMSSampler, model), []), + SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), []), + SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), []), ] samplers_for_img2img = [x for x in samplers if x.name != 'PLMS'] @@ -43,6 +45,23 @@ def p_sample_ddim_hook(sampler_wrapper, x_dec, cond, ts, *args, **kwargs): return sampler_wrapper.orig_p_sample_ddim(x_dec, cond, ts, *args, **kwargs) +def extended_tdqm(sequence, *args, desc=None, **kwargs): + state.sampling_steps = len(sequence) + state.sampling_step = 0 + + for x in tqdm.tqdm(sequence, *args, desc=state.job, **kwargs): + if state.interrupted: + break + + yield x + + state.sampling_step += 1 + + +ldm.models.diffusion.ddim.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs) +ldm.models.diffusion.plms.tqdm = lambda *args, desc=None, **kwargs: extended_tdqm(*args, desc=desc, **kwargs) + + class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): self.sampler = constructor(sd_model) @@ -102,13 +121,18 @@ class CFGDenoiser(torch.nn.Module): return denoised -def extended_trange(*args, **kwargs): - for x in tqdm.trange(*args, desc=state.job, **kwargs): +def extended_trange(count, *args, **kwargs): + state.sampling_steps = count + state.sampling_step = 0 + + for x in tqdm.trange(count, *args, desc=state.job, **kwargs): if state.interrupted: break yield x + state.sampling_step += 1 + class KDiffusionSampler: def __init__(self, funcname, sd_model): -- cgit v1.2.3