From 7d4d871d4679b5b78ff67b501da5367413542984 Mon Sep 17 00:00:00 2001 From: dongwenpu Date: Sun, 10 Sep 2023 17:53:42 +0800 Subject: fix: lora-bias-backup don't reset cache --- extensions-builtin/Lora/networks.py | 1 + 1 file changed, 1 insertion(+) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 96f935b2..315682b3 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -418,6 +418,7 @@ def network_forward(module, input, original_forward): def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): self.network_current_names = () self.network_weights_backup = None + self.network_bias_backup = None def network_Linear_forward(self, input): -- cgit v1.2.3 From 2aa485b5afb13fd6aab79777e4dfc488591b2f1c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Mon, 9 Oct 2023 22:52:09 +0800 Subject: add lora bundle system --- extensions-builtin/Lora/network.py | 1 + extensions-builtin/Lora/networks.py | 48 +++++++++++++++++++++++++++++++++++++ 2 files changed, 49 insertions(+) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index d8e8dfb7..6021fd8d 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -93,6 +93,7 @@ class Network: # LoraModule self.unet_multiplier = 1.0 self.dyn_dim = None self.modules = {} + self.bundle_embeddings = {} self.mtime = None self.mentioned_name = None diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 315682b3..652b8ebe 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -15,6 +15,7 @@ import torch from typing import Union from modules import shared, devices, sd_models, errors, scripts, sd_hijack +from modules.textual_inversion.textual_inversion import Embedding module_types = [ network_lora.ModuleTypeLora(), @@ -149,9 +150,15 @@ def load_network(name, network_on_disk): is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping matched_networks = {} + bundle_embeddings = {} for key_network, weight in sd.items(): key_network_without_network_parts, network_part = key_network.split(".", 1) + if key_network_without_network_parts == "bundle_emb": + emb_name, vec_name = network_part.split(".", 1) + emb_dict = bundle_embeddings.get(emb_name, {}) + emb_dict[vec_name] = weight + bundle_embeddings[emb_name] = emb_dict key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -195,6 +202,8 @@ def load_network(name, network_on_disk): net.modules[key] = net_module + net.bundle_embeddings = bundle_embeddings + if keys_failed_to_match: logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") @@ -210,11 +219,14 @@ def purge_networks_from_memory(): def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): + emb_db = sd_hijack.model_hijack.embedding_db already_loaded = {} for net in loaded_networks: if net.name in names: already_loaded[net.name] = net + for emb_name in net.bundle_embeddings: + emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) loaded_networks.clear() @@ -257,6 +269,41 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) + for emb_name, data in net.bundle_embeddings.items(): + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + else: + raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.") + + embedding = Embedding(vec, emb_name) + embedding.vectors = vectors + embedding.shape = shape + + if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: + emb_db.register_embedding(embedding, shared.sd_model) + else: + emb_db.skipped_embeddings[name] = embedding + if failed_to_load_networks: sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) @@ -565,6 +612,7 @@ extra_network_lora = None available_networks = {} available_network_aliases = {} loaded_networks = [] +loaded_bundle_embeddings = {} networks_in_memory = {} available_network_hash_lookup = {} forbidden_network_aliases = {} -- cgit v1.2.3 From 3d8b1af6beb9015f6b3573661d8ed00275f6129f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:09:33 +0800 Subject: Support string_to_param nested dict format: bundle_emb.EMBNAME.string_to_param.KEYNAME --- extensions-builtin/Lora/networks.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 652b8ebe..ab3517d8 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -157,7 +157,11 @@ def load_network(name, network_on_disk): if key_network_without_network_parts == "bundle_emb": emb_name, vec_name = network_part.split(".", 1) emb_dict = bundle_embeddings.get(emb_name, {}) - emb_dict[vec_name] = weight + if vec_name.split('.')[0] == 'string_to_param': + _, k2 = vec_name.split('.', 1) + emb_dict['string_to_param'] = {k2: weight} + else: + emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) @@ -301,6 +305,7 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: emb_db.register_embedding(embedding, shared.sd_model) + print(f'registered bundle embedding: {embedding.name}') else: emb_db.skipped_embeddings[name] = embedding -- cgit v1.2.3 From 2282eb8dd5905e8ed71231a0b8fc77599d10c12f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 10 Oct 2023 12:11:00 +0800 Subject: Remove dev debug print --- extensions-builtin/Lora/networks.py | 1 - 1 file changed, 1 deletion(-) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index ab3517d8..465e24c8 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -305,7 +305,6 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: emb_db.register_embedding(embedding, shared.sd_model) - print(f'registered bundle embedding: {embedding.name}') else: emb_db.skipped_embeddings[name] = embedding -- cgit v1.2.3 From 81e94de3185d42dba4e7bb72cf836f683f28b03f Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Tue, 10 Oct 2023 14:44:20 +0800 Subject: Add warning when meet emb name conflicting Choose standalone embedding (in /embeddings folder) first --- extensions-builtin/Lora/lora_logger.py | 33 ++++++++++++++ extensions-builtin/Lora/networks.py | 80 ++++++++++++++++++++-------------- 2 files changed, 81 insertions(+), 32 deletions(-) create mode 100644 extensions-builtin/Lora/lora_logger.py (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/lora_logger.py b/extensions-builtin/Lora/lora_logger.py new file mode 100644 index 00000000..d50e90f0 --- /dev/null +++ b/extensions-builtin/Lora/lora_logger.py @@ -0,0 +1,33 @@ +import sys +import copy +import logging + + +class ColoredFormatter(logging.Formatter): + COLORS = { + "DEBUG": "\033[0;36m", # CYAN + "INFO": "\033[0;32m", # GREEN + "WARNING": "\033[0;33m", # YELLOW + "ERROR": "\033[0;31m", # RED + "CRITICAL": "\033[0;37;41m", # WHITE ON RED + "RESET": "\033[0m", # RESET COLOR + } + + def format(self, record): + colored_record = copy.copy(record) + levelname = colored_record.levelname + seq = self.COLORS.get(levelname, self.COLORS["RESET"]) + colored_record.levelname = f"{seq}{levelname}{self.COLORS['RESET']}" + return super().format(colored_record) + + +logger = logging.getLogger("lora") +logger.propagate = False + + +if not logger.handlers: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter( + ColoredFormatter("[%(name)s]-%(levelname)s: %(message)s") + ) + logger.addHandler(handler) \ No newline at end of file diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 465e24c8..12f70576 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -17,6 +17,8 @@ from typing import Union from modules import shared, devices, sd_models, errors, scripts, sd_hijack from modules.textual_inversion.textual_inversion import Embedding +from lora_logger import logger + module_types = [ network_lora.ModuleTypeLora(), network_hada.ModuleTypeHada(), @@ -206,7 +208,40 @@ def load_network(name, network_on_disk): net.modules[key] = net_module - net.bundle_embeddings = bundle_embeddings + embeddings = {} + for emb_name, data in bundle_embeddings.items(): + # textual inversion embeddings + if 'string_to_param' in data: + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + else: + raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.") + + embedding = Embedding(vec, emb_name) + embedding.vectors = vectors + embedding.shape = shape + embedding.loaded = None + embeddings[emb_name] = embedding + + net.bundle_embeddings = embeddings if keys_failed_to_match: logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") @@ -229,8 +264,9 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No for net in loaded_networks: if net.name in names: already_loaded[net.name] = net - for emb_name in net.bundle_embeddings: - emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) + for emb_name, embedding in net.bundle_embeddings.items(): + if embedding.loaded: + emb_db.register_embedding_by_name(None, shared.sd_model, emb_name) loaded_networks.clear() @@ -273,37 +309,17 @@ def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=No net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) - for emb_name, data in net.bundle_embeddings.items(): - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding - vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} - shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] - vectors = data['clip_g'].shape[0] - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - else: - raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.") - - embedding = Embedding(vec, emb_name) - embedding.vectors = vectors - embedding.shape = shape + for emb_name, embedding in net.bundle_embeddings.items(): + if embedding.loaded is None and emb_name in emb_db.word_embeddings: + logger.warning( + f'Skip bundle embedding: "{emb_name}"' + ' as it was already loaded from embeddings folder' + ) + continue + embedding.loaded = False if emb_db.expected_shape == -1 or emb_db.expected_shape == embedding.shape: + embedding.loaded = True emb_db.register_embedding(embedding, shared.sd_model) else: emb_db.skipped_embeddings[name] = embedding -- cgit v1.2.3 From 906d1179e9a333eeb0f12a95b592dd5b44eb0aaa Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 11 Oct 2023 21:26:58 -0700 Subject: support inference with LyCORIS GLora networks --- extensions-builtin/Lora/network_glora.py | 33 ++++++++++++++++++++++++++++++++ extensions-builtin/Lora/networks.py | 2 ++ 2 files changed, 35 insertions(+) create mode 100644 extensions-builtin/Lora/network_glora.py (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/network_glora.py b/extensions-builtin/Lora/network_glora.py new file mode 100644 index 00000000..492d4870 --- /dev/null +++ b/extensions-builtin/Lora/network_glora.py @@ -0,0 +1,33 @@ + +import network + +class ModuleTypeGLora(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["a1.weight", "a2.weight", "alpha", "b1.weight", "b2.weight"]): + return NetworkModuleGLora(net, weights) + + return None + +# adapted from https://github.com/KohakuBlueleaf/LyCORIS +class NetworkModuleGLora(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.w1a = weights.w["a1.weight"] + self.w1b = weights.w["b1.weight"] + self.w2a = weights.w["a2.weight"] + self.w2b = weights.w["b2.weight"] + + def calc_updown(self, orig_weight): + w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w1a.size(0), w1b.size(1)] + updown = ((w2b @ w1b) + ((orig_weight @ w2a) @ w1a)) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 315682b3..ddab3c55 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -5,6 +5,7 @@ import re import lora_patches import network import network_lora +import network_glora import network_hada import network_ia3 import network_lokr @@ -23,6 +24,7 @@ module_types = [ network_lokr.ModuleTypeLokr(), network_full.ModuleTypeFull(), network_norm.ModuleTypeNorm(), + network_glora.ModuleTypeGLora(), ] -- cgit v1.2.3 From a8cbe50c9fa324ed887089e4333452ecc4355c92 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 14 Oct 2023 12:14:56 +0300 Subject: remove duplicated code --- extensions-builtin/Lora/networks.py | 31 +---------- modules/textual_inversion/textual_inversion.py | 74 ++++++++++++++------------ 2 files changed, 42 insertions(+), 63 deletions(-) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 12f70576..d5f0f9f1 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -15,7 +15,7 @@ import torch from typing import Union from modules import shared, devices, sd_models, errors, scripts, sd_hijack -from modules.textual_inversion.textual_inversion import Embedding +import modules.textual_inversion.textual_inversion as textual_inversion from lora_logger import logger @@ -210,34 +210,7 @@ def load_network(name, network_on_disk): embeddings = {} for emb_name, data in bundle_embeddings.items(): - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding - vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} - shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] - vectors = data['clip_g'].shape[0] - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - else: - raise Exception(f"Couldn't identify {emb_name} in lora: {name} as neither textual inversion embedding nor diffuser concept.") - - embedding = Embedding(vec, emb_name) - embedding.vectors = vectors - embedding.shape = shape + embedding = textual_inversion.create_embedding_from_data(data, emb_name, filename=network_on_disk.filename + "/" + emb_name) embedding.loaded = None embeddings[emb_name] = embedding diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 401a0a2a..04dda585 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -181,40 +181,7 @@ class EmbeddingDatabase: else: return - - # textual inversion embeddings - if 'string_to_param' in data: - param_dict = data['string_to_param'] - param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 - assert len(param_dict) == 1, 'embedding file has multiple terms in it' - emb = next(iter(param_dict.items()))[1] - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding - vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} - shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] - vectors = data['clip_g'].shape[0] - elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts - assert len(data.keys()) == 1, 'embedding file has multiple terms in it' - - emb = next(iter(data.values())) - if len(emb.shape) == 1: - emb = emb.unsqueeze(0) - vec = emb.detach().to(devices.device, dtype=torch.float32) - shape = vec.shape[-1] - vectors = vec.shape[0] - else: - raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") - - embedding = Embedding(vec, name) - embedding.step = data.get('step', None) - embedding.sd_checkpoint = data.get('sd_checkpoint', None) - embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - embedding.vectors = vectors - embedding.shape = shape - embedding.filename = path - embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') + embedding = create_embedding_from_data(data, name, filename=filename, filepath=path) if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) @@ -313,6 +280,45 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'): return fn +def create_embedding_from_data(data, name, filename='unknown embedding file', filepath=None): + if 'string_to_param' in data: # textual inversion embeddings + param_dict = data['string_to_param'] + param_dict = getattr(param_dict, '_parameters', param_dict) # fix for torch 1.12.1 loading saved file from torch 1.11 + assert len(param_dict) == 1, 'embedding file has multiple terms in it' + emb = next(iter(param_dict.items()))[1] + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + elif type(data) == dict and 'clip_g' in data and 'clip_l' in data: # SDXL embedding + vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()} + shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1] + vectors = data['clip_g'].shape[0] + elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts + assert len(data.keys()) == 1, 'embedding file has multiple terms in it' + + emb = next(iter(data.values())) + if len(emb.shape) == 1: + emb = emb.unsqueeze(0) + vec = emb.detach().to(devices.device, dtype=torch.float32) + shape = vec.shape[-1] + vectors = vec.shape[0] + else: + raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.") + + embedding = Embedding(vec, name) + embedding.step = data.get('step', None) + embedding.sd_checkpoint = data.get('sd_checkpoint', None) + embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) + embedding.vectors = vectors + embedding.shape = shape + + if filepath: + embedding.filename = filepath + embedding.set_hash(hashes.sha256(filepath, "textual_inversion/" + name) or '') + + return embedding + + def write_loss(log_directory, filename, step, epoch_len, values): if shared.opts.training_write_csv_every == 0: return -- cgit v1.2.3 From ec718f76b58b183859ed732e11ec748c41a13f76 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Tue, 17 Oct 2023 23:35:50 -0700 Subject: wip incorrect OFT implementation --- extensions-builtin/Lora/network_oft.py | 82 ++++++++++++++++++++++++++++++++++ extensions-builtin/Lora/networks.py | 5 +++ 2 files changed, 87 insertions(+) create mode 100644 extensions-builtin/Lora/network_oft.py (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py new file mode 100644 index 00000000..9ddb175c --- /dev/null +++ b/extensions-builtin/Lora/network_oft.py @@ -0,0 +1,82 @@ +import torch +import network + + +class ModuleTypeOFT(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["oft_blocks"]): + return NetworkModuleOFT(net, weights) + + return None + +# adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py +class NetworkModuleOFT(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.oft_blocks = weights.w["oft_blocks"] + self.alpha = weights.w["alpha"] + + self.dim = self.oft_blocks.shape[0] + self.num_blocks = self.dim + + #if type(self.alpha) == torch.Tensor: + # self.alpha = self.alpha.detach().numpy() + + if "Linear" in self.sd_module.__class__.__name__: + self.out_dim = self.sd_module.out_features + elif "Conv" in self.sd_module.__class__.__name__: + self.out_dim = self.sd_module.out_channels + + self.constraint = self.alpha * self.out_dim + self.block_size = self.out_dim // self.num_blocks + + self.oft_multiplier = self.multiplier() + + # replace forward method of original linear rather than replacing the module + # self.org_forward = self.sd_module.forward + # self.sd_module.forward = self.forward + + def get_weight(self): + block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + + block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + R = torch.block_diag(*block_R_weighted) + + return R + + def calc_updown(self, orig_weight): + oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) + block_Q = oft_blocks - oft_blocks.transpose(1, 2) + norm_Q = torch.norm(block_Q.flatten()) + new_norm_Q = torch.clamp(norm_Q, max=self.constraint) + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) + I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) + + block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + R = torch.block_diag(*block_R_weighted) + #R = self.get_weight().to(orig_weight.device, dtype=orig_weight.dtype) + # W = R*W_0 + updown = orig_weight + R + output_shape = [R.size(0), orig_weight.size(1)] + return self.finalize_updown(updown, orig_weight, output_shape) + + # def forward(self, x, y=None): + # x = self.org_forward(x) + # if self.oft_multiplier == 0.0: + # return x + + # R = self.get_weight().to(x.device, dtype=x.dtype) + # if x.dim() == 4: + # x = x.permute(0, 2, 3, 1) + # x = torch.matmul(x, R) + # x = x.permute(0, 3, 1, 2) + # else: + # x = torch.matmul(x, R) + # return x diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 60d8dec4..bd1f1b75 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -11,6 +11,7 @@ import network_ia3 import network_lokr import network_full import network_norm +import network_oft import torch from typing import Union @@ -28,6 +29,7 @@ module_types = [ network_full.ModuleTypeFull(), network_norm.ModuleTypeNorm(), network_glora.ModuleTypeGLora(), + network_oft.ModuleTypeOFT(), ] @@ -183,6 +185,9 @@ def load_network(name, network_on_disk): elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "oft_unet" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) # some SD1 Loras also have correct compvis keys if sd_module is None: -- cgit v1.2.3 From 1c6efdbba774d603c592debaccd6f5ad827bd1b2 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:16:01 -0700 Subject: inference working but SLOW --- extensions-builtin/Lora/network_oft.py | 73 +++++++++++++++++----------------- extensions-builtin/Lora/networks.py | 42 +++++++++++++++++-- 2 files changed, 75 insertions(+), 40 deletions(-) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 9ddb175c..f085eca5 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -12,6 +12,7 @@ class ModuleTypeOFT(network.ModuleType): # adapted from https://github.com/kohya-ss/sd-scripts/blob/main/networks/oft.py class NetworkModuleOFT(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) self.oft_blocks = weights.w["oft_blocks"] @@ -20,24 +21,29 @@ class NetworkModuleOFT(network.NetworkModule): self.dim = self.oft_blocks.shape[0] self.num_blocks = self.dim - #if type(self.alpha) == torch.Tensor: - # self.alpha = self.alpha.detach().numpy() - if "Linear" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_features elif "Conv" in self.sd_module.__class__.__name__: self.out_dim = self.sd_module.out_channels - self.constraint = self.alpha * self.out_dim + self.constraint = self.alpha + #self.constraint = self.alpha * self.out_dim self.block_size = self.out_dim // self.num_blocks - self.oft_multiplier = self.multiplier() + self.org_module: list[torch.Module] = [self.sd_module] + + self.R = self.get_weight() - # replace forward method of original linear rather than replacing the module - # self.org_forward = self.sd_module.forward - # self.sd_module.forward = self.forward + self.apply_to() + + # replace forward method of original linear rather than replacing the module + def apply_to(self): + self.org_forward = self.org_module[0].forward + self.org_module[0].forward = self.forward - def get_weight(self): + def get_weight(self, multiplier=None): + if not multiplier: + multiplier = self.multiplier() block_Q = self.oft_blocks - self.oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) new_norm_Q = torch.clamp(norm_Q, max=self.constraint) @@ -45,38 +51,31 @@ class NetworkModuleOFT(network.NetworkModule): I = torch.eye(self.block_size, device=self.oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I + block_R_weighted = multiplier * block_R + (1 - multiplier) * I R = torch.block_diag(*block_R_weighted) return R def calc_updown(self, orig_weight): - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - block_Q = oft_blocks - oft_blocks.transpose(1, 2) - norm_Q = torch.norm(block_Q.flatten()) - new_norm_Q = torch.clamp(norm_Q, max=self.constraint) - block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) - I = torch.eye(self.block_size, device=oft_blocks.device).unsqueeze(0).repeat(self.num_blocks, 1, 1) - block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) - - block_R_weighted = self.oft_multiplier * block_R + (1 - self.oft_multiplier) * I - R = torch.block_diag(*block_R_weighted) - #R = self.get_weight().to(orig_weight.device, dtype=orig_weight.dtype) - # W = R*W_0 - updown = orig_weight + R - output_shape = [R.size(0), orig_weight.size(1)] + R = self.R + if orig_weight.dim() == 4: + weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + else: + weight = torch.einsum("oi, op -> pi", orig_weight, R) + updown = orig_weight @ R + output_shape = [orig_weight.size(0), R.size(1)] + #output_shape = [R.size(0), orig_weight.size(1)] return self.finalize_updown(updown, orig_weight, output_shape) - # def forward(self, x, y=None): - # x = self.org_forward(x) - # if self.oft_multiplier == 0.0: - # return x - - # R = self.get_weight().to(x.device, dtype=x.dtype) - # if x.dim() == 4: - # x = x.permute(0, 2, 3, 1) - # x = torch.matmul(x, R) - # x = x.permute(0, 3, 1, 2) - # else: - # x = torch.matmul(x, R) - # return x + def forward(self, x, y=None): + x = self.org_forward(x) + if self.multiplier() == 0.0: + return x + R = self.get_weight().to(x.device, dtype=x.dtype) + if x.dim() == 4: + x = x.permute(0, 2, 3, 1) + x = torch.matmul(x, R) + x = x.permute(0, 3, 1, 2) + else: + x = torch.matmul(x, R) + return x diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index bd1f1b75..e5e73450 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -169,6 +169,10 @@ def load_network(name, network_on_disk): else: emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict + + #if key_network_without_network_parts == "oft_unet": + # print(key_network_without_network_parts) + # pass key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -185,15 +189,39 @@ def load_network(name, network_on_disk): elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) - elif sd_module is None and "oft_unet" in key_network_without_network_parts: - key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") - sd_module = shared.sd_model.network_layer_mapping.get(key, None) # some SD1 Loras also have correct compvis keys if sd_module is None: key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "oft_unet" in key_network_without_network_parts: + # UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] + # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] + UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] + # TODO: Change matchedm odules based on whether all linear, conv, etc + + key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + #key_no_suffix = key.rsplit("_to_", 1)[0] + ## Match all modules of class CrossAttention + #replace_module_list = [] + #for module_type in UNET_TARGET_REPLACE_MODULE_ATTN_ONLY: + # replace_module_list += [module for k, module in shared.sd_model.network_layer_mapping.items() if module_type in module.__class__.__name__] + + #matched_module = replace_module_list.get(key_no_suffix, None) + #if key.endswith('to_q'): + # sd_module = matched_module.to_q or None + #if key.endswith('to_k'): + # sd_module = matched_module.to_k or None + #if key.endswith('to_v'): + # sd_module = matched_module.to_v or None + #if key.endswith('to_out_0'): + # sd_module = matched_module.to_out[0] or None + #if key.endswith('to_out_1'): + # sd_module = matched_module.to_out[1] or None + + if sd_module is None: keys_failed_to_match[key_network] = key continue @@ -214,6 +242,14 @@ def load_network(name, network_on_disk): raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") net.modules[key] = net_module + + # replaces forward method of original Linear + # applied_to_count = 0 + #for key, created_module in net.modules.items(): + # if isinstance(created_module, network_oft.NetworkModuleOFT): + # net_module.apply_to() + #applied_to_count += 1 + # print(f'Applied OFT modules: {applied_to_count}') embeddings = {} for emb_name, data in bundle_embeddings.items(): -- cgit v1.2.3 From d10c4db57ed08234a7aed5f530f269ff78544ab0 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 19 Oct 2023 12:52:14 -0700 Subject: style: formatting --- extensions-builtin/Lora/network_oft.py | 4 ++-- extensions-builtin/Lora/networks.py | 35 ---------------------------------- 2 files changed, 2 insertions(+), 37 deletions(-) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 2af1bc4c..0a87958e 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -37,7 +37,7 @@ class NetworkModuleOFT(network.NetworkModule): def apply_to(self): self.org_forward = self.org_module[0].forward self.org_module[0].forward = self.forward - + def get_weight(self, oft_blocks, multiplier=None): block_Q = oft_blocks - oft_blocks.transpose(1, 2) norm_Q = torch.norm(block_Q.flatten()) @@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule): output_shape = self.oft_blocks.shape return self.finalize_updown(updown, orig_weight, output_shape) - + def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index e5e73450..78a97033 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -169,10 +169,6 @@ def load_network(name, network_on_disk): else: emb_dict[vec_name] = weight bundle_embeddings[emb_name] = emb_dict - - #if key_network_without_network_parts == "oft_unet": - # print(key_network_without_network_parts) - # pass key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) sd_module = shared.sd_model.network_layer_mapping.get(key, None) @@ -196,31 +192,8 @@ def load_network(name, network_on_disk): sd_module = shared.sd_model.network_layer_mapping.get(key, None) elif sd_module is None and "oft_unet" in key_network_without_network_parts: - # UNET_TARGET_REPLACE_MODULE_ALL_LINEAR = ["Transformer2DModel"] - # UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 = ["ResnetBlock2D", "Downsample2D", "Upsample2D"] - UNET_TARGET_REPLACE_MODULE_ATTN_ONLY = ["CrossAttention"] - # TODO: Change matchedm odules based on whether all linear, conv, etc - key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) - #key_no_suffix = key.rsplit("_to_", 1)[0] - ## Match all modules of class CrossAttention - #replace_module_list = [] - #for module_type in UNET_TARGET_REPLACE_MODULE_ATTN_ONLY: - # replace_module_list += [module for k, module in shared.sd_model.network_layer_mapping.items() if module_type in module.__class__.__name__] - - #matched_module = replace_module_list.get(key_no_suffix, None) - #if key.endswith('to_q'): - # sd_module = matched_module.to_q or None - #if key.endswith('to_k'): - # sd_module = matched_module.to_k or None - #if key.endswith('to_v'): - # sd_module = matched_module.to_v or None - #if key.endswith('to_out_0'): - # sd_module = matched_module.to_out[0] or None - #if key.endswith('to_out_1'): - # sd_module = matched_module.to_out[1] or None - if sd_module is None: keys_failed_to_match[key_network] = key @@ -242,14 +215,6 @@ def load_network(name, network_on_disk): raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") net.modules[key] = net_module - - # replaces forward method of original Linear - # applied_to_count = 0 - #for key, created_module in net.modules.items(): - # if isinstance(created_module, network_oft.NetworkModuleOFT): - # net_module.apply_to() - #applied_to_count += 1 - # print(f'Applied OFT modules: {applied_to_count}') embeddings = {} for emb_name, data in bundle_embeddings.items(): -- cgit v1.2.3 From 65ccd6305fcf72347d5ed68f03095dced865ef6e Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Thu, 2 Nov 2023 00:11:32 -0700 Subject: detect diag_oft type --- extensions-builtin/Lora/networks.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 78a97033..7f814706 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -191,10 +191,17 @@ def load_network(name, network_on_disk): key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # kohya_ss OFT module elif sd_module is None and "oft_unet" in key_network_without_network_parts: key = key_network_without_network_parts.replace("oft_unet", "diffusion_model") sd_module = shared.sd_model.network_layer_mapping.get(key, None) + # KohakuBlueLeaf OFT module + if sd_module is None and "oft_diag" in key: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + if sd_module is None: keys_failed_to_match[key_network] = key continue -- cgit v1.2.3