From ba70a220e3176153ba2a559acb9e5aa692dce7ca Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 5 Jun 2023 22:20:29 +0300 Subject: Remove a bunch of unused/vestigial code As found by Vulture and some eyes --- modules/devices.py | 7 ------- 1 file changed, 7 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 1ed6ffdc..620ed1a6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -15,13 +15,6 @@ def has_mps() -> bool: else: return mac_specific.has_mps -def extract_device_id(args, name): - for x in range(len(args)): - if name in args[x]: - return args[x + 1] - - return None - def get_cuda_device_string(): from modules import shared -- cgit v1.2.3 From da8916f92649fc4d947cb46d9d8f8ea1621b2a59 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 8 Jul 2023 17:13:18 +0300 Subject: added torch.mps.empty_cache() to torch_gc() changed a bunch of places that use torch.cuda.empty_cache() to use torch_gc() instead --- modules/devices.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 620ed1a6..c5ad950f 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -49,10 +49,13 @@ def get_device_for(task): def torch_gc(): + if torch.cuda.is_available(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() + elif has_mps() and hasattr(torch.mps, 'empty_cache'): + torch.mps.empty_cache() def enable_tf32(): -- cgit v1.2.3 From b85fc7187d953828340d4e3af34af46d9fc70b9e Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 10 Jul 2023 21:18:34 +0300 Subject: Fix MPS cache cleanup Importing torch does not import torch.mps so the call failed. --- modules/devices.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index c5ad950f..57e51da3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -54,8 +54,9 @@ def torch_gc(): with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() - elif has_mps() and hasattr(torch.mps, 'empty_cache'): - torch.mps.empty_cache() + + if has_mps(): + mac_specific.torch_mps_gc() def enable_tf32(): -- cgit v1.2.3