From d782a95967c9eea753df3333cd1954b6ec73eba0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 27 Dec 2022 08:50:55 -0500 Subject: Add Birch-san's sub-quadratic attention implementation --- modules/sub_quadratic_attention.py | 201 +++++++++++++++++++++++++++++++++++++ 1 file changed, 201 insertions(+) create mode 100644 modules/sub_quadratic_attention.py (limited to 'modules/sub_quadratic_attention.py') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py new file mode 100644 index 00000000..b11dc1c7 --- /dev/null +++ b/modules/sub_quadratic_attention.py @@ -0,0 +1,201 @@ +# original source: +# https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py +# license: +# unspecified +# credit: +# Amin Rezaei (original author) +# Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# implementation of: +# Self-attention Does Not Need O(n2) Memory": +# https://arxiv.org/abs/2112.05682v2 + +from functools import partial +import torch +from torch import Tensor +from torch.utils.checkpoint import checkpoint +import math +from typing import Optional, NamedTuple, Protocol, List + +def dynamic_slice( + x: Tensor, + starts: List[int], + sizes: List[int], +) -> Tensor: + slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] + return x[slicing] + +class AttnChunk(NamedTuple): + exp_values: Tensor + exp_weights_sum: Tensor + max_score: Tensor + +class SummarizeChunk(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> AttnChunk: ... + +class ComputeQueryChunkAttn(Protocol): + @staticmethod + def __call__( + query: Tensor, + key: Tensor, + value: Tensor, + ) -> Tensor: ... + +def _summarize_chunk( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> AttnChunk: + attn_weights = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + max_score, _ = torch.max(attn_weights, -1, keepdim=True) + max_score = max_score.detach() + exp_weights = torch.exp(attn_weights - max_score) + exp_values = torch.bmm(exp_weights, value) + max_score = max_score.squeeze(-1) + return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + +def _query_chunk_attention( + query: Tensor, + key: Tensor, + value: Tensor, + summarize_chunk: SummarizeChunk, + kv_chunk_size: int, +) -> Tensor: + batch_x_heads, k_tokens, k_channels_per_head = key.shape + _, _, v_channels_per_head = value.shape + + def chunk_scanner(chunk_idx: int) -> AttnChunk: + key_chunk = dynamic_slice( + key, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, k_channels_per_head) + ) + value_chunk = dynamic_slice( + value, + (0, chunk_idx, 0), + (batch_x_heads, kv_chunk_size, v_channels_per_head) + ) + return summarize_chunk(query, key_chunk, value_chunk) + + chunks: List[AttnChunk] = [ + chunk_scanner(chunk) for chunk in torch.arange(0, k_tokens, kv_chunk_size) + ] + acc_chunk = AttnChunk(*map(torch.stack, zip(*chunks))) + chunk_values, chunk_weights, chunk_max = acc_chunk + + global_max, _ = torch.max(chunk_max, 0, keepdim=True) + max_diffs = torch.exp(chunk_max - global_max) + chunk_values *= torch.unsqueeze(max_diffs, -1) + chunk_weights *= max_diffs + + all_values = chunk_values.sum(dim=0) + all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) + return all_values / all_weights + +# TODO: refactor CrossAttention#get_attention_scores to share code with this +def _get_attention_scores_no_kv_chunking( + query: Tensor, + key: Tensor, + value: Tensor, + scale: float, +) -> Tensor: + attn_scores = torch.baddbmm( + torch.empty(1, 1, 1, device=query.device, dtype=query.dtype), + query, + key.transpose(1,2), + alpha=scale, + beta=0, + ) + attn_probs = attn_scores.softmax(dim=-1) + del attn_scores + hidden_states_slice = torch.bmm(attn_probs, value) + return hidden_states_slice + +class ScannedChunk(NamedTuple): + chunk_idx: int + attn_chunk: AttnChunk + +def efficient_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + query_chunk_size=1024, + kv_chunk_size: Optional[int] = None, + kv_chunk_size_min: Optional[int] = None, + use_checkpoint=True, +): + """Computes efficient dot-product attention given query, key, and value. + This is efficient version of attention presented in + https://arxiv.org/abs/2112.05682v2 which comes with O(sqrt(n)) memory requirements. + Args: + query: queries for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + key: keys for calculating attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + value: values to be used in attention with shape of + `[batch * num_heads, tokens, channels_per_head]`. + query_chunk_size: int: query chunks size + kv_chunk_size: Optional[int]: key/value chunks size. if None: defaults to sqrt(key_tokens) + kv_chunk_size_min: Optional[int]: key/value minimum chunk size. only considered when kv_chunk_size is None. changes `sqrt(key_tokens)` into `max(sqrt(key_tokens), kv_chunk_size_min)`, to ensure our chunk sizes don't get too small (smaller chunks = more chunks = less concurrent work done). + use_checkpoint: bool: whether to use checkpointing (recommended True for training, False for inference) + Returns: + Output of shape `[batch * num_heads, query_tokens, channels_per_head]`. + """ + batch_x_heads, q_tokens, q_channels_per_head = query.shape + _, k_tokens, _ = key.shape + scale = q_channels_per_head ** -0.5 + + kv_chunk_size = min(kv_chunk_size or int(math.sqrt(k_tokens)), k_tokens) + if kv_chunk_size_min is not None: + kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) + + def get_query_chunk(chunk_idx: int) -> Tensor: + return dynamic_slice( + query, + (0, chunk_idx, 0), + (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + ) + + summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) + summarize_chunk: SummarizeChunk = partial(checkpoint, summarize_chunk) if use_checkpoint else summarize_chunk + compute_query_chunk_attn: ComputeQueryChunkAttn = partial( + _get_attention_scores_no_kv_chunking, + scale=scale + ) if k_tokens <= kv_chunk_size else ( + # fast-path for when there's just 1 key-value chunk per query chunk (this is just sliced attention btw) + partial( + _query_chunk_attention, + kv_chunk_size=kv_chunk_size, + summarize_chunk=summarize_chunk, + ) + ) + + if q_tokens <= query_chunk_size: + # fast-path for when there's just 1 query chunk + return compute_query_chunk_attn( + query=query, + key=key, + value=value, + ) + + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, + # and pass slices to be mutated, instead of torch.cat()ing the returned slices + res = torch.cat([ + compute_query_chunk_attn( + query=get_query_chunk(i * query_chunk_size), + key=key, + value=value, + ) for i in range(math.ceil(q_tokens / query_chunk_size)) + ], dim=1) + return res -- cgit v1.2.3 From b119815333026164f2bd7d1ca71f3e4f7a9afd0d Mon Sep 17 00:00:00 2001 From: brkirch Date: Thu, 5 Jan 2023 04:37:17 -0500 Subject: Use narrow instead of dynamic_slice --- modules/sub_quadratic_attention.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) (limited to 'modules/sub_quadratic_attention.py') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index b11dc1c7..95924d24 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -5,6 +5,7 @@ # credit: # Amin Rezaei (original author) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) +# brkirch (modified to use torch.narrow instead of dynamic_slice implementation) # implementation of: # Self-attention Does Not Need O(n2) Memory": # https://arxiv.org/abs/2112.05682v2 @@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint import math from typing import Optional, NamedTuple, Protocol, List -def dynamic_slice( - x: Tensor, - starts: List[int], - sizes: List[int], +def narrow_trunc( + input: Tensor, + dim: int, + start: int, + length: int ) -> Tensor: - slicing = [slice(start, start + size) for start, size in zip(starts, sizes)] - return x[slicing] + return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) class AttnChunk(NamedTuple): exp_values: Tensor @@ -76,15 +77,17 @@ def _query_chunk_attention( _, _, v_channels_per_head = value.shape def chunk_scanner(chunk_idx: int) -> AttnChunk: - key_chunk = dynamic_slice( + key_chunk = narrow_trunc( key, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, k_channels_per_head) + 1, + chunk_idx, + kv_chunk_size ) - value_chunk = dynamic_slice( + value_chunk = narrow_trunc( value, - (0, chunk_idx, 0), - (batch_x_heads, kv_chunk_size, v_channels_per_head) + 1, + chunk_idx, + kv_chunk_size ) return summarize_chunk(query, key_chunk, value_chunk) @@ -161,10 +164,11 @@ def efficient_dot_product_attention( kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min) def get_query_chunk(chunk_idx: int) -> Tensor: - return dynamic_slice( + return narrow_trunc( query, - (0, chunk_idx, 0), - (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head) + 1, + chunk_idx, + min(query_chunk_size, q_tokens) ) summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale) -- cgit v1.2.3 From c18add68ef7d2de3617cbbaff864b0c74cfdf6c0 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 6 Jan 2023 16:42:47 -0500 Subject: Added license --- html/licenses.html | 29 ++++++++++++++++++++++++++++- modules/sd_hijack_optimizations.py | 1 + modules/sub_quadratic_attention.py | 2 +- 3 files changed, 30 insertions(+), 2 deletions(-) (limited to 'modules/sub_quadratic_attention.py') diff --git a/html/licenses.html b/html/licenses.html index 9eeaa072..570630eb 100644 --- a/html/licenses.html +++ b/html/licenses.html @@ -184,7 +184,7 @@ SOFTWARE.

SwinIR

-Code added by contirubtors, most likely copied from this repository. +Code added by contributors, most likely copied from this repository.
                                  Apache License
@@ -390,3 +390,30 @@ SOFTWARE.
    limitations under the License.
 
+

Memory Efficient Attention

+The sub-quadratic cross attention optimization uses modified code from the Memory Efficient Attention package that Alex Birch optimized for 3D tensors. This license is updated to reflect that. +
+MIT License
+
+Copyright (c) 2023 Alex Birch
+Copyright (c) 2023 Amin Rezaei
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
+
+ diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index b416e9ac..cdc63ed7 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -216,6 +216,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 +# The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface def sub_quad_attention_forward(self, x, context=None, mask=None): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 95924d24..fea7aaac 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -1,7 +1,7 @@ # original source: # https://github.com/AminRezaei0x443/memory-efficient-attention/blob/1bc0d9e6ac5f82ea43a375135c4e1d3896ee1694/memory_efficient_attention/attention_torch.py # license: -# unspecified +# MIT License (see Memory Efficient Attention under the Licenses section in the web UI interface for the full license) # credit: # Amin Rezaei (original author) # Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks) -- cgit v1.2.3 From 984b86dd0abf0da7f6b116864c791a2bfe8859ef Mon Sep 17 00:00:00 2001 From: ProGamerGov Date: Sat, 7 Jan 2023 13:08:21 -0700 Subject: Add fallback for Protocol import --- modules/sub_quadratic_attention.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'modules/sub_quadratic_attention.py') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index fea7aaac..93381bae 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,7 +15,13 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math -from typing import Optional, NamedTuple, Protocol, List + +try: + from typing import Protocol +except: + from typing_extensions import Protocol + +from typing import Optional, NamedTuple, List def narrow_trunc( input: Tensor, -- cgit v1.2.3 From cdfcbd995932ffa728db0cc00a5f97665c752103 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 20:08:48 +0300 Subject: Remove fallback for Protocol import and remove Protocol import and remove instances of Protocol in code add some whitespace between functions to be in line with other code in the repo --- modules/sub_quadratic_attention.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) (limited to 'modules/sub_quadratic_attention.py') diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 93381bae..55052815 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -15,14 +15,9 @@ import torch from torch import Tensor from torch.utils.checkpoint import checkpoint import math - -try: - from typing import Protocol -except: - from typing_extensions import Protocol - from typing import Optional, NamedTuple, List + def narrow_trunc( input: Tensor, dim: int, @@ -31,12 +26,14 @@ def narrow_trunc( ) -> Tensor: return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start) + class AttnChunk(NamedTuple): exp_values: Tensor exp_weights_sum: Tensor max_score: Tensor -class SummarizeChunk(Protocol): + +class SummarizeChunk: @staticmethod def __call__( query: Tensor, @@ -44,7 +41,8 @@ class SummarizeChunk(Protocol): value: Tensor, ) -> AttnChunk: ... -class ComputeQueryChunkAttn(Protocol): + +class ComputeQueryChunkAttn: @staticmethod def __call__( query: Tensor, @@ -52,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol): value: Tensor, ) -> Tensor: ... + def _summarize_chunk( query: Tensor, key: Tensor, @@ -72,6 +71,7 @@ def _summarize_chunk( max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) + def _query_chunk_attention( query: Tensor, key: Tensor, @@ -112,6 +112,7 @@ def _query_chunk_attention( all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0) return all_values / all_weights + # TODO: refactor CrossAttention#get_attention_scores to share code with this def _get_attention_scores_no_kv_chunking( query: Tensor, @@ -131,10 +132,12 @@ def _get_attention_scores_no_kv_chunking( hidden_states_slice = torch.bmm(attn_probs, value) return hidden_states_slice + class ScannedChunk(NamedTuple): chunk_idx: int attn_chunk: AttnChunk + def efficient_dot_product_attention( query: Tensor, key: Tensor, -- cgit v1.2.3 From e3b53fd295aca784253dfc8668ec87b537a72f43 Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 25 Jan 2023 00:23:10 -0500 Subject: Add UI setting for upcasting attention to float32 Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers. In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also. --- modules/devices.py | 6 +- modules/processing.py | 2 +- modules/sd_hijack_optimizations.py | 159 +++++++++++++++++++++++-------------- modules/shared.py | 1 + modules/sub_quadratic_attention.py | 4 +- 5 files changed, 108 insertions(+), 64 deletions(-) (limited to 'modules/sub_quadratic_attention.py') diff --git a/modules/devices.py b/modules/devices.py index 0981ef80..6b36622c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -108,6 +108,10 @@ def autocast(disable=False): return torch.autocast("cuda") +def without_autocast(disable=False): + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + + class NansException(Exception): pass @@ -125,7 +129,7 @@ def test_for_nans(x, where): message = "A tensor with all NaNs was produced in Unet." if not shared.cmd_opts.no_half: - message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this." + message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this." elif where == "vae": message = "A tensor with all NaNs was produced in VAE." diff --git a/modules/processing.py b/modules/processing.py index 2d186ba0..a850082d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -611,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" - with devices.autocast(disable=devices.unet_needs_upcast): + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts) x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 74452709..c02d954c 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default from einops import rearrange -from modules import shared, errors +from modules import shared, errors, devices from modules.hypernetworks import hypernetwork from .sub_quadratic_attention import efficient_dot_product_attention @@ -52,18 +52,25 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - for i in range(0, q.shape[0], 2): - end = i + 2 - s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) - s1 *= self.scale + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v.float() - s2 = s1.softmax(dim=-1) - del s1 + with devices.without_autocast(disable=not shared.opts.upcast_attn): + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + for i in range(0, q.shape[0], 2): + end = i + 2 + s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end]) + s1 *= self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) + del s2 + del q, k, v - r1[i:end] = einsum('b i j, b j d -> b i d', s2, v[i:end]) - del s2 - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -82,45 +89,52 @@ def split_cross_attention_forward(self, x, context=None, mask=None): k_in = self.to_k(context_k) v_in = self.to_v(context_v) - k_in *= self.scale - - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - mem_free_total = get_available_vram() - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - if mem_required > mem_free_total: - steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + dtype = q_in.dtype + if shared.opts.upcast_attn: + q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float() - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') - - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - - s2 = s1.softmax(dim=-1, dtype=q.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k_in = k_in * self.scale + + del context, x + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) + + mem_free_total = get_available_vram() + + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 + + if mem_required > mem_free_total: + steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required / 64 / gb:0.1f}GB free, Have:{mem_free_total / gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v - del q, k, v + r1 = r1.to(dtype) r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): context = default(context, x) context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) - k = self.to_k(context_k) * self.scale + k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - r = einsum_op(q, k, v) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float() + + with devices.without_autocast(disable=not shared.opts.upcast_attn): + k = k * self.scale + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + r = einsum_op(q, k, v) + r = r.to(dtype) return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h)) # -- End of code from https://github.com/invoke-ai/InvokeAI -- @@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training) + x = x.to(dtype) + x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2) out_proj, dropout = self.to_out @@ -268,15 +296,16 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ query_chunk_size = q_tokens kv_chunk_size = k_tokens - return efficient_dot_product_attention( - q, - k, - v, - query_chunk_size=q_chunk_size, - kv_chunk_size=kv_chunk_size, - kv_chunk_size_min = kv_chunk_size_min, - use_checkpoint=use_checkpoint, - ) + with devices.without_autocast(disable=q.dtype == v.dtype): + return efficient_dot_product_attention( + q, + k, + v, + query_chunk_size=q_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min = kv_chunk_size_min, + use_checkpoint=use_checkpoint, + ) def get_xformers_flash_attention_op(q, k, v): @@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in)) del q_in, k_in, v_in + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) + out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) @@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x): v = self.v(h_) b, c, h, w = q.shape q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v)) + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() q = q.contiguous() k = k.contiguous() v = v.contiguous() out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v)) + out = out.to(dtype) out = rearrange(out, 'b (h w) c -> b c h w', h=h) out = self.proj_out(out) return x + out diff --git a/modules/shared.py b/modules/shared.py index 4ce1209b..6a0b96cb 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -410,6 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }), "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}), "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), + "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/modules/sub_quadratic_attention.py b/modules/sub_quadratic_attention.py index 55052815..05595323 100644 --- a/modules/sub_quadratic_attention.py +++ b/modules/sub_quadratic_attention.py @@ -67,7 +67,7 @@ def _summarize_chunk( max_score, _ = torch.max(attn_weights, -1, keepdim=True) max_score = max_score.detach() exp_weights = torch.exp(attn_weights - max_score) - exp_values = torch.bmm(exp_weights, value) + exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype) max_score = max_score.squeeze(-1) return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score) @@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking( ) attn_probs = attn_scores.softmax(dim=-1) del attn_scores - hidden_states_slice = torch.bmm(attn_probs, value) + hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype) return hidden_states_slice -- cgit v1.2.3