diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-03-26 04:04:43 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-03-26 04:04:43 +0000 |
commit | b705c9b72b290afc825f2a96b2f8d01d72028062 (patch) | |
tree | 0dad0f2de613c0014184f07758eeccd8f171e1d4 /extensions-builtin/Lora/lora.py | |
parent | 80b26d2a69617b75d2d01c1e6b7d11445815ed4d (diff) | |
parent | 7cb31a278e8f27367792b66cdd3bcfba41093b32 (diff) | |
download | stable-diffusion-webui-gfx803-b705c9b72b290afc825f2a96b2f8d01d72028062.tar.gz stable-diffusion-webui-gfx803-b705c9b72b290afc825f2a96b2f8d01d72028062.tar.bz2 stable-diffusion-webui-gfx803-b705c9b72b290afc825f2a96b2f8d01d72028062.zip |
Merge branch 'lora_sd2' into lora_inplace
Diffstat (limited to 'extensions-builtin/Lora/lora.py')
-rw-r--r-- | extensions-builtin/Lora/lora.py | 21 |
1 files changed, 19 insertions, 2 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index a737fec3..d4345ada 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -14,7 +14,7 @@ re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+) re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)")
-def convert_diffusers_name_to_compvis(key):
+def convert_diffusers_name_to_compvis(key, is_sd2):
def match(match_list, regex):
r = re.match(regex, key)
if not r:
@@ -36,6 +36,14 @@ def convert_diffusers_name_to_compvis(key): return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}"
if match(m, re_text_block):
+ if is_sd2:
+ if 'mlp_fc1' in m[1]:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}"
+ elif 'mlp_fc2' in m[1]:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}"
+ else:
+ return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}"
+
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}"
return key
@@ -102,9 +110,10 @@ def load_lora(name, filename): sd = sd_models.read_state_dict(filename)
keys_failed_to_match = []
+ is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping
for key_diffusers, weight in sd.items():
- fullkey = convert_diffusers_name_to_compvis(key_diffusers)
+ fullkey = convert_diffusers_name_to_compvis(key_diffusers, is_sd2)
key, lora_key = fullkey.split(".", 1)
sd_module = shared.sd_model.lora_layer_mapping.get(key, None)
@@ -123,9 +132,13 @@ def load_lora(name, filename): if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
+ elif type(sd_module) == torch.nn.modules.linear.NonDynamicallyQuantizableLinear:
+ module = torch.nn.modules.linear.NonDynamicallyQuantizableLinear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
else:
+ print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
+ continue
assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}'
with torch.no_grad():
@@ -242,6 +255,10 @@ def lora_Conv2d_load_state_dict(self: torch.nn.Conv2d, *args, **kwargs): return torch.nn.Conv2d_load_state_dict_before_lora(self, *args, **kwargs)
+def lora_NonDynamicallyQuantizableLinear_forward(self, input):
+ return lora_forward(self, input, torch.nn.NonDynamicallyQuantizableLinear_forward_before_lora(self, input))
+
+
def list_available_loras():
available_loras.clear()
|