diff options
Diffstat (limited to 'extensions-builtin/Lora')
-rw-r--r-- | extensions-builtin/Lora/network.py | 1 | ||||
-rw-r--r-- | extensions-builtin/Lora/networks.py | 27 | ||||
-rw-r--r-- | extensions-builtin/Lora/scripts/lora_script.py | 3 | ||||
-rw-r--r-- | extensions-builtin/Lora/ui_edit_user_metadata.py | 2 |
4 files changed, 29 insertions, 4 deletions
diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 8ecfa29a..0a18d69e 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -1,3 +1,4 @@ +from __future__ import annotations
import os
from collections import namedtuple
import enum
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index af8188e3..bc722e90 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -163,6 +163,11 @@ def load_network(name, network_on_disk): key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model")
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+ # some SD1 Loras also have correct compvis keys
+ if sd_module is None:
+ key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model")
+ sd_module = shared.sd_model.network_layer_mapping.get(key, None)
+
if sd_module is None:
keys_failed_to_match[key_network] = key
continue
@@ -190,6 +195,15 @@ def load_network(name, network_on_disk): return net
+def purge_networks_from_memory():
+ while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0:
+ name = next(iter(networks_in_memory))
+ networks_in_memory.pop(name, None)
+
+ devices.torch_gc()
+
+
+
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
already_loaded = {}
@@ -207,15 +221,19 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No failed_to_load_networks = []
- for i, name in enumerate(names):
+ for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)):
net = already_loaded.get(name, None)
- network_on_disk = networks_on_disk[i]
-
if network_on_disk is not None:
+ if net is None:
+ net = networks_in_memory.get(name)
+
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime:
try:
net = load_network(name, network_on_disk)
+
+ networks_in_memory.pop(name, None)
+ networks_in_memory[name] = net
except Exception as e:
errors.display(e, f"loading network {network_on_disk.filename}")
continue
@@ -237,6 +255,8 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks))
+ purge_networks_from_memory()
+
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
weights_backup = getattr(self, "network_weights_backup", None)
@@ -457,6 +477,7 @@ def infotext_pasted(infotext, params): available_networks = {}
available_network_aliases = {}
loaded_networks = []
+networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index cd28afc9..6ab8b6e7 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -65,6 +65,7 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"),
"lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"),
"lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}),
+ "lora_in_memory_limit": shared.OptionInfo(0, "Number of Lora networks to keep cached in memory", gr.Number, {"precision": 0}),
}))
@@ -121,3 +122,5 @@ def infotext_pasted(infotext, d): script_callbacks.on_infotext_pasted(infotext_pasted)
+
+shared.opts.onchange("lora_in_memory_limit", networks.purge_networks_from_memory)
diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 2ca997f7..390d9dde 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -167,7 +167,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) random_prompt = gr.Textbox(label='Random prompt', lines=4, max_lines=4, interactive=False)
with gr.Column(scale=1, min_width=120):
- generate_random_prompt = gr.Button('Generate').style(full_width=True, size="lg")
+ generate_random_prompt = gr.Button('Generate', size="lg", scale=1)
self.edit_notes = gr.TextArea(label='Notes', lines=4)
|