From e90d4334ad37024a802f4ef27069b625a6508f72 Mon Sep 17 00:00:00 2001 From: CodeHatchling Date: Wed, 6 Dec 2023 16:54:42 -0700 Subject: A custom blending function can be provided by p, replacing the use of soft_inpainting. --- modules/sd_samplers_cfg_denoiser.py | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) (limited to 'modules/sd_samplers_cfg_denoiser.py') diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index a700e692..f13e8dcc 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -6,7 +6,6 @@ import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback -import modules.soft_inpainting as si def catenate_conds(conds): @@ -44,7 +43,6 @@ class CFGDenoiser(torch.nn.Module): self.model_wrap = None self.mask = None self.nmask = None - self.soft_inpainting: si.SoftInpaintingParameters = None self.init_latent = None self.steps = None """number of steps as specified by user in UI""" @@ -94,7 +92,6 @@ class CFGDenoiser(torch.nn.Module): self.sampler.sampler_extra_args['uncond'] = uc def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): - if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -111,15 +108,24 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + # If we use masks, blending between the denoised and original latent images occurs here. + def apply_blend(latent): + if hasattr(self.p, "denoiser_masked_blend_function") and callable(self.p.denoiser_masked_blend_function): + return self.p.denoiser_masked_blend_function( + self, + # Using an argument dictionary so that arguments can be added without breaking extensions. + args= + { + "denoiser": self, + "current_latent": latent, + "sigma": sigma + }) + else: + return self.init_latent * self.mask + self.nmask * latent + # Blend in the original latents (before) if self.mask_before_denoising and self.mask is not None: - if self.soft_inpainting is None: - x = self.init_latent * self.mask + self.nmask * x - else: - x = si.latent_blend(self.soft_inpainting, - self.init_latent, - x, - si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) + x = apply_blend(x) batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -222,13 +228,7 @@ class CFGDenoiser(torch.nn.Module): # Blend in the original latents (after) if not self.mask_before_denoising and self.mask is not None: - if self.soft_inpainting is None: - denoised = self.init_latent * self.mask + self.nmask * denoised - else: - denoised = si.latent_blend(self.soft_inpainting, - self.init_latent, - denoised, - si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma)) + denoised = apply_blend(denoised) self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) -- cgit v1.2.3