aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_samplers_compvis.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-30 07:11:30 +0000
committerAUTOMATIC <16777216c@gmail.com>2023-01-30 07:11:30 +0000
commit4df63d2d197f26181758b5108f003f225fe84874 (patch)
treeb7421906e69147a1b76560cd823640d784eee2cc /modules/sd_samplers_compvis.py
parent274474105a5166a985a47508ffd0695db41623a5 (diff)
downloadstable-diffusion-webui-gfx803-4df63d2d197f26181758b5108f003f225fe84874.tar.gz
stable-diffusion-webui-gfx803-4df63d2d197f26181758b5108f003f225fe84874.tar.bz2
stable-diffusion-webui-gfx803-4df63d2d197f26181758b5108f003f225fe84874.zip
split samplers into one more files for k-diffusion
Diffstat (limited to 'modules/sd_samplers_compvis.py')
-rw-r--r--modules/sd_samplers_compvis.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index 3d35ff72..88541193 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -1,4 +1,6 @@
import math
+import ldm.models.diffusion.ddim
+import ldm.models.diffusion.plms
import numpy as np
import torch
@@ -7,6 +9,12 @@ from modules.shared import state
from modules import sd_samplers_common, prompt_parser, shared
+samplers_data_compvis = [
+ sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {}),
+ sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {}),
+]
+
+
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
self.sampler = constructor(sd_model)