aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sub_quadratic_attention.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-09 19:45:39 +0000
committerGitHub <noreply@github.com>2023-01-09 19:45:39 +0000
commit18c001792a3f034245c2a9c38cb568d31c147fed (patch)
treef43e79374dd7e74074dcbf48fc579f1b12b4d1a8 /modules/sub_quadratic_attention.py
parent72497895b9b1948f86d9309fe897cbb70c20ba7e (diff)
parent2b94ec78869db7d2beaad23bdff47340416edf85 (diff)
downloadstable-diffusion-webui-gfx803-18c001792a3f034245c2a9c38cb568d31c147fed.tar.gz
stable-diffusion-webui-gfx803-18c001792a3f034245c2a9c38cb568d31c147fed.tar.bz2
stable-diffusion-webui-gfx803-18c001792a3f034245c2a9c38cb568d31c147fed.zip
Merge branch 'master' into varsize
Diffstat (limited to 'modules/sub_quadratic_attention.py')
-rw-r--r--modules/sub_quadratic_attention.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py
index fea7aaac..55052815 100644
--- a/modules/sub_quadratic_attention.py
+++ b/modules/sub_quadratic_attention.py
@@ -15,7 +15,8 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math
-from typing import Optional, NamedTuple, Protocol, List
+from typing import Optional, NamedTuple, List
+
def narrow_trunc(
input: Tensor,
@@ -25,12 +26,14 @@ def narrow_trunc(
) -> Tensor:
return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)
+
class AttnChunk(NamedTuple):
exp_values: Tensor
exp_weights_sum: Tensor
max_score: Tensor
-class SummarizeChunk(Protocol):
+
+class SummarizeChunk:
@staticmethod
def __call__(
query: Tensor,
@@ -38,7 +41,8 @@ class SummarizeChunk(Protocol):
value: Tensor,
) -> AttnChunk: ...
-class ComputeQueryChunkAttn(Protocol):
+
+class ComputeQueryChunkAttn:
@staticmethod
def __call__(
query: Tensor,
@@ -46,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol):
value: Tensor,
) -> Tensor: ...
+
def _summarize_chunk(
query: Tensor,
key: Tensor,
@@ -66,6 +71,7 @@ def _summarize_chunk(
max_score = max_score.squeeze(-1)
return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
+
def _query_chunk_attention(
query: Tensor,
key: Tensor,
@@ -106,6 +112,7 @@ def _query_chunk_attention(
all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
return all_values / all_weights
+
# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
query: Tensor,
@@ -125,10 +132,12 @@ def _get_attention_scores_no_kv_chunking(
hidden_states_slice = torch.bmm(attn_probs, value)
return hidden_states_slice
+
class ScannedChunk(NamedTuple):
chunk_idx: int
attn_chunk: AttnChunk
+
def efficient_dot_product_attention(
query: Tensor,
key: Tensor,