diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-16 19:59:46 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-16 19:59:46 +0000 |
commit | 9991967f40120b88a1dc925fdf7d747d5e016888 (patch) | |
tree | fd622c4c330bf948edef259badee58a2d59939f5 /modules/devices.py | |
parent | 52f6e94338f31c286361802b08ee5210b8244141 (diff) | |
download | stable-diffusion-webui-gfx803-9991967f40120b88a1dc925fdf7d747d5e016888.tar.gz stable-diffusion-webui-gfx803-9991967f40120b88a1dc925fdf7d747d5e016888.tar.bz2 stable-diffusion-webui-gfx803-9991967f40120b88a1dc925fdf7d747d5e016888.zip |
Add a check and explanation for tensor with all NaNs.
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/modules/devices.py b/modules/devices.py index caeb0276..6f034948 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -106,6 +106,33 @@ def autocast(disable=False): return torch.autocast("cuda") +class NansException(Exception): + pass + + +def test_for_nans(x, where): + from modules import shared + + if not torch.all(torch.isnan(x)).item(): + return + + if where == "unet": + message = "A tensor with all NaNs was produced in Unet." + + if not shared.cmd_opts.no_half: + message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this." + + elif where == "vae": + message = "A tensor with all NaNs was produced in VAE." + + if not shared.cmd_opts.no_half and not shared.cmd_opts.no_half_vae: + message += " This could be because there's not enough precision to represent the picture. Try adding --no-half-vae commandline argument to fix this." + else: + message = "A tensor with all NaNs was produced." + + raise NansException(message) + + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 orig_tensor_to = torch.Tensor.to def tensor_to_fix(self, *args, **kwargs): @@ -156,3 +183,4 @@ if has_mps(): torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) orig_narrow = torch.narrow torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) + |