diff options
Diffstat (limited to 'extensions-builtin')
-rw-r--r-- | extensions-builtin/Lora/lora_logger.py | 33 | ||||
-rw-r--r-- | extensions-builtin/Lora/network.py | 1 | ||||
-rw-r--r-- | extensions-builtin/Lora/network_glora.py | 33 | ||||
-rw-r--r-- | extensions-builtin/Lora/networks.py | 44 | ||||
-rw-r--r-- | extensions-builtin/mobile/javascript/mobile.js | 2 |
5 files changed, 113 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/Lora/lora_logger.py new file mode 100644 index 00000000..d51de297 --- /dev/null +++ b/extensions-builtin/Lora/lora_logger.py @@ -0,0 +1,33 @@ +import sys +import copy +import logging + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +logger = logging.getLogger("lora") +logger.propagate = False + + +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s") + ) + logger.addHandler(handler) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index d8e8dfb7..6021fd8d 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -93,6 +93,7 @@ class Network: # LoraModule self.unet_multiplier = 1.0
self.dyn_dim = None
self.modules = {}
+ self.bundle_embeddings = {}
self.mtime = None
self.mentioned_name = None
diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py new file mode 100644 index 00000000..492d4870 --- /dev/null +++ b/extensions-builtin/Lora/network_glora.py @@ -0,0 +1,33 @@ + +import network + +class ModuleTypeGLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]): + return NetworkModuleGLora(net, weights) + + return None + +# adapted from https://github.com/KohakuBlueleaf/LyCORIS +class NetworkModuleGLora(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["a1.weight"] + self.w1b = weights.w["b1.weight"] + self.w2a = weights.w["a2.weight"] + self.w2b = weights.w["b2.weight"] + + def calc_updown(self, orig_weight): + w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w1a.size(0), w1b.size(1)] + updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index a21ea0fa..60d8dec4 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -5,6 +5,7 @@ import re import lora_patches
import network
import network_lora
+import network_glora
import network_hada
import network_ia3
import network_lokr
@@ -15,6 +16,9 @@ import torch from typing import Union
from modules import shared, devices, sd_models, errors, scripts, sd_hijack
+import modules.textual_inversion.textual_inversion as textual_inversion
+
+from lora_logger import logger
module_types = [
network_lora.ModuleTypeLora(),
@@ -23,6 +27,7 @@ module_types = [ network_lokr.ModuleTypeLokr(),
network_full.ModuleTypeFull(),
network_norm.ModuleTypeNorm(),
+ network_glora.ModuleTypeGLora(),
]
@@ -149,9 +154,19 @@ def load_network(name, network_on_disk): is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping
matched_networks = {}
+ bundle_embeddings = {}
for key_network, weight in sd.items():
key_network_without_network_parts, network_part = key_network.split(".", 1)
+ if key_network_without_network_parts == "bundle_emb":
+ emb_name, vec_name = network_part.split(".", 1)
+ emb_dict = bundle_embeddings.get(emb_name, {})
+ if vec_name.split('.')[0] == 'string_to_param':
+ _, k2 = vec_name.split('.', 1)
+ emb_dict['string_to_param'] = {k2: weight}
+ else:
+ emb_dict[vec_name] = weight
+ bundle_embeddings[emb_name] = emb_dict
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2)
sd_module = shared.sd_model.network_layer_mapping.get(key, None)
@@ -195,6 +210,14 @@ def load_network(name, network_on_disk): net.modules[key] = net_module
+ embeddings = {}
+ for emb_name, data in bundle_embeddings.items():
+ embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name)
+ embedding.loaded = None
+ embeddings[emb_name] = embedding
+
+ net.bundle_embeddings = embeddings
+
if keys_failed_to_match:
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}")
@@ -210,11 +233,15 @@ def purge_networks_from_memory(): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None):
+ emb_db = sd_hijack.model_hijack.embedding_db
already_loaded = {}
for net in loaded_networks:
if net.name in names:
already_loaded[net.name] = net
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded:
+ emb_db.register_embedding_by_name(None, shared.sd_model, emb_name)
loaded_networks.clear()
@@ -257,6 +284,21 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0
loaded_networks.append(net)
+ for emb_name, embedding in net.bundle_embeddings.items():
+ if embedding.loaded is None and emb_name in emb_db.word_embeddings:
+ logger.warning(
+ f'Skip bundle embedding: "{emb_name}"'
+ ' as it was already loaded from embeddings folder'
+ )
+ continue
+
+ embedding.loaded = False
+ if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape:
+ embedding.loaded = True
+ emb_db.register_embedding(embedding, shared.sd_model)
+ else:
+ emb_db.skipped_embeddings[name] = embedding
+
if failed_to_load_networks:
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks))
@@ -420,6 +462,7 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): self.network_weights_backup = None
self.network_bias_backup = None
+
def network_Linear_forward(self, input):
if shared.opts.lora_functional:
return network_forward(self, input, originals.Linear_forward)
@@ -564,6 +607,7 @@ extra_network_lora = None available_networks = {}
available_network_aliases = {}
loaded_networks = []
+loaded_bundle_embeddings = {}
networks_in_memory = {}
available_network_hash_lookup = {}
forbidden_network_aliases = {}
diff --git a/extensions-builtin/mobile/javascript/mobile.js b/extensions-builtin/mobile/javascript/mobile.js index 652f07ac..bff1aced 100644 --- a/extensions-builtin/mobile/javascript/mobile.js +++ b/extensions-builtin/mobile/javascript/mobile.js @@ -12,6 +12,8 @@ function isMobile() { } function reportWindowSize() { + if (gradioApp().querySelector('.toprow-compact-tools')) return; // not applicable for compact prompt layout + var currentlyMobile = isMobile(); if (currentlyMobile == isSetupForMobile) return; isSetupForMobile = currentlyMobile; |