aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-03-11 22:35:17 +0000
committerbrkirch <brkirch@users.noreply.github.com>2023-03-11 22:35:17 +0000
commita4cb96d4ae82741be9f0d072a37af3ae39521379 (patch)
tree7415bc6123a43d77b5d62a4db43cb6d8ed2b7e72
parent27e319dc4f09a2f040043948e5c52965976f8491 (diff)
downloadstable-diffusion-webui-gfx803-a4cb96d4ae82741be9f0d072a37af3ae39521379.tar.gz
stable-diffusion-webui-gfx803-a4cb96d4ae82741be9f0d072a37af3ae39521379.tar.bz2
stable-diffusion-webui-gfx803-a4cb96d4ae82741be9f0d072a37af3ae39521379.zip
Remove test, use bool tensor fix by default
The test isn't working correctly on macOS 13.3 and the bool tensor fix for cumsum is currently always needed anyway, so enable the fix by default.
-rw-r--r--modules/mac_specific.py3
1 files changed, 1 insertions, 2 deletions
diff --git a/modules/mac_specific.py b/modules/mac_specific.py
index ddcea53b..18e6ff72 100644
--- a/modules/mac_specific.py
+++ b/modules/mac_specific.py
@@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
- elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
+ elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)
@@ -45,7 +45,6 @@ if has_mps:
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
- cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)