diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-15 16:23:27 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-15 16:23:40 +0000 |
commit | f01682ee01e81e8ef84fd6fffe8f7aa17233285d (patch) | |
tree | 80f62099e6af5f77c7df8c092c37c71ed24750d9 /extensions-builtin/Lora/scripts | |
parent | 7327be97aa9beeae881bf4649a56792bd284efd5 (diff) | |
download | stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.tar.gz stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.tar.bz2 stable-diffusion-webui-gfx803-f01682ee01e81e8ef84fd6fffe8f7aa17233285d.zip |
store patches for Lora in a specialized module
Diffstat (limited to 'extensions-builtin/Lora/scripts')
-rw-r--r-- | extensions-builtin/Lora/scripts/lora_script.py | 52 |
1 files changed, 5 insertions, 47 deletions
diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 4c6e774a..546fb55e 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -7,17 +7,14 @@ from fastapi import FastAPI import network
import networks
import lora # noqa:F401
+import lora_patches
import extra_networks_lora
import ui_extra_networks_lora
-from modules import script_callbacks, ui_extra_networks, extra_networks, shared
+from modules import script_callbacks, ui_extra_networks, extra_networks, shared, patches
+
def unload():
- torch.nn.Linear.forward = torch.nn.Linear_forward_before_network
- torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network
- torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network
- torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network
- torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network
- torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network
+ networks.originals.undo()
def before_ui():
@@ -28,46 +25,7 @@ def before_ui(): extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco")
-if not hasattr(torch.nn, 'Linear_forward_before_network'):
- torch.nn.Linear_forward_before_network = torch.nn.Linear.forward
-
-if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'):
- torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict
-
-if not hasattr(torch.nn, 'Conv2d_forward_before_network'):
- torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward
-
-if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'):
- torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict
-
-if not hasattr(torch.nn, 'GroupNorm_forward_before_network'):
- torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward
-
-if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'):
- torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict
-
-if not hasattr(torch.nn, 'LayerNorm_forward_before_network'):
- torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward
-
-if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'):
- torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict
-
-if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'):
- torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward
-
-if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'):
- torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict
-
-torch.nn.Linear.forward = networks.network_Linear_forward
-torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict
-torch.nn.Conv2d.forward = networks.network_Conv2d_forward
-torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict
-torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward
-torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict
-torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward
-torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict
-torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward
-torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict
+networks.originals = lora_patches.LoraPatches()
script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules)
script_callbacks.on_script_unloaded(unload)
|