aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLouis Del Valle <92354925+nero-dv@users.noreply.github.com>2023-05-11 03:05:18 +0000
committerGitHub <noreply@github.com>2023-05-11 03:05:18 +0000
commitc8732dfa6f763332962d97ff040af156e24a9e62 (patch)
tree92723cd92da2d1557571778b1c44c81182eb8ea4
parent8aa87c564a79965013715d56a5f90d2a34d5d6ee (diff)
downloadstable-diffusion-webui-gfx803-c8732dfa6f763332962d97ff040af156e24a9e62.tar.gz
stable-diffusion-webui-gfx803-c8732dfa6f763332962d97ff040af156e24a9e62.tar.bz2
stable-diffusion-webui-gfx803-c8732dfa6f763332962d97ff040af156e24a9e62.zip
Update sub_quadratic_attention.py
1. Determine the number of query chunks. 2. Calculate the final shape of the res tensor. 3. Initialize the tensor with the calculated shape and dtype, (same dtype as the input tensors, usually) Can initialize the tensor as a zero-filled tensor with the correct shape and dtype, then compute the attention scores for each query chunk and fill the corresponding slice of tensor.
-rw-r--r--modules/sub_quadratic_attention.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index 05595323..f80c1600 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -202,13 +202,22 @@ def efficient_dot_product_attention(
value=value,
)
- # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
- # and pass slices to be mutated, instead of torch.cat()ing the returned slices
- res = torch.cat([
- compute_query_chunk_attn(
+ # slices of res tensor are mutable, modifications made
+ # to the slices will affect the original tensor.
+ # if output of compute_query_chunk_attn function has same number of
+ # dimensions as input query tensor, we initialize tensor like this:
+ num_query_chunks = int(np.ceil(q_tokens / query_chunk_size))
+ query_shape = get_query_chunk(0).shape
+ res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:])
+ res_dtype = get_query_chunk(0).dtype
+ res = torch.zeros(res_shape, dtype=res_dtype)
+
+ for i in range(num_query_chunks):
+ attn_scores = compute_query_chunk_attn(
query=get_query_chunk(i * query_chunk_size),
key=key,
value=value,
- ) for i in range(math.ceil(q_tokens / query_chunk_size))
- ], dim=1)
+ )
+ res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores
+
return res