diff options
author | Nuullll <vfirst218@gmail.com> | 2024-01-06 11:09:56 +0000 |
---|---|---|
committer | Nuullll <vfirst218@gmail.com> | 2024-01-06 11:09:56 +0000 |
commit | 73786c047f14d6ae658b2c12f493f05486ba1789 (patch) | |
tree | 0f2aa95217f704409b9bbf6cd3fced76c38ba402 | |
parent | b00b429477f8962001ddb556b9d543c5dcf34418 (diff) | |
download | stable-diffusion-webui-gfx803-73786c047f14d6ae658b2c12f493f05486ba1789.tar.gz stable-diffusion-webui-gfx803-73786c047f14d6ae658b2c12f493f05486ba1789.tar.bz2 stable-diffusion-webui-gfx803-73786c047f14d6ae658b2c12f493f05486ba1789.zip |
[IPEX] Fix torch.Generator hijack
-rw-r--r-- | modules/xpu_specific.py | 20 |
1 files changed, 16 insertions, 4 deletions
diff --git a/modules/xpu_specific.py b/modules/xpu_specific.py index 4e11125b..1137891a 100644 --- a/modules/xpu_specific.py +++ b/modules/xpu_specific.py @@ -94,11 +94,23 @@ def torch_xpu_scaled_dot_product_attention( return torch.reshape(result, (*N, L, Ev)) +def is_xpu_device(device: str | torch.device = None): + if device is None: + return False + if isinstance(device, str): + return device.startswith("xpu") + return device.type == "xpu" + + if has_xpu: - # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device - CondFunc('torch.Generator', - lambda orig_func, device=None: torch.xpu.Generator(device), - lambda orig_func, device=None: device is not None and device.type == "xpu") + try: + # torch.Generator supports "xpu" device since 2.1 + torch.Generator("xpu") + except: + # W/A for https://github.com/intel/intel-extension-for-pytorch/issues/452: torch.Generator API doesn't support XPU device (for IPEX < 2.1) + CondFunc('torch.Generator', + lambda orig_func, device=None: torch.xpu.Generator(device), + lambda orig_func, device=None: is_xpu_device(device)) # W/A for some OPs that could not handle different input dtypes CondFunc('torch.nn.functional.layer_norm', |