From ec0da07236d286f37c86f9cd92642e24381dd6a5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 8 May 2023 12:07:43 +0300 Subject: Lora: add an option to use old method of applying loras --- extensions-builtin/Lora/lora.py | 56 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 50 insertions(+), 6 deletions(-) (limited to 'extensions-builtin/Lora/lora.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 83933639..d488b5ae 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -245,6 +245,19 @@ def lora_calc_updown(lora, module, target): return updown +def lora_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): + weights_backup = getattr(self, "lora_weights_backup", None) + + if weights_backup is None: + return + + if isinstance(self, torch.nn.MultiheadAttention): + self.in_proj_weight.copy_(weights_backup[0]) + self.out_proj.weight.copy_(weights_backup[1]) + else: + self.weight.copy_(weights_backup) + + def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.MultiheadAttention]): """ Applies the currently selected set of Loras to the weights of torch layer self. @@ -269,12 +282,7 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu self.lora_weights_backup = weights_backup if current_names != wanted_names: - if weights_backup is not None: - if isinstance(self, torch.nn.MultiheadAttention): - self.in_proj_weight.copy_(weights_backup[0]) - self.out_proj.weight.copy_(weights_backup[1]) - else: - self.weight.copy_(weights_backup) + lora_restore_weights_from_backup(self) for lora in loaded_loras: module = lora.modules.get(lora_layer_name, None) @@ -305,12 +313,45 @@ def lora_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.Mu setattr(self, "lora_current_names", wanted_names) +def lora_forward(module, input, original_forward): + """ + Old way of applying Lora by executing operations during layer's forward. + Stacking many loras this way results in big performance degradation. + """ + + if len(loaded_loras) == 0: + return original_forward(module, input) + + input = devices.cond_cast_unet(input) + + lora_restore_weights_from_backup(module) + lora_reset_cached_weight(module) + + res = original_forward(module, input) + + lora_layer_name = getattr(module, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is None: + continue + + module.up.to(device=devices.device) + module.down.to(device=devices.device) + + res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0) + + return res + + def lora_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): setattr(self, "lora_current_names", ()) setattr(self, "lora_weights_backup", None) def lora_Linear_forward(self, input): + if shared.opts.lora_functional: + return lora_forward(self, input, torch.nn.Linear_forward_before_lora) + lora_apply_weights(self) return torch.nn.Linear_forward_before_lora(self, input) @@ -323,6 +364,9 @@ def lora_Linear_load_state_dict(self, *args, **kwargs): def lora_Conv2d_forward(self, input): + if shared.opts.lora_functional: + return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora) + lora_apply_weights(self) return torch.nn.Conv2d_forward_before_lora(self, input) -- cgit v1.2.3