diff options
Diffstat (limited to 'modules')
-rw-r--r-- | modules/gfpgan_model.py | 58 | ||||
-rw-r--r-- | modules/images.py | 290 | ||||
-rw-r--r-- | modules/img2img.py | 133 | ||||
-rw-r--r-- | modules/lowvram.py | 73 | ||||
-rw-r--r-- | modules/paths.py | 21 | ||||
-rw-r--r-- | modules/processing.py | 409 | ||||
-rw-r--r-- | modules/realesrgan_model.py | 70 | ||||
-rw-r--r-- | modules/scripts.py | 53 | ||||
-rw-r--r-- | modules/sd_hijack.py | 208 | ||||
-rw-r--r-- | modules/sd_samplers.py | 137 | ||||
-rw-r--r-- | modules/shared.py | 121 | ||||
-rw-r--r-- | modules/txt2img.py | 52 | ||||
-rw-r--r-- | modules/ui.py | 539 |
13 files changed, 2164 insertions, 0 deletions
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py new file mode 100644 index 00000000..3f42c163 --- /dev/null +++ b/modules/gfpgan_model.py @@ -0,0 +1,58 @@ +import os
+import sys
+import traceback
+
+from modules.paths import script_path
+from modules.shared import cmd_opts
+
+
+def gfpgan_model_path():
+ places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]
+ files = [cmd_opts.gfpgan_model] + [os.path.join(dirname, cmd_opts.gfpgan_model) for dirname in places]
+ found = [x for x in files if os.path.exists(x)]
+
+ if len(found) == 0:
+ raise Exception("GFPGAN model not found in paths: " + ", ".join(files))
+
+ return found[0]
+
+
+loaded_gfpgan_model = None
+
+
+def gfpgan():
+ global loaded_gfpgan_model
+
+ if loaded_gfpgan_model is None and gfpgan_constructor is not None:
+ loaded_gfpgan_model = gfpgan_constructor(model_path=gfpgan_model_path(), upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
+
+ return loaded_gfpgan_model
+
+
+def gfpgan_fix_faces(np_image):
+ np_image_bgr = np_image[:, :, ::-1]
+ cropped_faces, restored_faces, gfpgan_output_bgr = gfpgan().enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
+ np_image = gfpgan_output_bgr[:, :, ::-1]
+
+ return np_image
+
+
+have_gfpgan = False
+gfpgan_constructor = None
+
+def setup_gfpgan():
+ try:
+ gfpgan_model_path()
+
+ if os.path.exists(cmd_opts.gfpgan_dir):
+ sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
+ from gfpgan import GFPGANer
+
+ global have_gfpgan
+ have_gfpgan = True
+
+ global gfpgan_constructor
+ gfpgan_constructor = GFPGANer
+ except Exception:
+ print("Error setting up GFPGAN:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
diff --git a/modules/images.py b/modules/images.py new file mode 100644 index 00000000..b05276c3 --- /dev/null +++ b/modules/images.py @@ -0,0 +1,290 @@ +import math
+import os
+from collections import namedtuple
+import re
+
+import numpy as np
+from PIL import Image, ImageFont, ImageDraw, PngImagePlugin
+
+from modules.shared import opts
+
+LANCZOS = (Image.Resampling.LANCZOS if hasattr(Image, 'Resampling') else Image.LANCZOS)
+
+
+def image_grid(imgs, batch_size=1, rows=None):
+ if rows is None:
+ if opts.n_rows > 0:
+ rows = opts.n_rows
+ elif opts.n_rows == 0:
+ rows = batch_size
+ else:
+ rows = math.sqrt(len(imgs))
+ rows = round(rows)
+
+ cols = math.ceil(len(imgs) / rows)
+
+ w, h = imgs[0].size
+ grid = Image.new('RGB', size=(cols * w, rows * h), color='black')
+
+ for i, img in enumerate(imgs):
+ grid.paste(img, box=(i % cols * w, i // cols * h))
+
+ return grid
+
+
+Grid = namedtuple("Grid", ["tiles", "tile_w", "tile_h", "image_w", "image_h", "overlap"])
+
+
+def split_grid(image, tile_w=512, tile_h=512, overlap=64):
+ w = image.width
+ h = image.height
+
+ now = tile_w - overlap # non-overlap width
+ noh = tile_h - overlap
+
+ cols = math.ceil((w - overlap) / now)
+ rows = math.ceil((h - overlap) / noh)
+
+ grid = Grid([], tile_w, tile_h, w, h, overlap)
+ for row in range(rows):
+ row_images = []
+
+ y = row * noh
+
+ if y + tile_h >= h:
+ y = h - tile_h
+
+ for col in range(cols):
+ x = col * now
+
+ if x+tile_w >= w:
+ x = w - tile_w
+
+ tile = image.crop((x, y, x + tile_w, y + tile_h))
+
+ row_images.append([x, tile_w, tile])
+
+ grid.tiles.append([y, tile_h, row_images])
+
+ return grid
+
+
+def combine_grid(grid):
+ def make_mask_image(r):
+ r = r * 255 / grid.overlap
+ r = r.astype(np.uint8)
+ return Image.fromarray(r, 'L')
+
+ mask_w = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((1, grid.overlap)).repeat(grid.tile_h, axis=0))
+ mask_h = make_mask_image(np.arange(grid.overlap, dtype=np.float32).reshape((grid.overlap, 1)).repeat(grid.image_w, axis=1))
+
+ combined_image = Image.new("RGB", (grid.image_w, grid.image_h))
+ for y, h, row in grid.tiles:
+ combined_row = Image.new("RGB", (grid.image_w, h))
+ for x, w, tile in row:
+ if x == 0:
+ combined_row.paste(tile, (0, 0))
+ continue
+
+ combined_row.paste(tile.crop((0, 0, grid.overlap, h)), (x, 0), mask=mask_w)
+ combined_row.paste(tile.crop((grid.overlap, 0, w, h)), (x + grid.overlap, 0))
+
+ if y == 0:
+ combined_image.paste(combined_row, (0, 0))
+ continue
+
+ combined_image.paste(combined_row.crop((0, 0, combined_row.width, grid.overlap)), (0, y), mask=mask_h)
+ combined_image.paste(combined_row.crop((0, grid.overlap, combined_row.width, h)), (0, y + grid.overlap))
+
+ return combined_image
+
+
+class GridAnnotation:
+ def __init__(self, text='', is_active=True):
+ self.text = text
+ self.is_active = is_active
+ self.size = None
+
+
+def draw_grid_annotations(im, width, height, hor_texts, ver_texts):
+ def wrap(drawing, text, font, line_length):
+ lines = ['']
+ for word in text.split():
+ line = f'{lines[-1]} {word}'.strip()
+ if drawing.textlength(line, font=font) <= line_length:
+ lines[-1] = line
+ else:
+ lines.append(word)
+ return lines
+
+ def draw_texts(drawing, draw_x, draw_y, lines):
+ for i, line in enumerate(lines):
+ drawing.multiline_text((draw_x, draw_y + line.size[1] / 2), line.text, font=fnt, fill=color_active if line.is_active else color_inactive, anchor="mm", align="center")
+
+ if not line.is_active:
+ drawing.line((draw_x - line.size[0]//2, draw_y + line.size[1]//2, draw_x + line.size[0]//2, draw_y + line.size[1]//2), fill=color_inactive, width=4)
+
+ draw_y += line.size[1] + line_spacing
+
+ fontsize = (width + height) // 25
+ line_spacing = fontsize // 2
+ fnt = ImageFont.truetype(opts.font, fontsize)
+ color_active = (0, 0, 0)
+ color_inactive = (153, 153, 153)
+
+ pad_left = width * 3 // 4 if len(ver_texts) > 0 else 0
+
+ cols = im.width // width
+ rows = im.height // height
+
+ assert cols == len(hor_texts), f'bad number of horizontal texts: {len(hor_texts)}; must be {cols}'
+ assert rows == len(ver_texts), f'bad number of vertical texts: {len(ver_texts)}; must be {rows}'
+
+ calc_img = Image.new("RGB", (1, 1), "white")
+ calc_d = ImageDraw.Draw(calc_img)
+
+ for texts, allowed_width in zip(hor_texts + ver_texts, [width] * len(hor_texts) + [pad_left] * len(ver_texts)):
+ items = [] + texts
+ texts.clear()
+
+ for line in items:
+ wrapped = wrap(calc_d, line.text, fnt, allowed_width)
+ texts += [GridAnnotation(x, line.is_active) for x in wrapped]
+
+ for line in texts:
+ bbox = calc_d.multiline_textbbox((0, 0), line.text, font=fnt)
+ line.size = (bbox[2] - bbox[0], bbox[3] - bbox[1])
+
+ hor_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing for lines in hor_texts]
+ ver_text_heights = [sum([line.size[1] + line_spacing for line in lines]) - line_spacing * len(lines) for lines in ver_texts]
+
+ pad_top = max(hor_text_heights) + line_spacing * 2
+
+ result = Image.new("RGB", (im.width + pad_left, im.height + pad_top), "white")
+ result.paste(im, (pad_left, pad_top))
+
+ d = ImageDraw.Draw(result)
+
+ for col in range(cols):
+ x = pad_left + width * col + width / 2
+ y = pad_top / 2 - hor_text_heights[col] / 2
+
+ draw_texts(d, x, y, hor_texts[col])
+
+ for row in range(rows):
+ x = pad_left / 2
+ y = pad_top + height * row + height / 2 - ver_text_heights[row] / 2
+
+ draw_texts(d, x, y, ver_texts[row])
+
+ return result
+
+
+def draw_prompt_matrix(im, width, height, all_prompts):
+ prompts = all_prompts[1:]
+ boundary = math.ceil(len(prompts) / 2)
+
+ prompts_horiz = prompts[:boundary]
+ prompts_vert = prompts[boundary:]
+
+ hor_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_horiz)] for pos in range(1 << len(prompts_horiz))]
+ ver_texts = [[GridAnnotation(x, is_active=pos & (1 << i) != 0) for i, x in enumerate(prompts_vert)] for pos in range(1 << len(prompts_vert))]
+
+ return draw_grid_annotations(im, width, height, hor_texts, ver_texts)
+
+
+def resize_image(resize_mode, im, width, height):
+ if resize_mode == 0:
+ res = im.resize((width, height), resample=LANCZOS)
+ elif resize_mode == 1:
+ ratio = width / height
+ src_ratio = im.width / im.height
+
+ src_w = width if ratio > src_ratio else im.width * height // im.height
+ src_h = height if ratio <= src_ratio else im.height * width // im.width
+
+ resized = im.resize((src_w, src_h), resample=LANCZOS)
+ res = Image.new("RGB", (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+ else:
+ ratio = width / height
+ src_ratio = im.width / im.height
+
+ src_w = width if ratio < src_ratio else im.width * height // im.height
+ src_h = height if ratio >= src_ratio else im.height * width // im.width
+
+ resized = im.resize((src_w, src_h), resample=LANCZOS)
+ res = Image.new("RGB", (width, height))
+ res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
+
+ if ratio < src_ratio:
+ fill_height = height // 2 - src_h // 2
+ res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
+ res.paste(resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)), box=(0, fill_height + src_h))
+ elif ratio > src_ratio:
+ fill_width = width // 2 - src_w // 2
+ res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
+ res.paste(resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)), box=(fill_width + src_w, 0))
+
+ return res
+
+
+invalid_filename_chars = '<>:"/\\|?*\n'
+
+
+def sanitize_filename_part(text):
+ return text.replace(' ', '_').translate({ord(x): '' for x in invalid_filename_chars})[:128]
+
+
+def save_image(image, path, basename, seed=None, prompt=None, extension='png', info=None, short_filename=False, no_prompt=False):
+ if short_filename or prompt is None or seed is None:
+ file_decoration = ""
+ elif opts.save_to_dirs:
+ file_decoration = f"-{seed}"
+ else:
+ file_decoration = f"-{seed}-{sanitize_filename_part(prompt)[:128]}"
+
+ if extension == 'png' and opts.enable_pnginfo and info is not None:
+ pnginfo = PngImagePlugin.PngInfo()
+ pnginfo.add_text("parameters", info)
+ else:
+ pnginfo = None
+
+ if opts.save_to_dirs and not no_prompt:
+ words = re.findall(r'\w+', prompt or "")
+ if len(words) == 0:
+ words = ["empty"]
+
+ dirname = " ".join(words[0:opts.save_to_dirs_prompt_len])
+ path = os.path.join(path, dirname)
+
+ os.makedirs(path, exist_ok=True)
+
+ filecount = len([x for x in os.listdir(path) if os.path.splitext(x)[1] == '.' + extension])
+ fullfn = "a.png"
+ fullfn_without_extension = "a"
+ for i in range(100):
+ fn = f"{filecount:05}" if basename == '' else f"{basename}-{filecount:04}"
+ fullfn = os.path.join(path, f"{fn}{file_decoration}.{extension}")
+ fullfn_without_extension = os.path.join(path, f"{fn}{file_decoration}")
+ if not os.path.exists(fullfn):
+ break
+
+ image.save(fullfn, quality=opts.jpeg_quality, pnginfo=pnginfo)
+
+ target_side_length = 4000
+ oversize = image.width > target_side_length or image.height > target_side_length
+ if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
+ ratio = image.width / image.height
+
+ if oversize and ratio > 1:
+ image = image.resize((target_side_length, image.height * target_side_length // image.width), LANCZOS)
+ elif oversize:
+ image = image.resize((image.width * target_side_length // image.height, target_side_length), LANCZOS)
+
+ image.save(f"{fullfn_without_extension}.jpg", quality=opts.jpeg_quality, pnginfo=pnginfo)
+
+ if opts.save_txt and info is not None:
+ with open(f"{fullfn_without_extension}.txt", "w", encoding="utf8") as file:
+ file.write(info + "\n")
+
diff --git a/modules/img2img.py b/modules/img2img.py new file mode 100644 index 00000000..f2817ba8 --- /dev/null +++ b/modules/img2img.py @@ -0,0 +1,133 @@ +import math
+from PIL import Image
+
+from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images
+from modules.shared import opts, state
+import modules.shared as shared
+import modules.processing as processing
+from modules.ui import plaintext_to_html
+import modules.images as images
+
+def img2img(prompt: str, init_img, init_img_with_mask, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, use_GFPGAN: bool, prompt_matrix, mode: int, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, height: int, width: int, resize_mode: int, upscaler_name: str, upscale_overlap: int, inpaint_full_res: bool):
+ is_inpaint = mode == 1
+ is_loopback = mode == 2
+ is_upscale = mode == 3
+
+ if is_inpaint:
+ image = init_img_with_mask['image']
+ mask = init_img_with_mask['mask']
+ else:
+ image = init_img
+ mask = None
+
+ assert 0. <= denoising_strength <= 1., 'can only work with strength in [0.0, 1.0]'
+
+ p = StableDiffusionProcessingImg2Img(
+ sd_model=shared.sd_model,
+ outpath_samples=opts.outdir_samples or opts.outdir_img2img_samples,
+ outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids,
+ prompt=prompt,
+ seed=seed,
+ sampler_index=sampler_index,
+ batch_size=batch_size,
+ n_iter=n_iter,
+ steps=steps,
+ cfg_scale=cfg_scale,
+ width=width,
+ height=height,
+ prompt_matrix=prompt_matrix,
+ use_GFPGAN=use_GFPGAN,
+ init_images=[image],
+ mask=mask,
+ mask_blur=mask_blur,
+ inpainting_fill=inpainting_fill,
+ resize_mode=resize_mode,
+ denoising_strength=denoising_strength,
+ inpaint_full_res=inpaint_full_res,
+ extra_generation_params={"Denoising Strength": denoising_strength}
+ )
+
+ if is_loopback:
+ output_images, info = None, None
+ history = []
+ initial_seed = None
+ initial_info = None
+
+ for i in range(n_iter):
+ p.n_iter = 1
+ p.batch_size = 1
+ p.do_not_save_grid = True
+
+ state.job = f"Batch {i + 1} out of {n_iter}"
+ processed = process_images(p)
+
+ if initial_seed is None:
+ initial_seed = processed.seed
+ initial_info = processed.info
+
+ p.init_images = [processed.images[0]]
+ p.seed = processed.seed + 1
+ p.denoising_strength = max(p.denoising_strength * 0.95, 0.1)
+ history.append(processed.images[0])
+
+ grid = images.image_grid(history, batch_size, rows=1)
+
+ images.save_image(grid, p.outpath_grids, "grid", initial_seed, prompt, opts.grid_format, info=info, short_filename=not opts.grid_extended_filename)
+
+ processed = Processed(p, history, initial_seed, initial_info)
+
+ elif is_upscale:
+ initial_seed = None
+ initial_info = None
+
+ upscaler = shared.sd_upscalers.get(upscaler_name, next(iter(shared.sd_upscalers.values())))
+ img = upscaler(init_img)
+
+ processing.torch_gc()
+
+ grid = images.split_grid(img, tile_w=width, tile_h=height, overlap=upscale_overlap)
+
+ p.n_iter = 1
+ p.do_not_save_grid = True
+ p.do_not_save_samples = True
+
+ work = []
+ work_results = []
+
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ work.append(tiledata[2])
+
+ batch_count = math.ceil(len(work) / p.batch_size)
+ print(f"SD upscaling will process a total of {len(work)} images tiled as {len(grid.tiles[0][2])}x{len(grid.tiles)} in a total of {batch_count} batches.")
+
+ for i in range(batch_count):
+ p.init_images = work[i*p.batch_size:(i+1)*p.batch_size]
+
+ state.job = f"Batch {i + 1} out of {batch_count}"
+ processed = process_images(p)
+
+ if initial_seed is None:
+ initial_seed = processed.seed
+ initial_info = processed.info
+
+ p.seed = processed.seed + 1
+ work_results += processed.images
+
+ image_index = 0
+ for y, h, row in grid.tiles:
+ for tiledata in row:
+ tiledata[2] = work_results[image_index] if image_index < len(work_results) else Image.new("RGB", (p.width, p.height))
+ image_index += 1
+
+ combined_image = images.combine_grid(grid)
+
+ if opts.samples_save:
+ images.save_image(combined_image, p.outpath_samples, "", initial_seed, prompt, opts.grid_format, info=initial_info)
+
+ processed = Processed(p, [combined_image], initial_seed, initial_info)
+
+ else:
+ processed = process_images(p)
+
+ return processed.images, processed.js(), plaintext_to_html(processed.info)
diff --git a/modules/lowvram.py b/modules/lowvram.py new file mode 100644 index 00000000..4b78deab --- /dev/null +++ b/modules/lowvram.py @@ -0,0 +1,73 @@ +import torch
+
+module_in_gpu = None
+cpu = torch.device("cpu")
+gpu = torch.device("cuda")
+device = gpu if torch.cuda.is_available() else cpu
+
+
+def setup_for_low_vram(sd_model, use_medvram):
+ parents = {}
+
+ def send_me_to_gpu(module, _):
+ """send this module to GPU; send whatever tracked module was previous in GPU to CPU;
+ we add this as forward_pre_hook to a lot of modules and this way all but one of them will
+ be in CPU
+ """
+ global module_in_gpu
+
+ module = parents.get(module, module)
+
+ if module_in_gpu == module:
+ return
+
+ if module_in_gpu is not None:
+ module_in_gpu.to(cpu)
+
+ module.to(gpu)
+ module_in_gpu = module
+
+ # see below for register_forward_pre_hook;
+ # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
+ # useless here, and we just replace those methods
+ def first_stage_model_encode_wrap(self, encoder, x):
+ send_me_to_gpu(self, None)
+ return encoder(x)
+
+ def first_stage_model_decode_wrap(self, decoder, z):
+ send_me_to_gpu(self, None)
+ return decoder(z)
+
+ # remove three big modules, cond, first_stage, and unet from the model and then
+ # send the model to GPU. Then put modules back. the modules will be in CPU.
+ stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = None, None, None
+ sd_model.to(device)
+ sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.model = stored
+
+ # register hooks for those the first two models
+ sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
+ sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
+ sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
+ sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
+ parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model
+
+ if use_medvram:
+ sd_model.model.register_forward_pre_hook(send_me_to_gpu)
+ else:
+ diff_model = sd_model.model.diffusion_model
+
+ # the third remaining model is still too big for 4 GB, so we also do the same for its submodules
+ # so that only one of them is in GPU at a time
+ stored = diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = None, None, None, None
+ sd_model.model.to(device)
+ diff_model.input_blocks, diff_model.middle_block, diff_model.output_blocks, diff_model.time_embed = stored
+
+ # install hooks for bits of third model
+ diff_model.time_embed.register_forward_pre_hook(send_me_to_gpu)
+ for block in diff_model.input_blocks:
+ block.register_forward_pre_hook(send_me_to_gpu)
+ diff_model.middle_block.register_forward_pre_hook(send_me_to_gpu)
+ for block in diff_model.output_blocks:
+ block.register_forward_pre_hook(send_me_to_gpu)
diff --git a/modules/paths.py b/modules/paths.py new file mode 100644 index 00000000..6d11b304 --- /dev/null +++ b/modules/paths.py @@ -0,0 +1,21 @@ +import argparse
+import os
+import sys
+
+script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
+sys.path.insert(0, script_path)
+
+# use current directory as SD dir if it has related files, otherwise parent dir of script as stated in guide
+sd_path = os.path.abspath('.') if os.path.exists('./ldm/models/diffusion/ddpm.py') else os.path.dirname(script_path)
+
+# add parent directory to path; this is where Stable diffusion repo should be
+path_dirs = [
+ (sd_path, 'ldm', 'Stable Diffusion'),
+ (os.path.join(sd_path, '../taming-transformers'), 'taming', 'Taming Transformers')
+]
+for d, must_exist, what in path_dirs:
+ must_exist_path = os.path.abspath(os.path.join(script_path, d, must_exist))
+ if not os.path.exists(must_exist_path):
+ print(f"Warning: {what} not found at path {must_exist_path}", file=sys.stderr)
+ else:
+ sys.path.append(os.path.join(script_path, d))
diff --git a/modules/processing.py b/modules/processing.py new file mode 100644 index 00000000..faf56c9c --- /dev/null +++ b/modules/processing.py @@ -0,0 +1,409 @@ +import contextlib
+import json
+import math
+import os
+import sys
+
+import torch
+import numpy as np
+from PIL import Image, ImageFilter, ImageOps
+import random
+
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
+from modules.shared import opts, cmd_opts, state
+import modules.shared as shared
+import modules.gfpgan_model as gfpgan
+import modules.images as images
+
+# some of those options should not be changed at all because they would break the model, so I removed them from options.
+opt_C = 4
+opt_f = 8
+
+
+def torch_gc():
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+
+class StableDiffusionProcessing:
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
+ self.sd_model = sd_model
+ self.outpath_samples: str = outpath_samples
+ self.outpath_grids: str = outpath_grids
+ self.prompt: str = prompt
+ self.negative_prompt: str = (negative_prompt or "")
+ self.seed: int = seed
+ self.sampler_index: int = sampler_index
+ self.batch_size: int = batch_size
+ self.n_iter: int = n_iter
+ self.steps: int = steps
+ self.cfg_scale: float = cfg_scale
+ self.width: int = width
+ self.height: int = height
+ self.prompt_matrix: bool = prompt_matrix
+ self.use_GFPGAN: bool = use_GFPGAN
+ self.do_not_save_samples: bool = do_not_save_samples
+ self.do_not_save_grid: bool = do_not_save_grid
+ self.extra_generation_params: dict = extra_generation_params
+ self.overlay_images = overlay_images
+ self.paste_to = None
+
+ def init(self):
+ pass
+
+ def sample(self, x, conditioning, unconditional_conditioning):
+ raise NotImplementedError()
+
+
+class Processed:
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed, info):
+ self.images = images_list
+ self.prompt = p.prompt
+ self.seed = seed
+ self.info = info
+ self.width = p.width
+ self.height = p.height
+ self.sampler = samplers[p.sampler_index].name
+ self.cfg_scale = p.cfg_scale
+ self.steps = p.steps
+
+ def js(self):
+ obj = {
+ "prompt": self.prompt,
+ "seed": int(self.seed),
+ "width": self.width,
+ "height": self.height,
+ "sampler": self.sampler,
+ "cfg_scale": self.cfg_scale,
+ "steps": self.steps,
+ }
+
+ return json.dumps(obj)
+
+
+def create_random_tensors(shape, seeds):
+ xs = []
+ for seed in seeds:
+ torch.manual_seed(seed)
+
+ # randn results depend on device; gpu and cpu get different results for same seed;
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
+ # but the original script had it like this so I do not dare change it for now because
+ # it will break everyone's seeds.
+ xs.append(torch.randn(shape, device=shared.device))
+ x = torch.stack(xs)
+ return x
+
+
+def process_images(p: StableDiffusionProcessing) -> Processed:
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
+
+ prompt = p.prompt
+
+ assert p.prompt is not None
+ torch_gc()
+
+ seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
+
+ os.makedirs(p.outpath_samples, exist_ok=True)
+ os.makedirs(p.outpath_grids, exist_ok=True)
+
+ comments = []
+
+ prompt_matrix_parts = []
+ if p.prompt_matrix:
+ all_prompts = []
+ prompt_matrix_parts = prompt.split("|")
+ combination_count = 2 ** (len(prompt_matrix_parts) - 1)
+ for combination_num in range(combination_count):
+ selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
+
+ if opts.prompt_matrix_add_to_start:
+ selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
+ else:
+ selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
+
+ all_prompts.append(", ".join(selected_prompts))
+
+ p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
+ all_seeds = len(all_prompts) * [seed]
+
+ print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
+ else:
+ all_prompts = p.batch_size * p.n_iter * [prompt]
+ all_seeds = [seed + x for x in range(len(all_prompts))]
+
+ def infotext(iteration=0, position_in_batch=0):
+ generation_params = {
+ "Steps": p.steps,
+ "Sampler": samplers[p.sampler_index].name,
+ "CFG scale": p.cfg_scale,
+ "Seed": all_seeds[position_in_batch + iteration * p.batch_size],
+ "GFPGAN": ("GFPGAN" if p.use_GFPGAN else None)
+ }
+
+ if p.extra_generation_params is not None:
+ generation_params.update(p.extra_generation_params)
+
+ generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+
+ return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
+
+ if os.path.exists(cmd_opts.embeddings_dir):
+ model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
+
+ output_images = []
+ precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
+ ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
+ p.init()
+
+ for n in range(p.n_iter):
+ if state.interrupted:
+ break
+
+ prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+
+ uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
+ c = p.sd_model.get_learned_conditioning(prompts)
+
+ if len(model_hijack.comments) > 0:
+ comments += model_hijack.comments
+
+ # we manually generate all input noises because each one should have a specific seed
+ x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
+
+ if p.n_iter > 1:
+ shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+
+ samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
+
+ x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for i, x_sample in enumerate(x_samples_ddim):
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+
+ if p.use_GFPGAN:
+ torch_gc()
+
+ x_sample = gfpgan.gfpgan_fix_faces(x_sample)
+
+ image = Image.fromarray(x_sample)
+
+ if p.overlay_images is not None and i < len(p.overlay_images):
+ overlay = p.overlay_images[i]
+
+ if p.paste_to is not None:
+ x, y, w, h = p.paste_to
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ image = image.convert('RGBA')
+ image.alpha_composite(overlay)
+ image = image.convert('RGB')
+
+ if opts.samples_save and not p.do_not_save_samples:
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i))
+
+ output_images.append(image)
+
+ unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
+ if not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
+ return_grid = opts.return_grid
+
+ if p.prompt_matrix:
+ grid = images.image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2))
+
+ try:
+ grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
+ except Exception:
+ import traceback
+ print("Error creating prompt_matrix text:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ return_grid = True
+ else:
+ grid = images.image_grid(output_images, p.batch_size)
+
+ if return_grid:
+ output_images.insert(0, grid)
+
+ if opts.grid_save:
+ images.save_image(grid, p.outpath_grids, "grid", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
+
+ torch_gc()
+ return Processed(p, output_images, seed, infotext())
+
+
+class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
+ sampler = None
+
+ def init(self):
+ self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
+
+ def sample(self, x, conditioning, unconditional_conditioning):
+ samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+ return samples_ddim
+
+
+def get_crop_region(mask, pad=0):
+ h, w = mask.shape
+
+ crop_left = 0
+ for i in range(w):
+ if not (mask[:, i] == 0).all():
+ break
+ crop_left += 1
+
+ crop_right = 0
+ for i in reversed(range(w)):
+ if not (mask[:, i] == 0).all():
+ break
+ crop_right += 1
+
+ crop_top = 0
+ for i in range(h):
+ if not (mask[i] == 0).all():
+ break
+ crop_top += 1
+
+ crop_bottom = 0
+ for i in reversed(range(h)):
+ if not (mask[i] == 0).all():
+ break
+ crop_bottom += 1
+
+ return (
+ int(max(crop_left-pad, 0)),
|