diff options
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563..be599736 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ import contextlib from functools import lru_cache import torch -from modules import errors, shared +from modules import errors, shared, xpu_specific if sys.platform == "darwin": from modules import mac_specific @@ -30,6 +30,9 @@ def get_optimal_device_name(): if has_mps(): return "mps" + if xpu_specific.has_ipex: + return xpu_specific.get_xpu_device_string() + return "cpu" @@ -100,11 +103,15 @@ def autocast(disable=False): if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() + if xpu_specific.has_xpu: + return torch.autocast("xpu") + return torch.autocast("cuda") def without_autocast(disable=False): - return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + device_type = "xpu" if xpu_specific.has_xpu else "cuda" + return torch.autocast(device_type, enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() class NansException(Exception): |