From b75b004fe62826455f1aa77e849e7da13902cb17 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 23:13:55 +0300 Subject: lora extension rework to include other types of networks --- extensions-builtin/Lora/networks.py | 443 ++++++++++++++++++++++++++++++++++++ 1 file changed, 443 insertions(+) create mode 100644 extensions-builtin/Lora/networks.py (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py new file mode 100644 index 00000000..5b0ddfb6 --- /dev/null +++ b/extensions-builtin/Lora/networks.py @@ -0,0 +1,443 @@ +import os +import re + +import network +import network_lora +import network_hada + +import torch +from typing import Union + +from modules import shared, devices, sd_models, errors, scripts, sd_hijack + +module_types = [ + network_lora.ModuleTypeLora(), + network_hada.ModuleTypeHada(), +] + + +re_digits = re.compile(r"\d+") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} + + +def convert_diffusers_name_to_compvis(key, is_sd2): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" + + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" + + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): + if 'mlp_fc1' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + + return key + + +def assign_network_names_to_compvis_modules(sd_model): + network_layer_mapping = {} + + if shared.sd_model.is_sdxl: + for i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for name, module in embedder.wrapped.named_modules(): + network_name = f'{i}_{name.replace(".", "_")}' + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + else: + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + + for name, module in shared.sd_model.model.named_modules(): + network_name = name.replace(".", "_") + network_layer_mapping[network_name] = module + module.network_layer_name = network_name + + sd_model.network_layer_mapping = network_layer_mapping + + +def load_network(name, network_on_disk): + net = network.Network(name, network_on_disk) + net.mtime = os.path.getmtime(network_on_disk.filename) + + sd = sd_models.read_state_dict(network_on_disk.filename) + + # this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 + if not hasattr(shared.sd_model, 'network_layer_mapping'): + assign_network_names_to_compvis_modules(shared.sd_model) + + keys_failed_to_match = {} + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping + + matched_networks = {} + + for key_network, weight in sd.items(): + key_network_without_network_parts, network_part = key_network.split(".", 1) + + key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + if sd_module is None: + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) + + # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" + if sd_module is None and "lora_unet" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: + key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") + sd_module = shared.sd_model.network_layer_mapping.get(key, None) + + if sd_module is None: + keys_failed_to_match[key_network] = key + continue + + if key not in matched_networks: + matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) + + matched_networks[key].w[network_part] = weight + + for key, weights in matched_networks.items(): + net_module = None + for nettype in module_types: + net_module = nettype.create_module(net, weights) + if net_module is not None: + break + + if net_module is None: + raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") + + net.modules[key] = net_module + + if keys_failed_to_match: + print(f"Failed to match keys when loading network {network_on_disk.filename}: {keys_failed_to_match}") + + return net + + +def load_networks(names, multipliers=None): + already_loaded = {} + + for net in loaded_networks: + if net.name in names: + already_loaded[net.name] = net + + loaded_networks.clear() + + networks_on_disk = [available_network_aliases.get(name, None) for name in names] + if any(x is None for x in networks_on_disk): + list_available_networks() + + networks_on_disk = [available_network_aliases.get(name, None) for name in names] + + failed_to_load_networks = [] + + for i, name in enumerate(names): + net = already_loaded.get(name, None) + + network_on_disk = networks_on_disk[i] + + if network_on_disk is not None: + if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: + try: + net = load_network(name, network_on_disk) + except Exception as e: + errors.display(e, f"loading network {network_on_disk.filename}") + continue + + net.mentioned_name = name + + network_on_disk.read_hash() + + if net is None: + failed_to_load_networks.append(name) + print(f"Couldn't find network with name {name}") + continue + + net.multiplier = multipliers[i] if multipliers else 1.0 + loaded_networks.append(net) + + if failed_to_load_networks: + sd_hijack.model_hijack.comments.append("Failed to find networks: " + ", ".join(failed_to_load_networks)) + + +def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + weights_backup = getattr(self, "network_weights_backup", None) + + if weights_backup is None: + return + + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + +def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + """ + Applies the currently selected set of networks to the weights of torch layer self. + If weights already have this particular set of networks applied, does nothing. + If not, restores orginal weights from backup and alters weights according to networks. + """ + + network_layer_name = getattr(self, 'network_layer_name', None) + if network_layer_name is None: + return + + current_names = getattr(self, "network_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_networks) + + weights_backup = getattr(self, "network_weights_backup", None) + if weights_backup is None: + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + + self.network_weights_backup = weights_backup + + if current_names != wanted_names: + network_restore_weights_from_backup(self) + + for net in loaded_networks: + module = net.modules.get(network_layer_name, None) + if module is not None and hasattr(self, 'weight'): + with torch.no_grad(): + updown = module.calc_updown(self.weight) + + if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: + # inpainting model. zero pad updown to make channel[1] 4 to 9 + updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) + + self.weight += updown + + module_q = net.modules.get(network_layer_name + "_q_proj", None) + module_k = net.modules.get(network_layer_name + "_k_proj", None) + module_v = net.modules.get(network_layer_name + "_v_proj", None) + module_out = net.modules.get(network_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + with torch.no_grad(): + updown_q = module_q.calc_updown(self.in_proj_weight) + updown_k = module_k.calc_updown(self.in_proj_weight) + updown_v = module_v.calc_updown(self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) + + self.in_proj_weight += updown_qkv + self.out_proj.weight += module_out.calc_updown(self.out_proj.weight) + continue + + if module is None: + continue + + print(f'failed to calculate network weights for layer {network_layer_name}') + + self.network_current_names = wanted_names + + +def network_forward(module, input, original_forward): + """ + Old way of applying Lora by executing operations during layer's forward. + Stacking many loras this way results in big performance degradation. + """ + + if len(loaded_networks) == 0: + return original_forward(module, input) + + input = devices.cond_cast_unet(input) + + network_restore_weights_from_backup(module) + network_reset_cached_weight(module) + + y = original_forward(module, input) + + network_layer_name = getattr(module, 'network_layer_name', None) + for lora in loaded_networks: + module = lora.modules.get(network_layer_name, None) + if module is None: + continue + + y = module.forward(y, input) + + return y + + +def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): + self.network_current_names = () + self.network_weights_backup = None + + +def network_Linear_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.Linear_forward_before_network) + + network_apply_weights(self) + + return torch.nn.Linear_forward_before_network(self, input) + + +def network_Linear_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs) + + +def network_Conv2d_forward(self, input): + if shared.opts.lora_functional: + return network_forward(self, input, torch.nn.Conv2d_forward_before_network) + + network_apply_weights(self) + + return torch.nn.Conv2d_forward_before_network(self, input) + + +def network_Conv2d_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) + + +def network_MultiheadAttention_forward(self, *args, **kwargs): + network_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs) + + +def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): + network_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) + + +def list_available_networks(): + available_networks.clear() + available_network_aliases.clear() + forbidden_network_aliases.clear() + available_network_hash_lookup.clear() + forbidden_network_aliases.update({"none": 1, "Addams": 1}) + + os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) + + candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) + for filename in candidates: + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + try: + entry = network.NetworkOnDisk(name, filename) + except OSError: # should catch FileNotFoundError and PermissionError etc. + errors.report(f"Failed to load network {name} from {filename}", exc_info=True) + continue + + available_networks[name] = entry + + if entry.alias in available_network_aliases: + forbidden_network_aliases[entry.alias.lower()] = 1 + + available_network_aliases[name] = entry + available_network_aliases[entry.alias] = entry + + +re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") + + +def infotext_pasted(infotext, params): + if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: + return # if the other extension is active, it will handle those fields, no need to do anything + + added = [] + + for k in params: + if not k.startswith("AddNet Model "): + continue + + num = k[13:] + + if params.get("AddNet Module " + num) != "LoRA": + continue + + name = params.get("AddNet Model " + num) + if name is None: + continue + + m = re_network_name.match(name) + if m: + name = m.group(1) + + multiplier = params.get("AddNet Weight A " + num, "1.0") + + added.append(f"") + + if added: + params["Prompt"] += "\n" + "".join(added) + + +available_networks = {} +available_network_aliases = {} +loaded_networks = [] +available_network_hash_lookup = {} +forbidden_network_aliases = {} + +list_available_networks() -- cgit v1.2.3