diff options
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r-- | modules/sd_samplers.py | 13 |
1 files changed, 13 insertions, 0 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 2e1f7715..8d6eb762 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -26,6 +26,17 @@ samplers_k_diffusion = [ ('DPM adaptive', 'sample_dpm_adaptive', ['k_dpm_ad']),
]
+if opts.show_karras_scheduler_variants:
+ k_diffusion.sampling.sample_dpm_2_ka = k_diffusion.sampling.sample_dpm_2
+ k_diffusion.sampling.sample_dpm_2_ancestral_ka = k_diffusion.sampling.sample_dpm_2_ancestral
+ k_diffusion.sampling.sample_lms_ka = k_diffusion.sampling.sample_lms
+ samplers_k_diffusion_ka = [
+ ('LMS K Scheduling', 'sample_lms_ka', ['k_lms_ka']),
+ ('DPM2 K Scheduling', 'sample_dpm_2_ka', ['k_dpm_2_ka']),
+ ('DPM2 a K Scheduling', 'sample_dpm_2_ancestral_ka', ['k_dpm_2_a_ka']),
+ ]
+ samplers_k_diffusion.extend(samplers_k_diffusion_ka)
+
samplers_data_k_diffusion = [
SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
for label, funcname, aliases in samplers_k_diffusion
@@ -345,6 +356,8 @@ class KDiffusionSampler: if p.sampler_noise_scheduler_override:
sigmas = p.sampler_noise_scheduler_override(steps)
+ elif self.funcname.endswith('ka'):
+ sigmas = k_diffusion.sampling.get_sigmas_karras(n=steps, sigma_min=0.1, sigma_max=10, device=shared.device)
else:
sigmas = self.model_wrap.get_sigmas(steps)
x = x * sigmas[0]
|