diff options
Diffstat (limited to 'modules/sd_samplers_cfg_denoiser.py')
-rw-r--r-- | modules/sd_samplers_cfg_denoiser.py | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index b8101d38..a700e692 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -6,6 +6,7 @@ import modules.shared as shared from modules.script_callbacks import CFGDenoiserParams, cfg_denoiser_callback
from modules.script_callbacks import CFGDenoisedParams, cfg_denoised_callback
from modules.script_callbacks import AfterCFGCallbackParams, cfg_after_cfg_callback
+import modules.soft_inpainting as si
def catenate_conds(conds):
@@ -43,6 +44,7 @@ class CFGDenoiser(torch.nn.Module): self.model_wrap = None
self.mask = None
self.nmask = None
+ self.soft_inpainting: si.SoftInpaintingParameters = None
self.init_latent = None
self.steps = None
"""number of steps as specified by user in UI"""
@@ -56,6 +58,9 @@ class CFGDenoiser(torch.nn.Module): self.sampler = sampler
self.model_wrap = None
self.p = None
+
+ # NOTE: masking before denoising can cause the original latents to be oversmoothed
+ # as the original latents do not have noise
self.mask_before_denoising = False
@property
@@ -89,6 +94,7 @@ class CFGDenoiser(torch.nn.Module): self.sampler.sampler_extra_args['uncond'] = uc
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
+
if state.interrupted or state.skipped:
raise sd_samplers_common.InterruptedException
@@ -105,8 +111,15 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
+ # Blend in the original latents (before)
if self.mask_before_denoising and self.mask is not None:
- x = self.init_latent * self.mask + self.nmask * x
+ if self.soft_inpainting is None:
+ x = self.init_latent * self.mask + self.nmask * x
+ else:
+ x = si.latent_blend(self.soft_inpainting,
+ self.init_latent,
+ x,
+ si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma))
batch_size = len(conds_list)
repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -207,8 +220,15 @@ class CFGDenoiser(torch.nn.Module): else:
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
+ # Blend in the original latents (after)
if not self.mask_before_denoising and self.mask is not None:
- denoised = self.init_latent * self.mask + self.nmask * denoised
+ if self.soft_inpainting is None:
+ denoised = self.init_latent * self.mask + self.nmask * denoised
+ else:
+ denoised = si.latent_blend(self.soft_inpainting,
+ self.init_latent,
+ denoised,
+ si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma))
self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)
|