diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-11 08:09:51 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-11 08:10:17 +0000 |
commit | 948533950c9db5069a874d925fadd50bac00fdb5 (patch) | |
tree | 854b8016f8f8f387b6e038357a11a2c8def59dc2 /modules/hypernetwork.py | |
parent | 5e2627a1a63e4c9f87e6e604ecc24e9936f149de (diff) | |
download | stable-diffusion-webui-gfx803-948533950c9db5069a874d925fadd50bac00fdb5.tar.gz stable-diffusion-webui-gfx803-948533950c9db5069a874d925fadd50bac00fdb5.tar.bz2 stable-diffusion-webui-gfx803-948533950c9db5069a874d925fadd50bac00fdb5.zip |
replace duplicate code with a function
Diffstat (limited to 'modules/hypernetwork.py')
-rw-r--r-- | modules/hypernetwork.py | 23 |
1 files changed, 14 insertions, 9 deletions
diff --git a/modules/hypernetwork.py b/modules/hypernetwork.py index 498bc9d8..7bbc443e 100644 --- a/modules/hypernetwork.py +++ b/modules/hypernetwork.py @@ -64,21 +64,26 @@ def load_hypernetwork(filename): shared.loaded_hypernetwork = None
+def apply_hypernetwork(hypernetwork, context):
+ hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
+
+ if hypernetwork_layers is None:
+ return context, context
+
+ context_k = hypernetwork_layers[0](context)
+ context_v = hypernetwork_layers[1](context)
+ return context_k, context_v
+
+
def attention_CrossAttention_forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
- hypernetwork = shared.loaded_hypernetwork
- hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)
-
- if hypernetwork_layers is not None:
- k = self.to_k(hypernetwork_layers[0](context))
- v = self.to_v(hypernetwork_layers[1](context))
- else:
- k = self.to_k(context)
- v = self.to_v(context)
+ context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context)
+ k = self.to_k(context_k)
+ v = self.to_v(context_v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|