aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/lora_patches.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-15 16:23:27 +0000
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-15 16:23:40 +0000
commitf01682ee01e81e8ef84fd6fffe8f7aa17233285d (patch)
tree80f62099e6af5f77c7df8c092c37c71ed24750d9 /extensions-builtin/Lora/lora_patches.py
parent7327be97aa9beeae881bf4649a56792bd284efd5 (diff)
downloadstable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.tar.gz
stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.tar.bz2
stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.zip
store patches for Lora in a specialized module
Diffstat (limited to 'extensions-builtin/Lora/lora_patches.py')
-rw-r--r--extensions-builtin/Lora/lora_patches.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py
new file mode 100644
index 00000000..b394d8e9
--- /dev/null
+++ b/extensions-builtin/Lora/lora_patches.py
@@ -0,0 +1,31 @@
+import torch
+
+import networks
+from modules import patches
+
+
+class LoraPatches:
+ def __init__(self):
+ self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward)
+ self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict)
+ self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward)
+ self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict)
+ self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward)
+ self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict)
+ self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward)
+ self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict)
+ self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward)
+ self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict)
+
+ def undo(self):
+ self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward')
+ self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict')
+ self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward')
+ self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict')
+ self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward')
+ self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict')
+ self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward')
+ self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict')
+ self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward')
+ self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict')
+