aboutsummaryrefslogtreecommitdiffstats
path: root/modules/xpu_specific.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2024-01-07 05:21:43 +0000
committerGitHub <noreply@github.com>2024-01-07 05:21:43 +0000
commit71e00571379dd77f6a8695e893a2499251715f9f (patch)
tree7b35e955593635a2128237514720432228e58d31 /modules/xpu_specific.py
parentb00b429477f8962001ddb556b9d543c5dcf34418 (diff)
parent818d6a11e709bf07d48606bdccab944c46a5f4b0 (diff)
downloadstable-diffusion-webui-gfx803-71e00571379dd77f6a8695e893a2499251715f9f.tar.gz
stable-diffusion-webui-gfx803-71e00571379dd77f6a8695e893a2499251715f9f.tar.bz2
stable-diffusion-webui-gfx803-71e00571379dd77f6a8695e893a2499251715f9f.zip
Merge pull request #14562 from Nuullll/fix-ipex-xpu-generator
[IPEX] Fix xpu generator
Diffstat (limited to 'modules/xpu_specific.py')
-rw-r--r--modules/xpu_specific.py20
1 files changed, 16 insertions, 4 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py
index 4e11125b..2971dbc3 100644
--- a/modules/xpu_specific.py
+++ b/modules/xpu_specific.py
@@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention(
return torch.reshape(result, (*N, L, Ev))
+def is_xpu_device(device: str | torch.device = None):
+ if device is None:
+ return False
+ if isinstance(device, str):
+ return device.startswith("xpu")
+ return device.type == "xpu"
+
+
if has_xpu:
- # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device
- CondFunc('torch.Generator',
- lambda orig_func, device=None: torch.xpu.Generator(device),
- lambda orig_func, device=None: device is not None and device.type == "xpu")
+ try:
+ # torch.Generator supports "xpu" device since 2.1
+ torch.Generator("xpu")
+ except RuntimeError:
+ # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1)
+ CondFunc('torch.Generator',
+ lambda orig_func, device=None: torch.xpu.Generator(device),
+ lambda orig_func, device=None: is_xpu_device(device))
# W/A for some OPs that could not handle different input dtypes
CondFunc('torch.nn.functional.layer_norm',