From f01682ee01e81e8ef84fd6fffe8f7aa17233285d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 15 Aug 2023 19:23:27 +0300 Subject: store patches for Lora in a specialized module --- extensions-builtin/Lora/networks.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) (limited to 'extensions-builtin/Lora/networks.py') diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 22fdff4a..9fca36b6 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -2,6 +2,7 @@ import logging import os import re +import lora_patches import network import network_lora import network_hada @@ -418,74 +419,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def network_Linear_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.Linear_forward_before_network) + return network_forward(self, input, originals.Linear_forward) network_apply_weights(self) - return torch.nn.Linear_forward_before_network(self, input) + return originals.Linear_forward(self, input) def network_Linear_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs) + return originals.Linear_load_state_dict(self, *args, **kwargs) def network_Conv2d_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.Conv2d_forward_before_network) + return network_forward(self, input, originals.Conv2d_forward) network_apply_weights(self) - return torch.nn.Conv2d_forward_before_network(self, input) + return originals.Conv2d_forward(self, input) def network_Conv2d_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) + return originals.Conv2d_load_state_dict(self, *args, **kwargs) def network_GroupNorm_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.GroupNorm_forward_before_network) + return network_forward(self, input, originals.GroupNorm_forward) network_apply_weights(self) - return torch.nn.GroupNorm_forward_before_network(self, input) + return originals.GroupNorm_forward(self, input) def network_GroupNorm_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs) + return originals.GroupNorm_load_state_dict(self, *args, **kwargs) def network_LayerNorm_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.LayerNorm_forward_before_network) + return network_forward(self, input, originals.LayerNorm_forward) network_apply_weights(self) - return torch.nn.LayerNorm_forward_before_network(self, input) + return originals.LayerNorm_forward(self, input) def network_LayerNorm_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs) + return originals.LayerNorm_load_state_dict(self, *args, **kwargs) def network_MultiheadAttention_forward(self, *args, **kwargs): network_apply_weights(self) - return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs) + return originals.MultiheadAttention_forward(self, *args, **kwargs) def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) + return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs) def list_available_networks(): @@ -552,6 +553,9 @@ def infotext_pasted(infotext, params): if added: params["Prompt"] += "\n" + "".join(added) + +originals: lora_patches.LoraPatches = None + extra_network_lora = None available_networks = {} -- cgit v1.2.3