diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-17 06:00:47 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-17 06:00:47 +0000 |
commit | 238adeaffb037dedbcefe41e7fd4814a1f17baa2 (patch) | |
tree | 901557b4f4814effe18d23900317f7dbabf3e162 /extensions-builtin/Lora/extra_networks_lora.py | |
parent | 46466f09d0b0c14118033dee6af0f876059776d3 (diff) | |
download | stable-diffusion-webui-gfx803-238adeaffb037dedbcefe41e7fd4814a1f17baa2.tar.gz stable-diffusion-webui-gfx803-238adeaffb037dedbcefe41e7fd4814a1f17baa2.tar.bz2 stable-diffusion-webui-gfx803-238adeaffb037dedbcefe41e7fd4814a1f17baa2.zip |
support specifying te and unet weights separately
update lora code
support full module
Diffstat (limited to 'extensions-builtin/Lora/extra_networks_lora.py')
-rw-r--r-- | extensions-builtin/Lora/extra_networks_lora.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py index 8a6639cf..084c41d0 100644 --- a/extensions-builtin/Lora/extra_networks_lora.py +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -14,14 +14,28 @@ class ExtraNetworkLora(extra_networks.ExtraNetwork): params_list.append(extra_networks.ExtraNetworkParams(items=[additional, shared.opts.extra_networks_default_multiplier]))
names = []
- multipliers = []
+ te_multipliers = []
+ unet_multipliers = []
+ dyn_dims = []
for params in params_list:
assert params.items
- names.append(params.items[0])
- multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0)
+ names.append(params.positional[0])
- networks.load_networks(names, multipliers)
+ te_multiplier = float(params.positional[1]) if len(params.positional) > 1 else 1.0
+ te_multiplier = float(params.named.get("te", te_multiplier))
+
+ unet_multiplier = float(params.positional[2]) if len(params.positional) > 2 else 1.0
+ unet_multiplier = float(params.named.get("unet", unet_multiplier))
+
+ dyn_dim = int(params.positional[3]) if len(params.positional) > 3 else None
+ dyn_dim = int(params.named["dyn"]) if "dyn" in params.named else dyn_dim
+
+ te_multipliers.append(te_multiplier)
+ unet_multipliers.append(unet_multiplier)
+ dyn_dims.append(dyn_dim)
+
+ networks.load_networks(names, te_multipliers, unet_multipliers, dyn_dims)
if shared.opts.lora_add_hashes_to_infotext:
network_hashes = []
|