aboutsummaryrefslogtreecommitdiffstats
path: root/webui.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-08-30 11:04:49 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-08-30 11:04:49 +0000
commitff98e09d726fae5c87aab8c1316150865edf89b9 (patch)
tree74af44418ac738a8e1013fb9771c763c847e04bb /webui.py
parent54f74d44725b09540352f52ddf3b73e63a19bdda (diff)
downloadstable-diffusion-webui-gfx803-ff98e09d726fae5c87aab8c1316150865edf89b9.tar.gz
stable-diffusion-webui-gfx803-ff98e09d726fae5c87aab8c1316150865edf89b9.tar.bz2
stable-diffusion-webui-gfx803-ff98e09d726fae5c87aab8c1316150865edf89b9.zip
UI options for mask blur and inpainting fill
Diffstat (limited to 'webui.py')
-rw-r--r--webui.py43
1 files changed, 34 insertions, 9 deletions
diff --git a/webui.py b/webui.py
index 6998464c..27838325 100644
--- a/webui.py
+++ b/webui.py
@@ -789,7 +789,7 @@ class EmbeddingsWithFixes(nn.Module):
class StableDiffusionProcessing:
- def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None):
+ def __init__(self, outpath=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None):
self.outpath: str = outpath
self.prompt: str = prompt
self.seed: int = seed
@@ -805,6 +805,7 @@ class StableDiffusionProcessing:
self.do_not_save_samples: bool = do_not_save_samples
self.do_not_save_grid: bool = do_not_save_grid
self.extra_generation_params: dict = extra_generation_params
+ self.overlay_images = overlay_images
def init(self):
pass
@@ -950,6 +951,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
image = Image.fromarray(x_sample)
+ if p.overlay_images is not None and i < len(p.overlay_images):
+ image = image.convert('RGBA')
+ image.alpha_composite(p.overlay_images[i])
+ image = image.convert('RGB')
+
if not p.do_not_save_samples:
save_image(image, sample_path, f"{base_count:05}", seeds[i], prompts[i], opts.samples_format, info=infotext())
@@ -1122,7 +1128,7 @@ def fill(image, mask):
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, **kwargs):
+ def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -1131,6 +1137,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = None
self.original_mask = mask
self.mask_blur = mask_blur
+ self.inpainting_fill = inpainting_fill
self.mask = None
self.nmask = None
@@ -1149,14 +1156,22 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.mask = torch.asarray(1.0 - latmask).to(device).type(sd_model.dtype)
self.nmask = torch.asarray(latmask).to(device).type(sd_model.dtype)
+ self.overlay_images = []
+
imgs = []
for img in self.init_images:
image = img.convert("RGB")
image = resize_image(self.resize_mode, image, self.width, self.height)
- if self.original_mask is not None
- image = fill(image, self.original_mask)
+ if self.original_mask is not None:
+ if self.inpainting_fill == 0:
+ image = fill(image, self.original_mask)
+
+ image_masked = Image.new('RGBa', (image.width, image.height))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.original_mask.convert('L')))
+
+ self.overlay_images.append(image_masked.convert('RGBA'))
image = np.array(image).astype(np.float32) / 255.0
image = np.moveaxis(image, 2, 0)
@@ -1165,6 +1180,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if len(imgs) == 1:
batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
+ if self.overlay_images is not None:
+ self.overlay_images = self.overlay_images * self.batch_size
elif len(imgs) <= self.batch_size:
self.batch_size = len(imgs)
batch_images = np.array(imgs)
@@ -1178,15 +1195,19 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image))
def sample(self, x, conditioning, unconditional_conditioning):
- t_enc = int(self.denoising_strength * self.steps)
+ t_enc = int(min(self.denoising_strength, 0.999) * self.steps)
sigmas = self.sampler.model_wrap.get_sigmas(self.steps)
noise = x * sigmas[self.steps - t_enc - 1]
xi = self.init_latent + noise
- sigma_sched = sigmas[self.steps - t_enc - 1:]
- #if self.mask is not None:
- # xi = xi * self.mask + noise * self.nmask
+ if self.mask is not None:
+ if self.inpainting_fill == 2:
+ xi = xi * self.mask + noise * self.nmask
+ elif self.inpainting_fill == 3:
+ xi = xi * self.mask
+
+ sigma_sched = sigmas[self.steps - t_enc - 1:]
def mask_cb(v):
v["denoised"][:] = v["denoised"][:] * self.nmask + self.init_latent * self.mask
@@ -1199,7 +1220,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
return samples_ddim
-def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_index: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
+def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, loopback: bool, sd_upscale: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int):
outpath = opts.outdir or "outputs/img2img-samples"
if init_img_with_mask is not None:
@@ -1226,6 +1247,8 @@ def img2img(prompt: str, init_img, init_img_with_mask, ddim_steps: int, sampler_
use_GFPGAN=use_GFPGAN,
init_images=[image],
mask=mask,
+ mask_blur=mask_blur,
+ inpainting_fill=inpainting_fill,
resize_mode=resize_mode,
denoising_strength=denoising_strength,
extra_generation_params={"Denoising Strength": denoising_strength}
@@ -1327,6 +1350,8 @@ img2img_interface = gr.Interface(
gr.Image(label="Image for inpainting with mask", source="upload", interactive=True, type="pil", tool="sketch"),
gr.Slider(minimum=1, maximum=150, step=1, label="Sampling Steps", value=20),
gr.Radio(label='Sampling method', choices=[x.name for x in samplers_for_img2img], value=samplers_for_img2img[0].name, type="index"),
+ gr.Slider(label='Inpainting: mask blur', minimum=0, maximum=64, step=1, value=4),
+ gr.Radio(label='Inpainting: masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index"),
gr.Checkbox(label='Fix faces using GFPGAN', value=False, visible=have_gfpgan),
gr.Checkbox(label='Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', value=False),
gr.Checkbox(label='Loopback (use images from previous batch when creating next batch)', value=False),