aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
authordevdn <zboodles2@gmail.com>2023-03-28 22:18:28 +0000
committerdevdn <zboodles2@gmail.com>2023-03-29 00:56:01 +0000
commit42082e8a3239c1c32cd9e2a03a20b610af857b51 (patch)
tree2a39c6b03a71b551bf8038832d13abe35b312fd2 /modules/sd_samplers_kdiffusion.py
parent3856ada5cc9ac4124e20ff311ce7aa77330845d9 (diff)
downloadstable-diffusion-webui-gfx803-42082e8a3239c1c32cd9e2a03a20b610af857b51.tar.gz
stable-diffusion-webui-gfx803-42082e8a3239c1c32cd9e2a03a20b610af857b51.tar.bz2
stable-diffusion-webui-gfx803-42082e8a3239c1c32cd9e2a03a20b610af857b51.zip
performance increase
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py22
1 files changed, 17 insertions, 5 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index e9f08518..6a54ce32 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -76,7 +76,7 @@ class CFGDenoiser(torch.nn.Module):
return denoised
- def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
+ def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -116,6 +116,12 @@ class CFGDenoiser(torch.nn.Module):
tensor = denoiser_params.text_cond
uncond = denoiser_params.text_uncond
+ sigma_thresh = s_min_uncond
+ if(torch.dot(sigma,sigma) < sigma.shape[0] * (sigma_thresh*sigma_thresh) and not is_edit_model):
+ uncond = torch.zeros([0,0,uncond.shape[2]])
+ x_in=x_in[:x_in.shape[0]//2]
+ sigma_in=sigma_in[:sigma_in.shape[0]//2]
+
if tensor.shape[1] == uncond.shape[1]:
if not is_edit_model:
cond_in = torch.cat([tensor, uncond])
@@ -144,7 +150,8 @@ class CFGDenoiser(torch.nn.Module):
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(c_crossattn, image_cond_in[a:b]))
- x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
+ if uncond.shape[0]:
+ x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=make_condition_dict([uncond], image_cond_in[-uncond.shape[0]:]))
denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
cfg_denoised_callback(denoised_params)
@@ -157,7 +164,10 @@ class CFGDenoiser(torch.nn.Module):
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
if not is_edit_model:
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ if uncond.shape[0]:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ else:
+ denoised = x_out
else:
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
@@ -165,7 +175,6 @@ class CFGDenoiser(torch.nn.Module):
denoised = self.init_latent * self.mask + self.nmask * denoised
self.step += 1
-
return denoised
@@ -244,6 +253,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.step = 0
self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
+ self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
@@ -326,6 +336,7 @@ class KDiffusionSampler:
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
}
samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
@@ -359,7 +370,8 @@ class KDiffusionSampler:
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
+ 'cond_scale': p.cfg_scale,
+ 's_min_uncond': self.s_min_uncond
}, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples