aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/network_oft.py
diff options
context:
space:
mode:
Diffstat (limited to 'extensions-builtin/Lora/network_oft.py')
-rw-r--r--extensions-builtin/Lora/network_oft.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index 2af1bc4c..0a87958e 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -37,7 +37,7 @@ class NetworkModuleOFT(network.NetworkModule):
def apply_to(self):
self.org_forward = self.org_module[0].forward
self.org_module[0].forward = self.forward
-
+
def get_weight(self, oft_blocks, multiplier=None):
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
norm_Q = torch.norm(block_Q.flatten())
@@ -66,7 +66,7 @@ class NetworkModuleOFT(network.NetworkModule):
output_shape = self.oft_blocks.shape
return self.finalize_updown(updown, orig_weight, output_shape)
-
+
def forward(self, x, y=None):
x = self.org_forward(x)
if self.multiplier() == 0.0: