diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-11-27 10:08:54 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-11-27 10:08:54 +0000 |
commit | 5b2c316890b7b8af95f0d0334d1fd34b9a687b99 (patch) | |
tree | 619a16137b796d40f0496ed57ab55a5492895bd6 /modules/devices.py | |
parent | 997ac57020b734894dd9fb19301e80bc52d7de72 (diff) | |
download | stable-diffusion-webui-gfx803-5b2c316890b7b8af95f0d0334d1fd34b9a687b99.tar.gz stable-diffusion-webui-gfx803-5b2c316890b7b8af95f0d0334d1fd34b9a687b99.tar.bz2 stable-diffusion-webui-gfx803-5b2c316890b7b8af95f0d0334d1fd34b9a687b99.zip |
eliminate duplicated code from #5095
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 30 |
1 files changed, 11 insertions, 19 deletions
diff --git a/modules/devices.py b/modules/devices.py index 93d82bbc..dd50fe24 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -24,17 +24,18 @@ def extract_device_id(args, name): return None -def get_optimal_device(): - if torch.cuda.is_available(): - from modules import shared +def get_cuda_device_string(): + from modules import shared + + if shared.cmd_opts.device_id is not None: + return f"cuda:{shared.cmd_opts.device_id}" - device_id = shared.cmd_opts.device_id + return "cuda" - if device_id is not None: - cuda_device = f"cuda:{device_id}" - return torch.device(cuda_device) - else: - return torch.device("cuda") + +def get_optimal_device(): + if torch.cuda.is_available(): + return torch.device(get_cuda_device_string()) if has_mps(): return torch.device("mps") @@ -44,16 +45,7 @@ def get_optimal_device(): def torch_gc(): if torch.cuda.is_available(): - from modules import shared - - device_id = shared.cmd_opts.device_id - - if device_id is not None: - cuda_device = f"cuda:{device_id}" - else: - cuda_device = "cuda" - - with torch.cuda.device(cuda_device): + with torch.cuda.device(get_cuda_device_string()): torch.cuda.empty_cache() torch.cuda.ipc_collect() |