aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/lora.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-25 08:29:46 +0000
committerAUTOMATIC <16777216c@gmail.com>2023-01-25 08:29:46 +0000
commit1bfec873fa13d803f3d4ac2a12bf6983838233fe (patch)
treea51fc3389fd2165397ef408d1402322e57331a65 /extensions-builtin/Lora/lora.py
parent48a15821de768fea76e66f26df83df3fddf18f4b (diff)
downloadstable-diffusion-webui-gfx803-1bfec873fa13d803f3d4ac2a12bf6983838233fe.tar.gz
stable-diffusion-webui-gfx803-1bfec873fa13d803f3d4ac2a12bf6983838233fe.tar.bz2
stable-diffusion-webui-gfx803-1bfec873fa13d803f3d4ac2a12bf6983838233fe.zip
add an experimental option to apply loras to outputs rather than inputs
Diffstat (limited to 'extensions-builtin/Lora/lora.py')
-rw-r--r--extensions-builtin/Lora/lora.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py
index 137e58f7..cb8f1d36 100644
--- a/extensions-builtin/Lora/lora.py
+++ b/extensions-builtin/Lora/lora.py
@@ -166,7 +166,10 @@ def lora_forward(module, input, res):
for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
- res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ if shared.opts.lora_apply_to_outputs and res.shape == input.shape:
+ res = res + module.up(module.down(res)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
+ else:
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res