From a5121e7a0623db328a9462d340d389ed6737374a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:37:18 +0300 Subject: fixes for B007 --- modules/extra_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/extra_networks.py') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 1978673d..f9db41bc 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -91,7 +91,7 @@ def deactivate(p, extra_network_data): """call deactivate for extra networks in extra_network_data in specified order, then call deactivate for all remaining registered networks""" - for extra_network_name, extra_network_args in extra_network_data.items(): + for extra_network_name in extra_network_data: extra_network = extra_network_registry.get(extra_network_name, None) if extra_network is None: continue -- cgit v1.2.3 From 21ee46eea791d83b3b49cedd2306c7f0f1807250 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Fri, 19 May 2023 15:35:16 +0300 Subject: Deduplicate default extra network registration --- modules/extra_networks.py | 5 +++++ modules/ui_extra_networks.py | 9 +++++++++ webui.py | 16 ++++++---------- 3 files changed, 20 insertions(+), 10 deletions(-) (limited to 'modules/extra_networks.py') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index f9db41bc..94347275 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -14,6 +14,11 @@ def register_extra_network(extra_network): extra_network_registry[extra_network.name] = extra_network +def register_default_extra_networks(): + from modules.extra_networks_hypernet import ExtraNetworkHypernet + register_extra_network(ExtraNetworkHypernet()) + + class ExtraNetworkParams: def __init__(self, items=None): self.items = items or [] diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 24eeef0e..19fbaae5 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -236,6 +236,15 @@ def initialize(): extra_pages.clear() +def register_default_pages(): + from modules.ui_extra_networks_textual_inversion import ExtraNetworksPageTextualInversion + from modules.ui_extra_networks_hypernets import ExtraNetworksPageHypernetworks + from modules.ui_extra_networks_checkpoints import ExtraNetworksPageCheckpoints + register_page(ExtraNetworksPageTextualInversion()) + register_page(ExtraNetworksPageHypernetworks()) + register_page(ExtraNetworksPageCheckpoints()) + + class ExtraNetworksUi: def __init__(self): self.pages = None diff --git a/webui.py b/webui.py index 30e4f239..ad6be239 100644 --- a/webui.py +++ b/webui.py @@ -34,8 +34,7 @@ startup_timer.record("import gradio") import ldm.modules.encoders.modules # noqa: F401 startup_timer.record("import ldm") -from modules import extra_networks, ui_extra_networks_checkpoints -from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion +from modules import extra_networks from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors @@ -214,12 +213,11 @@ def initialize(): startup_timer.record("reload hypernets") ui_extra_networks.initialize() - ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) - ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) - ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + ui_extra_networks.register_default_pages() extra_networks.initialize() - extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + extra_networks.register_default_extra_networks() + startup_timer.record("extra networks") if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: @@ -420,12 +418,10 @@ def webui(): startup_timer.record("reload hypernetworks") ui_extra_networks.initialize() - ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) - ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) - ui_extra_networks.register_page(ui_extra_networks_checkpoints.ExtraNetworksPageCheckpoints()) + ui_extra_networks.register_default_pages() extra_networks.initialize() - extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + extra_networks.register_default_extra_networks() startup_timer.record("initialize extra networks") -- cgit v1.2.3 From 39ec4f06ffb2c26e1298b2c5d80874dc3fd693ac Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 19 May 2023 22:59:29 +0300 Subject: calculate hashes for Lora add lora hashes to infotext when pasting infotext, use infotext's lora hashes to find local loras for entries whose hashes match loras the user has --- extensions-builtin/Lora/extra_networks_lora.py | 18 +++++++ extensions-builtin/Lora/lora.py | 59 ++++++++++++++++++----- extensions-builtin/Lora/scripts/lora_script.py | 32 +++++++++++- extensions-builtin/Lora/ui_extra_networks_lora.py | 5 +- modules/extra_networks.py | 9 ++++ modules/hashes.py | 29 ++++++++--- 6 files changed, 130 insertions(+), 22 deletions(-) (limited to 'modules/extra_networks.py') diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index ccb249ac..b5fea4d2 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -23,5 +23,23 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): lora.load_loras(names, multipliers) + if shared.opts.lora_add_hashes_to_infotext: + lora_hashes = [] + for item in lora.loaded_loras: + shorthash = item.lora_on_disk.shorthash + if not shorthash: + continue + + alias = item.mentioned_name + if not alias: + continue + + alias = alias.replace(":", "").replace(",", "") + + lora_hashes.append(f"{alias}: {shorthash}") + + if lora_hashes: + p.extra_generation_params["Lora hashes"] = ", ".join(lora_hashes) + def deactivate(self, p): pass diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index fa57d466..eec14712 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -3,7 +3,7 @@ import re import torch from typing import Union -from modules import shared, devices, sd_models, errors, scripts, sd_hijack +from modules import shared, devices, sd_models, errors, scripts, sd_hijack, hashes metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} @@ -76,9 +76,9 @@ class LoraOnDisk: self.name = name self.filename = filename self.metadata = {} + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" - _, ext = os.path.splitext(filename) - if ext.lower() == ".safetensors": + if self.is_safetensors: try: self.metadata = sd_models.read_metadata_from_safetensors(filename) except Exception as e: @@ -94,14 +94,43 @@ class LoraOnDisk: self.ssmd_cover_images = self.metadata.pop('ssmd_cover_images', None) # those are cover images and they are too big to display in UI as text self.alias = self.metadata.get('ss_output_name', self.name) + self.hash = None + self.shorthash = None + self.set_hash( + self.metadata.get('sshs_model_hash') or + hashes.sha256_from_cache(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or + '' + ) + + def set_hash(self, v): + self.hash = v + self.shorthash = self.hash[0:12] + + if self.shorthash: + available_lora_hash_lookup[self.shorthash] = self + + def read_hash(self): + if not self.hash: + self.set_hash(hashes.sha256(self.filename, "lora/" + self.name, use_addnet_hash=self.is_safetensors) or '') + + def get_alias(self): + if shared.opts.lora_preferred_name == "Filename" or self.alias.lower() in forbidden_lora_aliases: + return self.name + else: + return self.alias + class LoraModule: - def __init__(self, name): + def __init__(self, name, lora_on_disk: LoraOnDisk): self.name = name + self.lora_on_disk = lora_on_disk self.multiplier = 1.0 self.modules = {} self.mtime = None + self.mentioned_name = None + """the text that was used to add lora to prompt - can be either name or an alias""" + class LoraUpDownModule: def __init__(self): @@ -126,11 +155,11 @@ def assign_lora_names_to_compvis_modules(sd_model): sd_model.lora_layer_mapping = lora_layer_mapping -def load_lora(name, filename): - lora = LoraModule(name) - lora.mtime = os.path.getmtime(filename) +def load_lora(name, lora_on_disk): + lora = LoraModule(name, lora_on_disk) + lora.mtime = os.path.getmtime(lora_on_disk.filename) - sd = sd_models.read_state_dict(filename) + sd = sd_models.read_state_dict(lora_on_disk.filename) # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 if not hasattr(shared.sd_model, 'lora_layer_mapping'): @@ -191,7 +220,7 @@ def load_lora(name, filename): raise AssertionError(f"Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha") if len(keys_failed_to_match) > 0: - print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") + print(f"Failed to match keys when loading Lora {lora_on_disk.filename}: {keys_failed_to_match}") return lora @@ -217,14 +246,19 @@ def load_loras(names, multipliers=None): lora = already_loaded.get(name, None) lora_on_disk = loras_on_disk[i] + if lora_on_disk is not None: if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: try: - lora = load_lora(name, lora_on_disk.filename) + lora = load_lora(name, lora_on_disk) except Exception as e: errors.display(e, f"loading Lora {lora_on_disk.filename}") continue + lora.mentioned_name = name + + lora_on_disk.read_hash() + if lora is None: failed_to_load_loras.append(name) print(f"Couldn't find Lora with name {name}") @@ -403,7 +437,8 @@ def list_available_loras(): available_loras.clear() available_lora_aliases.clear() forbidden_lora_aliases.clear() - forbidden_lora_aliases.update({"none": 1}) + available_lora_hash_lookup.clear() + forbidden_lora_aliases.update({"none": 1, "Addams": 1}) os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) @@ -457,8 +492,10 @@ def infotext_pasted(infotext, params): if added: params["Prompt"] += "\n" + "".join(added) + available_loras = {} available_lora_aliases = {} +available_lora_hash_lookup = {} forbidden_lora_aliases = {} loaded_loras = [] diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index a6b340ee..e650f469 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -1,3 +1,5 @@ +import re + import torch import gradio as gr from fastapi import FastAPI @@ -54,7 +56,8 @@ script_callbacks.on_infotext_pasted(lora.infotext_pasted) shared.options_templates.update(shared.options_section(('extra_networks', "Extra Networks"), { "sd_lora": shared.OptionInfo("None", "Add Lora to prompt", gr.Dropdown, lambda: {"choices": ["None", *lora.available_loras]}, refresh=lora.list_available_loras), - "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), + "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), + "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), })) @@ -84,3 +87,30 @@ def api_loras(_: gr.Blocks, app: FastAPI): script_callbacks.on_app_started(api_loras) +re_lora = re.compile(" Date: Wed, 31 May 2023 22:45:16 +0300 Subject: fix [Bug]: LoRA don't apply on dropdown list sd_lora #10880 --- modules/extra_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/extra_networks.py') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 34a3ba63..f4743928 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -26,7 +26,7 @@ class ExtraNetworkParams: self.named = {} for item in self.items: - parts = item.split('=', 2) + parts = item.split('=', 2) if isinstance(item, str) else [item] if len(parts) == 2: self.named[parts[0]] = parts[1] else: -- cgit v1.2.3 From f098e726d3e63aae8a6276ce83c55ac905c4379c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 4 Jun 2023 04:24:44 +0900 Subject: fix conds caching with extra network --- modules/extra_networks.py | 3 +++ modules/processing.py | 24 ++++++++++++------------ 2 files changed, 15 insertions(+), 12 deletions(-) (limited to 'modules/extra_networks.py') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index f4743928..5c5c9a53 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -32,6 +32,9 @@ class ExtraNetworkParams: else: self.positional.append(item) + def __eq__(self, other): + return self.items == other.items and self.positional == other.positional and self.named == other.named + class ExtraNetwork: def __init__(self, name): diff --git a/modules/processing.py b/modules/processing.py index 362ab4c2..22ddc256 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -171,6 +171,7 @@ class StableDiffusionProcessing: self.prompts = None self.negative_prompts = None + self.extra_network_data = None self.seeds = None self.subseeds = None @@ -311,7 +312,7 @@ class StableDiffusionProcessing: self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] - def get_conds_with_caching(self, function, required_prompts, steps, cache): + def get_conds_with_caching(self, function, required_prompts, steps, cache, extra_network_data): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) using a cache to store the result if the same arguments have been used before. @@ -321,21 +322,21 @@ class StableDiffusionProcessing: have been used before. The second element is where the previously computed result is stored. """ - if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) == cache[0]: + if cache[0] is not None and (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) == cache[0]: return cache[1] with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) - cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info) + cache[0] = (required_prompts, steps, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data) return cache[1] def setup_conds(self): sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c) + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c, self.extra_network_data) def parse_extra_network_prompts(self): self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts) @@ -681,7 +682,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.job_count == -1: state.job_count = p.n_iter - extra_network_data = None for n in range(p.n_iter): p.iteration = n @@ -702,11 +702,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(p.prompts) == 0: break - extra_network_data = p.parse_extra_network_prompts() + p.extra_network_data = p.parse_extra_network_prompts() if not p.disable_extra_networks: with devices.autocast(): - extra_networks.activate(p, extra_network_data) + extra_networks.activate(p, p.extra_network_data) if p.scripts is not None: p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) @@ -828,8 +828,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) - if not p.disable_extra_networks and extra_network_data: - extra_networks.deactivate(p, extra_network_data) + if not p.disable_extra_networks and p.extra_network_data: + extra_networks.deactivate(p, p.extra_network_data) devices.torch_gc() @@ -1101,8 +1101,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): super().setup_conds() if self.enable_hr: - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c) + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data) def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts() -- cgit v1.2.3 From 0a277ab59183df638650d8373d785c94d14634ed Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sun, 4 Jun 2023 05:19:47 +0900 Subject: remove redone compare --- modules/extra_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/extra_networks.py') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 5c5c9a53..1f093df2 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -33,7 +33,7 @@ class ExtraNetworkParams: self.positional.append(item) def __eq__(self, other): - return self.items == other.items and self.positional == other.positional and self.named == other.named + return self.items == other.items class ExtraNetwork: -- cgit v1.2.3 From 9c2a7f1e8bafcb59e566bf568fdefe1be95905fe Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 19 Jun 2023 15:37:20 +0900 Subject: add callback after_extra_networks_activate --- modules/extra_networks.py | 3 +++ modules/scripts.py | 23 +++++++++++++++++++++++ 2 files changed, 26 insertions(+) (limited to 'modules/extra_networks.py') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 1f093df2..41799b0a 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -103,6 +103,9 @@ def activate(p, extra_network_data): except Exception as e: errors.display(e, f"activating extra network {extra_network_name}") + if p.scripts is not None: + p.scripts.after_extra_networks_activate(p, batch_number=p.iteration, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds, extra_network_data=extra_network_data) + def deactivate(p, extra_network_data): """call deactivate for extra networks in extra_network_data in specified order, then call diff --git a/modules/scripts.py b/modules/scripts.py index 99bf836a..340f1480 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -116,6 +116,21 @@ class Script: pass + def after_extra_networks_activate(self, p, *args, **kwargs): + """ + Calledafter extra networks activation, before conds calculation + allow modification of the network after extra networks activation been applied + won't be call if p.disable_extra_networks + + **kwargs will have those items: + - batch_number - index of current batch, from 0 to number of batches-1 + - prompts - list of prompts for current batch; you can change contents of this list but changing the number of entries will likely break things + - seeds - list of seeds for current batch + - subseeds - list of subseeds for current batch + - extra_network_data - list of ExtraNetworkParams for current stage + """ + pass + def process_batch(self, p, *args, **kwargs): """ Same as process(), but called for every batch. @@ -483,6 +498,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running before_process_batch: {script.filename}", exc_info=True) + def after_extra_networks_activate(self, p, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.after_extra_networks_activate(p, *script_args, **kwargs) + except Exception: + errors.report(f"Error running after_extra_networks_activate: {script.filename}", exc_info=True) + def process_batch(self, p, **kwargs): for script in self.alwayson_scripts: try: -- cgit v1.2.3