aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_samplers_kdiffusion.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_samplers_kdiffusion.py')
-rw-r--r--modules/sd_samplers_kdiffusion.py95
1 files changed, 77 insertions, 18 deletions
diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py
index aa7f106b..528f513f 100644
--- a/modules/sd_samplers_kdiffusion.py
+++ b/modules/sd_samplers_kdiffusion.py
@@ -1,12 +1,14 @@
from collections import deque
import torch
import inspect
+import einops
import k_diffusion.sampling
from modules import prompt_parser, devices, sd_samplers_common
from modules.shared import opts, state
import modules.shared as shared
from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
+from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
samplers_k_diffusion = [
('Euler a', 'sample_euler_ancestral', ['k_euler_a', 'k_euler_ancestral'], {}),
@@ -56,6 +58,7 @@ class CFGDenoiser(torch.nn.Module):
self.nmask = None
self.init_latent = None
self.step = 0
+ self.image_cfg_scale = None
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
denoised_uncond = x_out[-uncond.shape[0]:]
@@ -67,19 +70,36 @@ class CFGDenoiser(torch.nn.Module):
return denoised
+ def combine_denoised_for_edit_model(self, x_out, cond_scale):
+ out_cond, out_img_cond, out_uncond = x_out.chunk(3)
+ denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
+
+ return denoised
+
def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
+ # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
+ # so is_edit_model is set to False to support AND composition.
+ is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0
+
conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step)
uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step)
+ assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
- x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
- image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
- sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ if not is_edit_model:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
+ else:
+ x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
+ sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
+ image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)])
denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps)
cfg_denoiser_callback(denoiser_params)
@@ -88,7 +108,10 @@ class CFGDenoiser(torch.nn.Module):
sigma_in = denoiser_params.sigma
if tensor.shape[1] == uncond.shape[1]:
- cond_in = torch.cat([tensor, uncond])
+ if not is_edit_model:
+ cond_in = torch.cat([tensor, uncond])
+ else:
+ cond_in = torch.cat([tensor, uncond, uncond])
if shared.batch_cond_uncond:
x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]})
@@ -104,10 +127,19 @@ class CFGDenoiser(torch.nn.Module):
for batch_offset in range(0, tensor.shape[0], batch_size):
a = batch_offset
b = min(a + batch_size, tensor.shape[0])
- x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]})
+
+ if not is_edit_model:
+ c_crossattn = [tensor[a:b]]
+ else:
+ c_crossattn = torch.cat([tensor[a:b]], uncond)
+
+ x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]})
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]})
+ denoised_params = CFGDenoisedParams(x_out, state.sampling_step, state.sampling_steps)
+ cfg_denoised_callback(denoised_params)
+
devices.test_for_nans(x_out, "unet")
if opts.live_preview_content == "Prompt":
@@ -115,7 +147,10 @@ class CFGDenoiser(torch.nn.Module):
elif opts.live_preview_content == "Negative prompt":
sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])
- denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ if not is_edit_model:
+ denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ else:
+ denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
if self.mask is not None:
denoised = self.init_latent * self.mask + self.nmask * denoised
@@ -198,6 +233,7 @@ class KDiffusionSampler:
self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
self.model_wrap_cfg.step = 0
+ self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
self.eta = p.eta if p.eta is not None else opts.eta_ancestral
k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
@@ -237,6 +273,16 @@ class KDiffusionSampler:
return sigmas
+ def create_noise_sampler(self, x, sigmas, p):
+ """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
+ if shared.opts.no_dpmpp_sde_batch_determinism:
+ return None
+
+ from k_diffusion.sampling import BrownianTreeNoiseSampler
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
+ return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)
+
def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
@@ -246,31 +292,38 @@ class KDiffusionSampler:
xi = x + noise * sigma_sched[0]
extra_params_kwargs = self.initialize(p)
- if 'sigma_min' in inspect.signature(self.func).parameters:
+ parameters = inspect.signature(self.func).parameters
+
+ if 'sigma_min' in parameters:
## last sigma is zero which isn't allowed by DPM Fast & Adaptive so taking value before last
extra_params_kwargs['sigma_min'] = sigma_sched[-2]
- if 'sigma_max' in inspect.signature(self.func).parameters:
+ if 'sigma_max' in parameters:
extra_params_kwargs['sigma_max'] = sigma_sched[0]
- if 'n' in inspect.signature(self.func).parameters:
+ if 'n' in parameters:
extra_params_kwargs['n'] = len(sigma_sched) - 1
- if 'sigma_sched' in inspect.signature(self.func).parameters:
+ if 'sigma_sched' in parameters:
extra_params_kwargs['sigma_sched'] = sigma_sched
- if 'sigmas' in inspect.signature(self.func).parameters:
+ if 'sigmas' in parameters:
extra_params_kwargs['sigmas'] = sigma_sched
+ if self.funcname == 'sample_dpmpp_sde':
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
+ extra_params_kwargs['noise_sampler'] = noise_sampler
+
self.model_wrap_cfg.init_latent = x
self.last_latent = x
-
- samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={
+ extra_args={
'cond': conditioning,
'image_cond': image_conditioning,
'uncond': unconditional_conditioning,
- 'cond_scale': p.cfg_scale
- }, disable=False, callback=self.callback_state, **extra_params_kwargs))
+ 'cond_scale': p.cfg_scale,
+ }
+
+ samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
return samples
- def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None):
+ def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
steps = steps or p.steps
sigmas = self.get_sigmas(p, steps)
@@ -278,14 +331,20 @@ class KDiffusionSampler:
x = x * sigmas[0]
extra_params_kwargs = self.initialize(p)
- if 'sigma_min' in inspect.signature(self.func).parameters:
+ parameters = inspect.signature(self.func).parameters
+
+ if 'sigma_min' in parameters:
extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
- if 'n' in inspect.signature(self.func).parameters:
+ if 'n' in parameters:
extra_params_kwargs['n'] = steps
else:
extra_params_kwargs['sigmas'] = sigmas
+ if self.funcname == 'sample_dpmpp_sde':
+ noise_sampler = self.create_noise_sampler(x, sigmas, p)
+ extra_params_kwargs['noise_sampler'] = noise_sampler
+
self.last_latent = x
samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
'cond': conditioning,