diff options
author | brkirch <brkirch@users.noreply.github.com> | 2023-01-05 09:37:17 +0000 |
---|---|---|
committer | brkirch <brkirch@users.noreply.github.com> | 2023-01-06 05:15:24 +0000 |
commit | b119815333026164f2bd7d1ca71f3e4f7a9afd0d (patch) | |
tree | 09b1c11b76c4a2e9f222375c0f4b746f26a14047 | |
parent | 3bfe2bb5498241c4873cdd71b3f0a5bac5f64d7f (diff) | |
download | stable-diffusion-webui-gfx803-b119815333026164f2bd7d1ca71f3e4f7a9afd0d.tar.gz stable-diffusion-webui-gfx803-b119815333026164f2bd7d1ca71f3e4f7a9afd0d.tar.bz2 stable-diffusion-webui-gfx803-b119815333026164f2bd7d1ca71f3e4f7a9afd0d.zip |
Use narrow instead of dynamic_slice
-rw-r--r-- | modules/sub_quadratic_attention.py | 34 |
1 files changed, 19 insertions, 15 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index b11dc1c7..95924d24 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -5,6 +5,7 @@ # credit: # Amin Rezaei (original author) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# brkirch (modified to use torch.narrow instead of dynamic_slice implementation) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 @@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint import math from typing import Optional, NamedTuple, Protocol, List -def dynamic_slice( - x: Tensor, - starts: List[int], - sizes: List[int], +def narrow_trunc( + input: Tensor, + dim: int, + start: int, + length: int ) -> Tensor: - slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] - return x[slicing] + return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) class AttnChunk(NamedTuple): exp_values: Tensor @@ -76,15 +77,17 @@ def _query_chunk_attention( _, _, v_channels_per_head = value.shape def chunk_scanner(chunk_idx: int) -> AttnChunk: - key_chunk = dynamic_slice( + key_chunk = narrow_trunc( key, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, k_channels_per_head) + 1, + chunk_idx, + kv_chunk_size ) - value_chunk = dynamic_slice( + value_chunk = narrow_trunc( value, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, v_channels_per_head) + 1, + chunk_idx, + kv_chunk_size ) return summarize_chunk(query, key_chunk, value_chunk) @@ -161,10 +164,11 @@ def efficient_dot_product_attention( kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) def get_query_chunk(chunk_idx: int) -> Tensor: - return dynamic_slice( + return narrow_trunc( query, - (0, chunk_idx, 0), - (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + 1, + chunk_idx, + min(query_chunk_size, q_tokens) ) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) |