From 3bca90b249d749ed5429f76e380d2ffa52fc0d41 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 30 Jul 2023 13:48:27 +0300 Subject: hires fix checkpoint selection --- modules/sd_models.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index acb1e817..cb67e425 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -52,6 +52,7 @@ class CheckpointInfo: self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) @@ -81,6 +82,7 @@ class CheckpointInfo: checkpoints_list.pop(self.title) self.title = f'{self.name} [{self.shorthash}]' + self.short_title = f'{self.name_for_extra} [{self.shorthash}]' self.register() return self.shorthash @@ -101,14 +103,8 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - def convert(name): - return int(name) if name.isdigit() else name.lower() - - def alphanumeric_key(key): - return [convert(c) for c in re.split('([0-9]+)', key)] - - return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) +def checkpoint_tiles(use_short=False): + return [x.short_title if use_short else x.title for x in checkpoints_list.values()] def list_models(): @@ -131,11 +127,14 @@ def list_models(): elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - for filename in sorted(model_list, key=str.lower): + for filename in model_list: checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() +re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") + + def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: @@ -145,6 +144,11 @@ def get_closet_checkpoint_match(search_string): if found: return found[0] + search_string_without_checksum = re.sub(re_strip_checksum, '', search_string) + found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + return None -- cgit v1.2.3 From 4d9b096663288e2aa738723fa63950f3d41f6170 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 31 Jul 2023 10:43:31 +0300 Subject: additional memory improvements when switching between models of different types --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index cb67e425..4855037a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -582,7 +582,10 @@ def reload_model_weights(sd_model=None, info=None): timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: - del sd_model + if sd_model is not None: + sd_model.to(device="meta") + + devices.torch_gc() load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model -- cgit v1.2.3 From b235022c615a7384f73c05fe240d8f4a28d103d4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 00:24:48 +0300 Subject: option to keep multiple models in memory --- modules/lowvram.py | 3 + modules/sd_hijack.py | 6 +- modules/sd_hijack_inpainting.py | 5 +- modules/sd_models.py | 136 +++++++++++++++++++++++++++++++++------- modules/sd_models_xl.py | 8 +-- modules/shared.py | 12 +++- 6 files changed, 135 insertions(+), 35 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/lowvram.py b/modules/lowvram.py index 3f830664..96f52b7b 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -15,6 +15,9 @@ def send_everything_to_cpu(): def setup_for_low_vram(sd_model, use_medvram): + if getattr(sd_model, 'lowvram', False): + return + sd_model.lowvram = True parents = {} diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cfa5f0eb..7d692e3c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -30,8 +30,10 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 -ldm.modules.attention.print = lambda *args: None -ldm.modules.diffusionmodules.model.print = lambda *args: None +ldm.modules.attention.print = shared.ldm_print +ldm.modules.diffusionmodules.model.print = shared.ldm_print +ldm.util.print = shared.ldm_print +ldm.models.diffusion.ddpm.print = shared.ldm_print optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index c1977b19..97350f4f 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -91,7 +91,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0, e_t -def do_inpainting_hijack(): - # p_sample_plms is needed because PLMS can't work with dicts as conditionings - - ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms +ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_models.py b/modules/sd_models.py index acb1e817..77195f2f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,7 +15,6 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl -from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -423,6 +422,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: def __init__(self): self.sd_model = None + self.loaded_sd_models = [] self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -437,6 +437,7 @@ class SdModelData: try: load_model() + except Exception as e: errors.display(e, "loading stable diffusion model", full_traceback=True) print("", file=sys.stderr) @@ -448,11 +449,24 @@ class SdModelData: def set_sd_model(self, v): self.sd_model = v + try: + self.loaded_sd_models.remove(v) + except ValueError: + pass + + if v is not None: + self.loaded_sd_models.insert(0, v) + model_data = SdModelData() def get_empty_cond(sd_model): + from modules import extra_networks, processing + + p = processing.StableDiffusionProcessingTxt2Img() + extra_networks.activate(p, {}) + if hasattr(sd_model, 'conditioner'): d = sd_model.get_learned_conditioning([""]) return d['crossattn'] @@ -460,19 +474,43 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) +def send_model_to_cpu(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + m.to(devices.cpu) + + devices.torch_gc() + + +def send_model_to_device(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) + else: + m.to(shared.device) + + +def send_model_to_trash(m): + m.to(device="meta") + devices.torch_gc() + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): - from modules import lowvram, sd_hijack + from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + timer = Timer() + if model_data.sd_model: - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + send_model_to_trash(model_data.sd_model) model_data.sd_model = None - gc.collect() devices.torch_gc() - do_inpainting_hijack() - - timer = Timer() + timer.record("unload existing model") if already_loaded_state_dict is not None: state_dict = already_loaded_state_dict @@ -512,12 +550,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): load_model_weights(sd_model, checkpoint_info, state_dict, timer) + timer.record("load weights from state dict") - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) - else: - sd_model.to(shared.device) - + send_model_to_device(sd_model) timer.record("move model to device") sd_hijack.model_hijack.hijack(sd_model) @@ -525,7 +560,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("hijack") sd_model.eval() - model_data.sd_model = sd_model + model_data.set_sd_model(sd_model) model_data.was_loaded_at_least_once = True sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model @@ -546,10 +581,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): return sd_model +def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): + """ + Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models. + If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary). + If not, returns the model that can be used to load weights from checkpoint_info's file. + If no such model exists, returns None. + Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit). + """ + + already_loaded = None + for i in reversed(range(len(model_data.loaded_sd_models))): + loaded_model = model_data.loaded_sd_models[i] + if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename: + already_loaded = loaded_model + continue + + if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: + print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") + model_data.loaded_sd_models.pop() + send_model_to_trash(loaded_model) + timer.record("send model to trash") + + if shared.opts.sd_checkpoints_keep_in_cpu: + send_model_to_cpu(sd_model) + timer.record("send model to cpu") + + if already_loaded is not None: + send_model_to_device(already_loaded) + timer.record("send model to device") + + model_data.set_sd_model(already_loaded) + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") + return model_data.sd_model + elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: + print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})") + + model_data.sd_model = None + load_model(checkpoint_info) + return model_data.sd_model + elif len(model_data.loaded_sd_models) > 0: + sd_model = model_data.loaded_sd_models.pop() + model_data.sd_model = sd_model + + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") + return sd_model + else: + return None + + def reload_model_weights(sd_model=None, info=None): - from modules import lowvram, devices, sd_hijack + from modules import devices, sd_hijack checkpoint_info = info or select_checkpoint() + timer = Timer() + if not sd_model: sd_model = model_data.sd_model @@ -558,19 +644,17 @@ def reload_model_weights(sd_model=None, info=None): else: current_checkpoint_info = sd_model.sd_checkpoint_info if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return - - sd_unet.apply_unet("None") + return sd_model - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) + sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) + if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + return sd_model + if sd_model is not None: + sd_unet.apply_unet("None") + send_model_to_cpu(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model) - timer = Timer() - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) @@ -578,7 +662,9 @@ def reload_model_weights(sd_model=None, info=None): timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: - del sd_model + if sd_model is not None: + send_model_to_trash(sd_model) + load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model @@ -601,6 +687,8 @@ def reload_model_weights(sd_model=None, info=None): print(f"Weights loaded in {timer.summary()}.") + model_data.set_sd_model(sd_model) + return sd_model diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index bc219508..01123321 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -98,10 +98,10 @@ def extend_sdxl(model): model.conditioner.wrapped = torch.nn.Module() -sgm.modules.attention.print = lambda *args: None -sgm.modules.diffusionmodules.model.print = lambda *args: None -sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None -sgm.modules.encoders.modules.print = lambda *args: None +sgm.modules.attention.print = shared.ldm_print +sgm.modules.diffusionmodules.model.print = shared.ldm_print +sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print +sgm.modules.encoders.modules.print = shared.ldm_print # this gets the code to load the vanilla attention that we override sgm.modules.attention.SDP_IS_AVAILABLE = True diff --git a/modules/shared.py b/modules/shared.py index aa72c9c8..0184fcd0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -392,6 +392,7 @@ options_templates.update(options_section(('system', "System"), { "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), + "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), })) options_templates.update(options_section(('training', "Training"), { @@ -411,7 +412,9 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), - "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), + "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), @@ -889,3 +892,10 @@ def walk_files(path, allowed_extensions=None): continue yield os.path.join(root, filename) + + +def ldm_print(*args, **kwargs): + if opts.hide_ldm_prints: + return + + print(*args, **kwargs) -- cgit v1.2.3 From 390bffa81b747a7eb38ac7a0cd6dfb9fcc388151 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 17:13:15 +0300 Subject: repair merge error --- modules/sd_models.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 40a450df..3c451a4b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,7 +15,6 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache -from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd -- cgit v1.2.3 From 20549a50cb3c41868ce561c6658bfaa0d20ac7ba Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 3 Aug 2023 22:46:57 +0300 Subject: add style editor dialog rework toprow for img2img and txt2img to use a class with fields fix the console error when editing checkpoint user metadata --- modules/sd_models.py | 2 +- modules/styles.py | 5 +- modules/ui.py | 230 ++++++++++--------------- modules/ui_common.py | 32 +++- modules/ui_extra_networks_checkpoints.py | 2 +- modules/ui_extra_networks_hypernets.py | 2 +- modules/ui_extra_networks_textual_inversion.py | 2 +- modules/ui_prompt_styles.py | 110 ++++++++++++ style.css | 13 ++ 9 files changed, 248 insertions(+), 150 deletions(-) create mode 100644 modules/ui_prompt_styles.py (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 8f72f21d..1d93d893 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -68,7 +68,7 @@ class CheckpointInfo: self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' - self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) + self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self diff --git a/modules/styles.py b/modules/styles.py index ec0e1bc5..0740fe1b 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -106,10 +106,7 @@ class StyleDatabase: if os.path.exists(path): shutil.copy(path, f"{path}.bak") - fd = os.open(path, os.O_RDWR | os.O_CREAT) - with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: - # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, - # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() + with open(path, "w", encoding="utf-8-sig", newline='') as file: writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer.writeheader() writer.writerows(style._asdict() for k, style in self.styles.items()) diff --git a/modules/ui.py b/modules/ui.py index ac2787eb..c059dcec 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,7 +12,7 @@ import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path from modules.ui_common import create_refresh_button @@ -92,19 +92,6 @@ def send_gradio_gallery_to_image(x): return image_from_url_text(x[0]) -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] - - def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): from modules import processing, devices @@ -129,13 +116,6 @@ def resize_from_to_html(width, height, scale_by): return f"resize: from {width}x{height} to {target_width}x{target_height}" -def apply_styles(prompt, prompt_neg, styles): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] - - def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles): if mode in {0, 1, 3, 4}: return [interrogation_function(ii_singles[mode]), None] @@ -267,71 +247,67 @@ def update_token_counter(text, steps): return f"{token_count}/{max_length}" -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" +class Toprow: + def __init__(self, is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + self.id_part = id_part - with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): - with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_classes="interrogate-col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): - with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") - skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) + with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): + with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + + self.button_interrogate = None + self.button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_classes="interrogate-col"): + self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): + with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): + self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") + self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") + self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + self.skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) - with gr.Row(elem_id=f"{id_part}_tools"): - paste = ToolButton(value=paste_symbol, elem_id="paste") - clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") - prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply") - save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create") - restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) - - token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) - negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) + self.interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) - with gr.Row(elem_id=f"{id_part}_styles_row"): - prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) - create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") + with gr.Row(elem_id=f"{id_part}_tools"): + self.paste = ToolButton(value=paste_symbol, elem_id="paste") + self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") + self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) + + self.token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) + self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) + self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") + + self.clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[self.prompt, self.negative_prompt], + outputs=[self.prompt, self.negative_prompt], + ) - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button + self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt) def setup_progressbar(*args, **kwargs): @@ -419,14 +395,14 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False) + toprow = txt2img_toprow = Toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: from modules import ui_extra_networks - extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img') + extra_networks_ui = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'txt2img') with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): @@ -532,9 +508,9 @@ def create_ui(): _js="submit", inputs=[ dummy_component, - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_styles, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, steps, sampler_index, restore_faces, @@ -569,12 +545,12 @@ def create_ui(): show_progress=False, ) - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) + toprow.prompt.submit(**txt2img_args) + toprow.submit.click(**txt2img_args) res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False) - restore_progress_button.click( + toprow.restore_progress_button.click( fn=progress.restore_progress, _js="restoreProgressTxt2img", inputs=[dummy_component], @@ -593,7 +569,7 @@ def create_ui(): txt_prompt_img ], outputs=[ - txt2img_prompt, + toprow.prompt, txt_prompt_img ], show_progress=False, @@ -607,8 +583,8 @@ def create_ui(): ) txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), + (toprow.prompt, "Prompt"), + (toprow.negative_prompt, "Negative prompt"), (steps, "Steps"), (sampler_index, "Sampler"), (restore_faces, "Face restoration"), @@ -621,7 +597,7 @@ def create_ui(): (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), - (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), @@ -639,12 +615,12 @@ def create_ui(): ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, + paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None, )) txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, + toprow.prompt, + toprow.negative_prompt, steps, sampler_index, cfg_scale, @@ -653,8 +629,8 @@ def create_ui(): height, ] - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) @@ -662,13 +638,13 @@ def create_ui(): modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True) + toprow = img2img_toprow = Toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: from modules import ui_extra_networks - extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img') + extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'img2img') with FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): @@ -889,7 +865,7 @@ def create_ui(): img2img_prompt_img ], outputs=[ - img2img_prompt, + toprow.prompt, img2img_prompt_img ], show_progress=False, @@ -901,9 +877,9 @@ def create_ui(): inputs=[ dummy_component, dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_styles, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, init_img, sketch, init_img_with_mask, @@ -962,11 +938,11 @@ def create_ui(): inpaint_color_sketch, init_img_inpaint, ], - outputs=[img2img_prompt, dummy_component], + outputs=[toprow.prompt, dummy_component], ) - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) + toprow.prompt.submit(**img2img_args) + toprow.submit.click(**img2img_args) res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False) @@ -978,7 +954,7 @@ def create_ui(): show_progress=False, ) - restore_progress_button.click( + toprow.restore_progress_button.click( fn=progress.restore_progress, _js="restoreProgressImg2img", inputs=[dummy_component], @@ -991,46 +967,24 @@ def create_ui(): show_progress=False, ) - img2img_interrogate.click( + toprow.button_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, ) - img2img_deepbooru.click( + toprow.button_deepbooru.click( fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), **interrogate_args, ) - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_styles, img2img_prompt_styles], - ) - - for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, styles], - outputs=[prompt, negative_prompt, styles], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter]) + toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), + (toprow.prompt, "Prompt"), + (toprow.negative_prompt, "Negative prompt"), (steps, "Steps"), (sampler_index, "Sampler"), (restore_faces, "Face restoration"), @@ -1044,7 +998,7 @@ def create_ui(): (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), - (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), *modules.scripts.scripts_img2img.infotext_fields @@ -1052,7 +1006,7 @@ def create_ui(): parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, + paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, )) modules.scripts.scripts_current = None diff --git a/modules/ui_common.py b/modules/ui_common.py index 11eb2a4b..ba75fa73 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -223,20 +223,44 @@ Requested path was: {f} def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component] + + label = None + for comp in refresh_components: + label = getattr(comp, 'label', None) + if label is not None: + break + def refresh(): refresh_method() args = refreshed_args() if callable(refreshed_args) else refreshed_args for k, v in args.items(): - setattr(refresh_component, k, v) + for comp in refresh_components: + setattr(comp, k, v) - return gr.update(**(args or {})) + return [gr.update(**(args or {})) for _ in refresh_components] - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh") refresh_button.click( fn=refresh, inputs=[], - outputs=[refresh_component] + outputs=[*refresh_components] ) return refresh_button + +def setup_dialog(button_show, dialog, *, button_close=None): + """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window.""" + + dialog.visible = False + + button_show.click( + fn=lambda: gr.update(visible=True), + inputs=[], + outputs=[dialog], + ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }") + + if button_close: + button_close.click(fn=None, _js="closePopup") + diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 2bb0a222..891d8f2c 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -12,7 +12,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.refresh_checkpoints() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) path, ext = os.path.splitext(checkpoint.filename) return { diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index e53ccb42..514a4562 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -11,7 +11,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.reload_hypernetworks() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): full_path = shared.hypernetworks[name] path, ext = os.path.splitext(full_path) diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index d1794e50..73134698 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -12,7 +12,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def refresh(self): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) path, ext = os.path.splitext(embedding.filename) diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py new file mode 100644 index 00000000..85eb3a64 --- /dev/null +++ b/modules/ui_prompt_styles.py @@ -0,0 +1,110 @@ +import gradio as gr + +from modules import shared, ui_common, ui_components, styles + +styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️ +styles_materialize_symbol = '\U0001f4cb' # 📋 + + +def select_style(name): + style = shared.prompt_styles.styles.get(name) + existing = style is not None + empty = not name + + prompt = style.prompt if style else gr.update() + negative_prompt = style.negative_prompt if style else gr.update() + + return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty) + + +def save_style(name, prompt, negative_prompt): + if not name: + return gr.update(visible=False) + + style = styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + shared.prompt_styles.save_styles(shared.styles_filename) + + return gr.update(visible=True) + + +def delete_style(name): + if name == "": + return + + shared.prompt_styles.styles.pop(name, None) + shared.prompt_styles.save_styles(shared.styles_filename) + + return '', '', '' + + +def materialize_styles(prompt, negative_prompt, styles): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])] + + +def refresh_styles(): + return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles)) + + +class UiPromptStyles: + def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): + self.tabname = tabname + + with gr.Row(elem_id=f"{tabname}_styles_row"): + self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles") + edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles") + + with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog: + with gr.Row(): + self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.") + ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles") + self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") + + with gr.Row(): + self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) + + with gr.Row(): + self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3) + + with gr.Row(): + self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False) + self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False) + self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close') + + self.selection.change( + fn=select_style, + inputs=[self.selection], + outputs=[self.prompt, self.neg_prompt, self.delete, self.save], + show_progress=False, + ) + + self.save.click( + fn=save_style, + inputs=[self.selection, self.prompt, self.neg_prompt], + outputs=[self.delete], + show_progress=False, + ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) + + self.delete.click( + fn=delete_style, + _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }', + inputs=[self.selection], + outputs=[self.selection, self.prompt, self.neg_prompt], + show_progress=False, + ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) + + self.materialize.click( + fn=materialize_styles, + inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + show_progress=False, + ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False) + + ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close) + + + + diff --git a/style.css b/style.css index 6c92d6e7..cf8470e4 100644 --- a/style.css +++ b/style.css @@ -972,3 +972,16 @@ div.block.gradio-box.edit-user-metadata { .edit-user-metadata-buttons{ margin-top: 1.5em; } + + + + +div.block.gradio-box.popup-dialog, .popup-dialog { + width: 56em; + background: var(--body-background-fill); + padding: 2em !important; +} + +div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{ + margin-top: 1em; +} -- cgit v1.2.3 From 24f21583cdba2ae6cc51773b956c6ce068d3dfe4 Mon Sep 17 00:00:00 2001 From: AnyISalIn Date: Fri, 4 Aug 2023 11:43:27 +0800 Subject: fix: prevent cache model.state_dict() after model hijack Signed-off-by: AnyISalIn --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 1d93d893..ba15b451 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -303,12 +303,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_models_xl.extend_sdxl(model) model.load_state_dict(state_dict, strict=False) - del state_dict timer.record("apply weights to model") if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + checkpoints_loaded[checkpoint_info] = state_dict + + del state_dict if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) -- cgit v1.2.3 From c96e4750d895a47290dc7f96e030197069c75fa4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 7 Aug 2023 08:07:09 +0300 Subject: SD VAE rework 2 - the setting for preferring opts.sd_vae has been inverted and reworded - resolve_vae function made easier to read and now returns an object rather than a tuple - if the checkbox for overriding per-model preferences is checked, opts.sd_vae overrides checkpoint user metadata - changing VAE in user metadata for currently loaded model immediately applies the selection --- modules/sd_models.py | 2 +- modules/sd_vae.py | 71 +++++++++++++++++----- modules/shared.py | 6 +- .../ui_extra_networks_checkpoints_user_metadata.py | 8 ++- webui.py | 2 +- 5 files changed, 69 insertions(+), 20 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6051604..d65735e3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -356,7 +356,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple() sd_vae.load_vae(model, vae_file, vae_source) timer.record("load VAE") diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0bd5e19b..38bcb840 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,5 +1,7 @@ import os import collections +from dataclasses import dataclass + from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -97,37 +99,74 @@ def find_vae_near_checkpoint(checkpoint_file): return None -def resolve_vae(checkpoint_file): - if shared.cmd_opts.vae_path is not None: - return shared.cmd_opts.vae_path, 'from commandline argument' +@dataclass +class VaeResolution: + vae: str = None + source: str = None + resolved: bool = True + + def tuple(self): + return self.vae, self.source + + +def is_automatic(): + return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + + +def resolve_vae_from_setting() -> VaeResolution: + if shared.opts.sd_vae == "None": + return VaeResolution() + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return VaeResolution(vae_from_options, 'specified in settings') + + if not is_automatic(): + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + return VaeResolution(resolved=False) + + +def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution: metadata = extra_networks.get_user_metadata(checkpoint_file) vae_metadata = metadata.get("vae", None) if vae_metadata is not None and vae_metadata != "Automatic": if vae_metadata == "None": - return None, None + return VaeResolution() vae_from_metadata = vae_dict.get(vae_metadata, None) if vae_from_metadata is not None: - return vae_from_metadata, "from user metadata" + return VaeResolution(vae_from_metadata, "from user metadata") + + return VaeResolution(resolved=False) - is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config +def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution: vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): - return vae_near_checkpoint, 'found near the checkpoint' + return VaeResolution(vae_near_checkpoint, 'found near the checkpoint') - if shared.opts.sd_vae == "None": - return None, None + return VaeResolution(resolved=False) - vae_from_options = vae_dict.get(shared.opts.sd_vae, None) - if vae_from_options is not None: - return vae_from_options, 'specified in settings' - if not is_automatic: - print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") +def resolve_vae(checkpoint_file) -> VaeResolution: + if shared.cmd_opts.vae_path is not None: + return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument') + + if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic(): + return resolve_vae_from_setting() + + res = resolve_vae_from_user_metadata(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_near_checkpoint(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_from_setting() - return None, None + return res def load_vae_dict(filename, map_location): @@ -201,7 +240,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): checkpoint_file = checkpoint_info.filename if vae_file == unspecified: - vae_file, vae_source = resolve_vae(checkpoint_file) + vae_file, vae_source = resolve_vae(checkpoint_file).tuple() else: vae_source = "from function argument" diff --git a/modules/shared.py b/modules/shared.py index 078e8135..da53f2d9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -479,7 +479,7 @@ For img2img, VAE is used to process user's input image before the sampling, and """), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), - "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), + "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"), "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"), "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"), @@ -733,6 +733,10 @@ class Options: with open(filename, "r", encoding="utf8") as file: self.data = json.load(file) + # 1.6.0 VAE defaults + if self.data.get('sd_vae_as_default') is not None and self.data.get('sd_vae_overrides_per_model_preferences') is None: + self.data['sd_vae_overrides_per_model_preferences'] = not self.data.get('sd_vae_as_default') + # 1.1.1 quicksettings list migration if self.data.get('quicksettings') is not None and self.data.get('quicksettings_list') is None: self.data['quicksettings_list'] = [i.strip() for i in self.data.get('quicksettings').split(',')] diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py index 2c69aab8..25df0a80 100644 --- a/modules/ui_extra_networks_checkpoints_user_metadata.py +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import ui_extra_networks_user_metadata, sd_vae +from modules import ui_extra_networks_user_metadata, sd_vae, shared from modules.ui_common import create_refresh_button @@ -18,6 +18,10 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE self.write_user_metadata(name, user_metadata) + def update_vae(self, name): + if name == shared.sd_model.sd_checkpoint_info.name_for_extra: + sd_vae.reload_vae_weights() + def put_values_into_components(self, name): user_metadata = self.get_user_metadata(name) values = super().put_values_into_components(name) @@ -58,3 +62,5 @@ class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataE ] self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components) + self.button_save.click(fn=self.update_vae, inputs=[self.edit_name_input]) + diff --git a/webui.py b/webui.py index 1803ea8a..a5b11575 100644 --- a/webui.py +++ b/webui.py @@ -211,7 +211,7 @@ def configure_sigint_handler(): def configure_opts_onchange(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) -- cgit v1.2.3 From 6e7828e1d271c644840047c3db60e669a232402a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 7 Aug 2023 08:16:20 +0300 Subject: apply unet overrides after switching model --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index d65735e3..53c1df54 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -699,6 +699,7 @@ def reload_model_weights(sd_model=None, info=None): print(f"Weights loaded in {timer.summary()}.") model_data.set_sd_model(sd_model) + sd_unet.apply_unet() return sd_model -- cgit v1.2.3