aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_samplers_compvis.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_compvis.py')
-rw-r--r--modules/sd_samplers_compvis.py12
1 files changed, 11 insertions, 1 deletions
diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py
index 4a8396f9..5df926d3 100644
--- a/modules/sd_samplers_compvis.py
+++ b/modules/sd_samplers_compvis.py
@@ -19,7 +19,8 @@ samplers_data_compvis = [
class VanillaStableDiffusionSampler:
def __init__(self, constructor, sd_model):
- self.sampler = constructor(sd_model)
+ self.p = None
+ self.sampler = constructor(shared.sd_model)
self.is_ddim = hasattr(self.sampler, 'p_sample_ddim')
self.is_plms = hasattr(self.sampler, 'p_sample_plms')
self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler)
@@ -32,6 +33,7 @@ class VanillaStableDiffusionSampler:
self.nmask = None
self.init_latent = None
self.sampler_noises = None
+ self.steps = None
self.step = 0
self.stop_at = None
self.eta = None
@@ -44,6 +46,7 @@ class VanillaStableDiffusionSampler:
return 0
def launch_sampling(self, steps, func):
+ self.steps = steps
state.sampling_steps = steps
state.sampling_step = 0
@@ -61,10 +64,15 @@ class VanillaStableDiffusionSampler:
return res
+ def update_inner_model(self):
+ self.sampler.model = shared.sd_model
+
def before_sample(self, x, ts, cond, unconditional_conditioning):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
+ sd_samplers_common.apply_refiner(self)
+
if self.stop_at is not None and self.step > self.stop_at:
raise sd_samplers_common.InterruptedException
@@ -134,6 +142,8 @@ class VanillaStableDiffusionSampler:
self.update_step(x)
def initialize(self, p):
+ self.p = p
+
if self.is_ddim:
self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
else: