diff options
author | Vladimir Mandic <mandic00@live.com> | 2023-01-23 17:25:07 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-23 17:25:07 +0000 |
commit | efa7287be0a018dcb92e362460cbe19d42d70b03 (patch) | |
tree | 5ca63fd1273dbd396453a29cbb7ee913c3d29880 /extensions-builtin/Lora/lora.py | |
parent | 925dd09c91e7338aef72e4ec99d67b8b57280215 (diff) | |
parent | c6f20f72629f3c417f10db2289d131441c6832f5 (diff) | |
download | stable-diffusion-webui-gfx803-efa7287be0a018dcb92e362460cbe19d42d70b03.tar.gz stable-diffusion-webui-gfx803-efa7287be0a018dcb92e362460cbe19d42d70b03.tar.bz2 stable-diffusion-webui-gfx803-efa7287be0a018dcb92e362460cbe19d42d70b03.zip |
Merge branch 'AUTOMATIC1111:master' into interrogate
Diffstat (limited to 'extensions-builtin/Lora/lora.py')
-rw-r--r-- | extensions-builtin/Lora/lora.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index da1797dc..137e58f7 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -57,6 +57,7 @@ class LoraUpDownModule: def __init__(self):
self.up = None
self.down = None
+ self.alpha = None
def assign_lora_names_to_compvis_modules(sd_model):
@@ -92,6 +93,15 @@ def load_lora(name, filename): keys_failed_to_match.append(key_diffusers)
continue
+ lora_module = lora.modules.get(key, None)
+ if lora_module is None:
+ lora_module = LoraUpDownModule()
+ lora.modules[key] = lora_module
+
+ if lora_key == "alpha":
+ lora_module.alpha = weight.item()
+ continue
+
if type(sd_module) == torch.nn.Linear:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.Conv2d:
@@ -104,17 +114,12 @@ def load_lora(name, filename): module.to(device=devices.device, dtype=devices.dtype)
- lora_module = lora.modules.get(key, None)
- if lora_module is None:
- lora_module = LoraUpDownModule()
- lora.modules[key] = lora_module
-
if lora_key == "lora_up.weight":
lora_module.up = module
elif lora_key == "lora_down.weight":
lora_module.down = module
else:
- assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight'
+ assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight, lora_down.weight or alpha'
if len(keys_failed_to_match) > 0:
print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}")
@@ -161,7 +166,7 @@ def lora_forward(module, input, res): for lora in loaded_loras:
module = lora.modules.get(lora_layer_name, None)
if module is not None:
- res = res + module.up(module.down(input)) * lora.multiplier
+ res = res + module.up(module.down(input)) * lora.multiplier * (module.alpha / module.up.weight.shape[1] if module.alpha else 1.0)
return res
|