aboutsummaryrefslogtreecommitdiffstats
path: root/modules/mac_specific.py
diff options
context:
space:
mode:
authorfuchen.ljl <yjqqqqdx_01@163.com>2023-12-06 12:42:04 +0000
committerGitHub <noreply@github.com>2023-12-06 12:42:04 +0000
commitc2bdbb67b66de06f1163de3f10c290213cd6bdb0 (patch)
tree0fcb3010a72ad253862f317ea18fdeb46b05a322 /modules/mac_specific.py
parent4d56383025f2cbd00dc6296161e31a896624ab75 (diff)
parentf92d61497a426a19818625c3ccdaae9beeb82b31 (diff)
downloadstable-diffusion-webui-gfx803-c2bdbb67b66de06f1163de3f10c290213cd6bdb0.tar.gz
stable-diffusion-webui-gfx803-c2bdbb67b66de06f1163de3f10c290213cd6bdb0.tar.bz2
stable-diffusion-webui-gfx803-c2bdbb67b66de06f1163de3f10c290213cd6bdb0.zip
Merge branch 'dev' into kingljl-patch-memory-leak
Diffstat (limited to 'modules/mac_specific.py')
-rw-r--r--modules/mac_specific.py15
1 files changed, 15 insertions, 0 deletions
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index 89256c5b..d96d86d7 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -1,6 +1,7 @@
import logging
import torch
+from torch import Tensor
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version
@@ -51,6 +52,17 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
return cumsum_func(input, *args, **kwargs)
+# MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
+def interpolate_with_fp32_fallback(orig_func, *args, **kwargs) -> Tensor:
+ try:
+ return orig_func(*args, **kwargs)
+ except RuntimeError as e:
+ if "not implemented for" in str(e) and "Half" in str(e):
+ input_tensor = args[0]
+ return orig_func(input_tensor.to(torch.float32), *args[1:], **kwargs).to(input_tensor.dtype)
+ else:
+ print(f"An unexpected RuntimeError occurred: {str(e)}")
+
if has_mps:
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
@@ -77,6 +89,9 @@ if has_mps:
# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')
+ # MPS workaround for https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/14046
+ CondFunc('torch.nn.functional.interpolate', interpolate_with_fp32_fallback, None)
+
# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']: