diff options
author | Aarni Koskela <akx@iki.fi> | 2023-07-10 18:18:34 +0000 |
---|---|---|
committer | Aarni Koskela <akx@iki.fi> | 2023-07-11 09:51:05 +0000 |
commit | b85fc7187d953828340d4e3af34af46d9fc70b9e (patch) | |
tree | 5a706bd757e03227c3cd1ae1c5a026eae65107ab | |
parent | 7b833291b3ef4690ef158ee3415c2e93948acb2d (diff) | |
download | stable-diffusion-webui-gfx803-b85fc7187d953828340d4e3af34af46d9fc70b9e.tar.gz stable-diffusion-webui-gfx803-b85fc7187d953828340d4e3af34af46d9fc70b9e.tar.bz2 stable-diffusion-webui-gfx803-b85fc7187d953828340d4e3af34af46d9fc70b9e.zip |
Fix MPS cache cleanup
Importing torch does not import torch.mps so the call failed.
-rw-r--r-- | modules/devices.py | 5 | ||||
-rw-r--r-- | modules/mac_specific.py | 14 |
2 files changed, 17 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py index c5ad950f..57e51da3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -54,8 +54,9 @@ def torch_gc(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() - elif has_mps() and hasattr(torch.mps, 'empty_cache'): - torch.mps.empty_cache() + + if has_mps(): + mac_specific.torch_mps_gc() def enable_tf32(): diff --git a/modules/mac_specific.py b/modules/mac_specific.py index 735847f5..2c2f15ca 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -1,8 +1,12 @@ +import logging + import torch import platform from modules.sd_hijack_utils import CondFunc from packaging import version +log = logging.getLogger() + # before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, # use check `getattr` and try it for compatibility. @@ -19,9 +23,19 @@ def check_for_mps() -> bool: return False else: return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + has_mps = check_for_mps() +def torch_mps_gc() -> None: + try: + from torch.mps import empty_cache + empty_cache() + except Exception: + log.warning("MPS garbage collection failed", exc_info=True) + + # MPS workaround for https://github.com/pytorch/pytorch/issues/89784 def cumsum_fix(input, cumsum_func, *args, **kwargs): if input.device.type == 'mps': |