From f01682ee01e81e8ef84fd6fffe8f7aa17233285d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 15 Aug 2023 19:23:27 +0300 Subject: store patches for Lora in a specialized module --- extensions-builtin/Lora/lora_patches.py | 31 +++++++++++++++ extensions-builtin/Lora/networks.py | 32 +++++++++------- extensions-builtin/Lora/scripts/lora_script.py | 52 +++----------------------- 3 files changed, 54 insertions(+), 61 deletions(-) create mode 100644 extensions-builtin/Lora/lora_patches.py (limited to 'extensions-builtin/Lora') diff --git a/extensions-builtin/Lora/lora_patches.py b/extensions-builtin/Lora/lora_patches.py new file mode 100644 index 00000000..b394d8e9 --- /dev/null +++ b/extensions-builtin/Lora/lora_patches.py @@ -0,0 +1,31 @@ +import torch + +import networks +from modules import patches + + +class LoraPatches: + def __init__(self): + self.Linear_forward = patches.patch(__name__, torch.nn.Linear, 'forward', networks.network_Linear_forward) + self.Linear_load_state_dict = patches.patch(__name__, torch.nn.Linear, '_load_from_state_dict', networks.network_Linear_load_state_dict) + self.Conv2d_forward = patches.patch(__name__, torch.nn.Conv2d, 'forward', networks.network_Conv2d_forward) + self.Conv2d_load_state_dict = patches.patch(__name__, torch.nn.Conv2d, '_load_from_state_dict', networks.network_Conv2d_load_state_dict) + self.GroupNorm_forward = patches.patch(__name__, torch.nn.GroupNorm, 'forward', networks.network_GroupNorm_forward) + self.GroupNorm_load_state_dict = patches.patch(__name__, torch.nn.GroupNorm, '_load_from_state_dict', networks.network_GroupNorm_load_state_dict) + self.LayerNorm_forward = patches.patch(__name__, torch.nn.LayerNorm, 'forward', networks.network_LayerNorm_forward) + self.LayerNorm_load_state_dict = patches.patch(__name__, torch.nn.LayerNorm, '_load_from_state_dict', networks.network_LayerNorm_load_state_dict) + self.MultiheadAttention_forward = patches.patch(__name__, torch.nn.MultiheadAttention, 'forward', networks.network_MultiheadAttention_forward) + self.MultiheadAttention_load_state_dict = patches.patch(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict', networks.network_MultiheadAttention_load_state_dict) + + def undo(self): + self.Linear_forward = patches.undo(__name__, torch.nn.Linear, 'forward') + self.Linear_load_state_dict = patches.undo(__name__, torch.nn.Linear, '_load_from_state_dict') + self.Conv2d_forward = patches.undo(__name__, torch.nn.Conv2d, 'forward') + self.Conv2d_load_state_dict = patches.undo(__name__, torch.nn.Conv2d, '_load_from_state_dict') + self.GroupNorm_forward = patches.undo(__name__, torch.nn.GroupNorm, 'forward') + self.GroupNorm_load_state_dict = patches.undo(__name__, torch.nn.GroupNorm, '_load_from_state_dict') + self.LayerNorm_forward = patches.undo(__name__, torch.nn.LayerNorm, 'forward') + self.LayerNorm_load_state_dict = patches.undo(__name__, torch.nn.LayerNorm, '_load_from_state_dict') + self.MultiheadAttention_forward = patches.undo(__name__, torch.nn.MultiheadAttention, 'forward') + self.MultiheadAttention_load_state_dict = patches.undo(__name__, torch.nn.MultiheadAttention, '_load_from_state_dict') + diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py index 22fdff4a..9fca36b6 100644 --- a/extensions-builtin/Lora/networks.py +++ b/extensions-builtin/Lora/networks.py @@ -2,6 +2,7 @@ import logging import os import re +import lora_patches import network import network_lora import network_hada @@ -418,74 +419,74 @@ def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): def network_Linear_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.Linear_forward_before_network) + return network_forward(self, input, originals.Linear_forward) network_apply_weights(self) - return torch.nn.Linear_forward_before_network(self, input) + return originals.Linear_forward(self, input) def network_Linear_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.Linear_load_state_dict_before_network(self, *args, **kwargs) + return originals.Linear_load_state_dict(self, *args, **kwargs) def network_Conv2d_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.Conv2d_forward_before_network) + return network_forward(self, input, originals.Conv2d_forward) network_apply_weights(self) - return torch.nn.Conv2d_forward_before_network(self, input) + return originals.Conv2d_forward(self, input) def network_Conv2d_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.Conv2d_load_state_dict_before_network(self, *args, **kwargs) + return originals.Conv2d_load_state_dict(self, *args, **kwargs) def network_GroupNorm_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.GroupNorm_forward_before_network) + return network_forward(self, input, originals.GroupNorm_forward) network_apply_weights(self) - return torch.nn.GroupNorm_forward_before_network(self, input) + return originals.GroupNorm_forward(self, input) def network_GroupNorm_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.GroupNorm_load_state_dict_before_network(self, *args, **kwargs) + return originals.GroupNorm_load_state_dict(self, *args, **kwargs) def network_LayerNorm_forward(self, input): if shared.opts.lora_functional: - return network_forward(self, input, torch.nn.LayerNorm_forward_before_network) + return network_forward(self, input, originals.LayerNorm_forward) network_apply_weights(self) - return torch.nn.LayerNorm_forward_before_network(self, input) + return originals.LayerNorm_forward(self, input) def network_LayerNorm_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.LayerNorm_load_state_dict_before_network(self, *args, **kwargs) + return originals.LayerNorm_load_state_dict(self, *args, **kwargs) def network_MultiheadAttention_forward(self, *args, **kwargs): network_apply_weights(self) - return torch.nn.MultiheadAttention_forward_before_network(self, *args, **kwargs) + return originals.MultiheadAttention_forward(self, *args, **kwargs) def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): network_reset_cached_weight(self) - return torch.nn.MultiheadAttention_load_state_dict_before_network(self, *args, **kwargs) + return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs) def list_available_networks(): @@ -552,6 +553,9 @@ def infotext_pasted(infotext, params): if added: params["Prompt"] += "\n" + "".join(added) + +originals: lora_patches.LoraPatches = None + extra_network_lora = None available_networks = {} diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index 4c6e774a..546fb55e 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -7,17 +7,14 @@ from fastapi import FastAPI import network import networks import lora # noqa:F401 +import lora_patches import extra_networks_lora import ui_extra_networks_lora -from modules import script_callbacks, ui_extra_networks, extra_networks, shared +from modules import script_callbacks, ui_extra_networks, extra_networks, shared, patches + def unload(): - torch.nn.Linear.forward = torch.nn.Linear_forward_before_network - torch.nn.Linear._load_from_state_dict = torch.nn.Linear_load_state_dict_before_network - torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_network - torch.nn.Conv2d._load_from_state_dict = torch.nn.Conv2d_load_state_dict_before_network - torch.nn.MultiheadAttention.forward = torch.nn.MultiheadAttention_forward_before_network - torch.nn.MultiheadAttention._load_from_state_dict = torch.nn.MultiheadAttention_load_state_dict_before_network + networks.originals.undo() def before_ui(): @@ -28,46 +25,7 @@ def before_ui(): extra_networks.register_extra_network_alias(networks.extra_network_lora, "lyco") -if not hasattr(torch.nn, 'Linear_forward_before_network'): - torch.nn.Linear_forward_before_network = torch.nn.Linear.forward - -if not hasattr(torch.nn, 'Linear_load_state_dict_before_network'): - torch.nn.Linear_load_state_dict_before_network = torch.nn.Linear._load_from_state_dict - -if not hasattr(torch.nn, 'Conv2d_forward_before_network'): - torch.nn.Conv2d_forward_before_network = torch.nn.Conv2d.forward - -if not hasattr(torch.nn, 'Conv2d_load_state_dict_before_network'): - torch.nn.Conv2d_load_state_dict_before_network = torch.nn.Conv2d._load_from_state_dict - -if not hasattr(torch.nn, 'GroupNorm_forward_before_network'): - torch.nn.GroupNorm_forward_before_network = torch.nn.GroupNorm.forward - -if not hasattr(torch.nn, 'GroupNorm_load_state_dict_before_network'): - torch.nn.GroupNorm_load_state_dict_before_network = torch.nn.GroupNorm._load_from_state_dict - -if not hasattr(torch.nn, 'LayerNorm_forward_before_network'): - torch.nn.LayerNorm_forward_before_network = torch.nn.LayerNorm.forward - -if not hasattr(torch.nn, 'LayerNorm_load_state_dict_before_network'): - torch.nn.LayerNorm_load_state_dict_before_network = torch.nn.LayerNorm._load_from_state_dict - -if not hasattr(torch.nn, 'MultiheadAttention_forward_before_network'): - torch.nn.MultiheadAttention_forward_before_network = torch.nn.MultiheadAttention.forward - -if not hasattr(torch.nn, 'MultiheadAttention_load_state_dict_before_network'): - torch.nn.MultiheadAttention_load_state_dict_before_network = torch.nn.MultiheadAttention._load_from_state_dict - -torch.nn.Linear.forward = networks.network_Linear_forward -torch.nn.Linear._load_from_state_dict = networks.network_Linear_load_state_dict -torch.nn.Conv2d.forward = networks.network_Conv2d_forward -torch.nn.Conv2d._load_from_state_dict = networks.network_Conv2d_load_state_dict -torch.nn.GroupNorm.forward = networks.network_GroupNorm_forward -torch.nn.GroupNorm._load_from_state_dict = networks.network_GroupNorm_load_state_dict -torch.nn.LayerNorm.forward = networks.network_LayerNorm_forward -torch.nn.LayerNorm._load_from_state_dict = networks.network_LayerNorm_load_state_dict -torch.nn.MultiheadAttention.forward = networks.network_MultiheadAttention_forward -torch.nn.MultiheadAttention._load_from_state_dict = networks.network_MultiheadAttention_load_state_dict +networks.originals = lora_patches.LoraPatches() script_callbacks.on_model_loaded(networks.assign_network_names_to_compvis_modules) script_callbacks.on_script_unloaded(unload) -- cgit v1.2.3