From 16b4d2cf3f51f1d88b97d1d459dec59d3a2d0642 Mon Sep 17 00:00:00 2001 From: Nuullll Date: Sat, 6 Jan 2024 16:32:18 +0800 Subject: [IPEX] Fix SDPA attn_mask dtype --- modules/xpu_specific.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/xpu_specific.py') diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index f7687a66..4e11125b 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -41,6 +41,8 @@ def torch_xpu_scaled_dot_product_attention( # cast to same dtype first key = key.to(query.dtype) value = value.to(query.dtype) + if attn_mask is not None and attn_mask.dtype != torch.bool: + attn_mask = attn_mask.to(query.dtype) N = query.shape[:-2] # Batch size L = query.size(-2) # Target sequence length -- cgit v1.2.3