From b75b004fe62826455f1aa77e849e7da13902cb17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 23:13:55 +0300 Subject: lora extension rework to include other types of networks --- extensions-builtin/Lora/networks.py | 443 ++++++++++++++++++++++++++++++++++++ 1 file changed, 443 insertions(+) create mode 100644 extensions-builtin/Lora/networks.py (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py new file mode 100644 index 00000000..5b0ddfb6 --- /dev/null +++ b/extensions-builtin/Lora/networks.py @@ -0,0 +1,443 @@ +import os +import re + +import network +import network_lora +import network_hada + +import torch +from typing import Union + +from modules import shared, devices, sd_models, errors, scripts, sd_hijack + +module_types = [ + network_lora.ModuleTypeLora(), + network_hada.ModuleTypeHada(), +] + + +re_digits = re.compile(r"\d+") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} + + +def convert_diffusers_name_to_compvis(key, is_sd2): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" + + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): + if 'mlp_fc1' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return key + + +def assign_network_names_to_compvis_modules(sd_model): + network_layer_mapping = {} + + if shared.sd_model.is_sdxl: + for i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for name, module in embedder.wrapped.named_modules(): + network_name = f'{i}_{name.replace(".", "_")}' + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + else: + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + + for name, module in shared.sd_model.model.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + + sd_model.network_layer_mapping = network_layer_mapping + + +def load_network(name, network_on_disk): + net = network.Network(name, network_on_disk) + net.mtime = os.path.getmtime(network_on_disk.filename) + + sd = sd_models.read_state_dict(network_on_disk.filename) + + # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 + if not hasattr(shared.sd_model, 'network_layer_mapping'): + assign_network_names_to_compvis_modules(shared.sd_model) + + keys_failed_to_match = {} + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping + + matched_networks = {} + + for key_network, weight in sd.items(): + key_network_without_network_parts, network_part = key_network.split(".", 1) + + key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + if sd_module is None: + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) + + # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" + if sd_module is None and "lora_unet" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + if sd_module is None: + keys_failed_to_match[key_network] = key + continue + + if key not in matched_networks: + matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) + + matched_networks[key].w[network_part] = weight + + for key, weights in matched_networks.items(): + net_module = None + for nettype in module_types: + net_module = nettype.create_module(net, weights) + if net_module is not None: + break + + if net_module is None: + raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") + + net.modules[key] = net_module + + if keys_failed_to_match: + print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") + + return net + + +def load_networks(names, multipliers=None): + already_loaded = {} + + for net in loaded_networks: + if net.name in names: + already_loaded[net.name] = net + + loaded_networks.clear() + + networks_on_disk = [available_network_aliases.get(name, None) for name in names] + if any(x is None for x in networks_on_disk): + list_available_networks() + + networks_on_disk = [available_network_aliases.get(name, None) for name in names] + + failed_to_load_networks = [] + + for i, name in enumerate(names): + net = already_loaded.get(name, None) + + network_on_disk = networks_on_disk[i] + + if network_on_disk is not None: + if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: + try: + net = load_network(name, network_on_disk) + except Exception as e: + errors.display(e, f"loading network {network_on_disk.filename}") + continue + + net.mentioned_name = name + + network_on_disk.read_hash() + + if net is None: + failed_to_load_networks.append(name) + print(f"Couldn't find network with name {name}") + continue + + net.multiplier = multipliers[i] if multipliers else 1.0 + loaded_networks.append(net) + + if failed_to_load_networks: + sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) + + +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + weights_backup = getattr(self, "network_weights_backup", None) + + if weights_backup is None: + return + + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + """ + Applies the currently selected set of networks to the weights of torch layer self. + If weights already have this particular set of networks applied, does nothing. + If not, restores orginal weights from backup and alters weights according to networks. + """ + + network_layer_name = getattr(self, 'network_layer_name', None) + if network_layer_name is None: + return + + current_names = getattr(self, "network_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks) + + weights_backup = getattr(self, "network_weights_backup", None) + if weights_backup is None: + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + + self.network_weights_backup = weights_backup + + if current_names != wanted_names: + network_restore_weights_from_backup(self) + + for net in loaded_networks: + module = net.modules.get(network_layer_name, None) + if module is not None and hasattr(self, 'weight'): + with torch.no_grad(): + updown = module.calc_updown(self.weight) + + if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + # inpainting model. zero pad updown to make channel[1] 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) + + self.weight += updown + + module_q = net.modules.get(network_layer_name + "_q_proj", None) + module_k = net.modules.get(network_layer_name + "_k_proj", None) + module_v = net.modules.get(network_layer_name + "_v_proj", None) + module_out = net.modules.get(network_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + with torch.no_grad(): + updown_q = module_q.calc_updown(self.in_proj_weight) + updown_k = module_k.calc_updown(self.in_proj_weight) + updown_v = module_v.calc_updown(self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += module_out.calc_updown(self.out_proj.weight) + continue + + if module is None: + continue + + print(f'failed to calculate network weights for layer {network_layer_name}') + + self.network_current_names = wanted_names + + +def network_forward(module, input, original_forward): + """ + Old way of applying Lora by executing operations during layer's forward. + Stacking many loras this way results in big performance degradation. + """ + + if len(loaded_networks) == 0: + return original_forward(module, input) + + input = devices.cond_cast_unet(input) + + network_restore_weights_from_backup(module) + network_reset_cached_weight(module) + + y = original_forward(module, input) + + network_layer_name = getattr(module, 'network_layer_name', None) + for lora in loaded_networks: + module = lora.modules.get(network_layer_name, None) + if module is None: + continue + + y = module.forward(y, input) + + return y + + +def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): + self.network_current_names = () + self.network_weights_backup = None + + +def network_Linear_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.Linear_forward_before_network) + + network_apply_weights(self) + + return torch.nn.Linear_forward_before_network(self, input) + + +def network_Linear_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs) + + +def network_Conv2d_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.Conv2d_forward_before_network) + + network_apply_weights(self) + + return torch.nn.Conv2d_forward_before_network(self, input) + + +def network_Conv2d_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) + + +def network_MultiheadAttention_forward(self, *args, **kwargs): + network_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs) + + +def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) + + +def list_available_networks(): + available_networks.clear() + available_network_aliases.clear() + forbidden_network_aliases.clear() + available_network_hash_lookup.clear() + forbidden_network_aliases.update({"none": 1, "Addams": 1}) + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) + for filename in candidates: + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + try: + entry = network.NetworkOnDisk(name, filename) + except OSError: # should catch FileNotFoundError and PermissionError etc. + errors.report(f"Failed to load network {name} from {filename}", exc_info=True) + continue + + available_networks[name] = entry + + if entry.alias in available_network_aliases: + forbidden_network_aliases[entry.alias.lower()] = 1 + + available_network_aliases[name] = entry + available_network_aliases[entry.alias] = entry + + +re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") + + +def infotext_pasted(infotext, params): + if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: + return # if the other extension is active, it will handle those fields, no need to do anything + + added = [] + + for k in params: + if not k.startswith("AddNet Model "): + continue + + num = k[13:] + + if params.get("AddNet Module " + num) != "LoRA": + continue + + name = params.get("AddNet Model " + num) + if name is None: + continue + + m = re_network_name.match(name) + if m: + name = m.group(1) + + multiplier = params.get("AddNet Weight A " + num, "1.0") + + added.append(f"") + + if added: + params["Prompt"] += "\n" + "".join(added) + + +available_networks = {} +available_network_aliases = {} +loaded_networks = [] +available_network_hash_lookup = {} +forbidden_network_aliases = {} + +list_available_networks() -- cgit v1.2.3 From ef5dac7786916dd39711edb2b8e90ce96ef78fca Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 00:01:17 +0300 Subject: fix --- extensions-builtin/Lora/network_hada.py | 3 --- extensions-builtin/Lora/networks.py | 1 + 2 files changed, 1 insertion(+), 3 deletions(-) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 15e7ffd8..799bb3bc 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -27,9 +27,6 @@ class NetworkModuleHada(network_lyco.NetworkModuleLyco): self.t1 = weights.w.get("hada_t1") self.t2 = weights.w.get("hada_t2") - self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None - self.scale = weights.w["scale"].item() if "scale" in weights.w else None - def calc_updown(self, orig_weight): w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 5b0ddfb6..90374faa 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -271,6 +271,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) self.weight += updown + continue module_q = net.modules.get(network_layer_name + "_q_proj", None) module_k = net.modules.get(network_layer_name + "_k_proj", None) -- cgit v1.2.3 From 58c3df32f3a73b20ea33d1709a1d25818b8a98dd Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 00:12:18 +0300 Subject: IA3 support --- extensions-builtin/Lora/network_ia3.py | 32 ++++++++++++++++++++++++++++++++ extensions-builtin/Lora/networks.py | 2 ++ 2 files changed, 34 insertions(+) create mode 100644 extensions-builtin/Lora/network_ia3.py (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py new file mode 100644 index 00000000..99f2307c --- /dev/null +++ b/extensions-builtin/Lora/network_ia3.py @@ -0,0 +1,32 @@ +import lyco_helpers +import network +import network_lyco + + +class ModuleTypeIa3(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["weight"]): + return NetworkModuleIa3(net, weights) + + return None + + +class NetworkModuleIa3(network_lyco.NetworkModuleLyco): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w = weights.w["weight"] + self.on_input = weights.w["on_input"].item() + + def calc_updown(self, orig_weight): + w = self.w.to(orig_weight.device, dtype=orig_weight.dtype) + + output_shape = [w.size(0), orig_weight.size(1)] + if self.on_input: + output_shape.reverse() + else: + w = w.reshape(-1, 1) + + updown = orig_weight * w + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 90374faa..bf810b5b 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -4,6 +4,7 @@ import re import network import network_lora import network_hada +import network_ia3 import torch from typing import Union @@ -13,6 +14,7 @@ from modules import shared, devices, sd_models, errors, scripts, sd_hijack module_types = [ network_lora.ModuleTypeLora(), network_hada.ModuleTypeHada(), + network_ia3.ModuleTypeIa3(), ] -- cgit v1.2.3 From 46466f09d0b0c14118033dee6af0f876059776d3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 00:29:07 +0300 Subject: Lokr support --- extensions-builtin/Lora/network_ia3.py | 1 - extensions-builtin/Lora/network_lokr.py | 65 +++++++++++++++++++++++++++++++++ extensions-builtin/Lora/networks.py | 2 + 3 files changed, 67 insertions(+), 1 deletion(-) create mode 100644 extensions-builtin/Lora/network_lokr.py (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index 99f2307c..d8806da0 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -1,4 +1,3 @@ -import lyco_helpers import network import network_lyco diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py new file mode 100644 index 00000000..f1731924 --- /dev/null +++ b/extensions-builtin/Lora/network_lokr.py @@ -0,0 +1,65 @@ +import torch + +import lyco_helpers +import network +import network_lyco + + +class ModuleTypeLokr(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + has_1 = "lokr_w1" in weights.w or ("lokr_w1a" in weights.w and "lokr_w1b" in weights.w) + has_2 = "lokr_w2" in weights.w or ("lokr_w2a" in weights.w and "lokr_w2b" in weights.w) + if has_1 and has_2: + return NetworkModuleLokr(net, weights) + + return None + + +def make_kron(orig_shape, w1, w2): + if len(w2.shape) == 4: + w1 = w1.unsqueeze(2).unsqueeze(2) + w2 = w2.contiguous() + return torch.kron(w1, w2).reshape(orig_shape) + + +class NetworkModuleLokr(network_lyco.NetworkModuleLyco): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.w1 = weights.w.get("lokr_w1") + self.w1a = weights.w.get("lokr_w1_a") + self.w1b = weights.w.get("lokr_w1_b") + self.dim = self.w1b.shape[0] if self.w1b else self.dim + self.w2 = weights.w.get("lokr_w2") + self.w2a = weights.w.get("lokr_w2_a") + self.w2b = weights.w.get("lokr_w2_b") + self.dim = self.w2b.shape[0] if self.w2b else self.dim + self.t2 = weights.w.get("lokr_t2") + + def calc_updown(self, orig_weight): + if self.w1 is not None: + w1 = self.w1.to(orig_weight.device, dtype=orig_weight.dtype) + else: + w1a = self.w1a.to(orig_weight.device, dtype=orig_weight.dtype) + w1b = self.w1b.to(orig_weight.device, dtype=orig_weight.dtype) + w1 = w1a @ w1b + + if self.w2 is not None: + w2 = self.w2.to(orig_weight.device, dtype=orig_weight.dtype) + elif self.t2 is None: + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = w2a @ w2b + else: + t2 = self.t2.to(orig_weight.device, dtype=orig_weight.dtype) + w2a = self.w2a.to(orig_weight.device, dtype=orig_weight.dtype) + w2b = self.w2b.to(orig_weight.device, dtype=orig_weight.dtype) + w2 = lyco_helpers.make_weight_cp(t2, w2a, w2b) + + output_shape = [w1.size(0) * w2.size(0), w1.size(1) * w2.size(1)] + if len(orig_weight.shape) == 4: + output_shape = orig_weight.shape + + updown = make_kron(output_shape, w1, w2) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index bf810b5b..1b358561 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -5,6 +5,7 @@ import network import network_lora import network_hada import network_ia3 +import network_lokr import torch from typing import Union @@ -15,6 +16,7 @@ module_types = [ network_lora.ModuleTypeLora(), network_hada.ModuleTypeHada(), network_ia3.ModuleTypeIa3(), + network_lokr.ModuleTypeLokr(), ] -- cgit v1.2.3 From 238adeaffb037dedbcefe41e7fd4814a1f17baa2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 09:00:47 +0300 Subject: support specifying te and unet weights separately update lora code support full module --- extensions-builtin/Lora/extra_networks_lora.py | 22 ++++++-- extensions-builtin/Lora/lyco_helpers.py | 6 +++ extensions-builtin/Lora/network.py | 40 +++++++++++++- extensions-builtin/Lora/network_full.py | 23 ++++++++ extensions-builtin/Lora/network_hada.py | 3 +- extensions-builtin/Lora/network_ia3.py | 3 +- extensions-builtin/Lora/network_lokr.py | 3 +- extensions-builtin/Lora/network_lora.py | 72 ++++++++++++++++---------- extensions-builtin/Lora/network_lyco.py | 35 ------------- extensions-builtin/Lora/networks.py | 22 ++++++-- 10 files changed, 151 insertions(+), 78 deletions(-) create mode 100644 extensions-builtin/Lora/network_full.py delete mode 100644 extensions-builtin/Lora/network_lyco.py (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 8a6639cf..084c41d0 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -14,14 +14,28 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier])) names = [] - multipliers = [] + te_multipliers = [] + unet_multipliers = [] + dyn_dims = [] for params in params_list: assert params.items - names.append(params.items[0]) - multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + names.append(params.positional[0]) - networks.load_networks(names, multipliers) + te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0 + te_multiplier = float(params.named.get("te", te_multiplier)) + + unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else 1.0 + unet_multiplier = float(params.named.get("unet", unet_multiplier)) + + dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None + dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim + + te_multipliers.append(te_multiplier) + unet_multipliers.append(unet_multiplier) + dyn_dims.append(dyn_dim) + + networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims) if shared.opts.lora_add_hashes_to_infotext: network_hashes = [] diff --git a/extensions-builtin/Lora/lyco_helpers.py b/extensions-builtin/Lora/lyco_helpers.py index 9ea499fb..279b34bc 100644 --- a/extensions-builtin/Lora/lyco_helpers.py +++ b/extensions-builtin/Lora/lyco_helpers.py @@ -13,3 +13,9 @@ def rebuild_conventional(up, down, shape, dyn_dim=None): up = up[:, :dyn_dim] down = down[:dyn_dim, :] return (up @ down).reshape(shape) + + +def rebuild_cp_decomposition(up, down, mid): + up = up.reshape(up.size(0), -1) + down = down.reshape(down.size(0), -1) + return torch.einsum('n m k l, i n, m j -> i j k l', mid, up, down) diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 4ac63722..fe42dbdd 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -68,7 +68,9 @@ class Network: # LoraModule def __init__(self, name, network_on_disk: NetworkOnDisk): self.name = name self.network_on_disk = network_on_disk - self.multiplier = 1.0 + self.te_multiplier = 1.0 + self.unet_multiplier = 1.0 + self.dyn_dim = None self.modules = {} self.mtime = None @@ -88,6 +90,42 @@ class NetworkModule: self.sd_key = weights.sd_key self.sd_module = weights.sd_module + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.dim = None + self.bias = weights.w.get("bias") + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + + def multiplier(self): + if 'transformer' in self.sd_key[:20]: + return self.network.te_multiplier + else: + return self.network.unet_multiplier + + def calc_scale(self): + if self.scale is not None: + return self.scale + if self.dim is not None and self.alpha is not None: + return self.alpha / self.dim + + return 1.0 + + def finalize_updown(self, updown, orig_weight, output_shape): + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + return updown * self.calc_scale() * self.multiplier() + def calc_updown(self, target): raise NotImplementedError() diff --git a/extensions-builtin/Lora/network_full.py b/extensions-builtin/Lora/network_full.py new file mode 100644 index 00000000..f0d8a6e0 --- /dev/null +++ b/extensions-builtin/Lora/network_full.py @@ -0,0 +1,23 @@ +import lyco_helpers +import network + + +class ModuleTypeFull(network.ModuleType): + def create_module(self, net: network.Network, weights: network.NetworkWeights): + if all(x in weights.w for x in ["diff"]): + return NetworkModuleFull(net, weights) + + return None + + +class NetworkModuleFull(network.NetworkModule): + def __init__(self, net: network.Network, weights: network.NetworkWeights): + super().__init__(net, weights) + + self.weight = weights.w.get("diff") + + def calc_updown(self, orig_weight): + output_shape = self.weight.shape + updown = self.weight.to(orig_weight.device, dtype=orig_weight.dtype) + + return self.finalize_updown(updown, orig_weight, output_shape) diff --git a/extensions-builtin/Lora/network_hada.py b/extensions-builtin/Lora/network_hada.py index 799bb3bc..5fcb0695 100644 --- a/extensions-builtin/Lora/network_hada.py +++ b/extensions-builtin/Lora/network_hada.py @@ -1,6 +1,5 @@ import lyco_helpers import network -import network_lyco class ModuleTypeHada(network.ModuleType): @@ -11,7 +10,7 @@ class ModuleTypeHada(network.ModuleType): return None -class NetworkModuleHada(network_lyco.NetworkModuleLyco): +class NetworkModuleHada(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) diff --git a/extensions-builtin/Lora/network_ia3.py b/extensions-builtin/Lora/network_ia3.py index d8806da0..7edc4249 100644 --- a/extensions-builtin/Lora/network_ia3.py +++ b/extensions-builtin/Lora/network_ia3.py @@ -1,5 +1,4 @@ import network -import network_lyco class ModuleTypeIa3(network.ModuleType): @@ -10,7 +9,7 @@ class ModuleTypeIa3(network.ModuleType): return None -class NetworkModuleIa3(network_lyco.NetworkModuleLyco): +class NetworkModuleIa3(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) diff --git a/extensions-builtin/Lora/network_lokr.py b/extensions-builtin/Lora/network_lokr.py index f1731924..920062e2 100644 --- a/extensions-builtin/Lora/network_lokr.py +++ b/extensions-builtin/Lora/network_lokr.py @@ -2,7 +2,6 @@ import torch import lyco_helpers import network -import network_lyco class ModuleTypeLokr(network.ModuleType): @@ -22,7 +21,7 @@ def make_kron(orig_shape, w1, w2): return torch.kron(w1, w2).reshape(orig_shape) -class NetworkModuleLokr(network_lyco.NetworkModuleLyco): +class NetworkModuleLokr(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) diff --git a/extensions-builtin/Lora/network_lora.py b/extensions-builtin/Lora/network_lora.py index b2d96537..26c0a72c 100644 --- a/extensions-builtin/Lora/network_lora.py +++ b/extensions-builtin/Lora/network_lora.py @@ -1,5 +1,6 @@ import torch +import lyco_helpers import network from modules import devices @@ -16,29 +17,42 @@ class NetworkModuleLora(network.NetworkModule): def __init__(self, net: network.Network, weights: network.NetworkWeights): super().__init__(net, weights) - self.up = self.create_module(weights.w["lora_up.weight"]) - self.down = self.create_module(weights.w["lora_down.weight"]) - self.alpha = weights.w["alpha"] if "alpha" in weights.w else None + self.up_model = self.create_module(weights.w, "lora_up.weight") + self.down_model = self.create_module(weights.w, "lora_down.weight") + self.mid_model = self.create_module(weights.w, "lora_mid.weight", none_ok=True) + + self.dim = weights.w["lora_down.weight"].shape[0] + + def create_module(self, weights, key, none_ok=False): + weight = weights.get(key) - def create_module(self, weight, none_ok=False): if weight is None and none_ok: return None - if type(self.sd_module) == torch.nn.Linear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: - module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.MultiheadAttention: + is_linear = type(self.sd_module) in [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear, torch.nn.MultiheadAttention] + is_conv = type(self.sd_module) in [torch.nn.Conv2d] + + if is_linear: + weight = weight.reshape(weight.shape[0], -1) module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) - elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1): + elif is_conv and key == "lora_down.weight" or key == "dyn_up": + if len(weight.shape) == 2: + weight = weight.reshape(weight.shape[0], -1, 1, 1) + + if weight.shape[2] != 1 or weight.shape[3] != 1: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + else: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + elif is_conv and key == "lora_mid.weight": + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], self.sd_module.kernel_size, self.sd_module.stride, self.sd_module.padding, bias=False) + elif is_conv and key == "lora_up.weight" or key == "dyn_down": module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) - elif type(self.sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3): - module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False) else: - print(f'Network layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') - return None + raise AssertionError(f'Lora layer {self.network_key} matched a layer with unsupported type: {type(self.sd_module).__name__}') with torch.no_grad(): + if weight.shape != module.weight.shape: + weight = weight.reshape(module.weight.shape) module.weight.copy_(weight) module.to(device=devices.cpu, dtype=devices.dtype) @@ -46,25 +60,27 @@ class NetworkModuleLora(network.NetworkModule): return module - def calc_updown(self, target): - up = self.up.weight.to(target.device, dtype=target.dtype) - down = self.down.weight.to(target.device, dtype=target.dtype) + def calc_updown(self, orig_weight): + up = self.up_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + down = self.down_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3): - updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3) + output_shape = [up.size(0), down.size(1)] + if self.mid_model is not None: + # cp-decomposition + mid = self.mid_model.weight.to(orig_weight.device, dtype=orig_weight.dtype) + updown = lyco_helpers.rebuild_cp_decomposition(up, down, mid) + output_shape += mid.shape[2:] else: - updown = up @ down - - updown = updown * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + if len(down.shape) == 4: + output_shape += down.shape[2:] + updown = lyco_helpers.rebuild_conventional(up, down, output_shape, self.network.dyn_dim) - return updown + return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y): - self.up.to(device=devices.device) - self.down.to(device=devices.device) + self.up_model.to(device=devices.device) + self.down_model.to(device=devices.device) - return y + self.up(self.down(x)) * self.network.multiplier * (self.alpha / self.up.weight.shape[1] if self.alpha else 1.0) + return y + self.up_model(self.down_model(x)) * self.multiplier() * self.calc_scale() diff --git a/extensions-builtin/Lora/network_lyco.py b/extensions-builtin/Lora/network_lyco.py deleted file mode 100644 index fc135314..00000000 --- a/extensions-builtin/Lora/network_lyco.py +++ /dev/null @@ -1,35 +0,0 @@ -import network - - -class NetworkModuleLyco(network.NetworkModule): - def __init__(self, net: network.Network, weights: network.NetworkWeights): - super().__init__(net, weights) - - if hasattr(self.sd_module, 'weight'): - self.shape = self.sd_module.weight.shape - - self.dim = None - self.bias = weights.w.get("bias") - self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None - self.scale = weights.w["scale"].item() if "scale" in weights.w else None - - def finalize_updown(self, updown, orig_weight, output_shape): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - scale = ( - self.scale if self.scale is not None - else self.alpha / self.dim if self.dim is not None and self.alpha is not None - else 1.0 - ) - - return updown * scale * self.network.multiplier - diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 1b358561..401430e8 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -6,6 +6,7 @@ import network_lora import network_hada import network_ia3 import network_lokr +import network_full import torch from typing import Union @@ -17,6 +18,7 @@ module_types = [ network_hada.ModuleTypeHada(), network_ia3.ModuleTypeIa3(), network_lokr.ModuleTypeLokr(), + network_full.ModuleTypeFull(), ] @@ -52,6 +54,15 @@ def convert_diffusers_name_to_compvis(key, is_sd2): m = [] + if match(m, r"lora_unet_conv_in(.*)"): + return f'diffusion_model_input_blocks_0_0{m[0]}' + + if match(m, r"lora_unet_conv_out(.*)"): + return f'diffusion_model_out_2{m[0]}' + + if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): + return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" @@ -179,7 +190,7 @@ def load_network(name, network_on_disk): return net -def load_networks(names, multipliers=None): +def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): already_loaded = {} for net in loaded_networks: @@ -218,7 +229,9 @@ def load_networks(names, multipliers=None): print(f"Couldn't find network with name {name}") continue - net.multiplier = multipliers[i] if multipliers else 1.0 + net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 + net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 + net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 loaded_networks.append(net) if failed_to_load_networks: @@ -250,7 +263,7 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn return current_names = getattr(self, "network_current_names", ()) - wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks) + wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) weights_backup = getattr(self, "network_weights_backup", None) if weights_backup is None: @@ -288,9 +301,10 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn updown_k = module_k.calc_updown(self.in_proj_weight) updown_v = module_v.calc_updown(self.in_proj_weight) updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + updown_out = module_out.calc_updown(self.out_proj.weight) self.in_proj_weight += updown_qkv - self.out_proj.weight += module_out.calc_updown(self.out_proj.weight) + self.out_proj.weight += updown_out continue if module is None: -- cgit v1.2.3 From 35510f7529dc05437a82496187ef06b852be9ab1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 10:06:02 +0300 Subject: add alias to lyco network read networks from LyCORIS dir if it exists add credits --- README.md | 1 + extensions-builtin/Lora/networks.py | 3 ++- extensions-builtin/Lora/scripts/lora_script.py | 5 ++++- modules/extra_networks.py | 16 ++++++++++++++-- 4 files changed, 21 insertions(+), 4 deletions(-) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/README.md b/README.md index e6d8e4bd..b796d150 100644 --- a/README.md +++ b/README.md @@ -168,5 +168,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al - Security advice - RyotaK - UniPC sampler - Wenliang Zhao - https://github.com/wl-zhao/UniPC - TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd +- LyCORIS - KohakuBlueleaf - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 401430e8..7b4c0312 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -11,7 +11,7 @@ import network_full import torch from typing import Union -from modules import shared, devices, sd_models, errors, scripts, sd_hijack +from modules import shared, devices, sd_models, errors, scripts, sd_hijack, paths module_types = [ network_lora.ModuleTypeLora(), @@ -399,6 +399,7 @@ def list_available_networks(): os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) + candidates += list(shared.walk_files(os.path.join(paths.models_path, "LyCORIS"), allowed_extensions=[".pt", ".ckpt", ".safetensors"])) for filename in candidates: if os.path.isdir(filename): continue diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 4c75821e..f478f718 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -22,7 +22,10 @@ def unload(): def before_ui(): ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) - extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) + + extra_network = extra_networks_lora.ExtraNetworkLora() + extra_networks.register_extra_network(extra_network) + extra_networks.register_extra_network_alias(extra_network, "lyco") if not hasattr(torch.nn, 'Linear_forward_before_network'): diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 41799b0a..6ae07e91 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -4,16 +4,22 @@ from collections import defaultdict from modules import errors extra_network_registry = {} +extra_network_aliases = {} def initialize(): extra_network_registry.clear() + extra_network_aliases.clear() def register_extra_network(extra_network): extra_network_registry[extra_network.name] = extra_network +def register_extra_network_alias(extra_network, alias): + extra_network_aliases[alias] = extra_network + + def register_default_extra_networks(): from modules.extra_networks_hypernet import ExtraNetworkHypernet register_extra_network(ExtraNetworkHypernet()) @@ -82,20 +88,26 @@ def activate(p, extra_network_data): """call activate for extra networks in extra_network_data in specified order, then call activate for all remaining registered networks with an empty argument list""" + activated = [] + for extra_network_name, extra_network_args in extra_network_data.items(): extra_network = extra_network_registry.get(extra_network_name, None) + + if extra_network is None: + extra_network = extra_network_aliases.get(extra_network_name, None) + if extra_network is None: print(f"Skipping unknown extra network: {extra_network_name}") continue try: extra_network.activate(p, extra_network_args) + activated.append(extra_network) except Exception as e: errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") for extra_network_name, extra_network in extra_network_registry.items(): - args = extra_network_data.get(extra_network_name, None) - if args is not None: + if extra_network in activated: continue try: -- cgit v1.2.3