aboutsummaryrefslogtreecommitdiffstats
path: root/modules/hypernetwork.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-10-11 08:09:51 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-10-11 08:10:17 +0000
commit948533950c9db5069a874d925fadd50bac00fdb5 (patch)
tree854b8016f8f8f387b6e038357a11a2c8def59dc2 /modules/hypernetwork.py
parent5e2627a1a63e4c9f87e6e604ecc24e9936f149de (diff)
downloadstable-diffusion-webui-gfx803-948533950c9db5069a874d925fadd50bac00fdb5.tar.gz
stable-diffusion-webui-gfx803-948533950c9db5069a874d925fadd50bac00fdb5.tar.bz2
stable-diffusion-webui-gfx803-948533950c9db5069a874d925fadd50bac00fdb5.zip
replace duplicate code with a function
Diffstat (limited to 'modules/hypernetwork.py')
-rw-r--r--modules/hypernetwork.py23
1 files changed, 14 insertions, 9 deletions
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py
index 498bc9d8..7bbc443e 100644
--- a/modules/hypernetwork.py
+++ b/modules/hypernetwork.py
@@ -64,21 +64,26 @@ def load_hypernetwork(filename):
shared.loaded_hypernetwork = None
+def apply_hypernetwork(hypernetwork, context):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is None:
+ return context, context
+
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
+
+
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k = self.to_k(hypernetwork_layers[0](context))
- v = self.to_v(hypernetwork_layers[1](context))
- else:
- k = self.to_k(context)
- v = self.to_v(context)
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))