diff options
author | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-12-13 17:43:24 +0000 |
---|---|---|
committer | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-12-13 17:43:24 +0000 |
commit | 265bc26c21264d63956e8f30f1ce31dec917fc76 (patch) | |
tree | d0372003336bff507a9548e03873f1ddef3ae242 /extensions-builtin/Lora | |
parent | 735c9e8059384d4f640e5582413c30871f83eac5 (diff) | |
download | stable-diffusion-webui-gfx803-265bc26c21264d63956e8f30f1ce31dec917fc76.tar.gz stable-diffusion-webui-gfx803-265bc26c21264d63956e8f30f1ce31dec917fc76.tar.bz2 stable-diffusion-webui-gfx803-265bc26c21264d63956e8f30f1ce31dec917fc76.zip |
Use self.scale instead of custom finalize
Diffstat (limited to 'extensions-builtin/Lora')
-rw-r--r-- | extensions-builtin/Lora/network_oft.py | 20 |
1 files changed, 2 insertions, 18 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py index 44465f7a..e3ae61a2 100644 --- a/extensions-builtin/Lora/network_oft.py +++ b/extensions-builtin/Lora/network_oft.py @@ -21,6 +21,8 @@ class NetworkModuleOFT(network.NetworkModule): self.lin_module = None self.org_module: list[torch.Module] = [self.sd_module] + self.scale = 1.0 + # kohya-ss if "oft_blocks" in weights.w.keys(): self.is_kohya = True @@ -78,21 +80,3 @@ class NetworkModuleOFT(network.NetworkModule): print(torch.norm(updown)) output_shape = orig_weight.shape return self.finalize_updown(updown, orig_weight, output_shape) - - def finalize_updown(self, updown, orig_weight, output_shape, ex_bias=None): - if self.bias is not None: - updown = updown.reshape(self.bias.shape) - updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) - updown = updown.reshape(output_shape) - - if len(output_shape) == 4: - updown = updown.reshape(output_shape) - - if orig_weight.size().numel() == updown.size().numel(): - updown = updown.reshape(orig_weight.shape) - - if ex_bias is not None: - ex_bias = ex_bias * self.multiplier() - - # Ignore calc_scale, which is not used in OFT. - return updown * self.multiplier(), ex_bias |