diff options
author | v0xie <28695009+v0xie@users.noreply.github.com> | 2023-10-18 11:27:44 +0000 |
---|---|---|
committer | v0xie <28695009+v0xie@users.noreply.github.com> | 2023-10-18 11:27:44 +0000 |
commit | 853e21d98eada4db9a9fd1ae8eda90cf763e2818 (patch) | |
tree | d507292825267486dfe4acccb51f604b9c80e30e /extensions-builtin/Lora/network_oft.py | |
parent | 1c6efdbba774d603c592debaccd6f5ad827bd1b2 (diff) | |
download | stable-diffusion-webui-gfx803-853e21d98eada4db9a9fd1ae8eda90cf763e2818.tar.gz stable-diffusion-webui-gfx803-853e21d98eada4db9a9fd1ae8eda90cf763e2818.tar.bz2 stable-diffusion-webui-gfx803-853e21d98eada4db9a9fd1ae8eda90cf763e2818.zip |
faster by using cached R in forward
Diffstat (limited to 'extensions-builtin/Lora/network_oft.py')
-rw-r--r-- | extensions-builtin/Lora/network_oft.py | 17 |
1 files changed, 14 insertions, 3 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index f085eca5..68efb1db 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -57,21 +57,32 @@ class NetworkModuleOFT(network.NetworkModule): return R def calc_updown(self, orig_weight): + # this works R = self.R + + # this causes major deepfrying i.e. just doesn't work + # R = self.R.to(orig_weight.device, dtype=orig_weight.dtype) + if orig_weight.dim() == 4: weight = torch.einsum("oihw, op -> pihw", orig_weight, R) else: weight = torch.einsum("oi, op -> pi", orig_weight, R) + updown = orig_weight @ R - output_shape = [orig_weight.size(0), R.size(1)] - #output_shape = [R.size(0), orig_weight.size(1)] + output_shape = self.oft_blocks.shape + + ## this works + # updown = orig_weight @ R + # output_shape = [orig_weight.size(0), R.size(1)] + return self.finalize_updown(updown, orig_weight, output_shape) def forward(self, x, y=None): x = self.org_forward(x) if self.multiplier() == 0.0: return x - R = self.get_weight().to(x.device, dtype=x.dtype) + #R = self.get_weight().to(x.device, dtype=x.dtype) + R = self.R.to(x.device, dtype=x.dtype) if x.dim() == 4: x = x.permute(0, 2, 3, 1) x = torch.matmul(x, R) |