aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
authorPam <pamhome21@gmail.com>2023-03-06 19:33:13 +0000
committerPam <pamhome21@gmail.com>2023-03-06 19:33:13 +0000
commitfec0a895119a124a295e3dad5205de5766031dc7 (patch)
tree000a8ea99831b164435454761d1e24830317bc89 /modules/sd_hijack_optimizations.py
parent0cc0ee1bcb4c24a8c9715f66cede06601bfc00c8 (diff)
downloadstable-diffusion-webui-gfx803-fec0a895119a124a295e3dad5205de5766031dc7.tar.gz
stable-diffusion-webui-gfx803-fec0a895119a124a295e3dad5205de5766031dc7.tar.bz2
stable-diffusion-webui-gfx803-fec0a895119a124a295e3dad5205de5766031dc7.zip
scaled dot product attention
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py42
1 files changed, 42 insertions, 0 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index c02d954c..a324a592 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -346,6 +346,48 @@ def xformers_attention_forward(self, x, context=None, mask=None):
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
+# Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py
+# The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface
+def scaled_dot_product_attention_forward(self, x, context=None, mask=None):
+ batch_size, sequence_length, inner_dim = x.shape
+
+ if mask is not None:
+ mask = self.prepare_attention_mask(mask, sequence_length, batch_size)
+ mask = mask.view(batch_size, self.heads, -1, mask.shape[-1])
+
+ h = self.heads
+ q_in = self.to_q(x)
+ context = default(context, x)
+
+ context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
+ k_in = self.to_k(context_k)
+ v_in = self.to_v(context_v)
+
+ head_dim = inner_dim // h
+ q = q_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ k = k_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+ v = v_in.view(batch_size, -1, h, head_dim).transpose(1, 2)
+
+ del q_in, k_in, v_in
+
+ dtype = q.dtype
+ if shared.opts.upcast_attn:
+ q, k = q.float(), k.float()
+
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
+ hidden_states = torch.nn.functional.scaled_dot_product_attention(
+ q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
+ )
+
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, h * head_dim)
+ hidden_states = hidden_states.to(dtype)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
def cross_attention_attnblock_forward(self, x):
h_ = x
h_ = self.norm(h_)