aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/extra_networks_lora.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-07-17 06:00:47 +0000
committerAUTOMATIC1111 <16777216c@gmail.com>2023-07-17 06:00:47 +0000
commit238adeaffb037dedbcefe41e7fd4814a1f17baa2 (patch)
tree901557b4f4814effe18d23900317f7dbabf3e162 /extensions-builtin/Lora/extra_networks_lora.py
parent46466f09d0b0c14118033dee6af0f876059776d3 (diff)
downloadstable-diffusion-webui-gfx803-238adeaffb037dedbcefe41e7fd4814a1f17baa2.tar.gz
stable-diffusion-webui-gfx803-238adeaffb037dedbcefe41e7fd4814a1f17baa2.tar.bz2
stable-diffusion-webui-gfx803-238adeaffb037dedbcefe41e7fd4814a1f17baa2.zip
support specifying te and unet weights separately
update lora code support full module
Diffstat (limited to 'extensions-builtin/Lora/extra_networks_lora.py')
-rw-r--r--extensions-builtin/Lora/extra_networks_lora.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py
index 8a6639cf..084c41d0 100644
--- a/extensions-builtin/Lora/extra_networks_lora.py
+++ b/extensions-builtin/Lora/extra_networks_lora.py
@@ -14,14 +14,28 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork):
params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
- multipliers = []
+ te_multipliers = []
+ unet_multipliers = []
+ dyn_dims = []
for params in params_list:
assert params.items
- names.append(params.items[0])
- multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+ names.append(params.positional[0])
- networks.load_networks(names, multipliers)
+ te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
+ te_multiplier = float(params.named.get("te", te_multiplier))
+
+ unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else 1.0
+ unet_multiplier = float(params.named.get("unet", unet_multiplier))
+
+ dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
+ dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
+
+ te_multipliers.append(te_multiplier)
+ unet_multipliers.append(unet_multiplier)
+ dyn_dims.append(dyn_dim)
+
+ networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
if shared.opts.lora_add_hashes_to_infotext:
network_hashes = []