From 3bca90b249d749ed5429f76e380d2ffa52fc0d41 Mon Sep 17 00:00:00 2001
From: AUTOMATIC1111 <16777216c@gmail.com>
Date: Sun, 30 Jul 2023 13:48:27 +0300
Subject: hires fix checkpoint selection
---
modules/sd_models.py | 22 +++++++++++++---------
1 file changed, 13 insertions(+), 9 deletions(-)
(limited to 'modules/sd_models.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index acb1e817..cb67e425 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -52,6 +52,7 @@ class CheckpointInfo:
self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+ self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
@@ -81,6 +82,7 @@ class CheckpointInfo:
checkpoints_list.pop(self.title)
self.title = f'{self.name} [{self.shorthash}]'
+ self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
self.register()
return self.shorthash
@@ -101,14 +103,8 @@ def setup_model():
enable_midas_autodownload()
-def checkpoint_tiles():
- def convert(name):
- return int(name) if name.isdigit() else name.lower()
-
- def alphanumeric_key(key):
- return [convert(c) for c in re.split('([0-9]+)', key)]
-
- return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+def checkpoint_tiles(use_short=False):
+ return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
def list_models():
@@ -131,11 +127,14 @@ def list_models():
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
- for filename in sorted(model_list, key=str.lower):
+ for filename in model_list:
checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register()
+re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
+
+
def get_closet_checkpoint_match(search_string):
checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
@@ -145,6 +144,11 @@ def get_closet_checkpoint_match(search_string):
if found:
return found[0]
+ search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
+ found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
+
return None
--
cgit v1.2.3
From 4d9b096663288e2aa738723fa63950f3d41f6170 Mon Sep 17 00:00:00 2001
From: AUTOMATIC1111 <16777216c@gmail.com>
Date: Mon, 31 Jul 2023 10:43:31 +0300
Subject: additional memory improvements when switching between models of
different types
---
modules/sd_models.py | 5 ++++-
1 file changed, 4 insertions(+), 1 deletion(-)
(limited to 'modules/sd_models.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index cb67e425..4855037a 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -582,7 +582,10 @@ def reload_model_weights(sd_model=None, info=None):
timer.record("find config")
if sd_model is None or checkpoint_config != sd_model.used_config:
- del sd_model
+ if sd_model is not None:
+ sd_model.to(device="meta")
+
+ devices.torch_gc()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
--
cgit v1.2.3
From b235022c615a7384f73c05fe240d8f4a28d103d4 Mon Sep 17 00:00:00 2001
From: AUTOMATIC1111 <16777216c@gmail.com>
Date: Tue, 1 Aug 2023 00:24:48 +0300
Subject: option to keep multiple models in memory
---
modules/lowvram.py | 3 +
modules/sd_hijack.py | 6 +-
modules/sd_hijack_inpainting.py | 5 +-
modules/sd_models.py | 136 +++++++++++++++++++++++++++++++++-------
modules/sd_models_xl.py | 8 +--
modules/shared.py | 12 +++-
6 files changed, 135 insertions(+), 35 deletions(-)
(limited to 'modules/sd_models.py')
diff --git a/modules/lowvram.py b/modules/lowvram.py
index 3f830664..96f52b7b 100644
--- a/modules/lowvram.py
+++ b/modules/lowvram.py
@@ -15,6 +15,9 @@ def send_everything_to_cpu():
def setup_for_low_vram(sd_model, use_medvram):
+ if getattr(sd_model, 'lowvram', False):
+ return
+
sd_model.lowvram = True
parents = {}
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py
index cfa5f0eb..7d692e3c 100644
--- a/modules/sd_hijack.py
+++ b/modules/sd_hijack.py
@@ -30,8 +30,10 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
# silence new console spam from SD2
-ldm.modules.attention.print = lambda *args: None
-ldm.modules.diffusionmodules.model.print = lambda *args: None
+ldm.modules.attention.print = shared.ldm_print
+ldm.modules.diffusionmodules.model.print = shared.ldm_print
+ldm.util.print = shared.ldm_print
+ldm.models.diffusion.ddpm.print = shared.ldm_print
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py
index c1977b19..97350f4f 100644
--- a/modules/sd_hijack_inpainting.py
+++ b/modules/sd_hijack_inpainting.py
@@ -91,7 +91,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
return x_prev, pred_x0, e_t
-def do_inpainting_hijack():
- # p_sample_plms is needed because PLMS can't work with dicts as conditionings
-
- ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
+ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
diff --git a/modules/sd_models.py b/modules/sd_models.py
index acb1e817..77195f2f 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -15,7 +15,6 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
-from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -423,6 +422,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
def __init__(self):
self.sd_model = None
+ self.loaded_sd_models = []
self.was_loaded_at_least_once = False
self.lock = threading.Lock()
@@ -437,6 +437,7 @@ class SdModelData:
try:
load_model()
+
except Exception as e:
errors.display(e, "loading stable diffusion model", full_traceback=True)
print("", file=sys.stderr)
@@ -448,11 +449,24 @@ class SdModelData:
def set_sd_model(self, v):
self.sd_model = v
+ try:
+ self.loaded_sd_models.remove(v)
+ except ValueError:
+ pass
+
+ if v is not None:
+ self.loaded_sd_models.insert(0, v)
+
model_data = SdModelData()
def get_empty_cond(sd_model):
+ from modules import extra_networks, processing
+
+ p = processing.StableDiffusionProcessingTxt2Img()
+ extra_networks.activate(p, {})
+
if hasattr(sd_model, 'conditioner'):
d = sd_model.get_learned_conditioning([""])
return d['crossattn']
@@ -460,19 +474,43 @@ def get_empty_cond(sd_model):
return sd_model.cond_stage_model([""])
+def send_model_to_cpu(m):
+ from modules import lowvram
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.send_everything_to_cpu()
+ else:
+ m.to(devices.cpu)
+
+ devices.torch_gc()
+
+
+def send_model_to_device(m):
+ from modules import lowvram
+
+ if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
+ else:
+ m.to(shared.device)
+
+
+def send_model_to_trash(m):
+ m.to(device="meta")
+ devices.torch_gc()
+
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
- from modules import lowvram, sd_hijack
+ from modules import sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
+ timer = Timer()
+
if model_data.sd_model:
- sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
+ send_model_to_trash(model_data.sd_model)
model_data.sd_model = None
- gc.collect()
devices.torch_gc()
- do_inpainting_hijack()
-
- timer = Timer()
+ timer.record("unload existing model")
if already_loaded_state_dict is not None:
state_dict = already_loaded_state_dict
@@ -512,12 +550,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
load_model_weights(sd_model, checkpoint_info, state_dict, timer)
+ timer.record("load weights from state dict")
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
- else:
- sd_model.to(shared.device)
-
+ send_model_to_device(sd_model)
timer.record("move model to device")
sd_hijack.model_hijack.hijack(sd_model)
@@ -525,7 +560,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
timer.record("hijack")
sd_model.eval()
- model_data.sd_model = sd_model
+ model_data.set_sd_model(sd_model)
model_data.was_loaded_at_least_once = True
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model
@@ -546,10 +581,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
return sd_model
+def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
+ """
+ Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
+ If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
+ If not, returns the model that can be used to load weights from checkpoint_info's file.
+ If no such model exists, returns None.
+ Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
+ """
+
+ already_loaded = None
+ for i in reversed(range(len(model_data.loaded_sd_models))):
+ loaded_model = model_data.loaded_sd_models[i]
+ if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ already_loaded = loaded_model
+ continue
+
+ if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
+ print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
+ model_data.loaded_sd_models.pop()
+ send_model_to_trash(loaded_model)
+ timer.record("send model to trash")
+
+ if shared.opts.sd_checkpoints_keep_in_cpu:
+ send_model_to_cpu(sd_model)
+ timer.record("send model to cpu")
+
+ if already_loaded is not None:
+ send_model_to_device(already_loaded)
+ timer.record("send model to device")
+
+ model_data.set_sd_model(already_loaded)
+ print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+ return model_data.sd_model
+ elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
+ print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
+
+ model_data.sd_model = None
+ load_model(checkpoint_info)
+ return model_data.sd_model
+ elif len(model_data.loaded_sd_models) > 0:
+ sd_model = model_data.loaded_sd_models.pop()
+ model_data.sd_model = sd_model
+
+ print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
+ return sd_model
+ else:
+ return None
+
+
def reload_model_weights(sd_model=None, info=None):
- from modules import lowvram, devices, sd_hijack
+ from modules import devices, sd_hijack
checkpoint_info = info or select_checkpoint()
+ timer = Timer()
+
if not sd_model:
sd_model = model_data.sd_model
@@ -558,19 +644,17 @@ def reload_model_weights(sd_model=None, info=None):
else:
current_checkpoint_info = sd_model.sd_checkpoint_info
if sd_model.sd_model_checkpoint == checkpoint_info.filename:
- return
-
- sd_unet.apply_unet("None")
+ return sd_model
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
- lowvram.send_everything_to_cpu()
- else:
- sd_model.to(devices.cpu)
+ sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
+ if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
+ return sd_model
+ if sd_model is not None:
+ sd_unet.apply_unet("None")
+ send_model_to_cpu(sd_model)
sd_hijack.model_hijack.undo_hijack(sd_model)
- timer = Timer()
-
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
@@ -578,7 +662,9 @@ def reload_model_weights(sd_model=None, info=None):
timer.record("find config")
if sd_model is None or checkpoint_config != sd_model.used_config:
- del sd_model
+ if sd_model is not None:
+ send_model_to_trash(sd_model)
+
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model
@@ -601,6 +687,8 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
+ model_data.set_sd_model(sd_model)
+
return sd_model
diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py
index bc219508..01123321 100644
--- a/modules/sd_models_xl.py
+++ b/modules/sd_models_xl.py
@@ -98,10 +98,10 @@ def extend_sdxl(model):
model.conditioner.wrapped = torch.nn.Module()
-sgm.modules.attention.print = lambda *args: None
-sgm.modules.diffusionmodules.model.print = lambda *args: None
-sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
-sgm.modules.encoders.modules.print = lambda *args: None
+sgm.modules.attention.print = shared.ldm_print
+sgm.modules.diffusionmodules.model.print = shared.ldm_print
+sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
+sgm.modules.encoders.modules.print = shared.ldm_print
# this gets the code to load the vanilla attention that we override
sgm.modules.attention.SDP_IS_AVAILABLE = True
diff --git a/modules/shared.py b/modules/shared.py
index aa72c9c8..0184fcd0 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -392,6 +392,7 @@ options_templates.update(options_section(('system', "System"), {
"print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
"list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
"disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
+ "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."),
}))
options_templates.update(options_section(('training', "Training"), {
@@ -411,7 +412,9 @@ options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), {
"sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
- "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
+ "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}),
+ "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"),
+ "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
"sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
@@ -889,3 +892,10 @@ def walk_files(path, allowed_extensions=None):
continue
yield os.path.join(root, filename)
+
+
+def ldm_print(*args, **kwargs):
+ if opts.hide_ldm_prints:
+ return
+
+ print(*args, **kwargs)
--
cgit v1.2.3
From 390bffa81b747a7eb38ac7a0cd6dfb9fcc388151 Mon Sep 17 00:00:00 2001
From: AUTOMATIC1111 <16777216c@gmail.com>
Date: Tue, 1 Aug 2023 17:13:15 +0300
Subject: repair merge error
---
modules/sd_models.py | 1 -
1 file changed, 1 deletion(-)
(limited to 'modules/sd_models.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 40a450df..3c451a4b 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -15,7 +15,6 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
-from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
--
cgit v1.2.3
From 20549a50cb3c41868ce561c6658bfaa0d20ac7ba Mon Sep 17 00:00:00 2001
From: AUTOMATIC1111 <16777216c@gmail.com>
Date: Thu, 3 Aug 2023 22:46:57 +0300
Subject: add style editor dialog rework toprow for img2img and txt2img to use
a class with fields fix the console error when editing checkpoint user
metadata
---
modules/sd_models.py | 2 +-
modules/styles.py | 5 +-
modules/ui.py | 230 ++++++++++---------------
modules/ui_common.py | 32 +++-
modules/ui_extra_networks_checkpoints.py | 2 +-
modules/ui_extra_networks_hypernets.py | 2 +-
modules/ui_extra_networks_textual_inversion.py | 2 +-
modules/ui_prompt_styles.py | 110 ++++++++++++
style.css | 13 ++
9 files changed, 248 insertions(+), 150 deletions(-)
create mode 100644 modules/ui_prompt_styles.py
(limited to 'modules/sd_models.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 8f72f21d..1d93d893 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -68,7 +68,7 @@ class CheckpointInfo:
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
- self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
+ self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
diff --git a/modules/styles.py b/modules/styles.py
index ec0e1bc5..0740fe1b 100644
--- a/modules/styles.py
+++ b/modules/styles.py
@@ -106,10 +106,7 @@ class StyleDatabase:
if os.path.exists(path):
shutil.copy(path, f"{path}.bak")
- fd = os.open(path, os.O_RDWR | os.O_CREAT)
- with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
- # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
- # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
+ with open(path, "w", encoding="utf-8-sig", newline='') as file:
writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
writer.writeheader()
writer.writerows(style._asdict() for k, style in self.styles.items())
diff --git a/modules/ui.py b/modules/ui.py
index ac2787eb..c059dcec 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -12,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
-from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger
+from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.ui_common import create_refresh_button
@@ -92,19 +92,6 @@ def send_gradio_gallery_to_image(x):
return image_from_url_text(x[0])
-def add_style(name: str, prompt: str, negative_prompt: str):
- if name is None:
- return [gr_show() for x in range(4)]
-
- style = modules.styles.PromptStyle(name, prompt, negative_prompt)
- shared.prompt_styles.styles[style.name] = style
- # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
- # reserialize all styles every time we save them
- shared.prompt_styles.save_styles(shared.styles_filename)
-
- return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]
-
-
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
from modules import processing, devices
@@ -129,13 +116,6 @@ def resize_from_to_html(width, height, scale_by):
return f"resize: from {width}x{height} to {target_width}x{target_height}"
-def apply_styles(prompt, prompt_neg, styles):
- prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
- prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)
-
- return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]
-
-
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
if mode in {0, 1, 3, 4}:
return [interrogation_function(ii_singles[mode]), None]
@@ -267,71 +247,67 @@ def update_token_counter(text, steps):
return f"{token_count}/{max_length}"
-def create_toprow(is_img2img):
- id_part = "img2img" if is_img2img else "txt2img"
+class Toprow:
+ def __init__(self, is_img2img):
+ id_part = "img2img" if is_img2img else "txt2img"
+ self.id_part = id_part
- with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
- with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
-
- with gr.Row():
- with gr.Column(scale=80):
- with gr.Row():
- negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
-
- button_interrogate = None
- button_deepbooru = None
- if is_img2img:
- with gr.Column(scale=1, elem_classes="interrogate-col"):
- button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
- button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
-
- with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
- with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
- interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
- skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
- submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
-
- skip.click(
- fn=lambda: shared.state.skip(),
- inputs=[],
- outputs=[],
- )
+ with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
+ with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
- interrupt.click(
- fn=lambda: shared.state.interrupt(),
- inputs=[],
- outputs=[],
- )
+ with gr.Row():
+ with gr.Column(scale=80):
+ with gr.Row():
+ self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
+
+ self.button_interrogate = None
+ self.button_deepbooru = None
+ if is_img2img:
+ with gr.Column(scale=1, elem_classes="interrogate-col"):
+ self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
+ self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
+
+ with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
+ with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
+ self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
+ self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
+ self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
+
+ self.skip.click(
+ fn=lambda: shared.state.skip(),
+ inputs=[],
+ outputs=[],
+ )
- with gr.Row(elem_id=f"{id_part}_tools"):
- paste = ToolButton(value=paste_symbol, elem_id="paste")
- clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
- extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
- prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
- save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")
- restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
-
- token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
- token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
- negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
- negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
-
- clear_prompt_button.click(
- fn=lambda *x: x,
- _js="confirm_clear_prompt",
- inputs=[prompt, negative_prompt],
- outputs=[prompt, negative_prompt],
- )
+ self.interrupt.click(
+ fn=lambda: shared.state.interrupt(),
+ inputs=[],
+ outputs=[],
+ )
- with gr.Row(elem_id=f"{id_part}_styles_row"):
- prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
- create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")
+ with gr.Row(elem_id=f"{id_part}_tools"):
+ self.paste = ToolButton(value=paste_symbol, elem_id="paste")
+ self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
+ self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
+ self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
+
+ self.token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
+ self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
+ self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
+ self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
+
+ self.clear_prompt_button.click(
+ fn=lambda *x: x,
+ _js="confirm_clear_prompt",
+ inputs=[self.prompt, self.negative_prompt],
+ outputs=[self.prompt, self.negative_prompt],
+ )
- return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button
+ self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
def setup_progressbar(*args, **kwargs):
@@ -419,14 +395,14 @@ def create_ui():
modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
with gr.Blocks(analytics_enabled=False) as txt2img_interface:
- txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
+ toprow = txt2img_toprow = Toprow(is_img2img=False)
dummy_component = gr.Label(visible=False)
txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
from modules import ui_extra_networks
- extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
+ extra_networks_ui = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'txt2img')
with gr.Row().style(equal_height=False):
with gr.Column(variant='compact', elem_id="txt2img_settings"):
@@ -532,9 +508,9 @@ def create_ui():
_js="submit",
inputs=[
dummy_component,
- txt2img_prompt,
- txt2img_negative_prompt,
- txt2img_prompt_styles,
+ toprow.prompt,
+ toprow.negative_prompt,
+ toprow.ui_styles.dropdown,
steps,
sampler_index,
restore_faces,
@@ -569,12 +545,12 @@ def create_ui():
show_progress=False,
)
- txt2img_prompt.submit(**txt2img_args)
- submit.click(**txt2img_args)
+ toprow.prompt.submit(**txt2img_args)
+ toprow.submit.click(**txt2img_args)
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
- restore_progress_button.click(
+ toprow.restore_progress_button.click(
fn=progress.restore_progress,
_js="restoreProgressTxt2img",
inputs=[dummy_component],
@@ -593,7 +569,7 @@ def create_ui():
txt_prompt_img
],
outputs=[
- txt2img_prompt,
+ toprow.prompt,
txt_prompt_img
],
show_progress=False,
@@ -607,8 +583,8 @@ def create_ui():
)
txt2img_paste_fields = [
- (txt2img_prompt, "Prompt"),
- (txt2img_negative_prompt, "Negative prompt"),
+ (toprow.prompt, "Prompt"),
+ (toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
@@ -621,7 +597,7 @@ def create_ui():
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
- (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
+ (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(enable_hr, lambda d: "Denoising strength" in d),
(hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
@@ -639,12 +615,12 @@ def create_ui():
]
parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
+ paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
))
txt2img_preview_params = [
- txt2img_prompt,
- txt2img_negative_prompt,
+ toprow.prompt,
+ toprow.negative_prompt,
steps,
sampler_index,
cfg_scale,
@@ -653,8 +629,8 @@ def create_ui():
height,
]
- token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
- negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
+ toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
@@ -662,13 +638,13 @@ def create_ui():
modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
with gr.Blocks(analytics_enabled=False) as img2img_interface:
- img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
+ toprow = img2img_toprow = Toprow(is_img2img=True)
img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
from modules import ui_extra_networks
- extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
+ extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'img2img')
with FormRow().style(equal_height=False):
with gr.Column(variant='compact', elem_id="img2img_settings"):
@@ -889,7 +865,7 @@ def create_ui():
img2img_prompt_img
],
outputs=[
- img2img_prompt,
+ toprow.prompt,
img2img_prompt_img
],
show_progress=False,
@@ -901,9 +877,9 @@ def create_ui():
inputs=[
dummy_component,
dummy_component,
- img2img_prompt,
- img2img_negative_prompt,
- img2img_prompt_styles,
+ toprow.prompt,
+ toprow.negative_prompt,
+ toprow.ui_styles.dropdown,
init_img,
sketch,
init_img_with_mask,
@@ -962,11 +938,11 @@ def create_ui():
inpaint_color_sketch,
init_img_inpaint,
],
- outputs=[img2img_prompt, dummy_component],
+ outputs=[toprow.prompt, dummy_component],
)
- img2img_prompt.submit(**img2img_args)
- submit.click(**img2img_args)
+ toprow.prompt.submit(**img2img_args)
+ toprow.submit.click(**img2img_args)
res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
@@ -978,7 +954,7 @@ def create_ui():
show_progress=False,
)
- restore_progress_button.click(
+ toprow.restore_progress_button.click(
fn=progress.restore_progress,
_js="restoreProgressImg2img",
inputs=[dummy_component],
@@ -991,46 +967,24 @@ def create_ui():
show_progress=False,
)
- img2img_interrogate.click(
+ toprow.button_interrogate.click(
fn=lambda *args: process_interrogate(interrogate, *args),
**interrogate_args,
)
- img2img_deepbooru.click(
+ toprow.button_deepbooru.click(
fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
**interrogate_args,
)
- prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
- style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
- style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]
-
- for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
- button.click(
- fn=add_style,
- _js="ask_for_style_name",
- # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
- # the same number of parameters, but we only know the style-name after the JavaScript prompt
- inputs=[dummy_component, prompt, negative_prompt],
- outputs=[txt2img_prompt_styles, img2img_prompt_styles],
- )
-
- for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
- button.click(
- fn=apply_styles,
- _js=js_func,
- inputs=[prompt, negative_prompt, styles],
- outputs=[prompt, negative_prompt, styles],
- )
-
- token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
- negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])
+ toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
+ toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
img2img_paste_fields = [
- (img2img_prompt, "Prompt"),
- (img2img_negative_prompt, "Negative prompt"),
+ (toprow.prompt, "Prompt"),
+ (toprow.negative_prompt, "Negative prompt"),
(steps, "Steps"),
(sampler_index, "Sampler"),
(restore_faces, "Face restoration"),
@@ -1044,7 +998,7 @@ def create_ui():
(subseed_strength, "Variation seed strength"),
(seed_resize_from_w, "Seed resize from-1"),
(seed_resize_from_h, "Seed resize from-2"),
- (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
+ (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
(denoising_strength, "Denoising strength"),
(mask_blur, "Mask blur"),
*modules.scripts.scripts_img2img.infotext_fields
@@ -1052,7 +1006,7 @@ def create_ui():
parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
- paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
+ paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
))
modules.scripts.scripts_current = None
diff --git a/modules/ui_common.py b/modules/ui_common.py
index 11eb2a4b..ba75fa73 100644
--- a/modules/ui_common.py
+++ b/modules/ui_common.py
@@ -223,20 +223,44 @@ Requested path was: {f}
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
+ refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]
+
+ label = None
+ for comp in refresh_components:
+ label = getattr(comp, 'label', None)
+ if label is not None:
+ break
+
def refresh():
refresh_method()
args = refreshed_args() if callable(refreshed_args) else refreshed_args
for k, v in args.items():
- setattr(refresh_component, k, v)
+ for comp in refresh_components:
+ setattr(comp, k, v)
- return gr.update(**(args or {}))
+ return [gr.update(**(args or {})) for _ in refresh_components]
- refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
+ refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
refresh_button.click(
fn=refresh,
inputs=[],
- outputs=[refresh_component]
+ outputs=[*refresh_components]
)
return refresh_button
+
+def setup_dialog(button_show, dialog, *, button_close=None):
+ """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""
+
+ dialog.visible = False
+
+ button_show.click(
+ fn=lambda: gr.update(visible=True),
+ inputs=[],
+ outputs=[dialog],
+ ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }")
+
+ if button_close:
+ button_close.click(fn=None, _js="closePopup")
+
diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py
index 2bb0a222..891d8f2c 100644
--- a/modules/ui_extra_networks_checkpoints.py
+++ b/modules/ui_extra_networks_checkpoints.py
@@ -12,7 +12,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
shared.refresh_checkpoints()
- def create_item(self, name, index=None):
+ def create_item(self, name, index=None, enable_filter=True):
checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
path, ext = os.path.splitext(checkpoint.filename)
return {
diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py
index e53ccb42..514a4562 100644
--- a/modules/ui_extra_networks_hypernets.py
+++ b/modules/ui_extra_networks_hypernets.py
@@ -11,7 +11,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
shared.reload_hypernetworks()
- def create_item(self, name, index=None):
+ def create_item(self, name, index=None, enable_filter=True):
full_path = shared.hypernetworks[name]
path, ext = os.path.splitext(full_path)
diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py
index d1794e50..73134698 100644
--- a/modules/ui_extra_networks_textual_inversion.py
+++ b/modules/ui_extra_networks_textual_inversion.py
@@ -12,7 +12,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage):
def refresh(self):
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)
- def create_item(self, name, index=None):
+ def create_item(self, name, index=None, enable_filter=True):
embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name)
path, ext = os.path.splitext(embedding.filename)
diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py
new file mode 100644
index 00000000..85eb3a64
--- /dev/null
+++ b/modules/ui_prompt_styles.py
@@ -0,0 +1,110 @@
+import gradio as gr
+
+from modules import shared, ui_common, ui_components, styles
+
+styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️
+styles_materialize_symbol = '\U0001f4cb' # 📋
+
+
+def select_style(name):
+ style = shared.prompt_styles.styles.get(name)
+ existing = style is not None
+ empty = not name
+
+ prompt = style.prompt if style else gr.update()
+ negative_prompt = style.negative_prompt if style else gr.update()
+
+ return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty)
+
+
+def save_style(name, prompt, negative_prompt):
+ if not name:
+ return gr.update(visible=False)
+
+ style = styles.PromptStyle(name, prompt, negative_prompt)
+ shared.prompt_styles.styles[style.name] = style
+ shared.prompt_styles.save_styles(shared.styles_filename)
+
+ return gr.update(visible=True)
+
+
+def delete_style(name):
+ if name == "":
+ return
+
+ shared.prompt_styles.styles.pop(name, None)
+ shared.prompt_styles.save_styles(shared.styles_filename)
+
+ return '', '', ''
+
+
+def materialize_styles(prompt, negative_prompt, styles):
+ prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
+ negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles)
+
+ return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])]
+
+
+def refresh_styles():
+ return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles))
+
+
+class UiPromptStyles:
+ def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt):
+ self.tabname = tabname
+
+ with gr.Row(elem_id=f"{tabname}_styles_row"):
+ self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles")
+ edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles")
+
+ with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog:
+ with gr.Row():
+ self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.")
+ ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles")
+ self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.")
+
+ with gr.Row():
+ self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3)
+
+ with gr.Row():
+ self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3)
+
+ with gr.Row():
+ self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False)
+ self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False)
+ self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close')
+
+ self.selection.change(
+ fn=select_style,
+ inputs=[self.selection],
+ outputs=[self.prompt, self.neg_prompt, self.delete, self.save],
+ show_progress=False,
+ )
+
+ self.save.click(
+ fn=save_style,
+ inputs=[self.selection, self.prompt, self.neg_prompt],
+ outputs=[self.delete],
+ show_progress=False,
+ ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+ self.delete.click(
+ fn=delete_style,
+ _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }',
+ inputs=[self.selection],
+ outputs=[self.selection, self.prompt, self.neg_prompt],
+ show_progress=False,
+ ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False)
+
+ self.materialize.click(
+ fn=materialize_styles,
+ inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+ outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown],
+ show_progress=False,
+ ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False)
+
+ ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close)
+
+
+
+
diff --git a/style.css b/style.css
index 6c92d6e7..cf8470e4 100644
--- a/style.css
+++ b/style.css
@@ -972,3 +972,16 @@ div.block.gradio-box.edit-user-metadata {
.edit-user-metadata-buttons{
margin-top: 1.5em;
}
+
+
+
+
+div.block.gradio-box.popup-dialog, .popup-dialog {
+ width: 56em;
+ background: var(--body-background-fill);
+ padding: 2em !important;
+}
+
+div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{
+ margin-top: 1em;
+}
--
cgit v1.2.3
From 24f21583cdba2ae6cc51773b956c6ce068d3dfe4 Mon Sep 17 00:00:00 2001
From: AnyISalIn
Date: Fri, 4 Aug 2023 11:43:27 +0800
Subject: fix: prevent cache model.state_dict() after model hijack
Signed-off-by: AnyISalIn
---
modules/sd_models.py | 5 +++--
1 file changed, 3 insertions(+), 2 deletions(-)
(limited to 'modules/sd_models.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 1d93d893..ba15b451 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -303,12 +303,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_models_xl.extend_sdxl(model)
model.load_state_dict(state_dict, strict=False)
- del state_dict
timer.record("apply weights to model")
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ checkpoints_loaded[checkpoint_info] = state_dict
+
+ del state_dict
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
--
cgit v1.2.3
From c96e4750d895a47290dc7f96e030197069c75fa4 Mon Sep 17 00:00:00 2001
From: AUTOMATIC1111 <16777216c@gmail.com>
Date: Mon, 7 Aug 2023 08:07:09 +0300
Subject: SD VAE rework 2
- the setting for preferring opts.sd_vae has been inverted and reworded
- resolve_vae function made easier to read and now returns an object rather than a tuple
- if the checkbox for overriding per-model preferences is checked, opts.sd_vae overrides checkpoint user metadata
- changing VAE in user metadata for currently loaded model immediately applies the selection
---
modules/sd_models.py | 2 +-
modules/sd_vae.py | 71 +++++++++++++++++-----
modules/shared.py | 6 +-
.../ui_extra_networks_checkpoints_user_metadata.py | 8 ++-
webui.py | 2 +-
5 files changed, 69 insertions(+), 20 deletions(-)
(limited to 'modules/sd_models.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index f6051604..d65735e3 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -356,7 +356,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple()
sd_vae.load_vae(model, vae_file, vae_source)
timer.record("load VAE")
diff --git a/modules/sd_vae.py b/modules/sd_vae.py
index 0bd5e19b..38bcb840 100644
--- a/modules/sd_vae.py
+++ b/modules/sd_vae.py
@@ -1,5 +1,7 @@
import os
import collections
+from dataclasses import dataclass
+
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
import glob
from copy import deepcopy
@@ -97,37 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file):
return None
-def resolve_vae(checkpoint_file):
- if shared.cmd_opts.vae_path is not None:
- return shared.cmd_opts.vae_path, 'from commandline argument'
+@dataclass
+class VaeResolution:
+ vae: str = None
+ source: str = None
+ resolved: bool = True
+
+ def tuple(self):
+ return self.vae, self.source
+
+
+def is_automatic():
+ return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
+
+
+def resolve_vae_from_setting() -> VaeResolution:
+ if shared.opts.sd_vae == "None":
+ return VaeResolution()
+
+ vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
+ if vae_from_options is not None:
+ return VaeResolution(vae_from_options, 'specified in settings')
+
+ if not is_automatic():
+ print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
+ return VaeResolution(resolved=False)
+
+
+def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution:
metadata = extra_networks.get_user_metadata(checkpoint_file)
vae_metadata = metadata.get("vae", None)
if vae_metadata is not None and vae_metadata != "Automatic":
if vae_metadata == "None":
- return None, None
+ return VaeResolution()
vae_from_metadata = vae_dict.get(vae_metadata, None)
if vae_from_metadata is not None:
- return vae_from_metadata, "from user metadata"
+ return VaeResolution(vae_from_metadata, "from user metadata")
+
+ return VaeResolution(resolved=False)
- is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config
+def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution:
vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic):
- return vae_near_checkpoint, 'found near the checkpoint'
+ return VaeResolution(vae_near_checkpoint, 'found near the checkpoint')
- if shared.opts.sd_vae == "None":
- return None, None
+ return VaeResolution(resolved=False)
- vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
- if vae_from_options is not None:
- return vae_from_options, 'specified in settings'
- if not is_automatic:
- print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")
+def resolve_vae(checkpoint_file) -> VaeResolution:
+ if shared.cmd_opts.vae_path is not None:
+ return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument')
+
+ if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic():
+ return resolve_vae_from_setting()
+
+ res = resolve_vae_from_user_metadata(checkpoint_file)
+ if res.resolved:
+ return res
+
+ res = resolve_vae_near_checkpoint(checkpoint_file)
+ if res.resolved:
+ return res
+
+ res = resolve_vae_from_setting()
- return None, None
+ return res
def load_vae_dict(filename, map_location):
@@ -201,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified):
checkpoint_file = checkpoint_info.filename
if vae_file == unspecified:
- vae_file, vae_source = resolve_vae(checkpoint_file)
+ vae_file, vae_source = resolve_vae(checkpoint_file).tuple()
else:
vae_source = "from function argument"
diff --git a/modules/shared.py b/modules/shared.py
index 078e8135..da53f2d9 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -479,7 +479,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
"""),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"),
- "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
+ "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
"auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
"sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
"sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"),
@@ -733,6 +733,10 @@ class Options:
with open(filename, "r", encoding="utf8") as file:
self.data = json.load(file)
+ # 1.6.0 VAE defaults
+ if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None:
+ self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default')
+
# 1.1.1 quicksettings list migration
if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None:
self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')]
diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py
index 2c69aab8..25df0a80 100644
--- a/modules/ui_extra_networks_checkpoints_user_metadata.py
+++ b/modules/ui_extra_networks_checkpoints_user_metadata.py
@@ -1,6 +1,6 @@
import gradio as gr
-from modules import ui_extra_networks_user_metadata, sd_vae
+from modules import ui_extra_networks_user_metadata, sd_vae, shared
from modules.ui_common import create_refresh_button
@@ -18,6 +18,10 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
self.write_user_metadata(name, user_metadata)
+ def update_vae(self, name):
+ if name == shared.sd_model.sd_checkpoint_info.name_for_extra:
+ sd_vae.reload_vae_weights()
+
def put_values_into_components(self, name):
user_metadata = self.get_user_metadata(name)
values = super().put_values_into_components(name)
@@ -58,3 +62,5 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE
]
self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)
+ self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input])
+
diff --git a/webui.py b/webui.py
index 1803ea8a..a5b11575 100644
--- a/webui.py
+++ b/webui.py
@@ -211,7 +211,7 @@ def configure_sigint_handler():
def configure_opts_onchange():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
--
cgit v1.2.3
From 6e7828e1d271c644840047c3db60e669a232402a Mon Sep 17 00:00:00 2001
From: AUTOMATIC1111 <16777216c@gmail.com>
Date: Mon, 7 Aug 2023 08:16:20 +0300
Subject: apply unet overrides after switching model
---
modules/sd_models.py | 1 +
1 file changed, 1 insertion(+)
(limited to 'modules/sd_models.py')
diff --git a/modules/sd_models.py b/modules/sd_models.py
index d65735e3..53c1df54 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -699,6 +699,7 @@ def reload_model_weights(sd_model=None, info=None):
print(f"Weights loaded in {timer.summary()}.")
model_data.set_sd_model(sd_model)
+ sd_unet.apply_unet()
return sd_model
--
cgit v1.2.3