aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/network_norm.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-13 05:28:48 +0000
committerGitHub <noreply@github.com>2023-08-13 05:28:48 +0000
commitda80d649fd6a6083be02aca5695367bd25abf0d5 (patch)
treecfb85edce888c9d3092a72ae62340c03afb915ce /extensions-builtin/Lora/network_norm.py
parent61673451ff2b6ea39c8b9591b4a14d7f19a32e63 (diff)
parent5881dcb8873b3f87b9c6545e9cb8d1d77023f4fe (diff)
downloadstable-diffusion-webui-gfx803-da80d649fd6a6083be02aca5695367bd25abf0d5.tar.gz
stable-diffusion-webui-gfx803-da80d649fd6a6083be02aca5695367bd25abf0d5.tar.bz2
stable-diffusion-webui-gfx803-da80d649fd6a6083be02aca5695367bd25abf0d5.zip
Merge pull request #12503 from AUTOMATIC1111/extra-norm-module
Add Norm Module to lora ext and add "bias" support
Diffstat (limited to 'extensions-builtin/Lora/network_norm.py')
-rw-r--r--extensions-builtin/Lora/network_norm.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py
new file mode 100644
index 00000000..ce450158
--- /dev/null
+++ b/extensions-builtin/Lora/network_norm.py
@@ -0,0 +1,28 @@
+import network
+
+
+class ModuleTypeNorm(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["w_norm", "b_norm"]):
+ return NetworkModuleNorm(net, weights)
+
+ return None
+
+
+class NetworkModuleNorm(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+
+ self.w_norm = weights.w.get("w_norm")
+ self.b_norm = weights.w.get("b_norm")
+
+ def calc_updown(self, orig_weight):
+ output_shape = self.w_norm.shape
+ updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ if self.b_norm is not None:
+ ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ else:
+ ex_bias = None
+
+ return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)