aboutsummaryrefslogtreecommitdiffstats
path: root/modules/mac_specific.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-05-09 07:28:24 +0000
committerGitHub <noreply@github.com>2023-05-09 07:28:24 +0000
commitea05ddfec879f31ca2a7e171ed4a00ce6b7eb06b (patch)
tree3adef432c993a36d2f44d958e579f3f543ea5607 /modules/mac_specific.py
parent2b96a7b694d3392f76940dfe5df895a2833400fb (diff)
parentde401d8ffb46515a7cb4749f564d6a23085b4a5e (diff)
downloadstable-diffusion-webui-gfx803-ea05ddfec879f31ca2a7e171ed4a00ce6b7eb06b.tar.gz
stable-diffusion-webui-gfx803-ea05ddfec879f31ca2a7e171ed4a00ce6b7eb06b.tar.bz2
stable-diffusion-webui-gfx803-ea05ddfec879f31ca2a7e171ed4a00ce6b7eb06b.zip
Merge pull request #10201 from brkirch/mps-nan-fixes
Fix MPS on PyTorch 2.0.1, Intel Macs
Diffstat (limited to 'modules/mac_specific.py')
-rw-r--r--modules/mac_specific.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 6fe8dea0..40ce2101 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -54,6 +54,11 @@ if has_mps:
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
- if version.parse(torch.__version__) == version.parse("2.0"):
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
- CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
+ CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
+
+ # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
+ if platform.processor() == 'i386':
+ for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
+ CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') \ No newline at end of file