diff options
author | 源文雨 <41315874+fumiama@users.noreply.github.com> | 2022-11-12 03:02:40 +0000 |
---|---|---|
committer | 源文雨 <41315874+fumiama@users.noreply.github.com> | 2022-11-12 03:02:40 +0000 |
commit | 76ab31e18898d4c2aacb9725cfbe25b230bff974 (patch) | |
tree | b2f9ef01fb9c903835573ced67523818eb3c9ea7 /modules/devices.py | |
parent | 7ba3923d5b494b7756d0b12f33acb3716d830b9a (diff) | |
download | stable-diffusion-webui-gfx803-76ab31e18898d4c2aacb9725cfbe25b230bff974.tar.gz stable-diffusion-webui-gfx803-76ab31e18898d4c2aacb9725cfbe25b230bff974.tar.bz2 stable-diffusion-webui-gfx803-76ab31e18898d4c2aacb9725cfbe25b230bff974.zip |
Fix wrong mps selection below MasOS 12.3
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/modules/devices.py b/modules/devices.py index 7511e1dc..9a3d29d7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,8 +3,15 @@ import contextlib import torch from modules import errors -# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility -has_mps = getattr(torch, 'has_mps', False) +# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. +# check `getattr` and try it for compatibility +def has_mps() -> bool: + if getattr(torch, 'has_mps', False): return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False cpu = torch.device("cpu") @@ -25,7 +32,7 @@ def get_optimal_device(): else: return torch.device("cuda") - if has_mps: + if has_mps(): return torch.device("mps") return cpu |