From fec0a895119a124a295e3dad5205de5766031dc7 Mon Sep 17 00:00:00 2001 From: Pam Date: Tue, 7 Mar 2023 00:33:13 +0500 Subject: scaled dot product attention --- modules/sd_hijack.py | 4 ++++ modules/sd_hijack_optimizations.py | 42 ++++++++++++++++++++++++++++++++++++++ modules/shared.py | 1 + 3 files changed, 47 insertions(+) (limited to 'modules') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 79476783..76cb9120 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -42,6 +42,10 @@ def apply_optimizations(): ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward optimization_method = 'xformers' + elif cmd_opts.opt_sdp_attention and (hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(getattr(torch.nn.functional, "scaled_dot_product_attention"))): + print("Applying scaled dot product cross attention optimization.") + ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward + optimization_method = 'sdp' elif cmd_opts.opt_sub_quad_attention: print("Applying sub-quadratic cross attention optimization.") ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index c02d954c..a324a592 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -346,6 +346,48 @@ def xformers_attention_forward(self, x, context=None, mask=None): out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) +# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py +# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface +def scaled_dot_product_attention_forward(self, x, context=None, mask=None): + batch_size, sequence_length, inner_dim = x.shape + + if mask is not None: + mask = self.prepare_attention_mask(mask, sequence_length, batch_size) + mask = mask.view(batch_size, self.heads, -1, mask.shape[-1]) + + h = self.heads + q_in = self.to_q(x) + context = default(context, x) + + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) + k_in = self.to_k(context_k) + v_in = self.to_v(context_v) + + head_dim = inner_dim // h + q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2) + + del q_in, k_in, v_in + + dtype = q.dtype + if shared.opts.upcast_attn: + q, k = q.float(), k.float() + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + hidden_states = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim) + hidden_states = hidden_states.to(dtype) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) diff --git a/modules/shared.py b/modules/shared.py index 805f9cc1..12d0756b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -69,6 +69,7 @@ parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size fo parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None) parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.") parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find") +parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*") parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) -- cgit v1.2.3