diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/masking.py | 99 | ||||
-rw-r--r-- | modules/processing.py | 59 | ||||
-rw-r--r-- | modules/shared.py | 2 | ||||
-rw-r--r-- | modules/ui.py | 57 |
4 files changed, 139 insertions, 78 deletions
diff --git a/modules/masking.py b/modules/masking.py new file mode 100644 index 00000000..fd8d9241 --- /dev/null +++ b/modules/masking.py @@ -0,0 +1,99 @@ +from PIL import Image, ImageFilter, ImageOps
+
+
+def get_crop_region(mask, pad=0):
+ """finds a rectangular region that contains all masked ares in an image. Returns (x1, y1, x2, y2) coordinates of the rectangle.
+ For example, if a user has painted the top-right part of a 512x512 image", the result may be (256, 0, 512, 256)"""
+
+ h, w = mask.shape
+
+ crop_left = 0
+ for i in range(w):
+ if not (mask[:, i] == 0).all():
+ break
+ crop_left += 1
+
+ crop_right = 0
+ for i in reversed(range(w)):
+ if not (mask[:, i] == 0).all():
+ break
+ crop_right += 1
+
+ crop_top = 0
+ for i in range(h):
+ if not (mask[i] == 0).all():
+ break
+ crop_top += 1
+
+ crop_bottom = 0
+ for i in reversed(range(h)):
+ if not (mask[i] == 0).all():
+ break
+ crop_bottom += 1
+
+ return (
+ int(max(crop_left-pad, 0)),
+ int(max(crop_top-pad, 0)),
+ int(min(w - crop_right + pad, w)),
+ int(min(h - crop_bottom + pad, h))
+ )
+
+
+def expand_crop_region(crop_region, processing_width, processing_height, image_width, image_height):
+ """expands crop region get_crop_region() to match the ratio of the image the region will processed in; returns expanded region
+ for example, if user drew mask in a 128x32 region, and the dimensions for processing are 512x512, the region will be expanded to 128x128."""
+
+ x1, y1, x2, y2 = crop_region
+
+ ratio_crop_region = (x2 - x1) / (y2 - y1)
+ ratio_processing = processing_width / processing_height
+
+ if ratio_crop_region > ratio_processing:
+ desired_height = (x2 - x1) * ratio_processing
+ desired_height_diff = int(desired_height - (y2-y1))
+ y1 -= desired_height_diff//2
+ y2 += desired_height_diff - desired_height_diff//2
+ if y2 >= image_height:
+ diff = y2 - image_height
+ y2 -= diff
+ y1 -= diff
+ if y1 < 0:
+ y2 -= y1
+ y1 -= y1
+ if y2 >= image_height:
+ y2 = image_height
+ else:
+ desired_width = (y2 - y1) * ratio_processing
+ desired_width_diff = int(desired_width - (x2-x1))
+ x1 -= desired_width_diff//2
+ x2 += desired_width_diff - desired_width_diff//2
+ if x2 >= image_width:
+ diff = x2 - image_width
+ x2 -= diff
+ x1 -= diff
+ if x1 < 0:
+ x2 -= x1
+ x1 -= x1
+ if x2 >= image_width:
+ x2 = image_width
+
+ return x1, y1, x2, y2
+
+
+def fill(image, mask):
+ """fills masked regions with colors from image using blur. Not extremely effective."""
+
+ image_mod = Image.new('RGBA', (image.width, image.height))
+
+ image_masked = Image.new('RGBa', (image.width, image.height))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
+
+ image_masked = image_masked.convert('RGBa')
+
+ for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
+ blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
+ for _ in range(repeats):
+ image_mod.alpha_composite(blurred)
+
+ return image_mod.convert("RGB")
+
diff --git a/modules/processing.py b/modules/processing.py index 147d64e3..1afbe39c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ import cv2 from skimage import exposure
import modules.sd_hijack
-from modules import devices, prompt_parser
+from modules import devices, prompt_parser, masking
from modules.sd_hijack import model_hijack
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.shared import opts, cmd_opts, state
@@ -365,58 +365,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
return samples_ddim
-
-def get_crop_region(mask, pad=0):
- h, w = mask.shape
-
- crop_left = 0
- for i in range(w):
- if not (mask[:, i] == 0).all():
- break
- crop_left += 1
-
- crop_right = 0
- for i in reversed(range(w)):
- if not (mask[:, i] == 0).all():
- break
- crop_right += 1
-
- crop_top = 0
- for i in range(h):
- if not (mask[i] == 0).all():
- break
- crop_top += 1
-
- crop_bottom = 0
- for i in reversed(range(h)):
- if not (mask[i] == 0).all():
- break
- crop_bottom += 1
-
- return (
- int(max(crop_left-pad, 0)),
- int(max(crop_top-pad, 0)),
- int(min(w - crop_right + pad, w)),
- int(min(h - crop_bottom + pad, h))
- )
-
-
-def fill(image, mask):
- image_mod = Image.new('RGBA', (image.width, image.height))
-
- image_masked = Image.new('RGBa', (image.width, image.height))
- image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
-
- image_masked = image_masked.convert('RGBa')
-
- for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
- blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
- for _ in range(repeats):
- image_mod.alpha_composite(blurred)
-
- return image_mod.convert("RGB")
-
-
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
@@ -456,7 +404,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.inpaint_full_res:
self.mask_for_overlay = self.image_mask
mask = self.image_mask.convert('L')
- crop_region = get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding)
+ crop_region = masking.get_crop_region(np.array(mask), opts.upscale_at_full_resolution_padding)
+ crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height)
x1, y1, x2, y2 = crop_region
mask = mask.crop(crop_region)
@@ -494,7 +443,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): if self.image_mask is not None:
if self.inpainting_fill != 1:
- image = fill(image, latent_mask)
+ image = masking.fill(image, latent_mask)
if add_color_corrections:
self.color_corrections.append(setup_color_correction(image))
diff --git a/modules/shared.py b/modules/shared.py index 3c3aa9b6..c5742c10 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -31,7 +31,7 @@ parser.add_argument("--embeddings-dir", type=str, default=os.path.join(script_pa parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
-parser.add_argument("--always-batch-cond-uncond", action='store_true', help="a workaround test; may help with speed if you use --lowvram")
+parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram")
parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)")
diff --git a/modules/ui.py b/modules/ui.py index 451ad253..ada84d33 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -164,7 +164,6 @@ def wrap_gradio_call(func): def check_progress_call():
-
if shared.state.job_count == 0:
return "", gr_show(False), gr_show(False)
@@ -201,6 +200,12 @@ def check_progress_call(): return f"<span style='display: none'>{time.time()}</span><p>{progressbar}</p>", preview_visibility, image
+def check_progress_call_initial():
+ shared.state.job_count = -1
+
+ return check_progress_call()
+
+
def roll_artist(prompt):
allowed_cats = set([x for x in shared.artist_db.categories() if len(opts.random_artist_categories)==0 or x in opts.random_artist_categories])
artist = random.choice([x for x in shared.artist_db.artists if x.category in allowed_cats])
@@ -308,14 +313,30 @@ def create_toprow(is_img2img): prompt_style_apply = gr.Button('Apply style', elem_id="style_apply")
save_style = gr.Button('Create style', elem_id="style_create")
- check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)
+ return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style
- return prompt, roll, prompt_style, negative_prompt, prompt_style2, submit, interrogate, prompt_style_apply, save_style, check_progress
+
+def setup_progressbar(progressbar, preview):
+ check_progress = gr.Button('Check progress', elem_id="check_progress", visible=False)
+ check_progress.click(
+ fn=check_progress_call,
+ show_progress=False,
+ inputs=[],
+ outputs=[progressbar, preview, preview],
+ )
+
+ check_progress_initial = gr.Button('Check progress (first)', elem_id="check_progress_initial", visible=False)
+ check_progress_initial.click(
+ fn=check_progress_call_initial,
+ show_progress=False,
+ inputs=[],
+ outputs=[progressbar, preview, preview],
+ )
def create_ui(txt2img, img2img, run_extras, run_pnginfo):
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style, check_progress = create_toprow(is_img2img=False)
+ txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, txt2img_prompt_style_apply, txt2img_save_style = create_toprow(is_img2img=False)
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -348,6 +369,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False)
txt2img_gallery = gr.Gallery(label='Output', elem_id='txt2img_gallery').style(grid=4)
+ setup_progressbar(progressbar, txt2img_preview)
+
with gr.Group():
with gr.Row():
save = gr.Button('Save')
@@ -384,19 +407,13 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): txt2img_gallery,
generation_info,
html_info
- ]
+ ],
+ show_progress=False,
)
txt2img_prompt.submit(**txt2img_args)
submit.click(**txt2img_args)
- check_progress.click(
- fn=check_progress_call,
- show_progress=False,
- inputs=[],
- outputs=[progressbar, txt2img_preview, txt2img_preview],
- )
-
interrupt.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
@@ -429,7 +446,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): )
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style, check_progress = create_toprow(is_img2img=True)
+ img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_prompt_style_apply, img2img_save_style = create_toprow(is_img2img=True)
with gr.Row().style(equal_height=False):
with gr.Column(variant='panel'):
@@ -485,6 +502,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): img2img_preview = gr.Image(elem_id='img2img_preview', visible=False)
img2img_gallery = gr.Gallery(label='Output', elem_id='img2img_gallery').style(grid=4)
+ setup_progressbar(progressbar, img2img_preview)
+
with gr.Group():
with gr.Row():
save = gr.Button('Save')
@@ -589,7 +608,8 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): img2img_gallery,
generation_info,
html_info
- ]
+ ],
+ show_progress=False,
)
img2img_prompt.submit(**img2img_args)
@@ -601,13 +621,6 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): outputs=[img2img_prompt],
)
- check_progress.click(
- fn=check_progress_call,
- show_progress=False,
- inputs=[],
- outputs=[progressbar, img2img_preview, img2img_preview],
- )
-
interrupt.click(
fn=lambda: shared.state.interrupt(),
inputs=[],
@@ -616,7 +629,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo): save.click(
fn=wrap_gradio_call(save_files),
- _js = "(x, y, z) => [x, y, selected_gallery_index()]",
+ _js="(x, y, z) => [x, y, selected_gallery_index()]",
inputs=[
generation_info,
img2img_gallery,
|