aboutsummaryrefslogtreecommitdiffstats
path: root/modules/mac_specific.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2024-01-01 13:39:51 +0000
committerGitHub <noreply@github.com>2024-01-01 13:39:51 +0000
commitdfd64382211317cc46ad337c373492bfc420fa18 (patch)
tree3b1b2f5f3648da07430f54d1c155ce379a6fa3f7 /modules/mac_specific.py
parent3d15e58b0a30f2ef1e731f9e429f4d3cf1c259c5 (diff)
parent0ce67cb61806cf43f4d726d4705a4f6fdc2540e6 (diff)
downloadstable-diffusion-webui-gfx803-dfd64382211317cc46ad337c373492bfc420fa18.tar.gz
stable-diffusion-webui-gfx803-dfd64382211317cc46ad337c373492bfc420fa18.tar.bz2
stable-diffusion-webui-gfx803-dfd64382211317cc46ad337c373492bfc420fa18.zip
Merge branch 'dev' into feat/interrupted-end
Diffstat (limited to 'modules/mac_specific.py')
-rw-r--r--modules/mac_specific.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 89256c5b..d96d86d7 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,6 +1,7 @@
import logging
import torch
+from torch import Tensor
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
@@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
return cumsum_func(input, *args, **kwargs)
+# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
+def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
+ try:
+ return orig_func(*args, **kwargs)
+ except RuntimeError as e:
+ if "not implemented for" in str(e) and "Half" in str(e):
+ input_tensor = args[0]
+ return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
+ else:
+ print(f"An unexpected RuntimeError occurred: {str(e)}")
+
if has_mps:
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
@@ -77,6 +89,9 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
+ # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
+ CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']: