aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorNuullll <vfirst218@gmail.com>2024-01-06 08:32:18 +0000
committerNuullll <vfirst218@gmail.com>2024-01-06 08:32:18 +0000
commit16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642 (patch)
treea45a71289b1aac4584c46d3e46c7c102a25e69cd
parent8b6848c6dbee95f055b98b33804b12bd188ac625 (diff)
downloadstable-diffusion-webui-gfx803-16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642.tar.gz
stable-diffusion-webui-gfx803-16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642.tar.bz2
stable-diffusion-webui-gfx803-16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642.zip
[IPEX] Fix SDPA attn_mask dtype
-rw-r--r--modules/xpu_specific.py2
1 files changed, 2 insertions, 0 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index f7687a66..4e11125b 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention(
# cast to same dtype first
key = key.to(query.dtype)
value = value.to(query.dtype)
+ if attn_mask is not None and attn_mask.dtype != torch.bool:
+ attn_mask = attn_mask.to(query.dtype)
N = query.shape[:-2] # Batch size
L = query.size(-2) # Target sequence length