diff options
Diffstat (limited to 'modules/mac_specific.py')
-rw-r--r-- | modules/mac_specific.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 735847f5..9ceb43ba 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,8 +1,12 @@ +import logging + import torch import platform from modules.sd_hijack_utils import CondFunc from packaging import version +log = logging.getLogger(__name__) + # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, # use check `getattr` and try it for compatibility. @@ -19,9 +23,23 @@ def check_for_mps() -> bool: return False else: return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + has_mps = check_for_mps() +def torch_mps_gc() -> None: + try: + from modules.shared import state + if state.current_latent is not None: + log.debug("`current_latent` is set, skipping MPS garbage collection") + return + from torch.mps import empty_cache + empty_cache() + except Exception: + log.warning("MPS garbage collection failed", exc_info=True) + + # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 def cumsum_fix(input, cumsum_func, *args, **kwargs): if input.device.type == 'mps': |