From daf41a273485e865c9c9ef458b2c26be4422bcb2 Mon Sep 17 00:00:00 2001 From: Hao-Wu Date: Thu, 6 Jul 2023 15:37:10 +0800 Subject: Fix warning of 'has_mps' is deprecated from PyTorch --- modules/mac_specific.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) (limited to 'modules/mac_specific.py') diff --git a/modules/mac_specific.py b/modules/mac_specific.py index d74c6b95..735847f5 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -4,16 +4,21 @@ from modules.sd_hijack_utils import CondFunc from packaging import version -# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. -# check `getattr` and try it for compatibility +# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, +# use check `getattr` and try it for compatibility. +# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty, +# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 def check_for_mps() -> bool: - if not getattr(torch, 'has_mps', False): - return False - try: - torch.zeros(1).to(torch.device("mps")) - return True - except Exception: - return False + if version.parse(torch.__version__) <= version.parse("2.0.1"): + if not getattr(torch, 'has_mps', False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False + else: + return torch.backends.mps.is_available() and torch.backends.mps.is_built() has_mps = check_for_mps() -- cgit v1.2.3 From b85fc7187d953828340d4e3af34af46d9fc70b9e Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 10 Jul 2023 21:18:34 +0300 Subject: Fix MPS cache cleanup Importing torch does not import torch.mps so the call failed. --- modules/mac_specific.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) (limited to 'modules/mac_specific.py') diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 735847f5..2c2f15ca 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,8 +1,12 @@ +import logging + import torch import platform from modules.sd_hijack_utils import CondFunc from packaging import version +log = logging.getLogger() + # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, # use check `getattr` and try it for compatibility. @@ -19,9 +23,19 @@ def check_for_mps() -> bool: return False else: return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + has_mps = check_for_mps() +def torch_mps_gc() -> None: + try: + from torch.mps import empty_cache + empty_cache() + except Exception: + log.warning("MPS garbage collection failed", exc_info=True) + + # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 def cumsum_fix(input, cumsum_func, *args, **kwargs): if input.device.type == 'mps': -- cgit v1.2.3