diff options
author | Nuullll <vfirst218@gmail.com> | 2023-12-02 06:00:46 +0000 |
---|---|---|
committer | Nuullll <vfirst218@gmail.com> | 2023-12-02 06:00:46 +0000 |
commit | 7499148ad4dbd3444215c843d02453f68c459707 (patch) | |
tree | 8c6e4360f7a56fcd540bd41edb548b2e69b08e84 /modules/devices.py | |
parent | 8b40f475a31109cc6ecbdc0d14a0cee9e0303291 (diff) | |
download | stable-diffusion-webui-gfx803-7499148ad4dbd3444215c843d02453f68c459707.tar.gz stable-diffusion-webui-gfx803-7499148ad4dbd3444215c843d02453f68c459707.tar.bz2 stable-diffusion-webui-gfx803-7499148ad4dbd3444215c843d02453f68c459707.zip |
Disable ipex autocast due to its bad perf
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/modules/devices.py b/modules/devices.py index be599736..37ecca78 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,11 +3,18 @@ import contextlib from functools import lru_cache import torch -from modules import errors, shared, xpu_specific +from modules import errors, shared if sys.platform == "darwin": from modules import mac_specific +if shared.cmd_opts.use_ipex: + from modules import xpu_specific + + +def has_xpu() -> bool: + return shared.cmd_opts.use_ipex and xpu_specific.has_xpu + def has_mps() -> bool: if sys.platform != "darwin": @@ -30,7 +37,7 @@ def get_optimal_device_name(): if has_mps(): return "mps" - if xpu_specific.has_ipex: + if has_xpu(): return xpu_specific.get_xpu_device_string() return "cpu" @@ -57,6 +64,9 @@ def torch_gc(): if has_mps(): mac_specific.torch_mps_gc() + if has_xpu(): + xpu_specific.torch_xpu_gc() + def enable_tf32(): if torch.cuda.is_available(): @@ -103,15 +113,11 @@ def autocast(disable=False): if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() - if xpu_specific.has_xpu: - return torch.autocast("xpu") - return torch.autocast("cuda") def without_autocast(disable=False): - device_type = "xpu" if xpu_specific.has_xpu else "cuda" - return torch.autocast(device_type, enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() class NansException(Exception): |