aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_hijack_optimizations.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-23 13:01:53 +0000
committerAUTOMATIC <16777216c@gmail.com>2023-01-23 13:01:53 +0000
commit3fa482076a5f07d81d37d58a31b8c4fe3a740843 (patch)
tree4f31fe62e676bc662d56b1674cfcf4b78f4444e2 /modules/sd_hijack_optimizations.py
parent194cbd065e4644e986889b78a5a949e075b610e8 (diff)
parent3262e825cc542ff634e6ba2e3a162eafdc6c1bba (diff)
downloadstable-diffusion-webui-gfx803-3fa482076a5f07d81d37d58a31b8c4fe3a740843.tar.gz
stable-diffusion-webui-gfx803-3fa482076a5f07d81d37d58a31b8c4fe3a740843.tar.bz2
stable-diffusion-webui-gfx803-3fa482076a5f07d81d37d58a31b8c4fe3a740843.zip
Merge remote-tracking branch 'takuma104/xformers-flash-attention'
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r--modules/sd_hijack_optimizations.py26
1 files changed, 24 insertions, 2 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py
index 4fa54329..9967359b 100644
--- a/modules/sd_hijack_optimizations.py
+++ b/modules/sd_hijack_optimizations.py
@@ -290,7 +290,19 @@ def xformers_attention_forward(self, x, context=None, mask=None):
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
+
+ if shared.cmd_opts.xformers_flash_attention:
+ op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
+ fw, bw = op
+ if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
+ # print('xformers_attention_forward', q.shape, k.shape, v.shape)
+ # Flash Attention is not availabe for the input arguments.
+ # Fallback to default xFormers' backend.
+ op = None
+ else:
+ op = None
+
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op)
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
@@ -365,7 +377,17 @@ def xformers_attnblock_forward(self, x):
q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
- out = xformers.ops.memory_efficient_attention(q, k, v)
+ if shared.cmd_opts.xformers_flash_attention:
+ op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
+ fw, bw = op
+ if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)):
+ # print('xformers_attnblock_forward', q.shape, k.shape, v.shape)
+ # Flash Attention is not availabe for the input arguments.
+ # Fallback to default xFormers' backend.
+ op = None
+ else:
+ op = None
+ out = xformers.ops.memory_efficient_attention(q, k, v, op=op)
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out