From 238adeaffb037dedbcefe41e7fd4814a1f17baa2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 09:00:47 +0300 Subject: support specifying te and unet weights separately update lora code support full module --- extensions-builtin/Lora/network.py | 40 +++++++++++++++++++++++++++++++++++++- 1 file changed, 39 insertions(+), 1 deletion(-) (limited to 'extensions-builtin/Lora/network.py') diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index 4ac63722..fe42dbdd 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -68,7 +68,9 @@ class Network: # LoraModule def __init__(self, name, network_on_disk: NetworkOnDisk): self.name = name self.network_on_disk = network_on_disk - self.multiplier = 1.0 + self.te_multiplier = 1.0 + self.unet_multiplier = 1.0 + self.dyn_dim = None self.modules = {} self.mtime = None @@ -88,6 +90,42 @@ class NetworkModule: self.sd_key = weights.sd_key self.sd_module = weights.sd_module + if hasattr(self.sd_module, 'weight'): + self.shape = self.sd_module.weight.shape + + self.dim = None + self.bias = weights.w.get("bias") + self.alpha = weights.w["alpha"].item() if "alpha" in weights.w else None + self.scale = weights.w["scale"].item() if "scale" in weights.w else None + + def multiplier(self): + if 'transformer' in self.sd_key[:20]: + return self.network.te_multiplier + else: + return self.network.unet_multiplier + + def calc_scale(self): + if self.scale is not None: + return self.scale + if self.dim is not None and self.alpha is not None: + return self.alpha / self.dim + + return 1.0 + + def finalize_updown(self, updown, orig_weight, output_shape): + if self.bias is not None: + updown = updown.reshape(self.bias.shape) + updown += self.bias.to(orig_weight.device, dtype=orig_weight.dtype) + updown = updown.reshape(output_shape) + + if len(output_shape) == 4: + updown = updown.reshape(output_shape) + + if orig_weight.size().numel() == updown.size().numel(): + updown = updown.reshape(orig_weight.shape) + + return updown * self.calc_scale() * self.multiplier() + def calc_updown(self, target): raise NotImplementedError() -- cgit v1.2.3