aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/scripts/lora_script.py
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-08-12 18:27:39 +0000
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-08-12 18:27:39 +0000
commitbd4da4474bef5c9c1f690c62b971704ee73d2860 (patch)
treef3624d0521366fb23c4e7861ea1d0a04b43483e6 /extensions-builtin/Lora/scripts/lora_script.py
parentb2080756fcdc328292fc38998c06ccf23e53bd7e (diff)
downloadstable-diffusion-webui-gfx803-bd4da4474bef5c9c1f690c62b971704ee73d2860.tar.gz
stable-diffusion-webui-gfx803-bd4da4474bef5c9c1f690c62b971704ee73d2860.tar.bz2
stable-diffusion-webui-gfx803-bd4da4474bef5c9c1f690c62b971704ee73d2860.zip
Add extra norm module into built-in lora ext
refer to LyCORIS 1.9.0.dev6 add new option and module for training norm layer (Which is reported to be good for style)
Diffstat (limited to 'extensions-builtin/Lora/scripts/lora_script.py')
-rw-r--r--extensions-builtin/Lora/scripts/lora_script.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py
index 6ab8b6e7..dc307f8c 100644
--- a/extensions-builtin/Lora/scripts/lora_script.py
+++ b/extensions-builtin/Lora/scripts/lora_script.py
@@ -40,6 +40,18 @@ if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
+if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
+ torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
+
+if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
+ torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
+
+if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
+ torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
+
+if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
+ torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
+
if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
@@ -50,6 +62,10 @@ torch.nn.Linear.forward = networks.network_Linear_forward
torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
torch.nn.Conv2d.forward = networks.network_Conv2d_forward
torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
+torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
+torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
+torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
+torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict