diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-23 13:40:20 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-23 13:40:20 +0000 |
commit | 59146621e256269b85feb536edeb745da20daf68 (patch) | |
tree | 4210d6c2d709be8472dcb3584afb54e4284fd09f /modules/sd_hijack_optimizations.py | |
parent | 3fa482076a5f07d81d37d58a31b8c4fe3a740843 (diff) | |
download | stable-diffusion-webui-gfx803-59146621e256269b85feb536edeb745da20daf68.tar.gz stable-diffusion-webui-gfx803-59146621e256269b85feb536edeb745da20daf68.tar.bz2 stable-diffusion-webui-gfx803-59146621e256269b85feb536edeb745da20daf68.zip |
better support for xformers flash attention on older versions of torch
Diffstat (limited to 'modules/sd_hijack_optimizations.py')
-rw-r--r-- | modules/sd_hijack_optimizations.py | 42 |
1 files changed, 18 insertions, 24 deletions
diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 9967359b..74452709 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -9,7 +9,7 @@ from torch import einsum from ldm.util import default
from einops import rearrange
-from modules import shared
+from modules import shared, errors
from modules.hypernetworks import hypernetwork
from .sub_quadratic_attention import efficient_dot_product_attention
@@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_ )
+def get_xformers_flash_attention_op(q, k, v):
+ if not shared.cmd_opts.xformers_flash_attention:
+ return None
+
+ try:
+ flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
+ fw, bw = flash_attention_op
+ if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
+ return flash_attention_op
+ except Exception as e:
+ errors.display_once(e, "enabling flash attention")
+
+ return None
+
+
def xformers_attention_forward(self, x, context=None, mask=None):
h = self.heads
q_in = self.to_q(x)
@@ -291,18 +306,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in
- if shared.cmd_opts.xformers_flash_attention:
- op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
- fw, bw = op
- if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
- # print('xformers_attention_forward', q.shape, k.shape, v.shape)
- # Flash Attention is not availabe for the input arguments.
- # Fallback to default xFormers' backend.
- op = None
- else:
- op = None
-
- out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op)
+ out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
out = rearrange(out, 'b n h d -> b n (h d)', h=h)
return self.to_out(out)
@@ -377,17 +381,7 @@ def xformers_attnblock_forward(self, x): q = q.contiguous()
k = k.contiguous()
v = v.contiguous()
- if shared.cmd_opts.xformers_flash_attention:
- op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
- fw, bw = op
- if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)):
- # print('xformers_attnblock_forward', q.shape, k.shape, v.shape)
- # Flash Attention is not availabe for the input arguments.
- # Fallback to default xFormers' backend.
- op = None
- else:
- op = None
- out = xformers.ops.memory_efficient_attention(q, k, v, op=op)
+ out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
out = rearrange(out, 'b (h w) c -> b c h w', h=h)
out = self.proj_out(out)
return x + out
|