aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--modules/images.py15
-rw-r--r--modules/img2img.py2
-rw-r--r--modules/processing.py9
-rw-r--r--modules/scripts.py100
-rw-r--r--modules/txt2img.py2
-rw-r--r--modules/ui.py6
-rw-r--r--scripts/poor_mans_outpainting.py110
7 files changed, 193 insertions, 51 deletions
diff --git a/modules/images.py b/modules/images.py
index b05276c3..4b9667d2 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -39,23 +39,26 @@ def split_grid(image, tile_w=512, tile_h=512, overlap=64):
w = image.width
h = image.height
- now = tile_w - overlap # non-overlap width
- noh = tile_h - overlap
+ non_overlap_width = tile_w - overlap
+ non_overlap_height = tile_h - overlap
- cols = math.ceil((w - overlap) / now)
- rows = math.ceil((h - overlap) / noh)
+ cols = math.ceil((w - overlap) / non_overlap_width)
+ rows = math.ceil((h - overlap) / non_overlap_height)
+
+ dx = (w - tile_w) // (cols-1) if cols > 1 else 0
+ dy = (h - tile_h) // (rows-1) if rows > 1 else 0
grid = Grid([], tile_w, tile_h, w, h, overlap)
for row in range(rows):
row_images = []
- y = row * noh
+ y = row * dy
if y + tile_h >= h:
y = h - tile_h
for col in range(cols):
- x = col * now
+ x = col * dx
if x+tile_w >= w:
x = w - tile_w
diff --git a/modules/img2img.py b/modules/img2img.py
index 06de2db3..d5787dd3 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -130,7 +130,7 @@ def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index
else:
- processed = modules.scripts.run(p, *args)
+ processed = modules.scripts.scripts_img2img.run(p, *args)
if processed is None:
processed = process_images(p)
diff --git a/modules/processing.py b/modules/processing.py
index 2830209e..adc5d851 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -271,7 +271,7 @@ def fill(image, mask):
image_masked = image_masked.convert('RGBa')
- for radius, repeats in [(64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
+ for radius, repeats in [(256, 1), (64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
for _ in range(repeats):
image_mod.alpha_composite(blurred)
@@ -290,6 +290,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.denoising_strength: float = denoising_strength
self.init_latent = None
self.image_mask = mask
+ #self.image_unblurred_mask = None
+ self.latent_mask = None
self.mask_for_overlay = None
self.mask_blur = mask_blur
self.inpainting_fill = inpainting_fill
@@ -308,6 +310,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_mask_invert:
self.image_mask = ImageOps.invert(self.image_mask)
+ #self.image_unblurred_mask = self.image_mask
+
if self.mask_blur > 0:
self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
@@ -368,7 +372,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
if self.image_mask is not None:
- latmask = self.image_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
+ init_mask = self.latent_mask if self.latent_mask is not None else self.image_mask
+ latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
latmask = latmask[0]
latmask = np.tile(latmask[None], (4, 1, 1))
diff --git a/modules/scripts.py b/modules/scripts.py
index 99502857..89a0618d 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -18,6 +18,9 @@ class Script:
def ui(self, is_img2img):
pass
+ def show(self, is_img2img):
+ return True
+
def run(self, *args):
raise NotImplementedError()
@@ -25,7 +28,7 @@ class Script:
return ""
-scripts = []
+scripts_data = []
def load_scripts(basedir):
@@ -49,10 +52,8 @@ def load_scripts(basedir):
for key, script_class in module.__dict__.items():
if type(script_class) == type and issubclass(script_class, Script):
- obj = script_class()
- obj.filename = path
+ scripts_data.append((script_class, path))
- scripts.append(obj)
except Exception:
print(f"Error loading script: {filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
@@ -69,52 +70,75 @@ def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
return default
-def setup_ui(is_img2img):
- titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in scripts]
+class ScriptRunner:
+ def __init__(self):
+ self.scripts = []
+
+ def setup_ui(self, is_img2img):
+ for script_class, path in scripts_data:
+ script = script_class()
+ script.filename = path
+
+ if not script.show(is_img2img):
+ continue
+
+ self.scripts.append(script)
+
+ titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.scripts]
+
+ dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
+ inputs = [dropdown]
+
+ for script in self.scripts:
+ script.args_from = len(inputs)
+
+ controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
+
+ if controls is None:
+ continue
- dropdown = gr.Dropdown(label="Script", choices=["None"] + titles, value="None", type="index")
+ for control in controls:
+ control.visible = False
- inputs = [dropdown]
+ inputs += controls
+ script.args_to = len(inputs)
- for script in scripts:
- script.args_from = len(inputs)
- controls = script.ui(is_img2img)
+ def select_script(script_index):
+ if 0 < script_index <= len(self.scripts):
+ script = self.scripts[script_index-1]
+ args_from = script.args_from
+ args_to = script.args_to
+ else:
+ args_from = 0
+ args_to = 0
- for control in controls:
- control.visible = False
+ return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
- inputs += controls
- script.args_to = len(inputs)
+ dropdown.change(
+ fn=select_script,
+ inputs=[dropdown],
+ outputs=inputs
+ )
- def select_script(index):
- if index > 0:
- script = scripts[index-1]
- args_from = script.args_from
- args_to = script.args_to
- else:
- args_from = 0
- args_to = 0
+ return inputs
- return [ui.gr_show(True if i == 0 else args_from <= i < args_to) for i in range(len(inputs))]
- dropdown.change(
- fn=select_script,
- inputs=[dropdown],
- outputs=inputs
- )
+ def run(self, p: StableDiffusionProcessing, *args):
+ script_index = args[0]
- return inputs
+ if script_index == 0:
+ return None
+ script = self.scripts[script_index-1]
-def run(p: StableDiffusionProcessing, *args):
- script_index = args[0] - 1
+ if script is None:
+ return None
- if script_index < 0 or script_index >= len(scripts):
- return None
+ script_args = args[script.args_from:script.args_to]
+ processed = script.run(p, *script_args)
- script = scripts[script_index]
+ return processed
- script_args = args[script.args_from:script.args_to]
- processed = script.run(p, *script_args)
- return processed
+scripts_txt2img = ScriptRunner()
+scripts_img2img = ScriptRunner()
diff --git a/modules/txt2img.py b/modules/txt2img.py
index f5ac0540..fb65a7f6 100644
--- a/modules/txt2img.py
+++ b/modules/txt2img.py
@@ -24,7 +24,7 @@ def txt2img(prompt: str, negative_prompt: str, steps: int, sampler_index: int, u
use_GFPGAN=use_GFPGAN
)
- processed = modules.scripts.run(p, *args)
+ processed = modules.scripts.scripts_txt2img.run(p, *args)
if processed is not None:
pass
diff --git a/modules/ui.py b/modules/ui.py
index ccca871a..65d53bcd 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -162,7 +162,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
seed = gr.Number(label='Seed', value=-1)
with gr.Group():
- custom_inputs = modules.scripts.setup_ui(is_img2img=False)
+ custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
with gr.Column(variant='panel'):
with gr.Group():
@@ -244,7 +244,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)
with gr.Row():
- inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=True, visible=False)
+ inpaint_full_res = gr.Checkbox(label='Inpaint at full resolution', value=False, visible=False)
inpainting_mask_invert = gr.Radio(label='Masking mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", visible=False)
with gr.Row():
@@ -269,7 +269,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo):
seed = gr.Number(label='Seed', value=-1)
with gr.Group():
- custom_inputs = modules.scripts.setup_ui(is_img2img=True)
+ custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
with gr.Column(variant='panel'):
diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py
new file mode 100644
index 00000000..98e1def0
--- /dev/null
+++ b/scripts/poor_mans_outpainting.py
@@ -0,0 +1,110 @@
+import math
+
+import modules.scripts as scripts
+import gradio as gr
+from PIL import Image, ImageDraw
+
+from modules import images, processing
+from modules.processing import Processed, process_images
+from modules.shared import opts, cmd_opts, state
+
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Poor man's outpainting"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ if not is_img2img:
+ return None
+
+ pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=128, step=8)
+ mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, visible=False)
+ inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", visible=False)
+
+ return [pixels, mask_blur, inpainting_fill]
+
+ def run(self, p, pixels, mask_blur, inpainting_fill):
+ initial_seed = None
+ initial_info = None
+
+ p.mask_blur = mask_blur
+ p.inpainting_fill = inpainting_fill
+ p.inpaint_full_res = False
+
+ init_img = p.init_images[0]
+ target_w = math.ceil((init_img.width + pixels * 2) / 64) * 64
+ target_h = math.ceil((init_img.height + pixels * 2) / 64) * 64
+
+ border_x = (target_w - init_img.width)//2
+ border_y = (target_h - init_img.height)//2
+
+ img = Image.new("RGB", (target_w, target_h))
+ img.paste(init_img, (border_x, border_y))
+
+ mask = Image.new("L", (img.width, img.height), "white")
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle((border_x + mask_blur * 2, border_y + mask_blur * 2, mask.width - border_x - mask_blur * 2, mask.height - border_y - mask_blur * 2), fill="black")
+
+ latent_mask = Image.new("L", (img.width, img.height), "white")
+ latent_draw = ImageDraw.Draw(latent_mask)
+ latent_draw.rectangle((border_x + 1, border_y + 1, mask.width - border_x - 1, mask.height - border_y - 1), fill="black")
+
+ processing.torch_gc()
+
+ grid = images.split_grid(img, tile_w=p.width, tile_h=p.height, overlap=pixels)
+ grid_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
+ grid_latent_mask = images.split_grid(mask, tile_w=p.width, tile_h=p.height, overlap=pixels)
+
+ p.n_iter = 1
+ p.batch_size = 1
+ p.do_not_save_grid = True
+ p.do_not_save_samples = True
+
+ work = []
+ work_mask = []
+ work_latent_mask = []
+ work_results = []
+
+ for (_, _, row), (_, _, row_mask), (_, _, row_latent_mask) in zip(grid.tiles, grid_mask.tiles, grid_latent_mask.tiles):
+ for tiledata, tiledata_mask, tiledata_latent_mask in zip(row, row_mask, row_latent_mask):
+ work.append(tiledata[2])
+ work_mask.append(tiledata_mask[2])
+ work_latent_mask.append(tiledata_latent_mask[2])
+
+ batch_count = len(work)
+ print(f"Poor man's outpainting will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)}.")
+
+ for i in range(batch_count):
+ p.init_images = [work[i]]
+ p.image_mask = work_mask[i]
+ p.latent_mask = work_latent_mask[i]
+
+ state.job = f"Batch {i + 1} out of {batch_count}"
+ processed = process_images(p)
+
+ if initial_seed is None:
+ initial_seed = processed.seed
+ initial_info = processed.info
+
+ p.seed = processed.seed + 1
+ work_results += processed.images
+
+ image_index = 0
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
+ image_index += 1
+
+ combined_image = images.combine_grid(grid)
+
+ if opts.samples_save:
+ images.save_image(combined_image, p.outpath_samples, "", initial_seed, p.prompt, opts.grid_format, info=initial_info)
+
+ processed = Processed(p, [combined_image], initial_seed, initial_info)
+
+ return processed
+