aboutsummaryrefslogtreecommitdiffstats
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py359
1 files changed, 275 insertions, 84 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 1a76e552..21d1492c 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -1,20 +1,20 @@
import json
+import logging
import math
import os
import sys
-import warnings
import hashlib
import torch
import numpy as np
-from PIL import Image, ImageFilter, ImageOps
+from PIL import Image, ImageOps
import random
import cv2
from skimage import exposure
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -24,13 +24,13 @@ import modules.images as images
import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
-import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
+
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
opt_f = 8
@@ -106,6 +106,9 @@ class StableDiffusionProcessing:
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
"""
+ cached_uc = [None, None]
+ cached_c = [None, None]
+
def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -150,6 +153,8 @@ class StableDiffusionProcessing:
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
self.disable_extra_networks = False
+ self.token_merging_ratio = 0
+ self.token_merging_ratio_hr = 0
if not seed_enable_extras:
self.subseed = -1
@@ -165,7 +170,21 @@ class StableDiffusionProcessing:
self.all_subseeds = None
self.iteration = 0
self.is_hr_pass = False
-
+ self.sampler = None
+
+ self.prompts = None
+ self.negative_prompts = None
+ self.extra_network_data = None
+ self.seeds = None
+ self.subseeds = None
+
+ self.step_multiplier = 1
+ self.cached_uc = StableDiffusionProcessing.cached_uc
+ self.cached_c = StableDiffusionProcessing.cached_c
+ self.uc = None
+ self.c = None
+
+ self.user = None
@property
def sd_model(self):
@@ -273,6 +292,64 @@ class StableDiffusionProcessing:
def close(self):
self.sampler = None
+ self.c = None
+ self.uc = None
+ if not opts.experimental_persistent_cond_cache:
+ StableDiffusionProcessing.cached_c = [None, None]
+ StableDiffusionProcessing.cached_uc = [None, None]
+
+ def get_token_merging_ratio(self, for_hr=False):
+ if for_hr:
+ return self.token_merging_ratio_hr or opts.token_merging_ratio_hr or self.token_merging_ratio or opts.token_merging_ratio
+
+ return self.token_merging_ratio or opts.token_merging_ratio
+
+ def setup_prompts(self):
+ if type(self.prompt) == list:
+ self.all_prompts = self.prompt
+ else:
+ self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
+
+ if type(self.negative_prompt) == list:
+ self.all_negative_prompts = self.negative_prompt
+ else:
+ self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
+
+ self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
+ self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
+
+ def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
+ """
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
+ using a cache to store the result if the same arguments have been used before.
+
+ cache is an array containing two elements. The first element is a tuple
+ representing the previously used arguments, or None if no arguments
+ have been used before. The second element is where the previously
+ computed result is stored.
+
+ caches is a list with items described above.
+ """
+ for cache in caches:
+ if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]:
+ return cache[1]
+
+ cache = caches[0]
+
+ with devices.autocast():
+ cache[1] = function(shared.sd_model, required_prompts, steps)
+
+ cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data)
+ return cache[1]
+
+ def setup_conds(self):
+ sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
+ self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+
+ def parse_extra_network_prompts(self):
+ self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
class Processed:
@@ -303,6 +380,8 @@ class Processed:
self.styles = p.styles
self.job_timestamp = state.job_timestamp
self.clip_skip = opts.CLIP_stop_at_last_layers
+ self.token_merging_ratio = p.token_merging_ratio
+ self.token_merging_ratio_hr = p.token_merging_ratio_hr
self.eta = p.eta
self.ddim_discretize = p.ddim_discretize
@@ -310,6 +389,7 @@ class Processed:
self.s_tmin = p.s_tmin
self.s_tmax = p.s_tmax
self.s_noise = p.s_noise
+ self.s_min_uncond = p.s_min_uncond
self.sampler_noise_scheduler_override = p.sampler_noise_scheduler_override
self.prompt = self.prompt if type(self.prompt) != list else self.prompt[0]
self.negative_prompt = self.negative_prompt if type(self.negative_prompt) != list else self.negative_prompt[0]
@@ -360,6 +440,9 @@ class Processed:
def infotext(self, p: StableDiffusionProcessing, index):
return create_infotext(p, self.all_prompts, self.all_seeds, self.all_subseeds, comments=[], position_in_batch=index % self.batch_size, iteration=index // self.batch_size)
+ def get_token_merging_ratio(self, for_hr=False):
+ return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
+
# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
def slerp(val, low, high):
@@ -468,10 +551,17 @@ def program_version():
return res
-def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0):
+def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0, use_main_prompt=False):
index = position_in_batch + iteration * p.batch_size
clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers)
+ enable_hr = getattr(p, 'enable_hr', False)
+ token_merging_ratio = p.get_token_merging_ratio()
+ token_merging_ratio_hr = p.get_token_merging_ratio(for_hr=True)
+
+ uses_ensd = opts.eta_noise_seed_delta != 0
+ if uses_ensd:
+ uses_ensd = sd_samplers_common.is_sampler_using_eta_noise_seed_delta(p)
generation_params = {
"Steps": p.steps,
@@ -485,27 +575,33 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')),
"Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
- "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
+ "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
"Denoising strength": getattr(p, 'denoising_strength', None),
"Conditional mask weight": getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) if p.is_using_inpainting_conditioning else None,
"Clip skip": None if clip_skip <= 1 else clip_skip,
- "ENSD": None if opts.eta_noise_seed_delta == 0 else opts.eta_noise_seed_delta,
+ "ENSD": opts.eta_noise_seed_delta if uses_ensd else None,
+ "Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
+ "Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
"RNG": opts.randn_source if opts.randn_source != "GPU" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
+ **p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
+ "User": p.user if opts.add_user_name_to_info else None,
}
- generation_params.update(p.extra_generation_params)
-
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
+ prompt_text = p.prompt if use_main_prompt else all_prompts[index]
negative_prompt_text = f"\nNegative prompt: {p.all_negative_prompts[index]}" if p.all_negative_prompts[index] else ""
- return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip()
+ return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
def process_images(p: StableDiffusionProcessing) -> Processed:
+ if p.scripts is not None:
+ p.scripts.before_process(p)
+
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
@@ -523,9 +619,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
if k == 'sd_vae':
sd_vae.reload_vae_weights()
+ sd_models.apply_token_merging(p.sd_model, p.get_token_merging_ratio())
+
res = process_images_inner(p)
finally:
+ sd_models.apply_token_merging(p.sd_model, 0)
+
# restore opts to original state
if p.override_settings_restore_afterwards:
for k, v in stored_opts.items():
@@ -555,15 +655,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
comments = {}
- if type(p.prompt) == list:
- p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
- else:
- p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
-
- if type(p.negative_prompt) == list:
- p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
- else:
- p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
+ p.setup_prompts()
if type(seed) == list:
p.all_seeds = seed
@@ -575,8 +667,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
else:
p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
- def infotext(iteration=0, position_in_batch=0):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch)
+ def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
+ return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -587,29 +679,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
infotexts = []
output_images = []
- cached_uc = [None, None]
- cached_c = [None, None]
-
- def get_conds_with_caching(function, required_prompts, steps, cache):
- """
- Returns the result of calling function(shared.sd_model, required_prompts, steps)
- using a cache to store the result if the same arguments have been used before.
-
- cache is an array containing two elements. The first element is a tuple
- representing the previously used arguments, or None if no arguments
- have been used before. The second element is where the previously
- computed result is stored.
- """
-
- if cache[0] is not None and (required_prompts, steps) == cache[0]:
- return cache[1]
-
- with devices.autocast():
- cache[1] = function(shared.sd_model, required_prompts, steps)
-
- cache[0] = (required_prompts, steps)
- return cache[1]
-
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -618,10 +687,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
sd_vae_approx.model()
+ sd_unet.apply_unet()
+
if state.job_count == -1:
state.job_count = p.n_iter
- extra_network_data = None
for n in range(p.n_iter):
p.iteration = n
@@ -631,25 +701,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
- prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if p.scripts is not None:
- p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+ p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
- if len(prompts) == 0:
+ if len(p.prompts) == 0:
break
- prompts, extra_network_data = extra_networks.parse_prompts(prompts)
+ p.parse_extra_network_prompts()
if not p.disable_extra_networks:
with devices.autocast():
- extra_networks.activate(p, extra_network_data)
+ extra_networks.activate(p, p.extra_network_data)
if p.scripts is not None:
- p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+ p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
@@ -660,14 +730,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
- step_multiplier = 1
- if not shared.opts.dont_fix_second_order_samplers_schedule:
- try:
- step_multiplier = 2 if sd_samplers.all_samplers_map.get(p.sampler_name).aliases[0] in ['k_dpmpp_2s_a', 'k_dpmpp_2s_a_ka', 'k_dpmpp_sde', 'k_dpmpp_sde_ka', 'k_dpm_2', 'k_dpm_2_a', 'k_heun'] else 1
- except:
- pass
- uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
- c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
+ p.setup_conds()
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -677,7 +740,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
+ samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
for x in x_samples_ddim:
@@ -688,7 +751,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
del samples_ddim
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -704,7 +767,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
- images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
devices.torch_gc()
@@ -721,13 +784,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
- images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
text = infotext(n, i)
infotexts.append(text)
@@ -740,10 +803,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
- images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
if opts.save_mask_composite:
- images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
if opts.return_mask:
output_images.append(image_mask)
@@ -765,7 +828,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
grid = images.image_grid(output_images, p.batch_size)
if opts.return_grid:
- text = infotext()
+ text = infotext(use_main_prompt=True)
infotexts.insert(0, text)
if opts.enable_pnginfo:
grid.info["parameters"] = text
@@ -773,10 +836,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
index_of_first_image = 1
if opts.grid_save:
- images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True)
+ images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
- if not p.disable_extra_networks and extra_network_data:
- extra_networks.deactivate(p, extra_network_data)
+ if not p.disable_extra_networks and p.extra_network_data:
+ extra_networks.deactivate(p, p.extra_network_data)
devices.torch_gc()
@@ -785,7 +848,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
images_list=output_images,
seed=p.all_seeds[0],
info=infotext(),
- comments="".join(f"\n\n{comment}" for comment in comments),
+ comments="".join(f"{comment}\n" for comment in comments),
subseed=p.all_subseeds[0],
index_of_first_image=index_of_first_image,
infotexts=infotexts,
@@ -811,8 +874,10 @@ def old_hires_fix_first_pass_dimensions(width, height):
class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
+ cached_hr_uc = [None, None]
+ cached_hr_c = [None, None]
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
@@ -823,6 +888,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_resize_y = hr_resize_y
self.hr_upscale_to_x = hr_resize_x
self.hr_upscale_to_y = hr_resize_y
+ self.hr_sampler_name = hr_sampler_name
+ self.hr_prompt = hr_prompt
+ self.hr_negative_prompt = hr_negative_prompt
+ self.all_hr_prompts = None
+ self.all_hr_negative_prompts = None
if firstphase_width != 0 or firstphase_height != 0:
self.hr_upscale_to_x = self.width
@@ -834,8 +904,26 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_y = 0
self.applied_old_hires_behavior_to = None
+ self.hr_prompts = None
+ self.hr_negative_prompts = None
+ self.hr_extra_network_data = None
+
+ self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc
+ self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c
+ self.hr_c = None
+ self.hr_uc = None
+
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
+ self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
+
+ if tuple(self.hr_prompt) != tuple(self.prompt):
+ self.extra_generation_params["Hires prompt"] = self.hr_prompt
+
+ if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
+ self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
+
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
self.hr_resize_x = self.width
self.hr_resize_y = self.height
@@ -901,7 +989,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
if self.enable_hr and latent_scale_mode is None:
- assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}"
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
@@ -965,9 +1054,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
shared.state.nextjob()
- img2img_sampler_name = self.sampler_name
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
+
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
img2img_sampler_name = 'DDIM'
+
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
@@ -978,17 +1069,101 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+ if not self.disable_extra_networks:
+ with devices.autocast():
+ extra_networks.activate(self, self.hr_extra_network_data)
+
+ with devices.autocast():
+ self.calculate_hr_conds()
+
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
+
+ if self.scripts is not None:
+ self.scripts.before_hr(self)
+
+ samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+
+ sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
self.is_hr_pass = False
return samples
+ def close(self):
+ super().close()
+ self.hr_c = None
+ self.hr_uc = None
+ if not opts.experimental_persistent_cond_cache:
+ StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
+ StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
+
+ def setup_prompts(self):
+ super().setup_prompts()
+
+ if not self.enable_hr:
+ return
+
+ if self.hr_prompt == '':
+ self.hr_prompt = self.prompt
+
+ if self.hr_negative_prompt == '':
+ self.hr_negative_prompt = self.negative_prompt
+
+ if type(self.hr_prompt) == list:
+ self.all_hr_prompts = self.hr_prompt
+ else:
+ self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
+
+ if type(self.hr_negative_prompt) == list:
+ self.all_hr_negative_prompts = self.hr_negative_prompt
+ else:
+ self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
+
+ self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
+ self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
+
+ def calculate_hr_conds(self):
+ if self.hr_c is not None:
+ return
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+
+ def setup_conds(self):
+ super().setup_conds()
+
+ self.hr_uc = None
+ self.hr_c = None
+
+ if self.enable_hr:
+ if shared.opts.hires_fix_use_firstpass_conds:
+ self.calculate_hr_conds()
+
+ elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
+ with devices.autocast():
+ extra_networks.activate(self, self.hr_extra_network_data)
+
+ self.calculate_hr_conds()
+
+ with devices.autocast():
+ extra_networks.activate(self, self.extra_network_data)
+
+ def parse_extra_network_prompts(self):
+ res = super().parse_extra_network_prompts()
+
+ if self.enable_hr:
+ self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
+ self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
+
+ self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
+
+ return res
+
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
+ def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs):
super().__init__(**kwargs)
self.init_images = init_images
@@ -999,7 +1174,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_mask = mask
self.latent_mask = None
self.mask_for_overlay = None
- self.mask_blur = mask_blur
+ if mask_blur is not None:
+ mask_blur_x = mask_blur
+ mask_blur_y = mask_blur
+ self.mask_blur_x = mask_blur_x
+ self.mask_blur_y = mask_blur_y
self.inpainting_fill = inpainting_fill
self.inpaint_full_res = inpaint_full_res
self.inpaint_full_res_padding = inpaint_full_res_padding
@@ -1021,8 +1200,17 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
if self.inpainting_mask_invert:
image_mask = ImageOps.invert(image_mask)
- if self.mask_blur > 0:
- image_mask = image_mask.filter(ImageFilter.GaussianBlur(self.mask_blur))
+ if self.mask_blur_x > 0:
+ np_mask = np.array(image_mask)
+ kernel_size = 2 * int(4 * self.mask_blur_x + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x)
+ image_mask = Image.fromarray(np_mask)
+
+ if self.mask_blur_y > 0:
+ np_mask = np.array(image_mask)
+ kernel_size = 2 * int(4 * self.mask_blur_y + 0.5) + 1
+ np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y)
+ image_mask = Image.fromarray(np_mask)
if self.inpaint_full_res:
self.mask_for_overlay = image_mask
@@ -1141,3 +1329,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
devices.torch_gc()
return samples
+
+ def get_token_merging_ratio(self, for_hr=False):
+ return self.token_merging_ratio or ("token_merging_ratio" in self.override_settings and opts.token_merging_ratio) or opts.token_merging_ratio_img2img or opts.token_merging_ratio