diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-09-03 09:08:45 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-09-03 09:08:45 +0000 |
commit | 345028099d893f8a66726cfd13627d8cc1bcc724 (patch) | |
tree | acb1da553620b7e7139db840ef43accf71b786a8 /modules/processing.py | |
parent | d7b67d9b40e47ede766d3beb149b0c2b74651ece (diff) | |
download | stable-diffusion-webui-gfx803-345028099d893f8a66726cfd13627d8cc1bcc724.tar.gz stable-diffusion-webui-gfx803-345028099d893f8a66726cfd13627d8cc1bcc724.tar.bz2 stable-diffusion-webui-gfx803-345028099d893f8a66726cfd13627d8cc1bcc724.zip |
split codebase into multiple files; to anyone this affects negatively: sorry
Diffstat (limited to 'modules/processing.py')
-rw-r--r-- | modules/processing.py | 409 |
1 files changed, 409 insertions, 0 deletions
diff --git a/modules/processing.py b/modules/processing.py new file mode 100644 index 00000000..faf56c9c --- /dev/null +++ b/modules/processing.py @@ -0,0 +1,409 @@ +import contextlib
+import json
+import math
+import os
+import sys
+
+import torch
+import numpy as np
+from PIL import Image, ImageFilter, ImageOps
+import random
+
+from modules.sd_hijack import model_hijack
+from modules.sd_samplers import samplers, samplers_for_img2img
+from modules.shared import opts, cmd_opts, state
+import modules.shared as shared
+import modules.gfpgan_model as gfpgan
+import modules.images as images
+
+# some of those options should not be changed at all because they would break the model, so I removed them from options.
+opt_C = 4
+opt_f = 8
+
+
+def torch_gc():
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+ torch.cuda.ipc_collect()
+
+
+class StableDiffusionProcessing:
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", seed=-1, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, prompt_matrix=False, use_GFPGAN=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
+ self.sd_model = sd_model
+ self.outpath_samples: str = outpath_samples
+ self.outpath_grids: str = outpath_grids
+ self.prompt: str = prompt
+ self.negative_prompt: str = (negative_prompt or "")
+ self.seed: int = seed
+ self.sampler_index: int = sampler_index
+ self.batch_size: int = batch_size
+ self.n_iter: int = n_iter
+ self.steps: int = steps
+ self.cfg_scale: float = cfg_scale
+ self.width: int = width
+ self.height: int = height
+ self.prompt_matrix: bool = prompt_matrix
+ self.use_GFPGAN: bool = use_GFPGAN
+ self.do_not_save_samples: bool = do_not_save_samples
+ self.do_not_save_grid: bool = do_not_save_grid
+ self.extra_generation_params: dict = extra_generation_params
+ self.overlay_images = overlay_images
+ self.paste_to = None
+
+ def init(self):
+ pass
+
+ def sample(self, x, conditioning, unconditional_conditioning):
+ raise NotImplementedError()
+
+
+class Processed:
+ def __init__(self, p: StableDiffusionProcessing, images_list, seed, info):
+ self.images = images_list
+ self.prompt = p.prompt
+ self.seed = seed
+ self.info = info
+ self.width = p.width
+ self.height = p.height
+ self.sampler = samplers[p.sampler_index].name
+ self.cfg_scale = p.cfg_scale
+ self.steps = p.steps
+
+ def js(self):
+ obj = {
+ "prompt": self.prompt,
+ "seed": int(self.seed),
+ "width": self.width,
+ "height": self.height,
+ "sampler": self.sampler,
+ "cfg_scale": self.cfg_scale,
+ "steps": self.steps,
+ }
+
+ return json.dumps(obj)
+
+
+def create_random_tensors(shape, seeds):
+ xs = []
+ for seed in seeds:
+ torch.manual_seed(seed)
+
+ # randn results depend on device; gpu and cpu get different results for same seed;
+ # the way I see it, it's better to do this on CPU, so that everyone gets same result;
+ # but the original script had it like this so I do not dare change it for now because
+ # it will break everyone's seeds.
+ xs.append(torch.randn(shape, device=shared.device))
+ x = torch.stack(xs)
+ return x
+
+
+def process_images(p: StableDiffusionProcessing) -> Processed:
+ """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
+
+ prompt = p.prompt
+
+ assert p.prompt is not None
+ torch_gc()
+
+ seed = int(random.randrange(4294967294) if p.seed == -1 else p.seed)
+
+ os.makedirs(p.outpath_samples, exist_ok=True)
+ os.makedirs(p.outpath_grids, exist_ok=True)
+
+ comments = []
+
+ prompt_matrix_parts = []
+ if p.prompt_matrix:
+ all_prompts = []
+ prompt_matrix_parts = prompt.split("|")
+ combination_count = 2 ** (len(prompt_matrix_parts) - 1)
+ for combination_num in range(combination_count):
+ selected_prompts = [text.strip().strip(',') for n, text in enumerate(prompt_matrix_parts[1:]) if combination_num & (1 << n)]
+
+ if opts.prompt_matrix_add_to_start:
+ selected_prompts = selected_prompts + [prompt_matrix_parts[0]]
+ else:
+ selected_prompts = [prompt_matrix_parts[0]] + selected_prompts
+
+ all_prompts.append(", ".join(selected_prompts))
+
+ p.n_iter = math.ceil(len(all_prompts) / p.batch_size)
+ all_seeds = len(all_prompts) * [seed]
+
+ print(f"Prompt matrix will create {len(all_prompts)} images using a total of {p.n_iter} batches.")
+ else:
+ all_prompts = p.batch_size * p.n_iter * [prompt]
+ all_seeds = [seed + x for x in range(len(all_prompts))]
+
+ def infotext(iteration=0, position_in_batch=0):
+ generation_params = {
+ "Steps": p.steps,
+ "Sampler": samplers[p.sampler_index].name,
+ "CFG scale": p.cfg_scale,
+ "Seed": all_seeds[position_in_batch + iteration * p.batch_size],
+ "GFPGAN": ("GFPGAN" if p.use_GFPGAN else None)
+ }
+
+ if p.extra_generation_params is not None:
+ generation_params.update(p.extra_generation_params)
+
+ generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
+
+ return f"{prompt}\n{generation_params_text}".strip() + "".join(["\n\n" + x for x in comments])
+
+ if os.path.exists(cmd_opts.embeddings_dir):
+ model_hijack.load_textual_inversion_embeddings(cmd_opts.embeddings_dir, p.sd_model)
+
+ output_images = []
+ precision_scope = torch.autocast if cmd_opts.precision == "autocast" else contextlib.nullcontext
+ ema_scope = (contextlib.nullcontext if cmd_opts.lowvram else p.sd_model.ema_scope)
+ with torch.no_grad(), precision_scope("cuda"), ema_scope():
+ p.init()
+
+ for n in range(p.n_iter):
+ if state.interrupted:
+ break
+
+ prompts = all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ seeds = all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+
+ uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt])
+ c = p.sd_model.get_learned_conditioning(prompts)
+
+ if len(model_hijack.comments) > 0:
+ comments += model_hijack.comments
+
+ # we manually generate all input noises because each one should have a specific seed
+ x = create_random_tensors([opt_C, p.height // opt_f, p.width // opt_f], seeds=seeds)
+
+ if p.n_iter > 1:
+ shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+
+ samples_ddim = p.sample(x=x, conditioning=c, unconditional_conditioning=uc)
+
+ x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
+ x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
+
+ for i, x_sample in enumerate(x_samples_ddim):
+ x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
+ x_sample = x_sample.astype(np.uint8)
+
+ if p.use_GFPGAN:
+ torch_gc()
+
+ x_sample = gfpgan.gfpgan_fix_faces(x_sample)
+
+ image = Image.fromarray(x_sample)
+
+ if p.overlay_images is not None and i < len(p.overlay_images):
+ overlay = p.overlay_images[i]
+
+ if p.paste_to is not None:
+ x, y, w, h = p.paste_to
+ base_image = Image.new('RGBA', (overlay.width, overlay.height))
+ image = images.resize_image(1, image, w, h)
+ base_image.paste(image, (x, y))
+ image = base_image
+
+ image = image.convert('RGBA')
+ image.alpha_composite(overlay)
+ image = image.convert('RGB')
+
+ if opts.samples_save and not p.do_not_save_samples:
+ images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i))
+
+ output_images.append(image)
+
+ unwanted_grid_because_of_img_count = len(output_images) < 2 and opts.grid_only_if_multiple
+ if not p.do_not_save_grid and not unwanted_grid_because_of_img_count:
+ return_grid = opts.return_grid
+
+ if p.prompt_matrix:
+ grid = images.image_grid(output_images, p.batch_size, rows=1 << ((len(prompt_matrix_parts)-1)//2))
+
+ try:
+ grid = images.draw_prompt_matrix(grid, p.width, p.height, prompt_matrix_parts)
+ except Exception:
+ import traceback
+ print("Error creating prompt_matrix text:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ return_grid = True
+ else:
+ grid = images.image_grid(output_images, p.batch_size)
+
+ if return_grid:
+ output_images.insert(0, grid)
+
+ if opts.grid_save:
+ images.save_image(grid, p.outpath_grids, "grid", seed, prompt, opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename)
+
+ torch_gc()
+ return Processed(p, output_images, seed, infotext())
+
+
+class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
+ sampler = None
+
+ def init(self):
+ self.sampler = samplers[self.sampler_index].constructor(self.sd_model)
+
+ def sample(self, x, conditioning, unconditional_conditioning):
+ samples_ddim = self.sampler.sample(self, x, conditioning, unconditional_conditioning)
+ return samples_ddim
+
+
+def get_crop_region(mask, pad=0):
+ h, w = mask.shape
+
+ crop_left = 0
+ for i in range(w):
+ if not (mask[:, i] == 0).all():
+ break
+ crop_left += 1
+
+ crop_right = 0
+ for i in reversed(range(w)):
+ if not (mask[:, i] == 0).all():
+ break
+ crop_right += 1
+
+ crop_top = 0
+ for i in range(h):
+ if not (mask[i] == 0).all():
+ break
+ crop_top += 1
+
+ crop_bottom = 0
+ for i in reversed(range(h)):
+ if not (mask[i] == 0).all():
+ break
+ crop_bottom += 1
+
+ return (
+ int(max(crop_left-pad, 0)),
+ int(max(crop_top-pad, 0)),
+ int(min(w - crop_right + pad, w)),
+ int(min(h - crop_bottom + pad, h))
+ )
+
+
+def fill(image, mask):
+ image_mod = Image.new('RGBA', (image.width, image.height))
+
+ image_masked = Image.new('RGBa', (image.width, image.height))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert('L')))
+
+ image_masked = image_masked.convert('RGBa')
+
+ for radius, repeats in [(64, 1), (16, 2), (4, 4), (2, 2), (0, 1)]:
+ blurred = image_masked.filter(ImageFilter.GaussianBlur(radius)).convert('RGBA')
+ for _ in range(repeats):
+ image_mod.alpha_composite(blurred)
+
+ return image_mod.convert("RGB")
+
+
+class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
+ sampler = None
+
+ def __init__(self, init_images=None, resize_mode=0, denoising_strength=0.75, mask=None, mask_blur=4, inpainting_fill=0, inpaint_full_res=True, **kwargs):
+ super().__init__(**kwargs)
+
+ self.init_images = init_images
+ self.resize_mode: int = resize_mode
+ self.denoising_strength: float = denoising_strength
+ self.init_latent = None
+ self.image_mask = mask
+ self.mask_for_overlay = None
+ self.mask_blur = mask_blur
+ self.inpainting_fill = inpainting_fill
+ self.inpaint_full_res = inpaint_full_res
+ self.mask = None
+ self.nmask = None
+
+ def init(self):
+ self.sampler = samplers_for_img2img[self.sampler_index].constructor(self.sd_model)
+ crop_region = None
+
+ if self.image_mask is not None:
+ if self.mask_blur > 0:
+ self.image_mask = self.image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur)).convert('L')
+
+ if self.inpaint_full_res:
+ self.mask_for_overlay = self.image_mask
+ mask = self.image_mask.convert('L')
+ crop_region = get_crop_region(np.array(mask), 64)
+ x1, y1, x2, y2 = crop_region
+
+ mask = mask.crop(crop_region)
+ self.image_mask = images.resize_image(2, mask, self.width, self.height)
+ self.paste_to = (x1, y1, x2-x1, y2-y1)
+ else:
+ self.image_mask = images.resize_image(self.resize_mode, self.image_mask, self.width, self.height)
+ self.mask_for_overlay = self.image_mask
+
+ self.overlay_images = []
+
+ imgs = []
+ for img in self.init_images:
+ image = img.convert("RGB")
+
+ if crop_region is None:
+ image = images.resize_image(self.resize_mode, image, self.width, self.height)
+
+ if self.image_mask is not None:
+ if self.inpainting_fill != 1:
+ image = fill(image, self.mask_for_overlay)
+
+ image_masked = Image.new('RGBa', (image.width, image.height))
+ image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L')))
+
+ self.overlay_images.append(image_masked.convert('RGBA'))
+
+ if crop_region is not None:
+ image = image.crop(crop_region)
+ image = images.resize_image(2, image, self.width, self.height)
+
+ image = np.array(image).astype(np.float32) / 255.0
+ image = np.moveaxis(image, 2, 0)
+
+ imgs.append(image)
+
+ if len(imgs) == 1:
+ batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0)
+ if self.overlay_images is not None:
+ self.overlay_images = self.overlay_images * self.batch_size
+ elif len(imgs) <= self.batch_size:
+ self.batch_size = len(imgs)
+ batch_images = np.array(imgs)
+ else:
+ raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
+
+ image = torch.from_numpy(batch_images)
+ image = 2. * image - 1.
+ image = image.to(shared.device)
+
+ self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+
+ if self.image_mask is not None:
+ latmask = self.image_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
+ latmask = np.moveaxis(np.array(latmask, dtype=np.float64), 2, 0) / 255
+ latmask = latmask[0]
+ latmask = np.tile(latmask[None], (4, 1, 1))
+
+ self.mask = torch.asarray(1.0 - latmask).to(shared.device).type(self.sd_model.dtype)
+ self.nmask = torch.asarray(latmask).to(shared.device).type(self.sd_model.dtype)
+
+ if self.inpainting_fill == 2:
+ self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], [self.seed + x + 1 for x in range(self.init_latent.shape[0])]) * self.nmask
+ elif self.inpainting_fill == 3:
+ self.init_latent = self.init_latent * self.mask
+
+ def sample(self, x, conditioning, unconditional_conditioning):
+ samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning)
+
+ if self.mask is not None:
+ samples = samples * self.nmask + self.init_latent * self.mask
+
+ return samples
|