diff options
Diffstat (limited to 'modules/sub_quadratic_attention.py')
-rw-r--r-- | modules/sub_quadratic_attention.py | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 05595323..cc38debd 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -201,14 +201,15 @@ def efficient_dot_product_attention( key=key, value=value, ) - - # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, - # and pass slices to be mutated, instead of torch.cat()ing the returned slices - res = torch.cat([ - compute_query_chunk_attn( + + res = torch.zeros_like(query) + for i in range(math.ceil(q_tokens / query_chunk_size)): + attn_scores = compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), key=key, value=value, - ) for i in range(math.ceil(q_tokens / query_chunk_size)) - ], dim=1) + ) + + res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores + return res |