diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2024-01-07 05:21:43 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-07 05:21:43 +0000 |
commit | 71e00571379dd77f6a8695e893a2499251715f9f (patch) | |
tree | 7b35e955593635a2128237514720432228e58d31 /modules/xpu_specific.py | |
parent | b00b429477f8962001ddb556b9d543c5dcf34418 (diff) | |
parent | 818d6a11e709bf07d48606bdccab944c46a5f4b0 (diff) | |
download | stable-diffusion-webui-gfx803-71e00571379dd77f6a8695e893a2499251715f9f.tar.gz stable-diffusion-webui-gfx803-71e00571379dd77f6a8695e893a2499251715f9f.tar.bz2 stable-diffusion-webui-gfx803-71e00571379dd77f6a8695e893a2499251715f9f.zip |
Merge pull request #14562 from Nuullll/fix-ipex-xpu-generator
[IPEX] Fix xpu generator
Diffstat (limited to 'modules/xpu_specific.py')
-rw-r--r-- | modules/xpu_specific.py | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 4e11125b..2971dbc3 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention( return torch.reshape(result, (*N, L, Ev)) +def is_xpu_device(device: str | torch.device = None): + if device is None: + return False + if isinstance(device, str): + return device.startswith("xpu") + return device.type == "xpu" + + if has_xpu: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device.type == "xpu") + try: + # torch.Generator supports "xpu" device since 2.1 + torch.Generator("xpu") + except RuntimeError: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for torch < 2.1) + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: is_xpu_device(device)) # W/A for some OPs that could not handle different input dtypes CondFunc('torch.nn.functional.layer_norm', |