diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-13 05:28:48 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-13 05:28:48 +0000 |
commit | da80d649fd6a6083be02aca5695367bd25abf0d5 (patch) | |
tree | cfb85edce888c9d3092a72ae62340c03afb915ce /extensions-builtin/Lora/scripts/lora_script.py | |
parent | 61673451ff2b6ea39c8b9591b4a14d7f19a32e63 (diff) | |
parent | 5881dcb8873b3f87b9c6545e9cb8d1d77023f4fe (diff) | |
download | stable-diffusion-webui-gfx803-da80d649fd6a6083be02aca5695367bd25abf0d5.tar.gz stable-diffusion-webui-gfx803-da80d649fd6a6083be02aca5695367bd25abf0d5.tar.bz2 stable-diffusion-webui-gfx803-da80d649fd6a6083be02aca5695367bd25abf0d5.zip |
Merge pull request #12503 from AUTOMATIC1111/extra-norm-module
Add Norm Module to lora ext and add "bias" support
Diffstat (limited to 'extensions-builtin/Lora/scripts/lora_script.py')
-rw-r--r-- | extensions-builtin/Lora/scripts/lora_script.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 6ab8b6e7..dc307f8c 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'): if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
+if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
+ torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
+
+if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
+ torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
+
+if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
+ torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
+
+if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
+ torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
+
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
@@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
+torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
+torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
+torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
+torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
|