diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-30 07:47:09 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-30 07:47:09 +0000 |
commit | 040ec7a80e23d340efe1108b9de5ead62d9011a9 (patch) | |
tree | abb66ecdc936a08f7b83efc1653d55ae052b9709 /modules/sd_samplers_kdiffusion.py | |
parent | 4df63d2d197f26181758b5108f003f225fe84874 (diff) | |
download | stable-diffusion-webui-gfx803-040ec7a80e23d340efe1108b9de5ead62d9011a9.tar.gz stable-diffusion-webui-gfx803-040ec7a80e23d340efe1108b9de5ead62d9011a9.tar.bz2 stable-diffusion-webui-gfx803-040ec7a80e23d340efe1108b9de5ead62d9011a9.zip |
make the program read Eta and Eta DDIM from generation parameters
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r-- | modules/sd_samplers_kdiffusion.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index adb6883e..aa7f106b 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch
import inspect
import k_diffusion.sampling
-from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_compvis
+from modules import prompt_parser, devices, sd_samplers_common
from modules.shared import opts, state
import modules.shared as shared
@@ -164,7 +164,6 @@ class KDiffusionSampler: self.sampler_noises = None
self.stop_at = None
self.eta = None
- self.default_eta = 1.0
self.config = None
self.last_latent = None
@@ -199,7 +198,7 @@ class KDiffusionSampler: self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
- self.eta = p.eta or opts.eta_ancestral
+ self.eta = p.eta if p.eta is not None else opts.eta_ancestral
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
@@ -209,6 +208,9 @@ class KDiffusionSampler: extra_params_kwargs[param_name] = getattr(p, param_name)
if 'eta' in inspect.signature(self.func).parameters:
+ if self.eta != 1.0:
+ p.extra_generation_params["Eta"] = self.eta
+
extra_params_kwargs['eta'] = self.eta
return extra_params_kwargs
|