diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-05-08 05:45:26 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-08 05:45:26 +0000 |
commit | b15bbef798f5aba047f0e6955ce94fe589071b44 (patch) | |
tree | cd369de588e77d79fdffc514c944666266a42cf3 /extensions-builtin/Lora/lora.py | |
parent | 67c884196d4627903f6598989251ec5b2c46a4ce (diff) | |
parent | c3eced22fc7b9da4fbb2f55f2d53a7e5e511cfbd (diff) | |
download | stable-diffusion-webui-gfx803-b15bbef798f5aba047f0e6955ce94fe589071b44.tar.gz stable-diffusion-webui-gfx803-b15bbef798f5aba047f0e6955ce94fe589071b44.tar.bz2 stable-diffusion-webui-gfx803-b15bbef798f5aba047f0e6955ce94fe589071b44.zip |
Merge pull request #10089 from AUTOMATIC1111/LoraFix
Fix some Lora's not working
Diffstat (limited to 'extensions-builtin/Lora/lora.py')
-rw-r--r-- | extensions-builtin/Lora/lora.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 94ec021b..83c1c6fd 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -166,8 +166,10 @@ def load_lora(name, filename): module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
elif type(sd_module) == torch.nn.MultiheadAttention:
module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False)
- elif type(sd_module) == torch.nn.Conv2d:
+ elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (1, 1):
module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False)
+ elif type(sd_module) == torch.nn.Conv2d and weight.shape[2:] == (3, 3):
+ module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (3, 3), bias=False)
else:
print(f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}')
continue
@@ -233,6 +235,8 @@ def lora_calc_updown(lora, module, target): if up.shape[2:] == (1, 1) and down.shape[2:] == (1, 1):
updown = (up.squeeze(2).squeeze(2) @ down.squeeze(2).squeeze(2)).unsqueeze(2).unsqueeze(3)
+ elif up.shape[2:] == (3, 3) or down.shape[2:] == (3, 3):
+ updown = torch.nn.functional.conv2d(down.permute(1, 0, 2, 3), up).permute(1, 0, 2, 3)
else:
updown = up @ down
|