diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-07 07:17:52 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-07 07:17:52 +0000 |
commit | bad7cb29cecac51c5c0f39afec332b007ed73133 (patch) | |
tree | 4d042c9fd673bcf110b0ac746db8205ce16c9bae /modules/sd_hijack_optimizations.py | |
parent | 2995107fa24cfd72b0a991e18271dcde148c2807 (diff) | |
download | stable-diffusion-webui-gfx803-bad7cb29cecac51c5c0f39afec332b007ed73133.tar.gz stable-diffusion-webui-gfx803-bad7cb29cecac51c5c0f39afec332b007ed73133.tar.bz2 stable-diffusion-webui-gfx803-bad7cb29cecac51c5c0f39afec332b007ed73133.zip |
added support for hypernetworks (???)
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r-- | modules/sd_hijack_optimizations.py | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index ea4cfdfc..d9cca485 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,8 @@ from torch import einsum from ldm.util import default
from einops import rearrange
+from modules import shared
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
@@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x)
context = default(context, x)
- k_in = self.to_k(context) * self.scale
- v_in = self.to_v(context)
+
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is not None:
+ k_in = self.to_k(hypernetwork_layers[0](context))
+ v_in = self.to_v(hypernetwork_layers[1](context))
+ else:
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
+
+ k_in *= self.scale
+
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
|