aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-05-22 15:02:05 +0000
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-05-22 15:02:05 +0000
commite6269cba7fd84a76b2bd0012cb954f947a79b6a5 (patch)
tree1b214de7515fbfe010e7891c42c29013a0d731b2 /modules/sd_samplers_kdiffusion.py
parent90ec557d60289a89b4ea6cd9b311658fbe682dc3 (diff)
downloadstable-diffusion-webui-gfx803-e6269cba7fd84a76b2bd0012cb954f947a79b6a5.tar.gz
stable-diffusion-webui-gfx803-e6269cba7fd84a76b2bd0012cb954f947a79b6a5.tar.bz2
stable-diffusion-webui-gfx803-e6269cba7fd84a76b2bd0012cb954f947a79b6a5.zip
Add dropdown for scheduler type
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index d428551d..441c040e 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -44,6 +44,12 @@ sampler_extra_params = {
'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
}
+k_diffusion_scheduler = {
+ 'karras': k_diffusion.sampling.get_sigmas_karras,
+ 'exponential': k_diffusion.sampling.get_sigmas_exponential,
+ 'polyexponential': k_diffusion.sampling.get_sigmas_polyexponential
+}
+
class CFGDenoiser(torch.nn.Module):
"""
@@ -305,10 +311,15 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif p.enable_karras:
- sigma_max = p.sigma_max
- sigma_min = p.sigma_min
- rho = p.rho
- sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=sigma_min, sigma_max=sigma_max, rho=rho, device=shared.device)
+ print(p.k_sched_type, p.sigma_min, p.sigma_max, p.rho)
+ sigmas_func = k_diffusion_scheduler[p.k_sched_type]
+ sigmas_kwargs = {
+ 'sigma_min': p.sigma_min,
+ 'sigma_max': p.sigma_max
+ }
+ if p.k_sched_type != 'exponential':
+ sigmas_kwargs['rho'] = p.rho
+ sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())