diff options
author | Spaceginner <ivan.demian2009@gmail.com> | 2023-01-27 12:35:54 +0000 |
---|---|---|
committer | Spaceginner <ivan.demian2009@gmail.com> | 2023-01-27 12:35:54 +0000 |
commit | 56c83e453a2ac333a0888ab3835ad4c82feacc25 (patch) | |
tree | bf7090e3b8faf0ab02e3fe5bd43ac1cde2dc62dc /modules/devices.py | |
parent | 9ecf1e827c5966e11495a0c066a127defbba9bcc (diff) | |
parent | 63391419c11c1749a3d83dade19235a836c509f9 (diff) | |
download | stable-diffusion-webui-gfx803-56c83e453a2ac333a0888ab3835ad4c82feacc25.tar.gz stable-diffusion-webui-gfx803-56c83e453a2ac333a0888ab3835ad4c82feacc25.tar.bz2 stable-diffusion-webui-gfx803-56c83e453a2ac333a0888ab3835ad4c82feacc25.zip |
Merge remote-tracking branch 'origin/master'
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 14 |
1 files changed, 10 insertions, 4 deletions
diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..4687944e 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,14 +34,18 @@ def get_cuda_device_string(): return "cuda" -def get_optimal_device(): +def get_optimal_device_name(): if torch.cuda.is_available(): - return torch.device(get_cuda_device_string()) + return get_cuda_device_string() if has_mps(): - return torch.device("mps") + return "mps" + + return "cpu" - return cpu + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) def get_device_for(task): @@ -139,6 +143,8 @@ def test_for_nans(x, where): else: message = "A tensor with all NaNs was produced." + message += " Use --disable-nan-check commandline argument to disable this check." + raise NansException(message) |