diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-11-12 07:00:49 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-11-12 07:00:49 +0000 |
commit | 0ab0a50f9ae14bd7ce7ec518323ebd31c7971155 (patch) | |
tree | c0520e5e6e116087a91df46411825889bf6a08d0 /modules/devices.py | |
parent | c62d17aee36b5f4ca24f9cfa7bf6d7aca0c923f8 (diff) | |
download | stable-diffusion-webui-gfx803-0ab0a50f9ae14bd7ce7ec518323ebd31c7971155.tar.gz stable-diffusion-webui-gfx803-0ab0a50f9ae14bd7ce7ec518323ebd31c7971155.tar.bz2 stable-diffusion-webui-gfx803-0ab0a50f9ae14bd7ce7ec518323ebd31c7971155.zip |
change formatting to match the main program in devices.py
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 21 |
1 files changed, 16 insertions, 5 deletions
diff --git a/modules/devices.py b/modules/devices.py index bd3e4ffb..67165bf6 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,23 +3,27 @@ import contextlib import torch from modules import errors + # has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: - if not getattr(torch, 'has_mps', False): return False + if not getattr(torch, 'has_mps', False): + return False try: torch.zeros(1).to(torch.device("mps")) return True except Exception: return False -cpu = torch.device("cpu") def extract_device_id(args, name): for x in range(len(args)): - if name in args[x]: return args[x+1] + if name in args[x]: + return args[x + 1] + return None + def get_optimal_device(): if torch.cuda.is_available(): from modules import shared @@ -52,10 +56,12 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") +cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 + def randn(seed, shape): # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. if device.type == 'mps': @@ -89,6 +95,11 @@ def autocast(disable=False): return torch.autocast("cuda") + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor -def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) +def mps_contiguous(input_tensor, device): + return input_tensor.contiguous() if device.type == 'mps' else input_tensor + + +def mps_contiguous_to(input_tensor, device): + return mps_contiguous(input_tensor, device).to(device) |