aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-11 08:09:51 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-10-11 08:10:17 +0000
commit948533950c9db5069a874d925fadd50bac00fdb5 (patch)
tree854b8016f8f8f387b6e038357a11a2c8def59dc2
parent5e2627a1a63e4c9f87e6e604ecc24e9936f149de (diff)
downloadstable-diffusion-webui-gfx803-948533950c9db5069a874d925fadd50bac00fdb5.tar.gz
stable-diffusion-webui-gfx803-948533950c9db5069a874d925fadd50bac00fdb5.tar.bz2
stable-diffusion-webui-gfx803-948533950c9db5069a874d925fadd50bac00fdb5.zip
replace duplicate code with a function
-rw-r--r--modules/hypernetwork.py23
-rw-r--r--modules/sd_hijack_optimizations.py44
2 files changed, 29 insertions, 38 deletions
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py
index 498bc9d8..7bbc443e 100644
--- a/modules/hypernetwork.py
+++ b/modules/hypernetwork.py
@@ -64,21 +64,26 @@ def load_hypernetwork(filename):
shared.loaded_hypernetwork = None
+def apply_hypernetwork(hypernetwork, context):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is None:
+ return context, context
+
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
+
+
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k = self.to_k(hypernetwork_layers[0](context))
- v = self.to_v(hypernetwork_layers[1](context))
- else:
- k = self.to_k(context)
- v = self.to_v(context)
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 18408e62..25cb67a4 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -8,7 +8,8 @@ from torch import einsum
from ldm.util import default
from einops import rearrange
-from modules import shared
+from modules import shared, hypernetwork
+
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@@ -26,16 +27,10 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
- del context, x
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+ del context, context_k, context_v, x
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
@@ -59,22 +54,16 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
return self.to_out(r2)
-# taken from https://github.com/Doggettx/stable-diffusion
+# taken from https://github.com/Doggettx/stable-diffusion and modified
def split_cross_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
k_in *= self.scale
@@ -130,14 +119,11 @@ def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
- if hypernetwork_layers is not None:
- k_in = self.to_k(hypernetwork_layers[0](context))
- v_in = self.to_v(hypernetwork_layers[1](context))
- else:
- k_in = self.to_k(context)
- v_in = self.to_v(context)
+
+ context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)