aboutsummaryrefslogtreecommitdiffstats
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py119
1 files changed, 90 insertions, 29 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 92fdebad..ad716e11 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -9,7 +9,7 @@ from dataclasses import dataclass, field
import torch
import numpy as np
-from PIL import Image, ImageOps
+from PIL import Image, ImageOps, ImageFilter
import random
import cv2
from skimage import exposure
@@ -62,6 +62,16 @@ def apply_color_correction(correction, original_image):
return image.convert('RGB')
+def uncrop(image, dest_size, paste_loc):
+ x, y, w, h = paste_loc
+ base_image = Image.new('RGBA', dest_size)
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ return image
+
+
def apply_overlay(image, paste_loc, index, overlays):
if overlays is None or index >= len(overlays):
return image
@@ -69,11 +79,7 @@ def apply_overlay(image, paste_loc, index, overlays):
overlay = overlays[index]
if paste_loc is not None:
- x, y, w, h = paste_loc
- base_image = Image.new('RGBA', (overlay.width, overlay.height))
- image = images.resize_image(1, image, w, h)
- base_image.paste(image, (x, y))
- image = base_image
+ image = uncrop(image, (overlay.width, overlay.height), paste_loc)
image = image.convert('RGBA')
image.alpha_composite(overlay)
@@ -140,6 +146,7 @@ class StableDiffusionProcessing:
do_not_save_grid: bool = False
extra_generation_params: dict[str, Any] = None
overlay_images: list = None
+ masks_for_overlay: list = None
eta: float = None
do_not_reload_embeddings: bool = False
denoising_strength: float = 0
@@ -865,11 +872,66 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if getattr(samples_ddim, 'already_decoded', False):
x_samples_ddim = samples_ddim
+ # todo: generate masks the old fashioned way
else:
if opts.sd_vae_decode_method != 'Full':
p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+ # Generate the mask(s) based on similarity between the original and denoised latent vectors
+ if getattr(p, "image_mask", None) is not None:
+ # latent_mask = p.nmask[0].float().cpu()
+
+ # convert the original mask into a form we use to scale distances for thresholding
+ # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
+ # mask_scalar = mask_scalar / (1.00001-mask_scalar)
+ # mask_scalar = mask_scalar.numpy()
+
+ latent_orig = p.init_latent
+ latent_proc = samples_ddim
+ latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1)
+
+ kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)
+
+ for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)):
+ converted_mask = distance_map.float().cpu().numpy()
+ converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.9, percentile_max=1, min_width=1)
+ converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
+ percentile_min=0.25, percentile_max=0.75, min_width=1)
+
+ # The distance at which opacity of original decreases to 50%
+ # half_weighted_distance = 1 # * mask_scalar
+ # converted_mask = converted_mask / half_weighted_distance
+
+ converted_mask = 1 / (1 + converted_mask ** 2)
+ converted_mask = images.smootherstep(converted_mask)
+ converted_mask = 1 - converted_mask
+ converted_mask = 255. * converted_mask
+ converted_mask = converted_mask.astype(np.uint8)
+ converted_mask = Image.fromarray(converted_mask)
+ converted_mask = images.resize_image(2, converted_mask, p.width, p.height)
+ converted_mask = create_binary_mask(converted_mask)
+
+ # Remove aliasing artifacts using a gaussian blur.
+ converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))
+
+ # Expand the mask to fit the whole image if needed.
+ if p.paste_to is not None:
+ converted_mask = uncrop(converted_mask,
+ (overlay_image.width, overlay_image.height),
+ p.paste_to)
+
+ p.masks_for_overlay[i] = converted_mask
+
+ image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
+ image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
+ mask=ImageOps.invert(converted_mask.convert('L')))
+
+ p.overlay_images[i] = image_masked.convert('RGBA')
+
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim,
+ target_device=devices.cpu,
+ check_for_nans=True)
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -892,7 +954,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
x_samples_ddim = batch_params.images
def infotext(index=0, use_main_prompt=False):
- return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
+ return create_infotext(p, p.prompts, p.seeds, p.subseeds,
+ use_main_prompt=use_main_prompt, index=index,
+ all_negative_prompts=p.negative_prompts)
save_samples = p.save_samples()
@@ -923,19 +987,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
+ # If the intention is to show the output from the model
+ # that is being composited over the original image,
+ # we need to keep the original image around
+ # and use it in the composite step.
+ original_denoised_image = image.copy()
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if save_samples:
- images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
+ images.save_image(image, p.outpath_samples, "", p.seeds[i],
+ p.prompts[i], opts.samples_format, info=infotext(i), p=p)
text = infotext(i)
infotexts.append(text)
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
- if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
- image_mask = p.mask_for_overlay.convert('RGB')
- image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
+ if save_samples and hasattr(p, 'masks_for_overlay') and p.masks_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
+ image_mask = p.masks_for_overlay[i].convert('RGB')
+ image_mask_composite = Image.composite(
+ original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size),
+ images.resize_image(2, p.masks_for_overlay[i], image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
@@ -1364,7 +1436,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
nmask: torch.Tensor = field(default=None, init=False)
image_conditioning: torch.Tensor = field(default=None, init=False)
init_img_hash: str = field(default=None, init=False)
- mask_for_overlay: Image = field(default=None, init=False)
init_latent: torch.Tensor = field(default=None, init=False)
def __post_init__(self):
@@ -1415,12 +1486,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
- np_mask = np.array(image_mask).astype(np.float32)
- np_mask /= 255
- np_mask = 1-pow(1-np_mask, 100)
- np_mask *= 255
- np_mask = np.clip(np_mask, 0, 255).astype(np.uint8)
- self.mask_for_overlay = Image.fromarray(np_mask)
mask = image_mask.convert('L')
crop_region = masking.get_crop_region(np.array(mask), self.inpaint_full_res_padding)
crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
@@ -1431,13 +1496,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.paste_to = (x1, y1, x2-x1, y2-y1)
else:
image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height)
- np_mask = np.array(image_mask).astype(np.float32)
- np_mask /= 255
- np_mask = 1-pow(1-np_mask, 100)
- np_mask *= 255
- np_mask = np.clip(np_mask, 0, 255).astype(np.uint8)
- self.mask_for_overlay = Image.fromarray(np_mask)
+ self.masks_for_overlay = []
self.overlay_images = []
latent_mask = self.latent_mask if self.latent_mask is not None else image_mask
@@ -1459,10 +1519,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
image = images.resize_image(self.resize_mode, image, self.width, self.height)
if image_mask is not None:
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
-
- self.overlay_images.append(image_masked.convert('RGBA'))
+ self.overlay_images.append(image)
+ self.masks_for_overlay.append(image_mask)
# crop_region is not None if we are doing inpaint full res
if crop_region is not None:
@@ -1486,6 +1544,9 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.overlay_images is not None:
self.overlay_images = self.overlay_images * self.batch_size
+ if self.masks_for_overlay is not None:
+ self.masks_for_overlay = self.masks_for_overlay * self.batch_size
+
if self.color_corrections is not None and len(self.color_corrections) == 1:
self.color_corrections = self.color_corrections * self.batch_size