diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-03-26 07:44:20 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-03-26 07:44:20 +0000 |
commit | 650ddc9dd3c1d126221682be8270f7fba1b5b6ce (patch) | |
tree | cd1c607f63bd6fe487ccc0335546b7c561b1fc9c /extensions-builtin/Lora/lora.py | |
parent | b705c9b72b290afc825f2a96b2f8d01d72028062 (diff) | |
download | stable-diffusion-webui-gfx803-650ddc9dd3c1d126221682be8270f7fba1b5b6ce.tar.gz stable-diffusion-webui-gfx803-650ddc9dd3c1d126221682be8270f7fba1b5b6ce.tar.bz2 stable-diffusion-webui-gfx803-650ddc9dd3c1d126221682be8270f7fba1b5b6ce.zip |
Lora support for SD2
Diffstat (limited to 'extensions-builtin/Lora/lora.py')
-rw-r--r-- | extensions-builtin/Lora/lora.py | 155 |
1 files changed, 116 insertions, 39 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index d4345ada..edd95f78 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -8,14 +8,27 @@ from modules import shared, devices, sd_models, errors metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20}
re_digits = re.compile(r"\d+")
-re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)")
-re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)")
-re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)")
-re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
+re_x_proj = re.compile(r"(.*)_([qkv]_proj)$")
+re_compiled = {}
+
+suffix_conversion = {
+ "attentions": {},
+ "resnets": {
+ "conv1": "in_layers_2",
+ "conv2": "out_layers_3",
+ "time_emb_proj": "emb_layers_1",
+ "conv_shortcut": "skip_connection",
+ }
+}
def convert_diffusers_name_to_compvis(key, is_sd2):
- def match(match_list, regex):
+ def match(match_list, regex_text):
+ regex = re_compiled.get(regex_text)
+ if regex is None:
+ regex = re.compile(regex_text)
+ re_compiled[regex_text] = regex
+
r = re.match(regex, key)
if not r:
return False
@@ -26,16 +39,25 @@ def convert_diffusers_name_to_compvis(key, is_sd2): m = []
- if match(m, re_unet_down_blocks):
- return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}"
+ if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
+ return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
+
+ if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2])
+ return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}"
- if match(m, re_unet_mid_blocks):
- return f"diffusion_model_middle_block_1_{m[1]}"
+ if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"):
+ suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3])
+ return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}"
- if match(m, re_unet_up_blocks):
- return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
+ if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"):
+ return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op"
- if match(m, re_text_block):
+ if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"):
+ return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv"
+
+ if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"):
if is_sd2:
if 'mlp_fc1' in m[1]:
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
@@ -109,16 +131,22 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename)
- keys_failed_to_match = []
+ keys_failed_to_match = {}
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
for key_diffusers, weight in sd.items():
- fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2)
- key, lora_key = fullkey.split(".", 1)
+ key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1)
+ key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
+
if sd_module is None:
- keys_failed_to_match.append(key_diffusers)
+ m = re_x_proj.match(key)
+ if m:
+ sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None)
+
+ if sd_module is None:
+ keys_failed_to_match[key_diffusers] = key
continue
lora_module = lora.modules.get(key, None)
@@ -133,7 +161,9 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
- module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False)
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
+ elif type(sd_module) == torch.nn.MultiheadAttention:
+ module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
@@ -190,54 +220,94 @@ def load_loras(names, multipliers=None): loaded_loras.append(lora)
-def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear):
+def lora_calc_updown(lora, module, target):
+ with torch.no_grad():
+ up = module.up.weight.to(target.device, dtype=target.dtype)
+ down = module.down.weight.to(target.device, dtype=target.dtype)
+
+ if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
+ updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ else:
+ updown = up @ down
+
+ updown = updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+
+ return updown
+
+
+def lora_apply_weights(self: torch.nn.Conv2d | torch.nn.Linear | torch.nn.MultiheadAttention):
"""
- Applies the currently selected set of Loras to the weight of torch layer self.
+ Applies the currently selected set of Loras to the weights of torch layer self.
If weights already have this particular set of loras applied, does nothing.
If not, restores orginal weights from backup and alters weights according to loras.
"""
+ lora_layer_name = getattr(self, 'lora_layer_name', None)
+ if lora_layer_name is None:
+ return
+
current_names = getattr(self, "lora_current_names", ())
wanted_names = tuple((x.name, x.multiplier) for x in loaded_loras)
weights_backup = getattr(self, "lora_weights_backup", None)
if weights_backup is None:
- weights_backup = self.weight.to(devices.cpu, copy=True)
+ if isinstance(self, torch.nn.MultiheadAttention):
+ weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True))
+ else:
+ weights_backup = self.weight.to(devices.cpu, copy=True)
+
self.lora_weights_backup = weights_backup
if current_names != wanted_names:
if weights_backup is not None:
- self.weight.copy_(weights_backup)
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
- lora_layer_name = getattr(self, 'lora_layer_name', None)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
- if module is None:
+ if module is not None and hasattr(self, 'weight'):
+ self.weight += lora_calc_updown(lora, module, self.weight)
continue
- with torch.no_grad():
- up = module.up.weight.to(self.weight.device, dtype=self.weight.dtype)
- down = module.down.weight.to(self.weight.device, dtype=self.weight.dtype)
+ module_q = lora.modules.get(lora_layer_name + "_q_proj", None)
+ module_k = lora.modules.get(lora_layer_name + "_k_proj", None)
+ module_v = lora.modules.get(lora_layer_name + "_v_proj", None)
+ module_out = lora.modules.get(lora_layer_name + "_out_proj", None)
+
+ if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out:
+ updown_q = lora_calc_updown(lora, module_q, self.in_proj_weight)
+ updown_k = lora_calc_updown(lora, module_k, self.in_proj_weight)
+ updown_v = lora_calc_updown(lora, module_v, self.in_proj_weight)
+ updown_qkv = torch.vstack([updown_q, updown_k, updown_v])
- if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
- updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
- else:
- updown = up @ down
+ self.in_proj_weight += updown_qkv
+ self.out_proj.weight += lora_calc_updown(lora, module_out, self.out_proj.weight)
+ continue
+
+ if module is None:
+ continue
- self.weight += updown * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ print(f'failed to calculate lora weights for layer {lora_layer_name}')
setattr(self, "lora_current_names", wanted_names)
+def lora_reset_cached_weight(self: torch.nn.Conv2d | torch.nn.Linear):
+ setattr(self, "lora_current_names", ())
+ setattr(self, "lora_weights_backup", None)
+
+
def lora_Linear_forward(self, input):
lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input)
-def lora_Linear_load_state_dict(self: torch.nn.Linear, *args, **kwargs):
- setattr(self, "lora_current_names", ())
- setattr(self, "lora_weights_backup", None)
+def lora_Linear_load_state_dict(self, *args, **kwargs):
+ lora_reset_cached_weight(self)
return torch.nn.Linear_load_state_dict_before_lora(self, *args, **kwargs)
@@ -248,15 +318,22 @@ def lora_Conv2d_forward(self, input): return torch.nn.Conv2d_forward_before_lora(self, input)
-def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs):
- setattr(self, "lora_current_names", ())
- setattr(self, "lora_weights_backup", None)
+def lora_Conv2d_load_state_dict(self, *args, **kwargs):
+ lora_reset_cached_weight(self)
return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
-def lora_NonDynamicallyQuantizableLinear_forward(self, input):
- return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input))
+def lora_MultiheadAttention_forward(self, *args, **kwargs):
+ lora_apply_weights(self)
+
+ return torch.nn.MultiheadAttention_forward_before_lora(self, *args, **kwargs)
+
+
+def lora_MultiheadAttention_load_state_dict(self, *args, **kwargs):
+ lora_reset_cached_weight(self)
+
+ return torch.nn.MultiheadAttention_load_state_dict_before_lora(self, *args, **kwargs)
def list_available_loras():
|