diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-04 16:56:35 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-04 16:56:35 +0000 |
commit | eeb1de4388773ba92b9920a4f64eb91add2e02ca (patch) | |
tree | 22f5d5e7417f24599a415fd64c9f1652495ce5a3 /modules/devices.py | |
parent | d85c2cb2d59f64cbb510a9e5596596de2e4f4dcc (diff) | |
parent | b7deea47eeb033052062621b0005d4321b53bff7 (diff) | |
download | stable-diffusion-webui-gfx803-eeb1de4388773ba92b9920a4f64eb91add2e02ca.tar.gz stable-diffusion-webui-gfx803-eeb1de4388773ba92b9920a4f64eb91add2e02ca.tar.bz2 stable-diffusion-webui-gfx803-eeb1de4388773ba92b9920a4f64eb91add2e02ca.zip |
Merge branch 'master' into gradient-clipping
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 115 |
1 files changed, 84 insertions, 31 deletions
diff --git a/modules/devices.py b/modules/devices.py index 7511e1dc..800510b7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -2,72 +2,95 @@ import sys, os, shlex import contextlib import torch from modules import errors +from packaging import version -# has_mps is only available in nightly pytorch (for now), `getattr` for compatibility -has_mps = getattr(torch, 'has_mps', False) -cpu = torch.device("cpu") +# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. +# check `getattr` and try it for compatibility +def has_mps() -> bool: + if not getattr(torch, 'has_mps', False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False + def extract_device_id(args, name): for x in range(len(args)): - if name in args[x]: return args[x+1] + if name in args[x]: + return args[x + 1] + return None -def get_optimal_device(): - if torch.cuda.is_available(): - from modules import shared - device_id = shared.cmd_opts.device_id +def get_cuda_device_string(): + from modules import shared + + if shared.cmd_opts.device_id is not None: + return f"cuda:{shared.cmd_opts.device_id}" - if device_id is not None: - cuda_device = f"cuda:{device_id}" - return torch.device(cuda_device) - else: - return torch.device("cuda") + return "cuda" - if has_mps: + +def get_optimal_device(): + if torch.cuda.is_available(): + return torch.device(get_cuda_device_string()) + + if has_mps(): return torch.device("mps") return cpu +def get_device_for(task): + from modules import shared + + if task in shared.cmd_opts.use_cpu: + return cpu + + return get_optimal_device() + + def torch_gc(): if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + with torch.cuda.device(get_cuda_device_string()): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def enable_tf32(): if torch.cuda.is_available(): + + # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't + # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 + if any([torch.cuda.get_device_capability(devid) == (7, 5) for devid in range(0, torch.cuda.device_count())]): + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True + errors.run(enable_tf32, "Enabling TF32") -device = device_interrogate = device_gfpgan = device_swinir = device_esrgan = device_scunet = device_codeformer = None +cpu = torch.device("cpu") +device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 -def randn(seed, shape): - # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. - if device.type == 'mps': - generator = torch.Generator(device=cpu) - generator.manual_seed(seed) - noise = torch.randn(shape, generator=generator, device=cpu).to(device) - return noise +def randn(seed, shape): torch.manual_seed(seed) + if device.type == 'mps': + return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) def randn_without_seed(shape): - # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. if device.type == 'mps': - generator = torch.Generator(device=cpu) - noise = torch.randn(shape, generator=generator, device=cpu).to(device) - return noise - + return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) @@ -82,6 +105,36 @@ def autocast(disable=False): return torch.autocast("cuda") + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 -def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor -def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device) +orig_tensor_to = torch.Tensor.to +def tensor_to_fix(self, *args, **kwargs): + if self.device.type != 'mps' and \ + ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ + (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): + self = self.contiguous() + return orig_tensor_to(self, *args, **kwargs) + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 +orig_layer_norm = torch.nn.functional.layer_norm +def layer_norm_fix(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': + args = list(args) + args[0] = args[0].contiguous() + return orig_layer_norm(*args, **kwargs) + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/90532 +orig_tensor_numpy = torch.Tensor.numpy +def numpy_fix(self, *args, **kwargs): + if self.requires_grad: + self = self.detach() + return orig_tensor_numpy(self, *args, **kwargs) + + +# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working +if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix |