aboutsummaryrefslogtreecommitdiffstats
path: root/modules/img2img.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-08 07:31:20 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-09-08 07:31:20 +0000
commit61785cef656335cce3ab50b420301d2821f7c5e1 (patch)
treedf66e649b5e30c21fb68d6b93af9bb48b0f48584 /modules/img2img.py
parent0fedd50886fb2f745cc6faab001090b77fbd0382 (diff)
parent9ddaf8269ebfb11c8fd2e48f0e8d33c125213437 (diff)
downloadstable-diffusion-webui-gfx803-61785cef656335cce3ab50b420301d2821f7c5e1.tar.gz
stable-diffusion-webui-gfx803-61785cef656335cce3ab50b420301d2821f7c5e1.tar.bz2
stable-diffusion-webui-gfx803-61785cef656335cce3ab50b420301d2821f7c5e1.zip
Merge remote-tracking branch 'origin/master'
Diffstat (limited to 'modules/img2img.py')
-rw-r--r--modules/img2img.py35
1 files changed, 31 insertions, 4 deletions
diff --git a/modules/img2img.py b/modules/img2img.py
index 3129798d..c2392305 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -1,5 +1,7 @@
import math
-from PIL import Image
+import cv2
+import numpy as np
+from PIL import Image, ImageOps, ImageChops
from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
from modules.shared import opts, state
@@ -16,7 +18,9 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
if is_inpaint:
image = init_img_with_mask['image']
- mask = init_img_with_mask['mask']
+ alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
+ mask = ImageChops.lighter(alpha_mask, init_img_with_mask['mask'].convert('L')).convert('RGBA')
+ image = image.convert('RGB')
else:
image = init_img
mask = None
@@ -57,8 +61,19 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
state.job_count = n_iter
+ do_color_correction = False
+ try:
+ from skimage import exposure
+ do_color_correction = True
+ except:
+ print("Install scikit-image to perform color correction on loopback")
+
+
for i in range(n_iter):
+ if do_color_correction and i == 0:
+ correction_target = cv2.cvtColor(np.asarray(init_img.copy()), cv2.COLOR_RGB2LAB)
+
p.n_iter = 1
p.batch_size = 1
p.do_not_save_grid = True
@@ -69,8 +84,20 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
if initial_seed is None:
initial_seed = processed.seed
initial_info = processed.info
-
- p.init_images = [processed.images[0]]
+
+ init_img = processed.images[0]
+
+ if do_color_correction and correction_target is not None:
+ init_img = Image.fromarray(cv2.cvtColor(exposure.match_histograms(
+ cv2.cvtColor(
+ np.asarray(init_img),
+ cv2.COLOR_RGB2LAB
+ ),
+ correction_target,
+ channel_axis=2
+ ), cv2.COLOR_LAB2RGB).astype("uint8"))
+
+ p.init_images = [init_img]
p.seed = processed.seed + 1
p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
history.append(processed.images[0])