aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-07 07:17:52 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-10-07 07:17:52 +0000
commitbad7cb29cecac51c5c0f39afec332b007ed73133 (patch)
tree4d042c9fd673bcf110b0ac746db8205ce16c9bae /modules/sd_hijack_optimizations.py
parent2995107fa24cfd72b0a991e18271dcde148c2807 (diff)
downloadstable-diffusion-webui-gfx803-bad7cb29cecac51c5c0f39afec332b007ed73133.tar.gz
stable-diffusion-webui-gfx803-bad7cb29cecac51c5c0f39afec332b007ed73133.tar.bz2
stable-diffusion-webui-gfx803-bad7cb29cecac51c5c0f39afec332b007ed73133.zip
added support for hypernetworks (???)
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index ea4cfdfc..d9cca485 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -5,6 +5,8 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
+from modules import shared
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
@@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- k_in = self.to_k(context) * self.scale
- v_in = self.to_v(context)
+
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is not None:
+ k_in = self.to_k(hypernetwork_layers[0](context))
+ v_in = self.to_v(hypernetwork_layers[1](context))
+ else:
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
+
+ k_in *= self.scale
+
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))