diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-10-14 09:18:55 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-10-14 09:18:55 +0000 |
commit | 4be7b620c2994a14c0ce879d3a6cea0db0f66a84 (patch) | |
tree | 3c1ca88bf58264fdc8b50dc6738cc5b7720cb2de /extensions-builtin/Lora | |
parent | 19f5795c27f13c84a922bc69f559cc69e408b3e2 (diff) | |
parent | a8cbe50c9fa324ed887089e4333452ecc4355c92 (diff) | |
download | stable-diffusion-webui-gfx803-4be7b620c2994a14c0ce879d3a6cea0db0f66a84.tar.gz stable-diffusion-webui-gfx803-4be7b620c2994a14c0ce879d3a6cea0db0f66a84.tar.bz2 stable-diffusion-webui-gfx803-4be7b620c2994a14c0ce879d3a6cea0db0f66a84.zip |
Merge pull request #13568 from AUTOMATIC1111/lora_emb_bundle
Add lora-embedding bundle system
Diffstat (limited to 'extensions-builtin/Lora')
-rw-r--r-- | extensions-builtin/Lora/lora_logger.py | 33 | ||||
-rw-r--r-- | extensions-builtin/Lora/network.py | 1 | ||||
-rw-r--r-- | extensions-builtin/Lora/networks.py | 41 |
3 files changed, 75 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/Lora/lora_logger.py new file mode 100644 index 00000000..d51de297 --- /dev/null +++ b/extensions-builtin/Lora/lora_logger.py @@ -0,0 +1,33 @@ +import sys +import copy +import logging + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +logger = logging.getLogger("lora") +logger.propagate = False + + +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s") + ) + logger.addHandler(handler) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index d8e8dfb7..6021fd8d 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -93,6 +93,7 @@ class Network: # LoraModule self.unet_multiplier = 1.0
self.dyn_dim = None
self.modules = {}
+ self.bundle_embeddings = {}
self.mtime = None
self.mentioned_name = None
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index ddab3c55..60d8dec4 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -16,6 +16,9 @@ import torch from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
+import modules.textual_inversion.textual_inversion as textual_inversion
+
+from lora_logger import logger
module_types = [
network_lora.ModuleTypeLora(),
@@ -151,9 +154,19 @@ def load_network(name, network_on_disk): is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {}
+ bundle_embeddings = {}
for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
+ if key_network_without_network_parts == "bundle_emb":
+ emb_name, vec_name = network_part.split(".", 1)
+ emb_dict = bundle_embeddings.get(emb_name, {})
+ if vec_name.split('.')[0] == 'string_to_param':
+ _, k2 = vec_name.split('.', 1)
+ emb_dict['string_to_param'] = {k2: weight}
+ else:
+ emb_dict[vec_name] = weight
+ bundle_embeddings[emb_name] = emb_dict
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
@@ -197,6 +210,14 @@ def load_network(name, network_on_disk): net.modules[key] = net_module
+ embeddings = {}
+ for emb_name, data in bundle_embeddings.items():
+ embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
+ embedding.loaded = None
+ embeddings[emb_name] = embedding
+
+ net.bundle_embeddings = embeddings
+
if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@@ -212,11 +233,15 @@ def purge_networks_from_memory(): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
+ emb_db = sd_hijack.model_hijack.embedding_db
already_loaded = {}
for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded:
+ emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
loaded_networks.clear()
@@ -259,6 +284,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded is None and emb_name in emb_db.word_embeddings:
+ logger.warning(
+ f'Skip bundle embedding: "{emb_name}"'
+ ' as it was already loaded from embeddings folder'
+ )
+ continue
+
+ embedding.loaded = False
+ if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
+ embedding.loaded = True
+ emb_db.register_embedding(embedding, shared.sd_model)
+ else:
+ emb_db.skipped_embeddings[name] = embedding
+
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
@@ -567,6 +607,7 @@ extra_network_lora = None available_networks = {}
available_network_aliases = {}
loaded_networks = []
+loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
|