aboutsummaryrefslogtreecommitdiffstats
path: root/modules/processing.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/processing.py')
-rwxr-xr-xmodules/processing.py392
1 files changed, 203 insertions, 189 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 63cd025c..c048ca25 100755
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -14,8 +14,10 @@ from skimage import exposure
from typing import Any, Dict, List
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng
+from modules.rng import slerp # noqa: F401
from modules.sd_hijack import model_hijack
+from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.paths as paths
@@ -83,7 +85,7 @@ def txt2img_image_conditioning(sd_model, x, width, height):
# The "masked-image" in this case will just be all zeros since the entire image is masked.
image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+ image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method))
# Add the fake full 1s mask to the first dimension.
image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
@@ -109,7 +111,7 @@ class StableDiffusionProcessing:
cached_uc = [None, None]
cached_c = [None, None]
- def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = False, tiling: bool = False, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = 1.0, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
+ def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None):
if sampler_index is not None:
print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr)
@@ -147,8 +149,8 @@ class StableDiffusionProcessing:
self.s_min_uncond = s_min_uncond or opts.s_min_uncond
self.s_churn = s_churn or opts.s_churn
self.s_tmin = s_tmin or opts.s_tmin
- self.s_tmax = s_tmax or float('inf') # not representable as a standard ui option
- self.s_noise = s_noise or opts.s_noise
+ self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf')
+ self.s_noise = s_noise if s_noise is not None else opts.s_noise
self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts}
self.override_settings_restore_afterwards = override_settings_restore_afterwards
self.is_using_inpainting_conditioning = False
@@ -171,6 +173,8 @@ class StableDiffusionProcessing:
self.iteration = 0
self.is_hr_pass = False
self.sampler = None
+ self.main_prompt = None
+ self.main_negative_prompt = None
self.prompts = None
self.negative_prompts = None
@@ -183,6 +187,7 @@ class StableDiffusionProcessing:
self.cached_c = StableDiffusionProcessing.cached_c
self.uc = None
self.c = None
+ self.rng: rng.ImageRNG = None
self.user = None
@@ -202,7 +207,7 @@ class StableDiffusionProcessing:
midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+ conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
size=conditioning_image.shape[2:],
@@ -215,7 +220,7 @@ class StableDiffusionProcessing:
return conditioning
def edit_image_conditioning(self, source_image):
- conditioning_image = self.sd_model.encode_first_stage(source_image).mode()
+ conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method))
return conditioning_image
@@ -294,7 +299,7 @@ class StableDiffusionProcessing:
self.sampler = None
self.c = None
self.uc = None
- if not opts.experimental_persistent_cond_cache:
+ if not opts.persistent_cond_cache:
StableDiffusionProcessing.cached_c = [None, None]
StableDiffusionProcessing.cached_uc = [None, None]
@@ -318,6 +323,24 @@ class StableDiffusionProcessing:
self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
+ self.main_prompt = self.all_prompts[0]
+ self.main_negative_prompt = self.all_negative_prompts[0]
+
+ def cached_params(self, required_prompts, steps, extra_network_data):
+ """Returns parameters that invalidate the cond cache if changed"""
+
+ return (
+ required_prompts,
+ steps,
+ opts.CLIP_stop_at_last_layers,
+ shared.sd_model.sd_checkpoint_info,
+ extra_network_data,
+ opts.sdxl_crop_left,
+ opts.sdxl_crop_top,
+ self.width,
+ self.height,
+ )
+
def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data):
"""
Returns the result of calling function(shared.sd_model, required_prompts, steps)
@@ -331,17 +354,7 @@ class StableDiffusionProcessing:
caches is a list with items described above.
"""
- cached_params = (
- required_prompts,
- steps,
- opts.CLIP_stop_at_last_layers,
- shared.sd_model.sd_checkpoint_info,
- extra_network_data,
- opts.sdxl_crop_left,
- opts.sdxl_crop_top,
- self.width,
- self.height,
- )
+ cached_params = self.cached_params(required_prompts, steps, extra_network_data)
for cache in caches:
if cache[0] is not None and cached_params == cache[0]:
@@ -364,9 +377,16 @@ class StableDiffusionProcessing:
self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)
+ def get_conds(self):
+ return self.c, self.uc
+
def parse_extra_network_prompts(self):
self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)
+ def save_samples(self) -> bool:
+ """Returns whether generated images need to be written to disk"""
+ return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)
+
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
@@ -460,86 +480,17 @@ class Processed:
return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio
-# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3
-def slerp(val, low, high):
- low_norm = low/torch.norm(low, dim=1, keepdim=True)
- high_norm = high/torch.norm(high, dim=1, keepdim=True)
- dot = (low_norm*high_norm).sum(1)
-
- if dot.mean() > 0.9995:
- return low * val + high * (1 - val)
-
- omega = torch.acos(dot)
- so = torch.sin(omega)
- res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
- return res
-
-
def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None):
- eta_noise_seed_delta = opts.eta_noise_seed_delta or 0
- xs = []
-
- # if we have multiple seeds, this means we are working with batch size>1; this then
- # enables the generation of additional tensors with noise that the sampler will use during its processing.
- # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to
- # produce the same images as with two batches [100], [101].
- if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0):
- sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))]
- else:
- sampler_noises = None
-
- for i, seed in enumerate(seeds):
- noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
-
- subnoise = None
- if subseeds is not None:
- subseed = 0 if i >= len(subseeds) else subseeds[i]
-
- subnoise = devices.randn(subseed, noise_shape)
-
- # randn results depend on device; gpu and cpu get different results for same seed;
- # the way I see it, it's better to do this on CPU, so that everyone gets same result;
- # but the original script had it like this, so I do not dare change it for now because
- # it will break everyone's seeds.
- noise = devices.randn(seed, noise_shape)
-
- if subnoise is not None:
- noise = slerp(subseed_strength, noise, subnoise)
-
- if noise_shape != shape:
- x = devices.randn(seed, shape)
- dx = (shape[2] - noise_shape[2]) // 2
- dy = (shape[1] - noise_shape[1]) // 2
- w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx
- h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy
- tx = 0 if dx < 0 else dx
- ty = 0 if dy < 0 else dy
- dx = max(-dx, 0)
- dy = max(-dy, 0)
-
- x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w]
- noise = x
-
- if sampler_noises is not None:
- cnt = p.sampler.number_of_needed_noises(p)
-
- if eta_noise_seed_delta > 0:
- torch.manual_seed(seed + eta_noise_seed_delta)
-
- for j in range(cnt):
- sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
+ g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w)
+ return g.next()
- xs.append(noise)
- if sampler_noises is not None:
- p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises]
-
- x = torch.stack(xs).to(shared.device)
- return x
+class DecodedSamples(list):
+ already_decoded = True
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
- samples = []
+ samples = DecodedSamples()
for i in range(batch.shape[0]):
sample = decode_first_stage(model, batch[i:i + 1])[0]
@@ -554,7 +505,7 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
errors.print_error_explanation(
"A tensor with all NaNs was produced in VAE.\n"
"Web UI will now convert VAE into 32-bit float and retry.\n"
- "To disable this behavior, disable the 'Automaticlly revert VAE to 32-bit floats' setting.\n"
+ "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
"To always start with 32-bit VAE, use --no-half-vae commandline flag."
)
@@ -572,14 +523,16 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
return samples
-def decode_first_stage(model, x):
- x = model.decode_first_stage(x.to(devices.dtype_vae))
-
- return x
-
-
def get_fixed_seed(seed):
- if seed is None or seed == '' or seed == -1:
+ if seed == '' or seed is None:
+ seed = -1
+ elif isinstance(seed, str):
+ try:
+ seed = int(seed)
+ except Exception:
+ seed = -1
+
+ if seed == -1:
return int(random.randrange(4294967294))
return seed
@@ -622,10 +575,12 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"CFG scale": p.cfg_scale,
"Image CFG scale": getattr(p, 'image_cfg_scale', None),
"Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
- "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
+ "Face restoration": opts.face_restoration_model if p.restore_faces else None,
"Size": f"{p.width}x{p.height}",
"Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
"Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
+ "VAE hash": sd_vae.get_loaded_vae_hash() if opts.add_model_hash_to_info else None,
+ "VAE": sd_vae.get_loaded_vae_name() if opts.add_model_name_to_info else None,
"Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
"Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
"Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -636,8 +591,9 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
- "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
+ "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
+ "Tiling": "True" if p.tiling else None,
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
"User": p.user if opts.add_user_name_to_info else None,
@@ -645,8 +601,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])
- prompt_text = p.prompt if use_main_prompt else all_prompts[index]
- negative_prompt_text = f"\nNegative prompt: {all_negative_prompts[index]}" if all_negative_prompts[index] else ""
+ prompt_text = p.main_prompt if use_main_prompt else all_prompts[index]
+ negative_prompt_text = f"\nNegative prompt: {p.main_negative_prompt if use_main_prompt else all_negative_prompts[index]}" if all_negative_prompts[index] else ""
return f"{prompt_text}{negative_prompt_text}\n{generation_params_text}".strip()
@@ -658,6 +614,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}
try:
+ # after running refiner, the refiner model is not unloaded - webui swaps back to main model here
+ if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
+ sd_models.reload_model_weights()
+
# if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
p.override_settings.pop('sd_model_checkpoint', None)
@@ -703,6 +663,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
seed = get_fixed_seed(p.seed)
subseed = get_fixed_seed(p.subseed)
+ if p.restore_faces is None:
+ p.restore_faces = opts.face_restoration
+
+ if p.tiling is None:
+ p.tiling = opts.tiling
+
modules.sd_hijack.model_hijack.apply_circular(p.tiling)
modules.sd_hijack.model_hijack.clear_comments()
@@ -751,11 +717,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if state.interrupted:
break
+ sd_models.reload_model_weights() # model can be changed for example by refiner
+
p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w)
+
if p.scripts is not None:
p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
@@ -777,7 +747,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
# strength, which is saved as "Model Strength: 1.0" in the infotext
if n == 0:
with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file:
- processed = Processed(p, [], p.seed, "")
+ processed = Processed(p, [])
file.write(processed.infotext(p, 0))
p.setup_conds()
@@ -793,7 +763,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+ if getattr(samples_ddim, 'already_decoded', False):
+ x_samples_ddim = samples_ddim
+ else:
+ if opts.sd_vae_decode_method != 'Full':
+ p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
+
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -817,6 +794,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
def infotext(index=0, use_main_prompt=False):
return create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts)
+ save_samples = p.save_samples()
+
for i, x_sample in enumerate(x_samples_ddim):
p.batch_index = i
@@ -824,7 +803,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
x_sample = x_sample.astype(np.uint8)
if p.restore_faces:
- if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
+ if save_samples and opts.save_images_before_face_restoration:
images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration")
devices.torch_gc()
@@ -838,16 +817,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
pp = scripts.PostprocessImageArgs(image)
p.scripts.postprocess_image(p, pp)
image = pp.image
-
if p.color_corrections is not None and i < len(p.color_corrections):
- if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
+ if save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
- if opts.samples_save and not p.do_not_save_samples:
+ if save_samples:
images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
text = infotext(i)
@@ -855,8 +833,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
if opts.enable_pnginfo:
image.info["parameters"] = text
output_images.append(image)
-
- if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
+ if save_samples and hasattr(p, 'mask_for_overlay') and p.mask_for_overlay and any([opts.save_mask, opts.save_mask_composite, opts.return_mask, opts.return_mask_composite]):
image_mask = p.mask_for_overlay.convert('RGB')
image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
@@ -892,7 +869,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
grid.info["parameters"] = text
output_images.insert(0, grid)
index_of_first_image = 1
-
if opts.grid_save:
images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not opts.grid_extended_filename, p=p, grid=True)
@@ -935,7 +911,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
cached_hr_uc = [None, None]
cached_hr_c = [None, None]
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
@@ -946,11 +922,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_resize_y = hr_resize_y
self.hr_upscale_to_x = hr_resize_x
self.hr_upscale_to_y = hr_resize_y
+ self.hr_checkpoint_name = hr_checkpoint_name
+ self.hr_checkpoint_info = None
self.hr_sampler_name = hr_sampler_name
self.hr_prompt = hr_prompt
self.hr_negative_prompt = hr_negative_prompt
self.all_hr_prompts = None
self.all_hr_negative_prompts = None
+ self.latent_scale_mode = None
if firstphase_width != 0 or firstphase_height != 0:
self.hr_upscale_to_x = self.width
@@ -971,8 +950,55 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_c = None
self.hr_uc = None
+ def calculate_target_resolution(self):
+ if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
+ self.hr_resize_x = self.width
+ self.hr_resize_y = self.height
+ self.hr_upscale_to_x = self.width
+ self.hr_upscale_to_y = self.height
+
+ self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
+ self.applied_old_hires_behavior_to = (self.width, self.height)
+
+ if self.hr_resize_x == 0 and self.hr_resize_y == 0:
+ self.extra_generation_params["Hires upscale"] = self.hr_scale
+ self.hr_upscale_to_x = int(self.width * self.hr_scale)
+ self.hr_upscale_to_y = int(self.height * self.hr_scale)
+ else:
+ self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
+
+ if self.hr_resize_y == 0:
+ self.hr_upscale_to_x = self.hr_resize_x
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
+ elif self.hr_resize_x == 0:
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
+ self.hr_upscale_to_y = self.hr_resize_y
+ else:
+ target_w = self.hr_resize_x
+ target_h = self.hr_resize_y
+ src_ratio = self.width / self.height
+ dst_ratio = self.hr_resize_x / self.hr_resize_y
+
+ if src_ratio < dst_ratio:
+ self.hr_upscale_to_x = self.hr_resize_x
+ self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
+ else:
+ self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
+ self.hr_upscale_to_y = self.hr_resize_y
+
+ self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
+ self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
+
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if self.hr_checkpoint_name:
+ self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
+
+ if self.hr_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
+
+ self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
+
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
@@ -982,51 +1008,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
- if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
- self.hr_resize_x = self.width
- self.hr_resize_y = self.height
- self.hr_upscale_to_x = self.width
- self.hr_upscale_to_y = self.height
+ self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+ if self.enable_hr and self.latent_scale_mode is None:
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
- self.width, self.height = old_hires_fix_first_pass_dimensions(self.width, self.height)
- self.applied_old_hires_behavior_to = (self.width, self.height)
-
- if self.hr_resize_x == 0 and self.hr_resize_y == 0:
- self.extra_generation_params["Hires upscale"] = self.hr_scale
- self.hr_upscale_to_x = int(self.width * self.hr_scale)
- self.hr_upscale_to_y = int(self.height * self.hr_scale)
- else:
- self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}"
-
- if self.hr_resize_y == 0:
- self.hr_upscale_to_x = self.hr_resize_x
- self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
- elif self.hr_resize_x == 0:
- self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
- self.hr_upscale_to_y = self.hr_resize_y
- else:
- target_w = self.hr_resize_x
- target_h = self.hr_resize_y
- src_ratio = self.width / self.height
- dst_ratio = self.hr_resize_x / self.hr_resize_y
-
- if src_ratio < dst_ratio:
- self.hr_upscale_to_x = self.hr_resize_x
- self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width
- else:
- self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height
- self.hr_upscale_to_y = self.hr_resize_y
-
- self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
- self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
-
- # special case: the user has chosen to do nothing
- if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
- self.enable_hr = False
- self.denoising_strength = None
- self.extra_generation_params.pop("Hires upscale", None)
- self.extra_generation_params.pop("Hires resize", None)
- return
+ self.calculate_target_resolution()
if not state.processing_has_refined_job_count:
if state.job_count == -1:
@@ -1045,17 +1032,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
- if self.enable_hr and latent_scale_mode is None:
- if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
- raise Exception(f"could not find upscaler named {self.hr_upscaler}")
-
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x = self.rng.next()
samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
+ del x
if not self.enable_hr:
return samples
+ if self.latent_scale_mode is None:
+ decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
+ else:
+ decoded_samples = None
+
+ current = shared.sd_model.sd_checkpoint_info
+ try:
+ if self.hr_checkpoint_info is not None:
+ self.sampler = None
+ sd_models.reload_model_weights(info=self.hr_checkpoint_info)
+ devices.torch_gc()
+
+ return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
+ finally:
+ self.sampler = None
+ sd_models.reload_model_weights(info=current)
+ devices.torch_gc()
+
+ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
self.is_hr_pass = True
target_width = self.hr_upscale_to_x
@@ -1064,7 +1066,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def save_intermediate(image, index):
"""saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images"""
- if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix:
+ if not self.save_samples() or not opts.save_images_before_highres_fix:
return
if not isinstance(image, Image.Image):
@@ -1073,11 +1075,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index)
images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix")
- if latent_scale_mode is not None:
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
+
+ self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
+
+ if self.latent_scale_mode is not None:
for i in range(samples.shape[0]):
save_intermediate(samples, i)
- samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"])
+ samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"])
# Avoid making the inpainting conditioning unless necessary as
# this does need some extra compute to decode / encode the image again.
@@ -1086,7 +1092,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
else:
image_conditioning = self.txt2img_image_conditioning(samples)
else:
- decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
batch_images = []
@@ -1103,28 +1108,22 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
batch_images.append(image)
decoded_samples = torch.from_numpy(np.array(batch_images))
- decoded_samples = decoded_samples.to(shared.device)
- decoded_samples = 2. * decoded_samples - 1.
+ decoded_samples = decoded_samples.to(shared.device, dtype=devices.dtype_vae)
- samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ if opts.sd_vae_encode_method != 'Full':
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
+ samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method))
image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
shared.state.nextjob()
- img2img_sampler_name = self.hr_sampler_name or self.sampler_name
-
- if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
- img2img_sampler_name = 'DDIM'
-
- self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
-
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
- noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self)
+ self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w)
+ noise = self.rng.next()
# GC now before running the next img2img to prevent running out of memory
- x = None
devices.torch_gc()
if not self.disable_extra_networks:
@@ -1143,15 +1142,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
+ decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
+
self.is_hr_pass = False
- return samples
+ return decoded_samples
def close(self):
super().close()
self.hr_c = None
self.hr_uc = None
- if not opts.experimental_persistent_cond_cache:
+ if not opts.persistent_cond_cache:
StableDiffusionProcessingTxt2Img.cached_hr_uc = [None, None]
StableDiffusionProcessingTxt2Img.cached_hr_c = [None, None]
@@ -1184,8 +1185,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if self.hr_c is not None:
return
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
+ hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y)
+ hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True)
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data)
def setup_conds(self):
super().setup_conds()
@@ -1193,7 +1197,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_uc = None
self.hr_c = None
- if self.enable_hr:
+ if self.enable_hr and self.hr_checkpoint_info is None:
if shared.opts.hires_fix_use_firstpass_conds:
self.calculate_hr_conds()
@@ -1206,6 +1210,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
with devices.autocast():
extra_networks.activate(self, self.extra_network_data)
+ def get_conds(self):
+ if self.is_hr_pass:
+ return self.hr_c, self.hr_uc
+
+ return super().get_conds()
+
+
def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()
@@ -1359,10 +1370,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less")
image = torch.from_numpy(batch_images)
- image = 2. * image - 1.
image = image.to(shared.device, dtype=devices.dtype_vae)
- self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
+ if opts.sd_vae_encode_method != 'Full':
+ self.extra_generation_params['VAE Encoder'] = opts.sd_vae_encode_method
+
+ self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model)
+ devices.torch_gc()
if self.resize_mode == 3:
self.init_latent = torch.nn.functional.interpolate(self.init_latent, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
@@ -1387,7 +1401,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask)
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
- x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
+ x = self.rng.next()
if self.initial_noise_multiplier != 1.0:
self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier