aboutsummaryrefslogtreecommitdiffstats
path: root/scripts/soft_inpainting.py
diff options
context:
space:
mode:
authorCodeHatchling <steve@codehatch.com>2023-12-08 03:19:35 +0000
committerCodeHatchling <steve@codehatch.com>2023-12-08 03:19:35 +0000
commitf284ae23bcdfa212cf4763659c06e124ec5b1456 (patch)
tree17ef1bf34e8bffa9e3e4819d142782dd65d7e450 /scripts/soft_inpainting.py
parent0ef4a4cb2365051b1e308f0136a0d8c01d071569 (diff)
downloadstable-diffusion-webui-gfx803-f284ae23bcdfa212cf4763659c06e124ec5b1456.tar.gz
stable-diffusion-webui-gfx803-f284ae23bcdfa212cf4763659c06e124ec5b1456.tar.bz2
stable-diffusion-webui-gfx803-f284ae23bcdfa212cf4763659c06e124ec5b1456.zip
Added parameters for the composite stage, fixed batched generation.
Diffstat (limited to 'scripts/soft_inpainting.py')
-rw-r--r--scripts/soft_inpainting.py198
1 files changed, 155 insertions, 43 deletions
diff --git a/scripts/soft_inpainting.py b/scripts/soft_inpainting.py
index 1f451b55..1b21aee9 100644
--- a/scripts/soft_inpainting.py
+++ b/scripts/soft_inpainting.py
@@ -6,22 +6,34 @@ import modules.scripts as scripts
class SoftInpaintingSettings:
- def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation):
+ def __init__(self,
+ mask_blend_power,
+ mask_blend_scale,
+ inpaint_detail_preservation,
+ composite_mask_influence,
+ composite_difference_threshold,
+ composite_difference_contrast):
self.mask_blend_power = mask_blend_power
self.mask_blend_scale = mask_blend_scale
self.inpaint_detail_preservation = inpaint_detail_preservation
+ self.composite_mask_influence = composite_mask_influence
+ self.composite_difference_threshold = composite_difference_threshold
+ self.composite_difference_contrast = composite_difference_contrast
def add_generation_params(self, dest):
dest[enabled_gen_param_label] = True
dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale
dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation
+ dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence
+ dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold
+ dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast
# ------------------- Methods -------------------
-def latent_blend(soft_inpainting, a, b, t):
+def latent_blend(settings, a, b, t):
"""
Interpolates two latent image representations according to the parameter t,
where the interpolated vectors' magnitudes are also interpolated separately.
@@ -54,11 +66,11 @@ def latent_blend(soft_inpainting, a, b, t):
# Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
- soft_inpainting.inpaint_detail_preservation) * one_minus_t3
+ settings.inpaint_detail_preservation) * one_minus_t3
b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
- soft_inpainting.inpaint_detail_preservation) * t3
+ settings.inpaint_detail_preservation) * t3
desired_magnitude = a_magnitude
- desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
+ desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
del a_magnitude, b_magnitude, t3, one_minus_t3
# Change the linearly interpolated image vectors' magnitudes to the value we want.
@@ -77,7 +89,7 @@ def latent_blend(soft_inpainting, a, b, t):
return image_interp_scaled
-def get_modified_nmask(soft_inpainting, nmask, sigma):
+def get_modified_nmask(settings, nmask, sigma):
"""
Converts a negative mask representing the transparency of the original latent vectors being overlayed
to a mask that is scaled according to the denoising strength for this step.
@@ -93,10 +105,12 @@ def get_modified_nmask(soft_inpainting, nmask, sigma):
NOTE: "mask" is not used
"""
import torch
- return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
+ return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)
def apply_adaptive_masks(
+ settings:SoftInpaintingSettings,
+ nmask,
latent_orig,
latent_processed,
overlay_images,
@@ -108,11 +122,13 @@ def apply_adaptive_masks(
from PIL import Image, ImageOps, ImageFilter
# TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
- # latent_mask = p.nmask[0].float().cpu()
+ latent_mask = nmask[0].float()
# convert the original mask into a form we use to scale distances for thresholding
- # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
- # mask_scalar = mask_scalar / (1.00001-mask_scalar)
- # mask_scalar = mask_scalar.numpy()
+ mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
+ mask_scalar = (0.5 * (1-settings.composite_mask_influence)
+ + mask_scalar * settings.composite_mask_influence)
+ mask_scalar = mask_scalar / (1.00001-mask_scalar)
+ mask_scalar = mask_scalar.cpu().numpy()
latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)
@@ -128,10 +144,10 @@ def apply_adaptive_masks(
percentile_min=0.25, percentile_max=0.75, min_width=1)
# The distance at which opacity of original decreases to 50%
- # half_weighted_distance = 1 # * mask_scalar
- # converted_mask = converted_mask / half_weighted_distance
+ half_weighted_distance = settings.composite_difference_threshold * mask_scalar
+ converted_mask = converted_mask / half_weighted_distance
- converted_mask = 1 / (1 + converted_mask ** 2)
+ converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
converted_mask = smootherstep(converted_mask)
converted_mask = 1 - converted_mask
converted_mask = 255. * converted_mask
@@ -161,7 +177,7 @@ def apply_adaptive_masks(
def apply_masks(
- soft_inpainting,
+ settings,
nmask,
overlay_images,
width, height,
@@ -172,7 +188,7 @@ def apply_masks(
from PIL import Image, ImageOps, ImageFilter
converted_mask = nmask[0].float()
- converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2)
+ converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)
converted_mask = 255. * converted_mask
converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
converted_mask = Image.fromarray(converted_mask)
@@ -395,7 +411,7 @@ def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
# ------------------- Constants -------------------
-default = SoftInpaintingSettings(1, 0.5, 4)
+default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)
enabled_ui_label = "Soft inpainting"
enabled_gen_param_label = "Soft inpainting enabled"
@@ -404,25 +420,37 @@ enabled_el_id = "soft_inpainting_enabled"
ui_labels = SoftInpaintingSettings(
"Schedule bias",
"Preservation strength",
- "Transition contrast boost")
+ "Transition contrast boost",
+ "Mask influence",
+ "Difference threshold",
+ "Difference contrast")
ui_info = SoftInpaintingSettings(
"Shifts when preservation of original content occurs during denoising.",
"How strongly partially masked content should be preserved.",
- "Amplifies the contrast that may be lost in partially masked regions.")
+ "Amplifies the contrast that may be lost in partially masked regions.",
+ "How strongly the original mask should bias the difference threshold.",
+ "How much an image region can change before the original pixels are not blended in anymore.",
+ "How sharp the transition should be between blended and not blended.")
gen_param_labels = SoftInpaintingSettings(
"Soft inpainting schedule bias",
"Soft inpainting preservation strength",
- "Soft inpainting transition contrast boost")
+ "Soft inpainting transition contrast boost",
+ "Soft inpainting mask influence",
+ "Soft inpainting difference threshold",
+ "Soft inpainting difference contrast")
el_ids = SoftInpaintingSettings(
"mask_blend_power",
"mask_blend_scale",
- "inpaint_detail_preservation")
+ "inpaint_detail_preservation",
+ "composite_mask_influence",
+ "composite_difference_threshold",
+ "composite_difference_contrast")
-# -----
+# ------------------- Script -------------------
class Script(scripts.Script):
@@ -449,28 +477,62 @@ class Script(scripts.Script):
**High _Mask blur_** values are recommended!
""")
- result = SoftInpaintingSettings(
+ power = \
gr.Slider(label=ui_labels.mask_blend_power,
info=ui_info.mask_blend_power,
minimum=0,
maximum=8,
step=0.1,
value=default.mask_blend_power,
- elem_id=el_ids.mask_blend_power),
+ elem_id=el_ids.mask_blend_power)
+ scale = \
gr.Slider(label=ui_labels.mask_blend_scale,
info=ui_info.mask_blend_scale,
minimum=0,
maximum=8,
step=0.05,
value=default.mask_blend_scale,
- elem_id=el_ids.mask_blend_scale),
+ elem_id=el_ids.mask_blend_scale)
+ detail = \
gr.Slider(label=ui_labels.inpaint_detail_preservation,
info=ui_info.inpaint_detail_preservation,
minimum=1,
maximum=32,
step=0.5,
value=default.inpaint_detail_preservation,
- elem_id=el_ids.inpaint_detail_preservation))
+ elem_id=el_ids.inpaint_detail_preservation)
+
+ gr.Markdown(
+ """
+ ### Pixel Composite Settings
+ """)
+
+ mask_inf = \
+ gr.Slider(label=ui_labels.composite_mask_influence,
+ info=ui_info.composite_mask_influence,
+ minimum=0,
+ maximum=1,
+ step=0.05,
+ value=default.composite_mask_influence,
+ elem_id=el_ids.composite_mask_influence)
+
+ dif_thresh = \
+ gr.Slider(label=ui_labels.composite_difference_threshold,
+ info=ui_info.composite_difference_threshold,
+ minimum=0,
+ maximum=8,
+ step=0.25,
+ value=default.composite_difference_threshold,
+ elem_id=el_ids.composite_difference_threshold)
+
+ dif_contr = \
+ gr.Slider(label=ui_labels.composite_difference_contrast,
+ info=ui_info.composite_difference_contrast,
+ minimum=0,
+ maximum=8,
+ step=0.25,
+ value=default.composite_difference_contrast,
+ elem_id=el_ids.composite_difference_contrast)
with gr.Accordion("Help", open=False):
gr.Markdown(
@@ -507,41 +569,86 @@ class Script(scripts.Script):
- **High values**: Stronger contrast, may over-saturate colors.
""")
+ gr.Markdown(
+ """
+ ## Pixel Composite Settings
+
+ Masks are generated based on how much a part of the image changed after denoising.
+ These masks are used to blend the original and final images together.
+ If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.
+ """)
+
+ gr.Markdown(
+ f"""
+ ### {ui_labels.composite_mask_influence}
+
+ This parameter controls how much the mask should bias this sensitivity to difference.
+
+ - **0**: Ignore the mask, only consider differences in image content.
+ - **1**: Follow the mask closely despite image content changes.
+ """)
+
+ gr.Markdown(
+ f"""
+ ### {ui_labels.composite_difference_threshold}
+
+ This value represents the difference at which the opacity of the original pixels will have less than 50% opacity.
+
+ - **Low values**: Two images patches must be almost the same in order to retain original pixels.
+ - **High values**: Two images patches can be very different and still retain original pixels.
+ """)
+
+ gr.Markdown(
+ f"""
+ ### {ui_labels.composite_difference_contrast}
+
+ This value represents the difference at which the opacity of the original pixels will have less than 50% opacity.
+
+ - **Low values**: Two images patches must be almost the same in order to retain original pixels.
+ - **High values**: Two images patches can be very different and still retain original pixels.
+ """)
+
self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),
- (result.mask_blend_power, gen_param_labels.mask_blend_power),
- (result.mask_blend_scale, gen_param_labels.mask_blend_scale),
- (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)]
+ (power, gen_param_labels.mask_blend_power),
+ (scale, gen_param_labels.mask_blend_scale),
+ (detail, gen_param_labels.inpaint_detail_preservation),
+ (mask_inf, gen_param_labels.composite_mask_influence),
+ (dif_thresh, gen_param_labels.composite_difference_threshold),
+ (dif_contr, gen_param_labels.composite_difference_contrast)]
self.paste_field_names = []
for _, field_name in self.infotext_fields:
self.paste_field_names.append(field_name)
return [soft_inpainting_enabled,
- result.mask_blend_power,
- result.mask_blend_scale,
- result.inpaint_detail_preservation]
-
- def process(self, p, enabled, power, scale, detail_preservation):
+ power,
+ scale,
+ detail,
+ mask_inf,
+ dif_thresh,
+ dif_contr]
+
+ def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled:
return
# Shut off the rounding it normally does.
p.mask_round = False
- settings = SoftInpaintingSettings(power, scale, detail_preservation)
+ settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
# p.extra_generation_params["Mask rounding"] = False
settings.add_generation_params(p.extra_generation_params)
- def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation):
+ def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled:
return
- if mba.sigma is None:
+ if mba.is_final_blend:
mba.blended_latent = mba.current_latent
return
- settings = SoftInpaintingSettings(power, scale, detail_preservation)
+ settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
# todo: Why is sigma 2D? Both values are the same.
mba.blended_latent = latent_blend(settings,
@@ -549,11 +656,11 @@ class Script(scripts.Script):
mba.current_latent,
get_modified_nmask(settings, mba.nmask, mba.sigma[0]))
- def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation):
+ def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled:
return
- settings = SoftInpaintingSettings(power, scale, detail_preservation)
+ settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)
from modules import images
from modules.shared import opts
@@ -570,15 +677,20 @@ class Script(scripts.Script):
self.overlay_images.append(image.convert('RGBA'))
+ if len(p.init_images) == 1:
+ self.overlay_images = self.overlay_images * p.batch_size
+
if getattr(ps.samples, 'already_decoded', False):
- self.masks_for_overlay = apply_masks(soft_inpainting=settings,
+ self.masks_for_overlay = apply_masks(settings=settings,
nmask=p.nmask,
overlay_images=self.overlay_images,
width=p.width,
height=p.height,
paste_to=p.paste_to)
else:
- self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent,
+ self.masks_for_overlay = apply_adaptive_masks(settings=settings,
+ nmask=p.nmask,
+ latent_orig=p.init_latent,
latent_processed=ps.samples,
overlay_images=self.overlay_images,
width=p.width,
@@ -586,7 +698,7 @@ class Script(scripts.Script):
paste_to=p.paste_to)
- def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation):
+ def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
if not enabled:
return