diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-03-26 07:44:20 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-03-26 07:44:20 +0000 |
commit | 650ddc9dd3c1d126221682be8270f7fba1b5b6ce (patch) | |
tree | cd1c607f63bd6fe487ccc0335546b7c561b1fc9c /extensions-builtin/Lora/scripts | |
parent | b705c9b72b290afc825f2a96b2f8d01d72028062 (diff) | |
download | stable-diffusion-webui-gfx803-650ddc9dd3c1d126221682be8270f7fba1b5b6ce.tar.gz stable-diffusion-webui-gfx803-650ddc9dd3c1d126221682be8270f7fba1b5b6ce.tar.bz2 stable-diffusion-webui-gfx803-650ddc9dd3c1d126221682be8270f7fba1b5b6ce.zip |
Lora support for SD2
Diffstat (limited to 'extensions-builtin/Lora/scripts')
-rw-r--r-- | extensions-builtin/Lora/scripts/lora_script.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index dc329e81..0adab225 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -12,6 +12,8 @@ def unload(): torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_lora
torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora
torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_lora
+ torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_lora
+ torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_lora
def before_ui():
@@ -31,10 +33,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_lora'):
torch.nn.Conv2d_load_state_dict_before_lora = torch.nn.Conv2d._load_from_state_dict
+if not hasattr(torch.nn, 'MultiheadAttention_forward_before_lora'):
+ torch.nn.MultiheadAttention_forward_before_lora = torch.nn.MultiheadAttention.forward
+
+if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_lora'):
+ torch.nn.MultiheadAttention_load_state_dict_before_lora = torch.nn.MultiheadAttention._load_from_state_dict
+
torch.nn.Linear.forward = lora.lora_Linear_forward
torch.nn.Linear._load_from_state_dict = lora.lora_Linear_load_state_dict
torch.nn.Conv2d.forward = lora.lora_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = lora.lora_Conv2d_load_state_dict
+torch.nn.MultiheadAttention.forward = lora.lora_MultiheadAttention_forward
+torch.nn.MultiheadAttention._load_from_state_dict = lora.lora_MultiheadAttention_load_state_dict
script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
|