From 2abd89acc66419abf2eee9b03fd093f2737670de Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 28 Jan 2023 20:04:35 +0300 Subject: index on master: 91c8d0d Merge pull request #7231 from EllangoK/master --- extensions-builtin/Lora/lora.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cb8f1d36..568a7675 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -12,7 +12,7 @@ re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+) re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") -def convert_diffusers_name_to_compvis(key): +def convert_diffusers_name_to_compvis(key, is_sd2): def match(match_list, regex): r = re.match(regex, key) if not r: @@ -34,6 +34,14 @@ def convert_diffusers_name_to_compvis(key): return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" if match(m, re_text_block): + if is_sd2: + if 'mlp_fc1' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + elif 'self_attn': + return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" return key @@ -83,9 +91,10 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename) keys_failed_to_match = [] + is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping for key_diffusers, weight in sd.items(): - fullkey = convert_diffusers_name_to_compvis(key_diffusers) + fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2) key, lora_key = fullkey.split(".", 1) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) @@ -104,9 +113,13 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: + module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) else: + print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}') + continue assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' with torch.no_grad(): @@ -182,6 +195,10 @@ def lora_Conv2d_forward(self, input): return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) +def lora_NonDynamicallyQuantizableLinear_forward(self, input): + return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input)) + + def list_available_loras(): available_loras.clear() -- cgit v1.2.3 From 04924241218bb51bee255bebc6c66ef1de449f4a Mon Sep 17 00:00:00 2001 From: bluelovers Date: Sun, 12 Mar 2023 10:18:33 +0800 Subject: feat: try sort as ignore-case https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/8368 --- extensions-builtin/Lora/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index cb8f1d36..7d3c0f90 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -192,7 +192,7 @@ def list_available_loras(): glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.safetensors'), recursive=True) + \ glob.glob(os.path.join(shared.cmd_opts.lora_dir, '**/*.ckpt'), recursive=True) - for filename in sorted(candidates): + for filename in sorted(candidates, key=str.lower): if os.path.isdir(filename): continue -- cgit v1.2.3 From 2f0181405f25e1448a55697081e380020fe8c68d Mon Sep 17 00:00:00 2001 From: FNSpd <125805478+FNSpd@users.noreply.github.com> Date: Tue, 21 Mar 2023 14:53:51 +0400 Subject: Update lora.py --- extensions-builtin/Lora/lora.py | 1 + 1 file changed, 1 insertion(+) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 8937b585..7c371deb 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -178,6 +178,7 @@ def load_loras(names, multipliers=None): def lora_forward(module, input, res): + input = devices.cond_cast_unet(input) if len(loaded_loras) == 0: return res -- cgit v1.2.3 From 80b26d2a69617b75d2d01c1e6b7d11445815ed4d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 25 Mar 2023 23:06:33 +0300 Subject: apply Lora by altering layer's weights instead of adding more calculations in forward() --- extensions-builtin/Lora/lora.py | 72 ++++++++++++++++++++++++++++++++--------- 1 file changed, 56 insertions(+), 16 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 7c371deb..a737fec3 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -131,7 +131,7 @@ def load_lora(name, filename): with torch.no_grad(): module.weight.copy_(weight) - module.to(device=devices.device, dtype=devices.dtype) + module.to(device=devices.cpu, dtype=devices.dtype) if lora_key == "lora_up.weight": lora_module.up = module @@ -177,29 +177,69 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora) -def lora_forward(module, input, res): - input = devices.cond_cast_unet(input) - if len(loaded_loras) == 0: - return res +def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear): + """ + Applies the currently selected set of Loras to the weight of torch layer self. + If weights already have this particular set of loras applied, does nothing. + If not, restores orginal weights from backup and alters weights according to loras. + """ - lora_layer_name = getattr(module, 'lora_layer_name', None) - for lora in loaded_loras: - module = lora.modules.get(lora_layer_name, None) - if module is not None: - if shared.opts.lora_apply_to_outputs and res.shape == input.shape: - res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) - else: - res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + current_names = getattr(self, "lora_current_names", ()) + wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) + + weights_backup = getattr(self, "lora_weights_backup", None) + if weights_backup is None: + weights_backup = self.weight.to(devices.cpu, copy=True) + self.lora_weights_backup = weights_backup + + if current_names != wanted_names: + if weights_backup is not None: + self.weight.copy_(weights_backup) + + lora_layer_name = getattr(self, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is None: + continue - return res + with torch.no_grad(): + up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype) + down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + setattr(self, "lora_current_names", wanted_names) def lora_Linear_forward(self, input): - return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Linear_forward_before_lora(self, input) + + +def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) def lora_Conv2d_forward(self, input): - return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + lora_apply_weights(self) + + return torch.nn.Conv2d_forward_before_lora(self, input) + + +def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) def list_available_loras(): -- cgit v1.2.3 From 650ddc9dd3c1d126221682be8270f7fba1b5b6ce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 26 Mar 2023 10:44:20 +0300 Subject: Lora support for SD2 --- extensions-builtin/Lora/lora.py | 155 ++++++++++++++++++++++++++++++---------- 1 file changed, 116 insertions(+), 39 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index d4345ada..edd95f78 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -8,14 +8,27 @@ from modules import shared, devices, sd_models, errors metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} re_digits = re.compile(r"\d+") -re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") -re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") -re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") -re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") +re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") +re_compiled = {} + +suffix_conversion = { + "attentions": {}, + "resnets": { + "conv1": "in_layers_2", + "conv2": "out_layers_3", + "time_emb_proj": "emb_layers_1", + "conv_shortcut": "skip_connection", + } +} def convert_diffusers_name_to_compvis(key, is_sd2): - def match(match_list, regex): + def match(match_list, regex_text): + regex = re_compiled.get(regex_text) + if regex is None: + regex = re.compile(regex_text) + re_compiled[regex_text] = regex + r = re.match(regex, key) if not r: return False @@ -26,16 +39,25 @@ def convert_diffusers_name_to_compvis(key, is_sd2): m = [] - if match(m, re_unet_down_blocks): - return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" + + if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) + return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" - if match(m, re_unet_mid_blocks): - return f"diffusion_model_middle_block_1_{m[1]}" + if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): + suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) + return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" - if match(m, re_unet_up_blocks): - return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): + return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" - if match(m, re_text_block): + if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): + return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" + + if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): if is_sd2: if 'mlp_fc1' in m[1]: return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" @@ -109,16 +131,22 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename) - keys_failed_to_match = [] + keys_failed_to_match = {} is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping for key_diffusers, weight in sd.items(): - fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2) - key, lora_key = fullkey.split(".", 1) + key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) + key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: - keys_failed_to_match.append(key_diffusers) + m = re_x_proj.match(key) + if m: + sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) + + if sd_module is None: + keys_failed_to_match[key_diffusers] = key continue lora_module = lora.modules.get(key, None) @@ -133,7 +161,9 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear: module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear: - module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False) + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.MultiheadAttention: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) elif type(sd_module) == torch.nn.Conv2d: module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) else: @@ -190,54 +220,94 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora) -def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear): +def lora_calc_updown(lora, module, target): + with torch.no_grad(): + up = module.up.weight.to(target.device, dtype=target.dtype) + down = module.down.weight.to(target.device, dtype=target.dtype) + + if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): + updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + updown = up @ down + + updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + return updown + + +def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.MultiheadAttention): """ - Applies the currently selected set of Loras to the weight of torch layer self. + Applies the currently selected set of Loras to the weights of torch layer self. If weights already have this particular set of loras applied, does nothing. If not, restores orginal weights from backup and alters weights according to loras. """ + lora_layer_name = getattr(self, 'lora_layer_name', None) + if lora_layer_name is None: + return + current_names = getattr(self, "lora_current_names", ()) wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras) weights_backup = getattr(self, "lora_weights_backup", None) if weights_backup is None: - weights_backup = self.weight.to(devices.cpu, copy=True) + if isinstance(self, torch.nn.MultiheadAttention): + weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) + else: + weights_backup = self.weight.to(devices.cpu, copy=True) + self.lora_weights_backup = weights_backup if current_names != wanted_names: if weights_backup is not None: - self.weight.copy_(weights_backup) + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) - lora_layer_name = getattr(self, 'lora_layer_name', None) for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) - if module is None: + if module is not None and hasattr(self, 'weight'): + self.weight += lora_calc_updown(lora, module, self.weight) continue - with torch.no_grad(): - up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype) - down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype) + module_q = lora.modules.get(lora_layer_name + "_q_proj", None) + module_k = lora.modules.get(lora_layer_name + "_k_proj", None) + module_v = lora.modules.get(lora_layer_name + "_v_proj", None) + module_out = lora.modules.get(lora_layer_name + "_out_proj", None) + + if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: + updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight) + updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight) + updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight) + updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) - if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1): - updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3) - else: - updown = up @ down + self.in_proj_weight += updown_qkv + self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight) + continue + + if module is None: + continue - self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + print(f'failed to calculate lora weights for layer {lora_layer_name}') setattr(self, "lora_current_names", wanted_names) +def lora_reset_cached_weight(self: torch.nn.Conv2d | torch.nn.Linear): + setattr(self, "lora_current_names", ()) + setattr(self, "lora_weights_backup", None) + + def lora_Linear_forward(self, input): lora_apply_weights(self) return torch.nn.Linear_forward_before_lora(self, input) -def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs): - setattr(self, "lora_current_names", ()) - setattr(self, "lora_weights_backup", None) +def lora_Linear_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs) @@ -248,15 +318,22 @@ def lora_Conv2d_forward(self, input): return torch.nn.Conv2d_forward_before_lora(self, input) -def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs): - setattr(self, "lora_current_names", ()) - setattr(self, "lora_weights_backup", None) +def lora_Conv2d_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs) -def lora_NonDynamicallyQuantizableLinear_forward(self, input): - return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input)) +def lora_MultiheadAttention_forward(self, *args, **kwargs): + lora_apply_weights(self) + + return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs) + + +def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs): + lora_reset_cached_weight(self) + + return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs) def list_available_loras(): -- cgit v1.2.3 From 9d7390d2d19a8baf04ee4ebe598b96ac6ba7f97e Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Mon, 27 Mar 2023 04:28:40 +0300 Subject: convert to python v3.9 --- extensions-builtin/Lora/lora.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index edd95f78..79d11e0e 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -2,6 +2,7 @@ import glob import os import re import torch +from typing import Union from modules import shared, devices, sd_models, errors @@ -235,7 +236,7 @@ def lora_calc_updown(lora, module, target): return updown -def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.MultiheadAttention): +def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): """ Applies the currently selected set of Loras to the weights of torch layer self. If weights already have this particular set of loras applied, does nothing. -- cgit v1.2.3 From 6a147db1287fe660e1bfb2ebf5b3fadc14835c69 Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Mon, 27 Mar 2023 04:40:31 +0300 Subject: convert to python v3.9 --- extensions-builtin/Lora/lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 79d11e0e..696be8ea 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -296,7 +296,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu setattr(self, "lora_current_names", wanted_names) -def lora_reset_cached_weight(self: torch.nn.Conv2d | torch.nn.Linear): +def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): setattr(self, "lora_current_names", ()) setattr(self, "lora_weights_backup", None) -- cgit v1.2.3