diff options
author | Greendayle <Greendayle> | 2022-10-07 16:31:49 +0000 |
---|---|---|
committer | Greendayle <Greendayle> | 2022-10-07 16:31:49 +0000 |
commit | 537da7a304adff95fb2ed8337f7a764d08f67c46 (patch) | |
tree | 4a8b2c23d7c870314083d70e2d82edd9acbe677c /modules/sd_hijack_optimizations.py | |
parent | 4320f386d9641c7c234589c4cb0c0c6cbeb156ad (diff) | |
parent | f7c787eb7c295c27439f4fbdf78c26b8389560be (diff) | |
download | stable-diffusion-webui-gfx803-537da7a304adff95fb2ed8337f7a764d08f67c46.tar.gz stable-diffusion-webui-gfx803-537da7a304adff95fb2ed8337f7a764d08f67c46.tar.bz2 stable-diffusion-webui-gfx803-537da7a304adff95fb2ed8337f7a764d08f67c46.zip |
Merge branch 'master' into dev/deepdanbooru
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r-- | modules/sd_hijack_optimizations.py | 25 |
1 files changed, 15 insertions, 10 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 9c079e57..d9cca485 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -5,6 +5,8 @@ from torch import einsum from ldm.util import default
from einops import rearrange
+from modules import shared
+
# see https://github.com/basujindal/stable-diffusion/pull/117 for discussion
def split_cross_attention_forward_v1(self, x, context=None, mask=None):
@@ -42,8 +44,19 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x)
context = default(context, x)
- k_in = self.to_k(context) * self.scale
- v_in = self.to_v(context)
+
+ hypernetwork = shared.selected_hypernetwork()
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is not None:
+ k_in = self.to_k(hypernetwork_layers[0](context))
+ v_in = self.to_v(hypernetwork_layers[1](context))
+ else:
+ k_in = self.to_k(context)
+ v_in = self.to_v(context)
+
+ k_in *= self.scale
+
del context, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
@@ -92,14 +105,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None): return self.to_out(r2)
-def nonlinearity_hijack(x):
- # swish
- t = torch.sigmoid(x)
- x *= t
- del t
-
- return x
-
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)
|