diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-09-03 14:21:15 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-09-03 14:21:15 +0000 |
commit | f40617d6c4e366773677baa8d7f4114ba2893282 (patch) | |
tree | 7e5e1810017d11133c4294190124ad74e1b2d961 /modules/sd_samplers.py | |
parent | 345028099d893f8a66726cfd13627d8cc1bcc724 (diff) | |
download | stable-diffusion-webui-gfx803-f40617d6c4e366773677baa8d7f4114ba2893282.tar.gz stable-diffusion-webui-gfx803-f40617d6c4e366773677baa8d7f4114ba2893282.tar.bz2 stable-diffusion-webui-gfx803-f40617d6c4e366773677baa8d7f4114ba2893282.zip |
support for scripts
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r-- | modules/sd_samplers.py | 32 |
1 files changed, 21 insertions, 11 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 54c5fd7c..6f028f5f 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -9,18 +9,28 @@ from ldm.models.diffusion.plms import PLMSSampler from modules.shared import opts, cmd_opts, state
import modules.shared as shared
-SamplerData = namedtuple('SamplerData', ['name', 'constructor'])
+
+SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases'])
+
+samplers_k_diffusion = [
+ ('Euler a', 'sample_euler_ancestral', ['k_euler_a']),
+ ('Euler', 'sample_euler', ['k_euler']),
+ ('LMS', 'sample_lms', ['k_lms']),
+ ('Heun', 'sample_heun', ['k_heun']),
+ ('DPM2', 'sample_dpm_2', ['k_dpm_2']),
+ ('DPM2 a', 'sample_dpm_2_ancestral', ['k_dpm_2_a']),
+]
+
+samplers_data_k_diffusion = [
+ SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases)
+ for label, funcname, aliases in samplers_k_diffusion
+ if hasattr(k_diffusion.sampling, funcname)
+]
+
samplers = [
- *[SamplerData(x[0], lambda model, funcname=x[1]: KDiffusionSampler(funcname, model)) for x in [
- ('Euler a', 'sample_euler_ancestral'),
- ('Euler', 'sample_euler'),
- ('LMS', 'sample_lms'),
- ('Heun', 'sample_heun'),
- ('DPM2', 'sample_dpm_2'),
- ('DPM2 a', 'sample_dpm_2_ancestral'),
- ] if hasattr(k_diffusion.sampling, x[1])],
- SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(DDIMSample, model)),
- SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(PLMSSampler, model)),
+ *samplers_data_k_diffusion,
+ SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(DDIMSampler, model), []),
+ SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(PLMSSampler, model), []),
]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
|