aboutsummaryrefslogtreecommitdiffstats
path: root/modules/devices.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2023-01-27 08:28:12 +0000
committerAUTOMATIC <16777216c@gmail.com>2023-01-27 08:28:12 +0000
commitd2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 (patch)
tree056355bca8b5ff3071f4aec4a0c4d725f026413a /modules/devices.py
parent7a14c8ab45da8a681792a6331d48a88dd684a0a9 (diff)
downloadstable-diffusion-webui-gfx803-d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2.tar.gz
stable-diffusion-webui-gfx803-d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2.tar.bz2
stable-diffusion-webui-gfx803-d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2.zip
remove the need to place configs near models
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py12
1 files changed, 8 insertions, 4 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 6b36622c..2d5f797a 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -34,14 +34,18 @@ def get_cuda_device_string():
return "cuda"
-def get_optimal_device():
+def get_optimal_device_name():
if torch.cuda.is_available():
- return torch.device(get_cuda_device_string())
+ return get_cuda_device_string()
if has_mps():
- return torch.device("mps")
+ return "mps"
+
+ return "cpu"
- return cpu
+
+def get_optimal_device():
+ return torch.device(get_optimal_device_name())
def get_device_for(task):