From eb01d7f0e0fb46285985803296a25715165fb3f9 Mon Sep 17 00:00:00 2001 From: v0xie <28695009+v0xie@users.noreply.github.com> Date: Wed, 18 Oct 2023 04:56:53 -0700 Subject: faster by calculating R in updown and using cached R in forward --- extensions-builtin/Lora/network_oft.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) (limited to 'extensions-builtin') diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 68efb1db..fd5b0c0f 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -58,17 +58,18 @@ class NetworkModuleOFT(network.NetworkModule): def calc_updown(self, orig_weight): # this works - R = self.R + # R = self.R + self.R = self.get_weight(self.multiplier()) - # this causes major deepfrying i.e. just doesn't work + # sending R to device causes major deepfrying i.e. just doesn't work # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) - if orig_weight.dim() == 4: - weight = torch.einsum("oihw, op -> pihw", orig_weight, R) - else: - weight = torch.einsum("oi, op -> pi", orig_weight, R) + # if orig_weight.dim() == 4: + # weight = torch.einsum("oihw, op -> pihw", orig_weight, R) + # else: + # weight = torch.einsum("oi, op -> pi", orig_weight, R) - updown = orig_weight @ R + updown = orig_weight @ self.R output_shape = self.oft_blocks.shape ## this works -- cgit v1.2.3