aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/networks.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2024-01-06 07:50:06 +0000
committerGitHub <noreply@github.com>2024-01-06 07:50:06 +0000
commita4ee64050a15263c73884ed7797c5332c9f559c1 (patch)
tree50be0ea6b7575a39e87a9b9af45b95889ab0f347 /extensions-builtin/Lora/networks.py
parent942617f82830df279f69da95797dd17780451545 (diff)
parent44744d6005da5e424267698ee3279caa597dfebc (diff)
downloadstable-diffusion-webui-gfx803-a4ee64050a15263c73884ed7797c5332c9f559c1.tar.gz
stable-diffusion-webui-gfx803-a4ee64050a15263c73884ed7797c5332c9f559c1.tar.bz2
stable-diffusion-webui-gfx803-a4ee64050a15263c73884ed7797c5332c9f559c1.zip
Merge pull request #14547 from AUTOMATIC1111/lyco-forward
Implement general forward method for all method in built-in lora ext
Diffstat (limited to 'extensions-builtin/Lora/networks.py')
-rw-r--r--extensions-builtin/Lora/networks.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/extensions-builtin/Lora/networks.py b/extensions-builtin/Lora/networks.py
index 72ebd624..32e10b62 100644
--- a/extensions-builtin/Lora/networks.py
+++ b/extensions-builtin/Lora/networks.py
@@ -458,23 +458,23 @@ def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn
self.network_current_names = wanted_names
-def network_forward(module, input, original_forward):
+def network_forward(org_module, input, original_forward):
"""
Old way of applying Lora by executing operations during layer's forward.
Stacking many loras this way results in big performance degradation.
"""
if len(loaded_networks) == 0:
- return original_forward(module, input)
+ return original_forward(org_module, input)
input = devices.cond_cast_unet(input)
- network_restore_weights_from_backup(module)
- network_reset_cached_weight(module)
+ network_restore_weights_from_backup(org_module)
+ network_reset_cached_weight(org_module)
- y = original_forward(module, input)
+ y = original_forward(org_module, input)
- network_layer_name = getattr(module, 'network_layer_name', None)
+ network_layer_name = getattr(org_module, 'network_layer_name', None)
for lora in loaded_networks:
module = lora.modules.get(network_layer_name, None)
if module is None: