From b5d1af11b7dc718d4d91d379c75e46f4bd2e2fe6 Mon Sep 17 00:00:00 2001 From: Abdullah Barhoum Date: Sun, 11 Sep 2022 07:11:27 +0200 Subject: Modular device management --- modules/devices.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) create mode 100644 modules/devices.py (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py new file mode 100644 index 00000000..25008a04 --- /dev/null +++ b/modules/devices.py @@ -0,0 +1,12 @@ +import torch + + +# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility +has_mps = getattr(torch, 'has_mps', False) + +def get_optimal_device(): + if torch.cuda.is_available(): + return torch.device("cuda") + if has_mps: + return torch.device("mps") + return torch.device("cpu") -- cgit v1.2.3