diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-10-22 10:58:00 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-22 10:58:00 +0000 |
commit | e80bdcab91df0d91fa268991bee1d0143e81920a (patch) | |
tree | 347f8cbcdf644885fcf3481ed7a2dc55f8942c6e /modules/devices.py | |
parent | 5aa9525046b7520d39fe8fc8c5c6cc10ab4d5fdb (diff) | |
parent | 1fa53dab2c5a857b9773f904fadf853dac1f1bd6 (diff) | |
download | stable-diffusion-webui-gfx803-e80bdcab91df0d91fa268991bee1d0143e81920a.tar.gz stable-diffusion-webui-gfx803-e80bdcab91df0d91fa268991bee1d0143e81920a.tar.bz2 stable-diffusion-webui-gfx803-e80bdcab91df0d91fa268991bee1d0143e81920a.zip |
Merge pull request #3377 from Extraltodeus/cuda-device-id-selection
Implementation of CUDA device id selection (--device-id 0/1/2)
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/modules/devices.py b/modules/devices.py index eb422583..8a159282 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,7 +1,6 @@ +import sys, os, shlex import contextlib - import torch - from modules import errors # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility @@ -9,10 +8,26 @@ has_mps = getattr(torch, 'has_mps', False) cpu = torch.device("cpu") +def extract_device_id(args, name): + for x in range(len(args)): + if name in args[x]: return args[x+1] + return None def get_optimal_device(): if torch.cuda.is_available(): - return torch.device("cuda") + # CUDA device selection support: + if "shared" not in sys.modules: + commandline_args = os.environ.get('COMMANDLINE_ARGS', "") #re-parse the commandline arguments because using the shared.py module creates an import loop. + sys.argv += shlex.split(commandline_args) + device_id = extract_device_id(sys.argv, '--device-id') + else: + device_id = shared.cmd_opts.device_id + + if device_id is not None: + cuda_device = f"cuda:{device_id}" + return torch.device(cuda_device) + else: + return torch.device("cuda") if has_mps: return torch.device("mps") |