From 269833067de1e7d0b6a6bd65724743d6b88a133f Mon Sep 17 00:00:00 2001 From: Kyle Date: Thu, 2 Feb 2023 09:37:01 -0500 Subject: instruct-pix2pix support --- modules/sd_samplers_kdiffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index aa7f106b..31ee22d3 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -77,9 +77,9 @@ class CFGDenoiser(torch.nn.Module): batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [image_cond]) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) cfg_denoiser_callback(denoiser_params) @@ -88,7 +88,7 @@ class CFGDenoiser(torch.nn.Module): sigma_in = denoiser_params.sigma if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) + cond_in = torch.cat([tensor, uncond, uncond]) if shared.batch_cond_uncond: x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) -- cgit v1.2.3 From cf0cfefe910b0de18c4751ce8d8cf7a6053a39b0 Mon Sep 17 00:00:00 2001 From: Kyle Date: Thu, 2 Feb 2023 19:15:38 -0500 Subject: Revert "instruct-pix2pix support" This reverts commit 269833067de1e7d0b6a6bd65724743d6b88a133f. --- modules/sd_samplers_kdiffusion.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 31ee22d3..aa7f106b 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -77,9 +77,9 @@ class CFGDenoiser(torch.nn.Module): batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [image_cond]) + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) cfg_denoiser_callback(denoiser_params) @@ -88,7 +88,7 @@ class CFGDenoiser(torch.nn.Module): sigma_in = denoiser_params.sigma if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond, uncond]) + cond_in = torch.cat([tensor, uncond]) if shared.batch_cond_uncond: x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) -- cgit v1.2.3 From 6c6c6636bb123d664999c888cda47a1f8bad635b Mon Sep 17 00:00:00 2001 From: Kyle Date: Fri, 3 Feb 2023 18:19:56 -0500 Subject: Image CFG Added (Full Implementation) Uses separate denoiser for edit (instruct-pix2pix) models No impact to txt2img or regular img2img "Image CFG Scale" will only apply to instruct-pix2pix models and metadata will only be added if using such model --- modules/sd_samplers_kdiffusion.py | 101 +++++++++++++++++++++++++++++++++++--- 1 file changed, 95 insertions(+), 6 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index aa7f106b..a16ba69b 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,6 +1,7 @@ from collections import deque import torch import inspect +import einops import k_diffusion.sampling from modules import prompt_parser, devices, sd_samplers_common @@ -40,6 +41,90 @@ sampler_extra_params = { 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], } +class CFGDenoiserEdit(torch.nn.Module): + """ + Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) + that can take a noisy picture and produce a noise-free picture using two guidances (prompts) + instead of one. Originally, the second prompt is just an empty string, but we use non-empty + negative prompt. + """ + + def __init__(self, model): + super().__init__() + self.inner_model = model + self.mask = None + self.nmask = None + self.init_latent = None + self.step = 0 + + def combine_denoised(self, x_out, conds_list, uncond, cond_scale, image_cfg_scale): + denoised_uncond = x_out[-uncond.shape[0]:] + denoised = torch.clone(denoised_uncond) + + for i, conds in enumerate(conds_list): + for cond_index, weight in conds: + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + denoised[i] = out_uncond[cond_index] + cond_scale * (out_cond[cond_index] - out_img_cond[cond_index]) + image_cfg_scale * (out_img_cond[cond_index] - out_uncond[cond_index]) + + return denoised + + def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, image_cfg_scale): + if state.interrupted or state.skipped: + raise sd_samplers_common.InterruptedException + + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) + uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + + batch_size = len(conds_list) + repeats = [len(conds_list[i]) for i in range(batch_size)] + + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) + + denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) + cfg_denoiser_callback(denoiser_params) + x_in = denoiser_params.x + image_cond_in = denoiser_params.image_cond + sigma_in = denoiser_params.sigma + + if tensor.shape[1] == uncond.shape[1]: + cond_in = torch.cat([tensor, uncond, uncond]) + + if shared.batch_cond_uncond: + x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) + else: + x_out = torch.zeros_like(x_in) + for batch_offset in range(0, x_out.shape[0], batch_size): + a = batch_offset + b = a + batch_size + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) + else: + x_out = torch.zeros_like(x_in) + batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size + for batch_offset in range(0, tensor.shape[0], batch_size): + a = batch_offset + b = min(a + batch_size, tensor.shape[0]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": torch.cat([tensor[a:b]], uncond) , "c_concat": [image_cond_in[a:b]]}) + + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) + + devices.test_for_nans(x_out, "unet") + + if opts.live_preview_content == "Prompt": + sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) + elif opts.live_preview_content == "Negative prompt": + sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) + + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, image_cfg_scale) + + if self.mask is not None: + denoised = self.init_latent * self.mask + self.nmask * denoised + + self.step += 1 + + return denoised + class CFGDenoiser(torch.nn.Module): """ @@ -78,8 +163,8 @@ class CFGDenoiser(torch.nn.Module): repeats = [len(conds_list[i]) for i in range(batch_size)] x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) cfg_denoiser_callback(denoiser_params) @@ -160,7 +245,7 @@ class KDiffusionSampler: self.funcname = funcname self.func = getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + self.model_wrap_cfg = CFGDenoiser(self.model_wrap) if not shared.sd_model.cond_stage_key == "edit" else CFGDenoiserEdit(self.model_wrap) self.sampler_noises = None self.stop_at = None self.eta = None @@ -260,13 +345,17 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x self.last_latent = x - - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ + extra_args={ 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, - 'cond_scale': p.cfg_scale - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + 'cond_scale': p.cfg_scale, + } + + if p.image_cfg_scale: + extra_args['image_cfg_scale'] = p.image_cfg_scale + + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples -- cgit v1.2.3 From c27c0de0f73c5f533acfa10426dbac7ac988bc85 Mon Sep 17 00:00:00 2001 From: Kyle Date: Fri, 3 Feb 2023 19:15:32 -0500 Subject: txt2img Hires Fix --- modules/sd_samplers_kdiffusion.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index a16ba69b..6107e99e 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -352,7 +352,7 @@ class KDiffusionSampler: 'cond_scale': p.cfg_scale, } - if p.image_cfg_scale: + if hasattr(p, 'image_cfg_scale'): extra_args['image_cfg_scale'] = p.image_cfg_scale samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) -- cgit v1.2.3 From ba6a4e7e9431d02ba3656c6ae44d5dfe29908d68 Mon Sep 17 00:00:00 2001 From: Kyle Date: Fri, 3 Feb 2023 19:46:13 -0500 Subject: Use original CFGDenoiser if image_cfg_scale = 1 If image_cfg_scale is =1 then the original image is not used for the output. We can then use the original CFGDenoiser to get the same result to support AND functionality. Maybe in the future AND can be supported with "Image CFG Scale" --- modules/sd_samplers_kdiffusion.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 6107e99e..6c57fdec 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -245,7 +245,7 @@ class KDiffusionSampler: self.funcname = funcname self.func = getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) if not shared.sd_model.cond_stage_key == "edit" else CFGDenoiserEdit(self.model_wrap) + self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None self.stop_at = None self.eta = None @@ -280,6 +280,9 @@ class KDiffusionSampler: return p.steps def initialize(self, p): + if shared.sd_model.cond_stage_key == "edit" and getattr(p, 'image_cfg_scale', None) != 1: + self.model_wrap_cfg = CFGDenoiserEdit(self.model_wrap) + self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 @@ -352,7 +355,7 @@ class KDiffusionSampler: 'cond_scale': p.cfg_scale, } - if hasattr(p, 'image_cfg_scale'): + if hasattr(p, 'image_cfg_scale') and p.image_cfg_scale != 1 and p.image_cfg_scale != None: extra_args['image_cfg_scale'] = p.image_cfg_scale samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) -- cgit v1.2.3 From 72dd5785d9721b95e8d61210a56be8f6c6b1e97d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 4 Feb 2023 11:06:17 +0300 Subject: merge CFGDenoiserEdit and CFGDenoiser into single object --- modules/sd_samplers_kdiffusion.py | 133 +++++++++++--------------------------- 1 file changed, 37 insertions(+), 96 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 6c57fdec..f076fc55 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -41,90 +41,6 @@ sampler_extra_params = { 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], } -class CFGDenoiserEdit(torch.nn.Module): - """ - Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet) - that can take a noisy picture and produce a noise-free picture using two guidances (prompts) - instead of one. Originally, the second prompt is just an empty string, but we use non-empty - negative prompt. - """ - - def __init__(self, model): - super().__init__() - self.inner_model = model - self.mask = None - self.nmask = None - self.init_latent = None - self.step = 0 - - def combine_denoised(self, x_out, conds_list, uncond, cond_scale, image_cfg_scale): - denoised_uncond = x_out[-uncond.shape[0]:] - denoised = torch.clone(denoised_uncond) - - for i, conds in enumerate(conds_list): - for cond_index, weight in conds: - out_cond, out_img_cond, out_uncond = x_out.chunk(3) - denoised[i] = out_uncond[cond_index] + cond_scale * (out_cond[cond_index] - out_img_cond[cond_index]) + image_cfg_scale * (out_img_cond[cond_index] - out_uncond[cond_index]) - - return denoised - - def forward(self, x, sigma, uncond, cond, cond_scale, image_cond, image_cfg_scale): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) - - batch_size = len(conds_list) - repeats = [len(conds_list[i]) for i in range(batch_size)] - - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) - - denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) - cfg_denoiser_callback(denoiser_params) - x_in = denoiser_params.x - image_cond_in = denoiser_params.image_cond - sigma_in = denoiser_params.sigma - - if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond, uncond]) - - if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) - else: - x_out = torch.zeros_like(x_in) - for batch_offset in range(0, x_out.shape[0], batch_size): - a = batch_offset - b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) - else: - x_out = torch.zeros_like(x_in) - batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size - for batch_offset in range(0, tensor.shape[0], batch_size): - a = batch_offset - b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": torch.cat([tensor[a:b]], uncond) , "c_concat": [image_cond_in[a:b]]}) - - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) - - devices.test_for_nans(x_out, "unet") - - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(x_out[0:uncond.shape[0]]) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale, image_cfg_scale) - - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised - - self.step += 1 - - return denoised - class CFGDenoiser(torch.nn.Module): """ @@ -141,6 +57,7 @@ class CFGDenoiser(torch.nn.Module): self.nmask = None self.init_latent = None self.step = 0 + self.image_cfg_scale = None def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -152,19 +69,36 @@ class CFGDenoiser(torch.nn.Module): return denoised + def combine_denoised_for_edit_model(self, x_out, cond_scale): + out_cond, out_img_cond, out_uncond = x_out.chunk(3) + denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond) + + return denoised + def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, + # so is_edit_model is set to False to support AND composition. + is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 + conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) uncond = prompt_parser.reconstruct_cond_batch(uncond, self.step) + assert not is_edit_model or all([len(conds) == 1 for conds in conds_list]), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] - x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) - sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) - image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) + if not is_edit_model: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) + else: + x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x]) + sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond] + [torch.zeros_like(self.init_latent)]) denoiser_params = CFGDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps) cfg_denoiser_callback(denoiser_params) @@ -173,7 +107,10 @@ class CFGDenoiser(torch.nn.Module): sigma_in = denoiser_params.sigma if tensor.shape[1] == uncond.shape[1]: - cond_in = torch.cat([tensor, uncond]) + if not is_edit_model: + cond_in = torch.cat([tensor, uncond]) + else: + cond_in = torch.cat([tensor, uncond, uncond]) if shared.batch_cond_uncond: x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) @@ -189,7 +126,13 @@ class CFGDenoiser(torch.nn.Module): for batch_offset in range(0, tensor.shape[0], batch_size): a = batch_offset b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) + + if not is_edit_model: + c_crossattn = [tensor[a:b]] + else: + c_crossattn = torch.cat([tensor[a:b]], uncond) + + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": c_crossattn, "c_concat": [image_cond_in[a:b]]}) x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) @@ -200,7 +143,10 @@ class CFGDenoiser(torch.nn.Module): elif opts.live_preview_content == "Negative prompt": sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + if not is_edit_model: + denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) + else: + denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) if self.mask is not None: denoised = self.init_latent * self.mask + self.nmask * denoised @@ -280,12 +226,10 @@ class KDiffusionSampler: return p.steps def initialize(self, p): - if shared.sd_model.cond_stage_key == "edit" and getattr(p, 'image_cfg_scale', None) != 1: - self.model_wrap_cfg = CFGDenoiserEdit(self.model_wrap) - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 + self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) self.eta = p.eta if p.eta is not None else opts.eta_ancestral k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) @@ -355,9 +299,6 @@ class KDiffusionSampler: 'cond_scale': p.cfg_scale, } - if hasattr(p, 'image_cfg_scale') and p.image_cfg_scale != 1 and p.image_cfg_scale != None: - extra_args['image_cfg_scale'] = p.image_cfg_scale - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples -- cgit v1.2.3