aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sub_quadratic_attention.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-05-18 07:26:35 +0000
committerGitHub <noreply@github.com>2023-05-18 07:26:35 +0000
commit97e1cf69c04a3c62aa1bb19a14ffc948d9cc6c4e (patch)
tree7a24bdd31580fe0e4bf8d4205b57b55df0a2568d /modules/sub_quadratic_attention.py
parent484948f5c0b755a921c02cccbcacb2684a86a814 (diff)
parentbb431df52bf3dc5e233e42907f2d8f56e4fb6c0c (diff)
downloadstable-diffusion-webui-gfx803-97e1cf69c04a3c62aa1bb19a14ffc948d9cc6c4e.tar.gz
stable-diffusion-webui-gfx803-97e1cf69c04a3c62aa1bb19a14ffc948d9cc6c4e.tar.bz2
stable-diffusion-webui-gfx803-97e1cf69c04a3c62aa1bb19a14ffc948d9cc6c4e.zip
Merge branch 'dev' into master
Diffstat (limited to 'modules/sub_quadratic_attention.py')
-rw-r--r--modules/sub_quadratic_attention.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index 05595323..497568eb 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -179,7 +179,7 @@ def efficient_dot_product_attention(
chunk_idx,
min(query_chunk_size, q_tokens)
)
-
+
summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)
summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk
compute_query_chunk_attn: ComputeQueryChunkAttn = partial(
@@ -201,14 +201,15 @@ def efficient_dot_product_attention(
key=key,
value=value,
)
-
- # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
- # and pass slices to be mutated, instead of torch.cat()ing the returned slices
- res = torch.cat([
- compute_query_chunk_attn(
+
+ res = torch.zeros_like(query)
+ for i in range(math.ceil(q_tokens / query_chunk_size)):
+ attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
- ) for i in range(math.ceil(q_tokens / query_chunk_size))
- ], dim=1)
+ )
+
+ res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores
+
return res