diff options
author | Nuullll <vfirst218@gmail.com> | 2024-01-06 08:32:18 +0000 |
---|---|---|
committer | Nuullll <vfirst218@gmail.com> | 2024-01-06 08:32:18 +0000 |
commit | 16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642 (patch) | |
tree | a45a71289b1aac4584c46d3e46c7c102a25e69cd /modules/xpu_specific.py | |
parent | 8b6848c6dbee95f055b98b33804b12bd188ac625 (diff) | |
download | stable-diffusion-webui-gfx803-16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642.tar.gz stable-diffusion-webui-gfx803-16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642.tar.bz2 stable-diffusion-webui-gfx803-16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642.zip |
[IPEX] Fix SDPA attn_mask dtype
Diffstat (limited to 'modules/xpu_specific.py')
-rw-r--r-- | modules/xpu_specific.py | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index f7687a66..4e11125b 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention( # cast to same dtype first key = key.to(query.dtype) value = value.to(query.dtype) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(query.dtype) N = query.shape[:-2] # Batch size L = query.size(-2) # Target sequence length |