aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--extensions-builtin/Lora/ui_edit_user_metadata.py2
-rw-r--r--javascript/localization.js10
-rw-r--r--modules/cmd_args.py2
-rw-r--r--modules/devices.py83
-rw-r--r--modules/errors.py50
-rw-r--r--modules/extensions.py10
-rw-r--r--modules/extra_networks.py19
-rw-r--r--modules/generation_parameters_copypaste.py3
-rw-r--r--modules/gradio_extensons.py60
-rw-r--r--modules/hypernetworks/hypernetwork.py5
-rw-r--r--modules/images.py2
-rw-r--r--modules/img2img.py6
-rw-r--r--modules/processing.py111
-rw-r--r--modules/prompt_parser.py9
-rw-r--r--modules/rng_philox.py102
-rw-r--r--modules/scripts.py60
-rw-r--r--modules/sd_hijack.py6
-rw-r--r--modules/sd_hijack_clip.py2
-rw-r--r--modules/sd_hijack_optimizations.py4
-rw-r--r--modules/sd_models.py29
-rw-r--r--modules/sd_samplers_common.py20
-rw-r--r--modules/sd_samplers_kdiffusion.py9
-rw-r--r--modules/sd_vae.py16
-rw-r--r--modules/shared.py53
-rw-r--r--modules/styles.py5
-rw-r--r--modules/textual_inversion/textual_inversion.py4
-rw-r--r--modules/txt2img.py3
-rw-r--r--modules/ui.py352
-rw-r--r--modules/ui_checkpoint_merger.py2
-rw-r--r--modules/ui_common.py34
-rw-r--r--modules/ui_components.py2
-rw-r--r--modules/ui_extensions.py26
-rw-r--r--modules/ui_extra_networks.py13
-rw-r--r--modules/ui_extra_networks_checkpoints.py5
-rw-r--r--modules/ui_extra_networks_checkpoints_user_metadata.py60
-rw-r--r--modules/ui_extra_networks_hypernets.py2
-rw-r--r--modules/ui_extra_networks_textual_inversion.py2
-rw-r--r--modules/ui_postprocessing.py2
-rw-r--r--modules/ui_prompt_styles.py110
-rw-r--r--modules/ui_settings.py2
-rw-r--r--requirements.txt4
-rw-r--r--requirements_versions.txt14
-rw-r--r--scripts/xyz_grid.py13
-rw-r--r--style.css29
-rw-r--r--webui.py43
45 files changed, 924 insertions, 476 deletions
diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py
index 2ca997f7..390d9dde 100644
--- a/extensions-builtin/Lora/ui_edit_user_metadata.py
+++ b/extensions-builtin/Lora/ui_edit_user_metadata.py
@@ -167,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor)
random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
with gr.Column(scale=1, min_width=120):
- generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
+ generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
self.edit_notes = gr.TextArea(label='Notes', lines=4)
diff --git a/javascript/localization.js b/javascript/localization.js
index eb22b8a7..0c9032f9 100644
--- a/javascript/localization.js
+++ b/javascript/localization.js
@@ -11,11 +11,11 @@ var ignore_ids_for_localization = {
train_hypernetwork: 'OPTION',
txt2img_styles: 'OPTION',
img2img_styles: 'OPTION',
- setting_random_artist_categories: 'SPAN',
- setting_face_restoration_model: 'SPAN',
- setting_realesrgan_enabled_models: 'SPAN',
- extras_upscaler_1: 'SPAN',
- extras_upscaler_2: 'SPAN',
+ setting_random_artist_categories: 'OPTION',
+ setting_face_restoration_model: 'OPTION',
+ setting_realesrgan_enabled_models: 'OPTION',
+ extras_upscaler_1: 'OPTION',
+ extras_upscaler_2: 'OPTION',
};
var re_num = /^[.\d]+$/;
diff --git a/modules/cmd_args.py b/modules/cmd_args.py
index cb4ec5f7..64f21e01 100644
--- a/modules/cmd_args.py
+++ b/modules/cmd_args.py
@@ -112,3 +112,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi
parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server')
parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api')
parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn')
+parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False)
+parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False)
diff --git a/modules/devices.py b/modules/devices.py
index 57e51da3..00a00b18 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors
+from modules import errors, rng_philox
if sys.platform == "darwin":
from modules import mac_specific
@@ -71,14 +71,17 @@ def enable_tf32():
torch.backends.cudnn.allow_tf32 = True
-
errors.run(enable_tf32, "Enabling TF32")
-cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
-dtype = torch.float16
-dtype_vae = torch.float16
-dtype_unet = torch.float16
+cpu: torch.device = torch.device("cpu")
+device: torch.device = None
+device_interrogate: torch.device = None
+device_gfpgan: torch.device = None
+device_esrgan: torch.device = None
+device_codeformer: torch.device = None
+dtype: torch.dtype = torch.float16
+dtype_vae: torch.dtype = torch.float16
+dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False
@@ -90,23 +93,87 @@ def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
+nv_rng = None
+
+
def randn(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
+
from modules.shared import opts
- torch.manual_seed(seed)
+ manual_seed(seed)
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def randn_local(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ rng = rng_philox.Generator(seed)
+ return torch.asarray(rng.randn(shape), device=device)
+
+ local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
+ local_generator = torch.Generator(local_device).manual_seed(int(seed))
+ return torch.randn(shape, device=local_device, generator=local_generator).to(device)
+
+
+def randn_like(x):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
+
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
+ return torch.randn_like(x, device=cpu).to(x.device)
+
+ return torch.randn_like(x)
+
+
def randn_without_seed(shape):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
from modules.shared import opts
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def manual_seed(seed):
+ """Set up a global random number generator using the specified seed."""
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ global nv_rng
+ nv_rng = rng_philox.Generator(seed)
+ return
+
+ torch.manual_seed(seed)
+
+
def autocast(disable=False):
from modules import shared
diff --git a/modules/errors.py b/modules/errors.py
index dffabe45..192cd8ff 100644
--- a/modules/errors.py
+++ b/modules/errors.py
@@ -84,3 +84,53 @@ def run(code, task):
code()
except Exception as e:
display(task, e)
+
+
+def check_versions():
+ from packaging import version
+ from modules import shared
+
+ import torch
+ import gradio
+
+ expected_torch_version = "2.0.0"
+ expected_xformers_version = "0.0.20"
+ expected_gradio_version = "3.39.0"
+
+ if version.parse(torch.__version__) < version.parse(expected_torch_version):
+ print_error_explanation(f"""
+You are running torch {torch.__version__}.
+The program is tested to work with torch {expected_torch_version}.
+To reinstall the desired version, run with commandline flag --reinstall-torch.
+Beware that this will cause a lot of large files to be downloaded, as well as
+there are reports of issues with training tab on the latest version.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
+ if shared.xformers_available:
+ import xformers
+
+ if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
+ print_error_explanation(f"""
+You are running xformers {xformers.__version__}.
+The program is tested to work with xformers {expected_xformers_version}.
+To reinstall the desired version, run with commandline flag --reinstall-xformers.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
+ if gradio.__version__ != expected_gradio_version:
+ print_error_explanation(f"""
+You are running gradio {gradio.__version__}.
+The program is designed to work with gradio {expected_gradio_version}.
+Using a different version of gradio is extremely likely to break the program.
+
+Reasons why you have the mismatched gradio version can be:
+ - you use --skip-install flag.
+ - you use webui.py to start the program instead of launch.py.
+ - an extension installs the incompatible gradio version.
+
+Use --skip-version-check commandline argument to disable this check.
+ """.strip())
+
diff --git a/modules/extensions.py b/modules/extensions.py
index 3ad5ed53..e4633af4 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True)
def active():
- if shared.opts.disable_all_extensions == "all":
+ if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all":
return []
- elif shared.opts.disable_all_extensions == "extra":
+ elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra":
return [x for x in extensions if x.enabled and x.is_builtin]
else:
return [x for x in extensions if x.enabled]
@@ -141,8 +141,12 @@ def list_extensions():
if not os.path.isdir(extensions_dir):
return
- if shared.opts.disable_all_extensions == "all":
+ if shared.cmd_opts.disable_all_extensions:
+ print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
+ elif shared.opts.disable_all_extensions == "all":
print("*** \"Disable all extensions\" option was set, will not load any extensions ***")
+ elif shared.cmd_opts.disable_extra_extensions:
+ print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***")
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
diff --git a/modules/extra_networks.py b/modules/extra_networks.py
index 6ae07e91..fa28ac75 100644
--- a/modules/extra_networks.py
+++ b/modules/extra_networks.py
@@ -1,3 +1,5 @@
+import json
+import os
import re
from collections import defaultdict
@@ -177,3 +179,20 @@ def parse_prompts(prompts):
return res, extra_data
+
+def get_user_metadata(filename):
+ if filename is None:
+ return {}
+
+ basename, ext = os.path.splitext(filename)
+ metadata_filename = basename + '.json'
+
+ metadata = {}
+ try:
+ if os.path.isfile(metadata_filename):
+ with open(metadata_filename, "r", encoding="utf8") as file:
+ metadata = json.load(file)
+ except Exception as e:
+ errors.display(e, f"reading extra network user metadata from {metadata_filename}")
+
+ return metadata
diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py
index a3448be9..4e286558 100644
--- a/modules/generation_parameters_copypaste.py
+++ b/modules/generation_parameters_copypaste.py
@@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model
if "Hires sampler" not in res:
res["Hires sampler"] = "Use same sampler"
+ if "Hires checkpoint" not in res:
+ res["Hires checkpoint"] = "Use same checkpoint"
+
if "Hires prompt" not in res:
res["Hires prompt"] = ""
diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py
new file mode 100644
index 00000000..5af7fd8e
--- /dev/null
+++ b/modules/gradio_extensons.py
@@ -0,0 +1,60 @@
+import gradio as gr
+
+from modules import scripts
+
+def add_classes_to_gradio_component(comp):
+ """
+ this adds gradio-* to the component for css styling (ie gradio-button to gr.Button), as well as some others
+ """
+
+ comp.elem_classes = [f"gradio-{comp.get_block_name()}", *(comp.elem_classes or [])]
+
+ if getattr(comp, 'multiselect', False):
+ comp.elem_classes.append('multiselect')
+
+
+def IOComponent_init(self, *args, **kwargs):
+ self.webui_tooltip = kwargs.pop('tooltip', None)
+
+ if scripts.scripts_current is not None:
+ scripts.scripts_current.before_component(self, **kwargs)
+
+ scripts.script_callbacks.before_component_callback(self, **kwargs)
+
+ res = original_IOComponent_init(self, *args, **kwargs)
+
+ add_classes_to_gradio_component(self)
+
+ scripts.script_callbacks.after_component_callback(self, **kwargs)
+
+ if scripts.scripts_current is not None:
+ scripts.scripts_current.after_component(self, **kwargs)
+
+ return res
+
+
+def Block_get_config(self):
+ config = original_Block_get_config(self)
+
+ webui_tooltip = getattr(self, 'webui_tooltip', None)
+ if webui_tooltip:
+ config["webui_tooltip"] = webui_tooltip
+
+ return config
+
+
+def BlockContext_init(self, *args, **kwargs):
+ res = original_BlockContext_init(self, *args, **kwargs)
+
+ add_classes_to_gradio_component(self)
+
+ return res
+
+
+original_IOComponent_init = gr.components.IOComponent.__init__
+original_Block_get_config = gr.blocks.Block.get_config
+original_BlockContext_init = gr.blocks.BlockContext.__init__
+
+gr.components.IOComponent.__init__ = IOComponent_init
+gr.blocks.Block.get_config = Block_get_config
+gr.blocks.BlockContext.__init__ = BlockContext_init
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index c4821d21..70f1cbd2 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -10,7 +10,7 @@ import torch
import tqdm
from einops import rearrange, repeat
from ldm.util import default
-from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
+from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors
from modules.textual_inversion import textual_inversion, logging
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum
@@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,
def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
- from modules import images
+ from modules import images, processing
save_hypernetwork_every = save_hypernetwork_every or 0
create_image_every = create_image_every or 0
diff --git a/modules/images.py b/modules/images.py
index 38aa933d..ba3c43a4 100644
--- a/modules/images.py
+++ b/modules/images.py
@@ -318,7 +318,7 @@ def resize_image(resize_mode, im, width, height, upscaler_name=None):
return res
-invalid_filename_chars = '<>:"/\\|?*\n'
+invalid_filename_chars = '<>:"/\\|?*\n\r\t'
invalid_filename_prefix = ' '
invalid_filename_postfix = ' .'
re_nonletters = re.compile(r'[\s' + string.punctuation + ']+')
diff --git a/modules/img2img.py b/modules/img2img.py
index 68e415ef..d8e1c534 100644
--- a/modules/img2img.py
+++ b/modules/img2img.py
@@ -3,7 +3,7 @@ from contextlib import closing
from pathlib import Path
import numpy as np
-from PIL import Image, ImageOps, ImageFilter, ImageEnhance, ImageChops, UnidentifiedImageError
+from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError
import gradio as gr
from modules import sd_samplers, images as imgutil
@@ -129,9 +129,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s
mask = None
elif mode == 2: # inpaint
image, mask = init_img_with_mask["image"], init_img_with_mask["mask"]
- alpha_mask = ImageOps.invert(image.split()[-1]).convert('L').point(lambda x: 255 if x > 0 else 0, mode='1')
- mask = mask.convert('L').point(lambda x: 255 if x > 128 else 0, mode='1')
- mask = ImageChops.lighter(alpha_mask, mask).convert('L')
+ mask = mask.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
image = image.convert("RGB")
elif mode == 3: # inpaint sketch
image = inpaint_color_sketch
diff --git a/modules/processing.py b/modules/processing.py
index b0992ee1..ae58b108 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
from einops import repeat, rearrange
from blendmodes.blend import blendLayers, BlendType
+decode_first_stage = sd_samplers_common.decode_first_stage
# some of those options should not be changed at all because they would break the model, so I removed them from options.
opt_C = 4
@@ -492,7 +493,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8)
subnoise = None
- if subseeds is not None:
+ if subseeds is not None and subseed_strength != 0:
subseed = 0 if i >= len(subseeds) else subseeds[i]
subnoise = devices.randn(subseed, noise_shape)
@@ -524,7 +525,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
cnt = p.sampler.number_of_needed_noises(p)
if eta_noise_seed_delta > 0:
- torch.manual_seed(seed + eta_noise_seed_delta)
+ devices.manual_seed(seed + eta_noise_seed_delta)
for j in range(cnt):
sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape)))
@@ -538,8 +539,12 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see
return x
+class DecodedSamples(list):
+ already_decoded = True
+
+
def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
- samples = []
+ samples = DecodedSamples()
for i in range(batch.shape[0]):
sample = decode_first_stage(model, batch[i:i + 1])[0]
@@ -572,12 +577,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
return samples
-def decode_first_stage(model, x):
- x = model.decode_first_stage(x.to(devices.dtype_vae))
-
- return x
-
-
def get_fixed_seed(seed):
if seed is None or seed == '' or seed == -1:
return int(random.randrange(4294967294))
@@ -636,7 +635,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
"Token merging ratio": None if token_merging_ratio == 0 else token_merging_ratio,
"Token merging ratio hr": None if not enable_hr or token_merging_ratio_hr == 0 else token_merging_ratio_hr,
"Init image hash": getattr(p, 'init_img_hash', None),
- "RNG": opts.randn_source if opts.randn_source != "GPU" else None,
+ "RNG": opts.randn_source if opts.randn_source != "GPU" and opts.randn_source != "NV" else None,
"NGMS": None if p.s_min_uncond == 0 else p.s_min_uncond,
**p.extra_generation_params,
"Version": program_version() if opts.add_version_to_infotext else None,
@@ -793,7 +792,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
- x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+ if getattr(samples_ddim, 'already_decoded', False):
+ x_samples_ddim = samples_ddim
+ else:
+ x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
+
x_samples_ddim = torch.stack(x_samples_ddim).float()
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
@@ -935,7 +938,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
cached_hr_uc = [None, None]
cached_hr_c = [None, None]
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
@@ -946,11 +949,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.hr_resize_y = hr_resize_y
self.hr_upscale_to_x = hr_resize_x
self.hr_upscale_to_y = hr_resize_y
+ self.hr_checkpoint_name = hr_checkpoint_name
+ self.hr_checkpoint_info = None
self.hr_sampler_name = hr_sampler_name
self.hr_prompt = hr_prompt
self.hr_negative_prompt = hr_negative_prompt
self.all_hr_prompts = None
self.all_hr_negative_prompts = None
+ self.latent_scale_mode = None
if firstphase_width != 0 or firstphase_height != 0:
self.hr_upscale_to_x = self.width
@@ -973,6 +979,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if self.hr_checkpoint_name:
+ self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
+
+ if self.hr_checkpoint_info is None:
+ raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}')
+
+ self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title
+
if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
@@ -982,6 +996,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
+ self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
+ if self.enable_hr and self.latent_scale_mode is None:
+ if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
+ raise Exception(f"could not find upscaler named {self.hr_upscaler}")
+
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
self.hr_resize_x = self.width
self.hr_resize_y = self.height
@@ -1020,14 +1039,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f
self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f
- # special case: the user has chosen to do nothing
- if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height:
- self.enable_hr = False
- self.denoising_strength = None
- self.extra_generation_params.pop("Hires upscale", None)
- self.extra_generation_params.pop("Hires resize", None)
- return
-
if not state.processing_has_refined_job_count:
if state.job_count == -1:
state.job_count = self.n_iter
@@ -1045,17 +1056,32 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
- latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
- if self.enable_hr and latent_scale_mode is None:
- if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers):
- raise Exception(f"could not find upscaler named {self.hr_upscaler}")
-