aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/network_oft.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/network_oft.py')
-rw-r--r--extensions-builtin/Lora/network_oft.py17
1 files changed, 14 insertions, 3 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index f085eca5..68efb1db 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -57,21 +57,32 @@ class NetworkModuleOFT(network.NetworkModule):
return R
def calc_updown(self, orig_weight):
+ # this works
R = self.R
+
+ # this causes major deepfrying i.e. just doesn't work
+ # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype)
+
if orig_weight.dim() == 4:
weight = torch.einsum("oihw, op -> pihw", orig_weight, R)
else:
weight = torch.einsum("oi, op -> pi", orig_weight, R)
+
updown = orig_weight @ R
- output_shape = [orig_weight.size(0), R.size(1)]
- #output_shape = [R.size(0), orig_weight.size(1)]
+ output_shape = self.oft_blocks.shape
+
+ ## this works
+ # updown = orig_weight @ R
+ # output_shape = [orig_weight.size(0), R.size(1)]
+
return self.finalize_updown(updown, orig_weight, output_shape)
def forward(self, x, y=None):
x = self.org_forward(x)
if self.multiplier() == 0.0:
return x
- R = self.get_weight().to(x.device, dtype=x.dtype)
+ #R = self.get_weight().to(x.device, dtype=x.dtype)
+ R = self.R.to(x.device, dtype=x.dtype)
if x.dim() == 4:
x = x.permute(0, 2, 3, 1)
x = torch.matmul(x, R)