diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-09 17:08:48 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-09 17:08:48 +0000 |
commit | cdfcbd995932ffa728db0cc00a5f97665c752103 (patch) | |
tree | f0d92f4b35ccae48937f71f6c72d81bd2c03628b | |
parent | 89c3663080658b84b506c12563a729b6fc65ae10 (diff) | |
download | stable-diffusion-webui-gfx803-cdfcbd995932ffa728db0cc00a5f97665c752103.tar.gz stable-diffusion-webui-gfx803-cdfcbd995932ffa728db0cc00a5f97665c752103.tar.bz2 stable-diffusion-webui-gfx803-cdfcbd995932ffa728db0cc00a5f97665c752103.zip |
Remove fallback for Protocol import and remove Protocol import and remove instances of Protocol in code
add some whitespace between functions to be in line with other code in the repo
-rw-r--r-- | modules/sub_quadratic_attention.py | 19 |
1 files changed, 11 insertions, 8 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 93381bae..55052815 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,14 +15,9 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math - -try: - from typing import Protocol -except: - from typing_extensions import Protocol - from typing import Optional, NamedTuple, List + def narrow_trunc( input: Tensor, dim: int, @@ -31,12 +26,14 @@ def narrow_trunc( ) -> Tensor: return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) + class AttnChunk(NamedTuple): exp_values: Tensor exp_weights_sum: Tensor max_score: Tensor -class SummarizeChunk(Protocol): + +class SummarizeChunk: @staticmethod def __call__( query: Tensor, @@ -44,7 +41,8 @@ class SummarizeChunk(Protocol): value: Tensor, ) -> AttnChunk: ... -class ComputeQueryChunkAttn(Protocol): + +class ComputeQueryChunkAttn: @staticmethod def __call__( query: Tensor, @@ -52,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol): value: Tensor, ) -> Tensor: ... + def _summarize_chunk( query: Tensor, key: Tensor, @@ -72,6 +71,7 @@ def _summarize_chunk( max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + def _query_chunk_attention( query: Tensor, key: Tensor, @@ -112,6 +112,7 @@ def _query_chunk_attention( all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) return all_values / all_weights + # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( query: Tensor, @@ -131,10 +132,12 @@ def _get_attention_scores_no_kv_chunking( hidden_states_slice = torch.bmm(attn_probs, value) return hidden_states_slice + class ScannedChunk(NamedTuple): chunk_idx: int attn_chunk: AttnChunk + def efficient_dot_product_attention( query: Tensor, key: Tensor, |