aboutsummaryrefslogtreecommitdiffstats
path: root/modules/mac_specific.py
diff options
context:
space:
mode:
authormissionfloyd <missionfloyd@users.noreply.github.com>2023-07-12 08:57:57 +0000
committerGitHub <noreply@github.com>2023-07-12 08:57:57 +0000
commite0218c4f22396a1be8aa4fde3db17c6fc85904eb (patch)
tree9cc918c09da3afaa4dfd93a823e5485006b9e5c6 /modules/mac_specific.py
parent3fee3c34f1b01d21770ab0a226b432cdd8444792 (diff)
parent15adff3d6d5e8ba186b3df6eee8a8d774c8f3879 (diff)
downloadstable-diffusion-webui-gfx803-e0218c4f22396a1be8aa4fde3db17c6fc85904eb.tar.gz
stable-diffusion-webui-gfx803-e0218c4f22396a1be8aa4fde3db17c6fc85904eb.tar.bz2
stable-diffusion-webui-gfx803-e0218c4f22396a1be8aa4fde3db17c6fc85904eb.zip
Merge branch 'dev' into img2img-save
Diffstat (limited to 'modules/mac_specific.py')
-rw-r--r--modules/mac_specific.py35
1 files changed, 27 insertions, 8 deletions
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index d74c6b95..2c2f15ca 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,20 +1,39 @@
+import logging
+
import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
+log = logging.getLogger()
+
-# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
-# check `getattr` and try it for compatibility
+# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
+# use check `getattr` and try it for compatibility.
+# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
+# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool:
- if not getattr(torch, 'has_mps', False):
- return False
+ if version.parse(torch.__version__) <= version.parse("2.0.1"):
+ if not getattr(torch, 'has_mps', False):
+ return False
+ try:
+ torch.zeros(1).to(torch.device("mps"))
+ return True
+ except Exception:
+ return False
+ else:
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
+
+
+has_mps = check_for_mps()
+
+
+def torch_mps_gc() -> None:
try:
- torch.zeros(1).to(torch.device("mps"))
- return True
+ from torch.mps import empty_cache
+ empty_cache()
except Exception:
- return False
-has_mps = check_for_mps()
+ log.warning("MPS garbage collection failed", exc_info=True)
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784