diff options
author | Liam <liamthekerr@gmail.com> | 2022-09-27 20:37:24 +0000 |
---|---|---|
committer | Liam <liamthekerr@gmail.com> | 2022-09-27 20:37:24 +0000 |
commit | 981fe9c4a3994bb42ea5ff5212e4fe53b748bdd9 (patch) | |
tree | 80fd53962ccafeb773b2d43b178e1ee39ac03ca3 /modules/sd_samplers.py | |
parent | 5034f7d7597685aaa4779296983be0f49f4f991f (diff) | |
parent | f2a4a2c3a672e22f088a7455d6039557370dd3f2 (diff) | |
download | stable-diffusion-webui-gfx803-981fe9c4a3994bb42ea5ff5212e4fe53b748bdd9.tar.gz stable-diffusion-webui-gfx803-981fe9c4a3994bb42ea5ff5212e4fe53b748bdd9.tar.bz2 stable-diffusion-webui-gfx803-981fe9c4a3994bb42ea5ff5212e4fe53b748bdd9.zip |
Merge remote-tracking branch 'upstream/master' into token_count
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r-- | modules/sd_samplers.py | 28 |
1 files changed, 22 insertions, 6 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 1fc9d18c..666ee1ee 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -37,6 +37,11 @@ samplers = [ ]
samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']
+sampler_extra_params = {
+ 'sample_euler':['s_churn','s_tmin','s_tmax','s_noise'],
+ 'sample_heun' :['s_churn','s_tmin','s_tmax','s_noise'],
+ 'sample_dpm_2':['s_churn','s_tmin','s_tmax','s_noise'],
+}
def setup_img2img_steps(p, steps=None):
if opts.img2img_fix_steps or steps is not None:
@@ -120,9 +125,9 @@ class VanillaStableDiffusionSampler: # existing code fails with cetain step counts, like 9
try:
- self.sampler.make_schedule(ddim_num_steps=steps, verbose=False)
+ self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
except Exception:
- self.sampler.make_schedule(ddim_num_steps=steps+1, verbose=False)
+ self.sampler.make_schedule(ddim_num_steps=steps+1,ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)
@@ -149,9 +154,9 @@ class VanillaStableDiffusionSampler: # existing code fails with cetin step counts, like 9
try:
- samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
+ samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
except Exception:
- samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x)
+ samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.ddim_eta)
return samples_ddim
@@ -224,6 +229,7 @@ class KDiffusionSampler: self.model_wrap = k_diffusion.external.CompVisDenoiser(sd_model, quantize=shared.opts.enable_quantization)
self.funcname = funcname
self.func = getattr(k_diffusion.sampling, self.funcname)
+ self.extra_params = sampler_extra_params.get(funcname,[])
self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
self.sampler_noises = None
self.sampler_noise_index = 0
@@ -269,7 +275,12 @@ class KDiffusionSampler: if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
- return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
+ extra_params_kwargs = {}
+ for val in self.extra_params:
+ if hasattr(p,val):
+ extra_params_kwargs[val] = getattr(p,val)
+
+ return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
steps = steps or p.steps
@@ -286,7 +297,12 @@ class KDiffusionSampler: if self.sampler_noises is not None:
k_diffusion.sampling.torch = TorchHijack(self)
- samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state)
+ extra_params_kwargs = {}
+ for val in self.extra_params:
+ if hasattr(p,val):
+ extra_params_kwargs[val] = getattr(p,val)
+
+ samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)
return samples
|