aboutsummaryrefslogtreecommitdiffstats
path: root/extensions-builtin/Lora/network_oft.py
diff options
context:
space:
mode:
authorv0xie <28695009+v0xie@users.noreply.github.com>2023-11-04 00:52:55 +0000
committerv0xie <28695009+v0xie@users.noreply.github.com>2023-11-04 00:52:55 +0000
commitfe1967a4c4a02eccfa45b65ee19a5b0773ced31c (patch)
tree3b98961130dcd6a1e3fceeea4ddb75bf24810439 /extensions-builtin/Lora/network_oft.py
parentd727ddfccdc6d474767be9dc3bf504150e81a8a5 (diff)
downloadstable-diffusion-webui-gfx803-fe1967a4c4a02eccfa45b65ee19a5b0773ced31c.tar.gz
stable-diffusion-webui-gfx803-fe1967a4c4a02eccfa45b65ee19a5b0773ced31c.tar.bz2
stable-diffusion-webui-gfx803-fe1967a4c4a02eccfa45b65ee19a5b0773ced31c.zip
skip multihead attn for now
Diffstat (limited to 'extensions-builtin/Lora/network_oft.py')
-rw-r--r--extensions-builtin/Lora/network_oft.py54
1 files changed, 37 insertions, 17 deletions
diff --git a/extensions-builtin/Lora/network_oft.py b/extensions-builtin/Lora/network_oft.py
index e102eafc..979a2047 100644
--- a/extensions-builtin/Lora/network_oft.py
+++ b/extensions-builtin/Lora/network_oft.py
@@ -18,6 +18,7 @@ class NetworkModuleOFT(network.NetworkModule):
super().__init__(net, weights)
self.lin_module = None
+ self.org_module: list[torch.Module] = [self.sd_module]
# kohya-ss
if "oft_blocks" in weights.w.keys():
self.is_kohya = True
@@ -30,7 +31,7 @@ class NetworkModuleOFT(network.NetworkModule):
# alpha is rank if alpha is 0 or None
if self.alpha is None:
pass
- self.dim = self.oft_blocks.shape[0] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
+ self.dim = self.oft_blocks.shape[1] # FIXME: almost certainly incorrect, assumes tensor is shape [*, m, n]
else:
raise ValueError("oft_blocks or oft_diag must be in weights dict")
@@ -46,6 +47,12 @@ class NetworkModuleOFT(network.NetworkModule):
# raise ValueError("Linear sd_module must have out_features or embed_dim")
elif is_other_linear:
self.out_dim = self.sd_module.embed_dim
+ #self.org_weight = self.org_module[0].weight
+# if hasattr(self.sd_module, "in_proj_weight"):
+# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
+# if hasattr(self.sd_module, "out_proj_weight"):
+# self.out_proj_dim = self.sd_module.out_proj_weight.shape[0]
+# self.in_proj_dim = self.sd_module.in_proj_weight.shape[1]
elif is_conv:
self.out_dim = self.sd_module.out_channels
else:
@@ -58,10 +65,9 @@ class NetworkModuleOFT(network.NetworkModule):
self.constraint = self.alpha * self.out_dim
#elif is_linear or is_conv:
else:
- self.num_blocks, self.block_size = factorization(self.out_dim, self.dim)
+ self.block_size, self.num_blocks = factorization(self.out_dim, self.dim)
self.constraint = None
- self.org_module: list[torch.Module] = [self.sd_module]
# if is_other_linear:
# weight = self.oft_blocks.reshape(self.oft_blocks.shape[0], -1)
@@ -110,25 +116,39 @@ class NetworkModuleOFT(network.NetworkModule):
def calc_updown(self, orig_weight):
multiplier = self.multiplier() * self.calc_scale()
- R = self.get_weight(self.oft_blocks, multiplier)
- #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
- merged_weight = self.merge_weight(R, orig_weight)
+ is_other_linear = type(self.sd_module) in [ torch.nn.MultiheadAttention]
+ if self.is_kohya and not is_other_linear:
+ R = self.get_weight(self.oft_blocks, multiplier)
+ #R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
+ merged_weight = self.merge_weight(R, orig_weight)
+ elif not self.is_kohya and not is_other_linear:
+ if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
+ orig_weight=orig_weight.permute(1, 0)
+ R = self.oft_blocks.to(orig_weight.device, dtype=orig_weight.dtype)
+ merged_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
+ #orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.block_size, n=self.num_blocks)
+ merged_weight = torch.einsum(
+ 'k n m, k n ... -> k m ...',
+ R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
+ merged_weight
+ )
+ merged_weight = rearrange(merged_weight, 'k m ... -> (k m) ...')
+ if is_other_linear and orig_weight.shape[0] != orig_weight.shape[1]:
+ orig_weight=orig_weight.permute(1, 0)
+ #merged_weight=merged_weight.permute(1, 0)
+ updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
+ #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
+ output_shape = orig_weight.shape
+ else:
+ # skip for now
+ updown = torch.zeros([orig_weight.shape[1], orig_weight.shape[1]], device=orig_weight.device, dtype=orig_weight.dtype)
+ output_shape = (orig_weight.shape[1], orig_weight.shape[1])
#if self.lin_module is not None:
# R = self.lin_module.weight.to(orig_weight.device, dtype=orig_weight.dtype)
# weight = torch.mul(torch.mul(R, multiplier), orig_weight)
#else:
- # orig_weight = rearrange(orig_weight, '(k n) ... -> k n ...', k=self.num_blocks, n=self.block_size)
- # weight = torch.einsum(
- # 'k n m, k n ... -> k m ...',
- # R * multiplier + torch.eye(self.block_size, device=orig_weight.device),
- # orig_weight
- # )
- # weight = rearrange(weight, 'k m ... -> (k m) ...')
-
- updown = merged_weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
- #updown = weight.to(orig_weight.device, dtype=orig_weight.dtype) - orig_weight
- output_shape = orig_weight.shape
+
orig_weight = orig_weight
return self.finalize_updown(updown, orig_weight, output_shape)