aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-05-24 12:35:58 +0000
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-05-24 12:35:58 +0000
commit4b88e24ebe776680b327e33fe96d7fcf38e2e5d2 (patch)
tree573e6f8a43bb1cabcfe55b3dc0f80fafcbe984a7 /modules/sd_samplers_kdiffusion.py
parent1601fccebca2dc5a806a0d2f0d33aa2da81a28fb (diff)
downloadstable-diffusion-webui-gfx803-4b88e24ebe776680b327e33fe96d7fcf38e2e5d2.tar.gz
stable-diffusion-webui-gfx803-4b88e24ebe776680b327e33fe96d7fcf38e2e5d2.tar.bz2
stable-diffusion-webui-gfx803-4b88e24ebe776680b327e33fe96d7fcf38e2e5d2.zip
improvements
See: https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/10649#issuecomment-1561047723
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index a4c797c6..d2d172e4 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -296,12 +296,6 @@ class KDiffusionSampler:
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
- if opts.k_sched_type != "Automatic":
- p.extra_generation_params["KDiffusion Scheduler Type"] = opts.k_sched_type
- p.extra_generation_params["KDiffusion Scheduler sigma_max"] = opts.sigma_max
- p.extra_generation_params["KDiffusion Scheduler sigma_min"] = opts.sigma_min
- p.extra_generation_params["KDiffusion Scheduler rho"] = opts.rho
-
extra_params_kwargs = {}
for param_name in self.extra_params:
if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
@@ -326,14 +320,27 @@ class KDiffusionSampler:
if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
elif opts.k_sched_type != "Automatic":
- sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
- sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
+ m_sigma_min, m_sigma_max = (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())
+ sigma_min, sigma_max = (0.1, 10)
sigmas_kwargs = {
- 'sigma_min': opts.sigma_min or sigma_min,
- 'sigma_max': opts.sigma_max or sigma_max
+ 'sigma_min': sigma_min if opts.use_old_karras_scheduler_sigmas else m_sigma_min,
+ 'sigma_max': sigma_max if opts.use_old_karras_scheduler_sigmas else m_sigma_max
}
+
+ sigmas_func = k_diffusion_scheduler[opts.k_sched_type]
+ p.extra_generation_params["KDiff Sched Type"] = opts.k_sched_type
+
+ if opts.sigma_min != 0.3:
+ # take 0.0 as model default
+ sigmas_kwargs['sigma_min'] = opts.sigma_min or m_sigma_min
+ p.extra_generation_params["KDiff Sched min sigma"] = opts.sigma_min
+ if opts.sigma_max != 14.6:
+ sigmas_kwargs['sigma_max'] = opts.sigma_max or m_sigma_max
+ p.extra_generation_params["KDiff Sched max sigma"] = opts.sigma_max
if opts.k_sched_type != 'exponential':
sigmas_kwargs['rho'] = opts.rho
+ p.extra_generation_params["KDiff Sched rho"] = opts.rho
+
sigmas = sigmas_func(n=steps, **sigmas_kwargs, device=shared.device)
elif self.config is not None and self.config.options.get('scheduler', None) == 'karras':
sigma_min, sigma_max = (0.1, 10) if opts.use_old_karras_scheduler_sigmas else (self.model_wrap.sigmas[0].item(), self.model_wrap.sigmas[-1].item())