diff options
author | brkirch <brkirch@users.noreply.github.com> | 2022-11-17 08:52:17 +0000 |
---|---|---|
committer | brkirch <brkirch@users.noreply.github.com> | 2022-11-21 07:07:19 +0000 |
commit | e247b7400a592c0a19c197cd080aeec38ee02b68 (patch) | |
tree | 1f76c6ed9f55b2ad362b2eb68586dbd437c31e7c /modules/devices.py | |
parent | a5106a7cdc24153332e4eb1d28e66ea1d7f1ef79 (diff) | |
download | stable-diffusion-webui-gfx803-e247b7400a592c0a19c197cd080aeec38ee02b68.tar.gz stable-diffusion-webui-gfx803-e247b7400a592c0a19c197cd080aeec38ee02b68.tar.bz2 stable-diffusion-webui-gfx803-e247b7400a592c0a19c197cd080aeec38ee02b68.zip |
Add fixes for PyTorch 1.12.1
Fix typo "MasOS" -> "macOS"
If MPS is available and PyTorch is an earlier version than 1.13:
* Monkey patch torch.Tensor.to to ensure all tensors sent to MPS are contiguous
* Monkey patch torch.nn.functional.layer_norm to ensure input tensor is contiguous (required for this program to work with MPS on unmodified PyTorch 1.12.1)
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 28 |
1 files changed, 27 insertions, 1 deletions
diff --git a/modules/devices.py b/modules/devices.py index a87d0d4c..6e8277e5 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -2,9 +2,10 @@ import sys, os, shlex import contextlib import torch from modules import errors +from packaging import version -# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+. +# has_mps is only available in nightly pytorch (for now) and macOS 12.3+. # check `getattr` and try it for compatibility def has_mps() -> bool: if not getattr(torch, 'has_mps', False): @@ -94,3 +95,28 @@ def autocast(disable=False): return contextlib.nullcontext() return torch.autocast("cuda") + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/79383 +orig_tensor_to = torch.Tensor.to +def tensor_to_fix(self, *args, **kwargs): + if self.device.type != 'mps' and \ + ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \ + (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')): + self = self.contiguous() + return orig_tensor_to(self, *args, **kwargs) + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 +orig_layer_norm = torch.nn.functional.layer_norm +def layer_norm_fix(*args, **kwargs): + if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps': + args = list(args) + args[0] = args[0].contiguous() + return orig_layer_norm(*args, **kwargs) + + +# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working +if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix |