diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-09 19:45:39 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-09 19:45:39 +0000 |
commit | 18c001792a3f034245c2a9c38cb568d31c147fed (patch) | |
tree | f43e79374dd7e74074dcbf48fc579f1b12b4d1a8 /modules/sub_quadratic_attention.py | |
parent | 72497895b9b1948f86d9309fe897cbb70c20ba7e (diff) | |
parent | 2b94ec78869db7d2beaad23bdff47340416edf85 (diff) | |
download | stable-diffusion-webui-gfx803-18c001792a3f034245c2a9c38cb568d31c147fed.tar.gz stable-diffusion-webui-gfx803-18c001792a3f034245c2a9c38cb568d31c147fed.tar.bz2 stable-diffusion-webui-gfx803-18c001792a3f034245c2a9c38cb568d31c147fed.zip |
Merge branch 'master' into varsize
Diffstat (limited to 'modules/sub_quadratic_attention.py')
-rw-r--r-- | modules/sub_quadratic_attention.py | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index fea7aaac..55052815 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,7 +15,8 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math -from typing import Optional, NamedTuple, Protocol, List +from typing import Optional, NamedTuple, List + def narrow_trunc( input: Tensor, @@ -25,12 +26,14 @@ def narrow_trunc( ) -> Tensor: return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) + class AttnChunk(NamedTuple): exp_values: Tensor exp_weights_sum: Tensor max_score: Tensor -class SummarizeChunk(Protocol): + +class SummarizeChunk: @staticmethod def __call__( query: Tensor, @@ -38,7 +41,8 @@ class SummarizeChunk(Protocol): value: Tensor, ) -> AttnChunk: ... -class ComputeQueryChunkAttn(Protocol): + +class ComputeQueryChunkAttn: @staticmethod def __call__( query: Tensor, @@ -46,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol): value: Tensor, ) -> Tensor: ... + def _summarize_chunk( query: Tensor, key: Tensor, @@ -66,6 +71,7 @@ def _summarize_chunk( max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + def _query_chunk_attention( query: Tensor, key: Tensor, @@ -106,6 +112,7 @@ def _query_chunk_attention( all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) return all_values / all_weights + # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( query: Tensor, @@ -125,10 +132,12 @@ def _get_attention_scores_no_kv_chunking( hidden_states_slice = torch.bmm(attn_probs, value) return hidden_states_slice + class ScannedChunk(NamedTuple): chunk_idx: int attn_chunk: AttnChunk + def efficient_dot_product_attention( query: Tensor, key: Tensor, |