diff options
author | Dynamic <bradje@naver.com> | 2022-10-23 13:36:56 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-23 13:36:56 +0000 |
commit | 660ae690bd7107b78aac6413e1370f8cd72676bc (patch) | |
tree | b666cfd0872687ccd293a41d9d0a90fcdfe1ea0a /modules/devices.py | |
parent | 21364c5c39b269497944b56dd6664792d779333b (diff) | |
parent | 6bd6154a92eb05c80d66df661a38f8b70cc13729 (diff) | |
download | stable-diffusion-webui-gfx803-660ae690bd7107b78aac6413e1370f8cd72676bc.tar.gz stable-diffusion-webui-gfx803-660ae690bd7107b78aac6413e1370f8cd72676bc.tar.bz2 stable-diffusion-webui-gfx803-660ae690bd7107b78aac6413e1370f8cd72676bc.zip |
Merge branch 'AUTOMATIC1111:master' into kr-localization
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 19 |
1 files changed, 15 insertions, 4 deletions
diff --git a/modules/devices.py b/modules/devices.py index eb422583..dc1f3cdd 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,7 +1,6 @@ +import sys, os, shlex import contextlib - import torch - from modules import errors # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility @@ -9,10 +8,22 @@ has_mps = getattr(torch, 'has_mps', False) cpu = torch.device("cpu") +def extract_device_id(args, name): + for x in range(len(args)): + if name in args[x]: return args[x+1] + return None def get_optimal_device(): if torch.cuda.is_available(): - return torch.device("cuda") + from modules import shared + + device_id = shared.cmd_opts.device_id + + if device_id is not None: + cuda_device = f"cuda:{device_id}" + return torch.device(cuda_device) + else: + return torch.device("cuda") if has_mps: return torch.device("mps") @@ -34,7 +45,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") -device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = get_optimal_device() +device = device_interrogate = device_gfpgan = device_bsrgan = device_esrgan = device_scunet = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 |