diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-05-11 04:21:18 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-11 04:21:18 +0000 |
commit | c9e5b921061d842ef64efcf50431253b3002e1ed (patch) | |
tree | 92723cd92da2d1557571778b1c44c81182eb8ea4 /modules/sub_quadratic_attention.py | |
parent | 8aa87c564a79965013715d56a5f90d2a34d5d6ee (diff) | |
parent | c8732dfa6f763332962d97ff040af156e24a9e62 (diff) | |
download | stable-diffusion-webui-gfx803-c9e5b921061d842ef64efcf50431253b3002e1ed.tar.gz stable-diffusion-webui-gfx803-c9e5b921061d842ef64efcf50431253b3002e1ed.tar.bz2 stable-diffusion-webui-gfx803-c9e5b921061d842ef64efcf50431253b3002e1ed.zip |
Merge pull request #10266 from nero-dv/dev
Update sub_quadratic_attention.py
Diffstat (limited to 'modules/sub_quadratic_attention.py')
-rw-r--r-- | modules/sub_quadratic_attention.py | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 05595323..f80c1600 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -202,13 +202,22 @@ def efficient_dot_product_attention( value=value, ) - # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, - # and pass slices to be mutated, instead of torch.cat()ing the returned slices - res = torch.cat([ - compute_query_chunk_attn( + # slices of res tensor are mutable, modifications made + # to the slices will affect the original tensor. + # if output of compute_query_chunk_attn function has same number of + # dimensions as input query tensor, we initialize tensor like this: + num_query_chunks = int(np.ceil(q_tokens / query_chunk_size)) + query_shape = get_query_chunk(0).shape + res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:]) + res_dtype = get_query_chunk(0).dtype + res = torch.zeros(res_shape, dtype=res_dtype) + + for i in range(num_query_chunks): + attn_scores = compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), key=key, value=value, - ) for i in range(math.ceil(q_tokens / query_chunk_size)) - ], dim=1) + ) + res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores + return res |