diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-03-12 05:14:26 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-03-12 05:14:26 +0000 |
commit | bbc4b0478ab24e69c94060d81ef778fcbe087b57 (patch) | |
tree | de0d23e27ad718deb15cb297e3e8f096f75a54ef /modules/mac_specific.py | |
parent | 55ccc8fe6f7e229b9b6fa724168541610bfa7631 (diff) | |
parent | a4cb96d4ae82741be9f0d072a37af3ae39521379 (diff) | |
download | stable-diffusion-webui-gfx803-bbc4b0478ab24e69c94060d81ef778fcbe087b57.tar.gz stable-diffusion-webui-gfx803-bbc4b0478ab24e69c94060d81ef778fcbe087b57.tar.bz2 stable-diffusion-webui-gfx803-bbc4b0478ab24e69c94060d81ef778fcbe087b57.zip |
Merge pull request #8518 from brkirch/remove-bool-test
Fix image generation on macOS 13.3 betas
Diffstat (limited to 'modules/mac_specific.py')
-rw-r--r-- | modules/mac_specific.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/modules/mac_specific.py b/modules/mac_specific.py index ddcea53b..18e6ff72 100644 --- a/modules/mac_specific.py +++ b/modules/mac_specific.py @@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs): output_dtype = kwargs.get('dtype', input.dtype) if output_dtype == torch.int64: return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) - elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): + elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) return cumsum_func(input, *args, **kwargs) @@ -45,7 +45,6 @@ if has_mps: CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) elif version.parse(torch.__version__) > version.parse("1.13.1"): cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) - cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) CondFunc('torch.cumsum', cumsum_fix_func, None) CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) |