import torch
# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility
has_mps =getattr(torch,'has_mps',False)
cpu = torch.device("cpu")defget_optimal_device():if torch.cuda.is_available():return torch.device("cuda")if has_mps:return torch.device("mps")return cpu