diff options
author | v0xie <28695009+v0xie@users.noreply.github.com> | 2023-11-15 11:08:50 +0000 |
---|---|---|
committer | v0xie <28695009+v0xie@users.noreply.github.com> | 2023-11-15 11:08:50 +0000 |
commit | d6d0b22e6657fc84039e82ee735a57101bfe7c17 (patch) | |
tree | 9a77f9e0266b5ef5ed2ebd13837d0b3929fd169a /extensions-builtin | |
parent | 7edd50f304ebf8a713839035d4e9eacaa98d3762 (diff) | |
download | stable-diffusion-webui-gfx803-d6d0b22e6657fc84039e82ee735a57101bfe7c17.tar.gz stable-diffusion-webui-gfx803-d6d0b22e6657fc84039e82ee735a57101bfe7c17.tar.bz2 stable-diffusion-webui-gfx803-d6d0b22e6657fc84039e82ee735a57101bfe7c17.zip |
fix: ignore calc_scale() for COFT which has very small alpha
Diffstat (limited to 'extensions-builtin')
-rw-r--r-- | extensions-builtin/Lora/network_oft.py | 16 |
1 files changed, 5 insertions, 11 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 93402bb2..c45a8d23 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -99,12 +99,9 @@ class NetworkModuleOFT(network.NetworkModule): is_other_linear = type(self.sd_module) in [torch.nn.MultiheadAttention] if not is_other_linear: - #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - # orig_weight=orig_weight.permute(1, 0) - oft_blocks = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) - # without this line the results are significantly worse / less accurate + # ensure skew-symmetric matrix oft_blocks = oft_blocks - oft_blocks.transpose(1, 2) R = oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype) @@ -118,9 +115,6 @@ class NetworkModuleOFT(network.NetworkModule): ) merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...') - #if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]: - # orig_weight=orig_weight.permute(1, 0) - updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight output_shape = orig_weight.shape else: @@ -132,10 +126,10 @@ class NetworkModuleOFT(network.NetworkModule): return self.finalize_updown(updown, orig_weight, output_shape) def calc_updown(self, orig_weight): - multiplier = self.multiplier() * self.calc_scale() - #if self.is_kohya: - # return self.calc_updown_kohya(orig_weight, multiplier) - #else: + # if alpha is a very small number as in coft, calc_scale will return a almost zero number so we ignore it + #multiplier = self.multiplier() * self.calc_scale() + multiplier = self.multiplier() + return self.calc_updown_kb(orig_weight, multiplier) # override to remove the multiplier/scale factor; it's already multiplied in get_weight |