aboutsummaryrefslogtreecommitdiffstats
path: root/modules/devices.py
diff options
context:
space:
mode:
authorDynamic <bradje@naver.com>2022-10-23 13:36:56 +0000
committerGitHub <noreply@github.com>2022-10-23 13:36:56 +0000
commit660ae690bd7107b78aac6413e1370f8cd72676bc (patch)
treeb666cfd0872687ccd293a41d9d0a90fcdfe1ea0a /modules/devices.py
parent21364c5c39b269497944b56dd6664792d779333b (diff)
parent6bd6154a92eb05c80d66df661a38f8b70cc13729 (diff)
downloadstable-diffusion-webui-gfx803-660ae690bd7107b78aac6413e1370f8cd72676bc.tar.gz
stable-diffusion-webui-gfx803-660ae690bd7107b78aac6413e1370f8cd72676bc.tar.bz2
stable-diffusion-webui-gfx803-660ae690bd7107b78aac6413e1370f8cd72676bc.zip
Merge branch 'AUTOMATIC1111:master' into kr-localization
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py19
1 files changed, 15 insertions, 4 deletions
diff --git a/modules/devices.py b/modules/devices.py
index eb422583..dc1f3cdd 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,7 +1,6 @@
+import sys, os, shlex
import contextlib
-
import torch
-
from modules import errors
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
@@ -9,10 +8,22 @@ has_mps = getattr(torch, 'has_mps', False)
cpu = torch.device("cpu")
+def extract_device_id(args, name):
+ for x in range(len(args)):
+ if name in args[x]: return args[x+1]
+ return None
def get_optimal_device():
if torch.cuda.is_available():
- return torch.device("cuda")
+ from modules import shared
+
+ device_id = shared.cmd_opts.device_id
+
+ if device_id is not None:
+ cuda_device = f"cuda:{device_id}"
+ return torch.device(cuda_device)
+ else:
+ return torch.device("cuda")
if has_mps:
return torch.device("mps")
@@ -34,7 +45,7 @@ def enable_tf32():
errors.run(enable_tf32, "Enabling TF32")
-device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device()
+device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16