diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-19 04:59:39 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-19 04:59:39 +0000 |
commit | 0a334b447ff0c41519bb9e280050736913ad9cf8 (patch) | |
tree | e27963f76b7357ff0cb7b2c3fdcb720ab64f0e50 /modules/mac_specific.py | |
parent | 6094310704f4b3853bfa5d05d9c1ace58b2deee7 (diff) | |
parent | c2b975485708791b29d44d79ee1a48d3abd838b7 (diff) | |
download | stable-diffusion-webui-gfx803-0a334b447ff0c41519bb9e280050736913ad9cf8.tar.gz stable-diffusion-webui-gfx803-0a334b447ff0c41519bb9e280050736913ad9cf8.tar.bz2 stable-diffusion-webui-gfx803-0a334b447ff0c41519bb9e280050736913ad9cf8.zip |
Merge branch 'dev' into allow-no-venv-install
Diffstat (limited to 'modules/mac_specific.py')
-rw-r--r-- | modules/mac_specific.py | 39 |
1 files changed, 31 insertions, 8 deletions
diff --git a/modules/mac_specific.py b/modules/mac_specific.py index d74c6b95..9ceb43ba 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,20 +1,43 @@ +import logging + import torch import platform from modules.sd_hijack_utils import CondFunc from packaging import version +log = logging.getLogger(__name__) + -# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. -# check `getattr` and try it for compatibility +# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, +# use check `getattr` and try it for compatibility. +# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty, +# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 def check_for_mps() -> bool: - if not getattr(torch, 'has_mps', False): - return False + if version.parse(torch.__version__) <= version.parse("2.0.1"): + if not getattr(torch, 'has_mps', False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False + else: + return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + +has_mps = check_for_mps() + + +def torch_mps_gc() -> None: try: - torch.zeros(1).to(torch.device("mps")) - return True + from modules.shared import state + if state.current_latent is not None: + log.debug("`current_latent` is set, skipping MPS garbage collection") + return + from torch.mps import empty_cache + empty_cache() except Exception: - return False -has_mps = check_for_mps() + log.warning("MPS garbage collection failed", exc_info=True) # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 |