aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/scripts/lora_script.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-13 05:28:48 +0000
committerGitHub <noreply@github.com>2023-08-13 05:28:48 +0000
commitda80d649fd6a6083be02aca5695367bd25abf0d5 (patch)
treecfb85edce888c9d3092a72ae62340c03afb915ce /extensions-builtin/Lora/scripts/lora_script.py
parent61673451ff2b6ea39c8b9591b4a14d7f19a32e63 (diff)
parent5881dcb8873b3f87b9c6545e9cb8d1d77023f4fe (diff)
downloadstable-diffusion-webui-gfx803-da80d649fd6a6083be02aca5695367bd25abf0d5.tar.gz
stable-diffusion-webui-gfx803-da80d649fd6a6083be02aca5695367bd25abf0d5.tar.bz2
stable-diffusion-webui-gfx803-da80d649fd6a6083be02aca5695367bd25abf0d5.zip
Merge pull request #12503 from AUTOMATIC1111/extra-norm-module
Add Norm Module to lora ext and add "bias" support
Diffstat (limited to 'extensions-builtin/Lora/scripts/lora_script.py')
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 6ab8b6e7..dc307f8c 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
+if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
+ torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
+
+if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
+ torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
+
+if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
+ torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
+
+if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
+ torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
+
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
@@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
+torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
+torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
+torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
+torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict