diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-15 16:23:27 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-15 16:23:40 +0000 |
commit | f01682ee01e81e8ef84fd6fffe8f7aa17233285d (patch) | |
tree | 80f62099e6af5f77c7df8c092c37c71ed24750d9 /extensions-builtin/Lora/networks.py | |
parent | 7327be97aa9beeae881bf4649a56792bd284efd5 (diff) | |
download | stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.tar.gz stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.tar.bz2 stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.zip |
store patches for Lora in a specialized module
Diffstat (limited to 'extensions-builtin/Lora/networks.py')
-rw-r--r-- | extensions-builtin/Lora/networks.py | 32 |
1 files changed, 18 insertions, 14 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 22fdff4a..9fca36b6 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -2,6 +2,7 @@ import logging import os
import re
+import lora_patches
import network
import network_lora
import network_hada
@@ -418,74 +419,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def network_Linear_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.Linear_forward_before_network)
+ return network_forward(self, input, originals.Linear_forward)
network_apply_weights(self)
- return torch.nn.Linear_forward_before_network(self, input)
+ return originals.Linear_forward(self, input)
def network_Linear_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.Linear_load_state_dict(self, *args, **kwargs)
def network_Conv2d_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.Conv2d_forward_before_network)
+ return network_forward(self, input, originals.Conv2d_forward)
network_apply_weights(self)
- return torch.nn.Conv2d_forward_before_network(self, input)
+ return originals.Conv2d_forward(self, input)
def network_Conv2d_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.Conv2d_load_state_dict(self, *args, **kwargs)
def network_GroupNorm_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.GroupNorm_forward_before_network)
+ return network_forward(self, input, originals.GroupNorm_forward)
network_apply_weights(self)
- return torch.nn.GroupNorm_forward_before_network(self, input)
+ return originals.GroupNorm_forward(self, input)
def network_GroupNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.GroupNorm_load_state_dict(self, *args, **kwargs)
def network_LayerNorm_forward(self, input):
if shared.opts.lora_functional:
- return network_forward(self, input, torch.nn.LayerNorm_forward_before_network)
+ return network_forward(self, input, originals.LayerNorm_forward)
network_apply_weights(self)
- return torch.nn.LayerNorm_forward_before_network(self, input)
+ return originals.LayerNorm_forward(self, input)
def network_LayerNorm_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.LayerNorm_load_state_dict(self, *args, **kwargs)
def network_MultiheadAttention_forward(self, *args, **kwargs):
network_apply_weights(self)
- return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs)
+ return originals.MultiheadAttention_forward(self, *args, **kwargs)
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs):
network_reset_cached_weight(self)
- return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs)
+ return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs)
def list_available_networks():
@@ -552,6 +553,9 @@ def infotext_pasted(infotext, params): if added:
params["Prompt"] += "\n" + "".join(added)
+
+originals: lora_patches.LoraPatches = None
+
extra_network_lora = None
available_networks = {}
|