diff options
author | noodleanon <122053346+noodleanon@users.noreply.github.com> | 2023-01-07 14:18:09 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-07 14:18:09 +0000 |
commit | 50e25362794d46cd9a55c70e953a8b4126fd42f7 (patch) | |
tree | ea528f29a7c967de32f08217c50d994eebb277b3 /modules/devices.py | |
parent | eadd1bf06adbd7263875640a6446d3b0184d1561 (diff) | |
parent | 151233399c4b79934bdbb7c12a97eeb6499572fb (diff) | |
download | stable-diffusion-webui-gfx803-50e25362794d46cd9a55c70e953a8b4126fd42f7.tar.gz stable-diffusion-webui-gfx803-50e25362794d46cd9a55c70e953a8b4126fd42f7.tar.bz2 stable-diffusion-webui-gfx803-50e25362794d46cd9a55c70e953a8b4126fd42f7.zip |
Merge branch 'AUTOMATIC1111:master' into img2img-api-scripts
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 28 |
1 files changed, 23 insertions, 5 deletions
diff --git a/modules/devices.py b/modules/devices.py index 800510b7..caeb0276 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -133,8 +133,26 @@ def numpy_fix(self, *args, **kwargs): return orig_tensor_numpy(self, *args, **kwargs) -# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working -if has_mps() and version.parse(torch.__version__) < version.parse("1.13"): - torch.Tensor.to = tensor_to_fix - torch.nn.functional.layer_norm = layer_norm_fix - torch.Tensor.numpy = numpy_fix +# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +orig_cumsum = torch.cumsum +orig_Tensor_cumsum = torch.Tensor.cumsum +def cumsum_fix(input, cumsum_func, *args, **kwargs): + if input.device.type == 'mps': + output_dtype = kwargs.get('dtype', input.dtype) + if any(output_dtype == broken_dtype for broken_dtype in [torch.bool, torch.int8, torch.int16, torch.int64]): + return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + return cumsum_func(input, *args, **kwargs) + + +if has_mps(): + if version.parse(torch.__version__) < version.parse("1.13"): + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working + torch.Tensor.to = tensor_to_fix + torch.nn.functional.layer_norm = layer_norm_fix + torch.Tensor.numpy = numpy_fix + elif version.parse(torch.__version__) > version.parse("1.13.1"): + if not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.Tensor([1,1]).to(torch.device("mps")).cumsum(0, dtype=torch.int16)): + torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) + torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) + orig_narrow = torch.narrow + torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) |