aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/lora.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-05-08 09:07:43 +0000
committerAUTOMATIC <16777216c@gmail.com>2023-05-08 09:07:43 +0000
commitec0da07236d286f37c86f9cd92642e24381dd6a5 (patch)
treebdb0725577d94241c1cc26b5a85f0de209276081 /extensions-builtin/Lora/lora.py
parent083dc3c76ab7dbc7b2b04f3396d4f5280b002906 (diff)
downloadstable-diffusion-webui-gfx803-ec0da07236d286f37c86f9cd92642e24381dd6a5.tar.gz
stable-diffusion-webui-gfx803-ec0da07236d286f37c86f9cd92642e24381dd6a5.tar.bz2
stable-diffusion-webui-gfx803-ec0da07236d286f37c86f9cd92642e24381dd6a5.zip
Lora: add an option to use old method of applying loras
Diffstat (limited to 'extensions-builtin/Lora/lora.py')
-rw-r--r--extensions-builtin/Lora/lora.py56
1 files changed, 50 insertions, 6 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 83933639..d488b5ae 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -245,6 +245,19 @@ def lora_calc_updown(lora, module, target):
return updown
+def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
+ weights_backup = getattr(self, "lora_weights_backup", None)
+
+ if weights_backup is None:
+ return
+
+ if isinstance(self, torch.nn.MultiheadAttention):
+ self.in_proj_weight.copy_(weights_backup[0])
+ self.out_proj.weight.copy_(weights_backup[1])
+ else:
+ self.weight.copy_(weights_backup)
+
+
def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]):
"""
Applies the currently selected set of Loras to the weights of torch layer self.
@@ -269,12 +282,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
self.lora_weights_backup = weights_backup
if current_names != wanted_names:
- if weights_backup is not None:
- if isinstance(self, torch.nn.MultiheadAttention):
- self.in_proj_weight.copy_(weights_backup[0])
- self.out_proj.weight.copy_(weights_backup[1])
- else:
- self.weight.copy_(weights_backup)
+ lora_restore_weights_from_backup(self)
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
@@ -305,12 +313,45 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu
setattr(self, "lora_current_names", wanted_names)
+def lora_forward(module, input, original_forward):
+ """
+ Old way of applying Lora by executing operations during layer's forward.
+ Stacking many loras this way results in big performance degradation.
+ """
+
+ if len(loaded_loras) == 0:
+ return original_forward(module, input)
+
+ input = devices.cond_cast_unet(input)
+
+ lora_restore_weights_from_backup(module)
+ lora_reset_cached_weight(module)
+
+ res = original_forward(module, input)
+
+ lora_layer_name = getattr(module, 'lora_layer_name', None)
+ for lora in loaded_loras:
+ module = lora.modules.get(lora_layer_name, None)
+ if module is None:
+ continue
+
+ module.up.to(device=devices.device)
+ module.down.to(device=devices.device)
+
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+
+ return res
+
+
def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]):
setattr(self, "lora_current_names", ())
setattr(self, "lora_weights_backup", None)
def lora_Linear_forward(self, input):
+ if shared.opts.lora_functional:
+ return lora_forward(self, input, torch.nn.Linear_forward_before_lora)
+
lora_apply_weights(self)
return torch.nn.Linear_forward_before_lora(self, input)
@@ -323,6 +364,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs):
def lora_Conv2d_forward(self, input):
+ if shared.opts.lora_functional:
+ return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora)
+
lora_apply_weights(self)
return torch.nn.Conv2d_forward_before_lora(self, input)