diff options
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/modules/devices.py b/modules/devices.py index be542f8f..655ca1d3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,14 +34,18 @@ def get_cuda_device_string(): return "cuda" -def get_optimal_device(): +def get_optimal_device_name(): if torch.cuda.is_available(): - return torch.device(get_cuda_device_string()) + return get_cuda_device_string() if has_mps(): - return torch.device("mps") + return "mps" + + return "cpu" - return cpu + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) def get_device_for(task): @@ -147,6 +151,8 @@ def test_for_nans(x, where): else: message = "A tensor with all NaNs was produced." + message += " Use --disable-nan-check commandline argument to disable this check." + raise NansException(message) |