aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/network_norm.py
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-08-12 18:27:39 +0000
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-08-12 18:27:39 +0000
commitbd4da4474bef5c9c1f690c62b971704ee73d2860 (patch)
treef3624d0521366fb23c4e7861ea1d0a04b43483e6 /extensions-builtin/Lora/network_norm.py
parentb2080756fcdc328292fc38998c06ccf23e53bd7e (diff)
downloadstable-diffusion-webui-gfx803-bd4da4474bef5c9c1f690c62b971704ee73d2860.tar.gz
stable-diffusion-webui-gfx803-bd4da4474bef5c9c1f690c62b971704ee73d2860.tar.bz2
stable-diffusion-webui-gfx803-bd4da4474bef5c9c1f690c62b971704ee73d2860.zip
Add extra norm module into built-in lora ext
refer to LyCORIS 1.9.0.dev6 add new option and module for training norm layer (Which is reported to be good for style)
Diffstat (limited to 'extensions-builtin/Lora/network_norm.py')
-rw-r--r--extensions-builtin/Lora/network_norm.py29
1 files changed, 29 insertions, 0 deletions
diff --git a/extensions-builtin/Lora/network_norm.py b/extensions-builtin/Lora/network_norm.py
new file mode 100644
index 00000000..dab8b684
--- /dev/null
+++ b/extensions-builtin/Lora/network_norm.py
@@ -0,0 +1,29 @@
+import network
+
+
+class ModuleTypeNorm(network.ModuleType):
+ def create_module(self, net: network.Network, weights: network.NetworkWeights):
+ if all(x in weights.w for x in ["w_norm", "b_norm"]):
+ return NetworkModuleNorm(net, weights)
+
+ return None
+
+
+class NetworkModuleNorm(network.NetworkModule):
+ def __init__(self, net: network.Network, weights: network.NetworkWeights):
+ super().__init__(net, weights)
+ print("NetworkModuleNorm")
+
+ self.w_norm = weights.w.get("w_norm")
+ self.b_norm = weights.w.get("b_norm")
+
+ def calc_updown(self, orig_weight):
+ output_shape = self.w_norm.shape
+ updown = self.w_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+
+ if self.b_norm is not None:
+ ex_bias = self.b_norm.to(orig_weight.device, dtype=orig_weight.dtype)
+ else:
+ ex_bias = None
+
+ return self.finalize_updown(updown, orig_weight, output_shape, ex_bias)