diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-07 13:39:51 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-07 13:39:51 +0000 |
commit | f7c787eb7c295c27439f4fbdf78c26b8389560be (patch) | |
tree | 699c9721baa119af3f8f6e888fa25373f46c6042 /modules/hypernetwork.py | |
parent | 97bc0b9504572d2df80598d0b694703bcd626de6 (diff) | |
download | stable-diffusion-webui-gfx803-f7c787eb7c295c27439f4fbdf78c26b8389560be.tar.gz stable-diffusion-webui-gfx803-f7c787eb7c295c27439f4fbdf78c26b8389560be.tar.bz2 stable-diffusion-webui-gfx803-f7c787eb7c295c27439f4fbdf78c26b8389560be.zip |
make it possible to use hypernetworks without opt split attention
Diffstat (limited to 'modules/hypernetwork.py')
-rw-r--r-- | modules/hypernetwork.py | 42 |
1 files changed, 34 insertions, 8 deletions
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index c5cf4afa..c7b86682 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -4,7 +4,12 @@ import sys import traceback
import torch
-from modules import devices
+
+from ldm.util import default
+from modules import devices, shared
+import torch
+from torch import einsum
+from einops import rearrange, repeat
class HypernetworkModule(torch.nn.Module):
@@ -48,15 +53,36 @@ def load_hypernetworks(path): return res
-def apply(self, x, context=None, mask=None, original=None):
+def attention_CrossAttention_forward(self, x, context=None, mask=None):
+ h = self.heads
+
+ q = self.to_q(x)
+ context = default(context, x)
- if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork:
- if context.shape[1] == 77 and CrossAttention.noise_cond:
- context = context + (torch.randn_like(context) * 0.1)
- h_k, h_v = CrossAttention.hypernetwork[context.shape[2]]
- k = self.to_k(h_k(context))
- v = self.to_v(h_v(context))
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is not None:
+ k = self.to_k(hypernetwork_layers[0](context))
+ v = self.to_v(hypernetwork_layers[1](context))
else:
k = self.to_k(context)
v = self.to_v(context)
+
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
+
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
+
+ if mask is not None:
+ mask = rearrange(mask, 'b ... -> b (...)')
+ max_neg_value = -torch.finfo(sim.dtype).max
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
+ sim.masked_fill_(~mask, max_neg_value)
+
+ # attention, what we cannot get enough of
+ attn = sim.softmax(dim=-1)
+
+ out = einsum('b i j, b j d -> b i d', attn, v)
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
+ return self.to_out(out)
|