diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-05-11 04:45:05 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-05-11 04:45:05 +0000 |
commit | e334758ec281eaf7723c806713721d12bb568e24 (patch) | |
tree | 1f34358bb006da9aa4baee64aaecec2bdfd333b3 /modules/sub_quadratic_attention.py | |
parent | c9e5b921061d842ef64efcf50431253b3002e1ed (diff) | |
download | stable-diffusion-webui-gfx803-e334758ec281eaf7723c806713721d12bb568e24.tar.gz stable-diffusion-webui-gfx803-e334758ec281eaf7723c806713721d12bb568e24.tar.bz2 stable-diffusion-webui-gfx803-e334758ec281eaf7723c806713721d12bb568e24.zip |
repair #10266
Diffstat (limited to 'modules/sub_quadratic_attention.py')
-rw-r--r-- | modules/sub_quadratic_attention.py | 18 |
1 files changed, 5 insertions, 13 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index f80c1600..cc38debd 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -201,23 +201,15 @@ def efficient_dot_product_attention( key=key, value=value, ) - - # slices of res tensor are mutable, modifications made - # to the slices will affect the original tensor. - # if output of compute_query_chunk_attn function has same number of - # dimensions as input query tensor, we initialize tensor like this: - num_query_chunks = int(np.ceil(q_tokens / query_chunk_size)) - query_shape = get_query_chunk(0).shape - res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:]) - res_dtype = get_query_chunk(0).dtype - res = torch.zeros(res_shape, dtype=res_dtype) - - for i in range(num_query_chunks): + + res = torch.zeros_like(query) + for i in range(math.ceil(q_tokens / query_chunk_size)): attn_scores = compute_query_chunk_attn( query=get_query_chunk(i * query_chunk_size), key=key, value=value, ) - res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores + + res[:, i * query_chunk_size:i * query_chunk_size + attn_scores.shape[1], :] = attn_scores return res |