diff options
author | InvincibleDude <81354513+InvincibleDude@users.noreply.github.com> | 2023-01-29 11:36:10 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-29 11:36:10 +0000 |
commit | ee3d63b6beb88e63542976a1095d4c8aa97388bd (patch) | |
tree | cf891130682c343107ae0a0f8cec309aea16807a /modules/devices.py | |
parent | 44c0e6b993d00bb2f441f0fde409bcb79136f034 (diff) | |
parent | 00dab8f10defbbda579a1bc89c8d4e972c58a20d (diff) | |
download | stable-diffusion-webui-gfx803-ee3d63b6beb88e63542976a1095d4c8aa97388bd.tar.gz stable-diffusion-webui-gfx803-ee3d63b6beb88e63542976a1095d4c8aa97388bd.tar.bz2 stable-diffusion-webui-gfx803-ee3d63b6beb88e63542976a1095d4c8aa97388bd.zip |
Merge branch 'master' into master
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/modules/devices.py b/modules/devices.py index 524ec7af..655ca1d3 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -34,14 +34,18 @@ def get_cuda_device_string(): return "cuda" -def get_optimal_device(): +def get_optimal_device_name(): if torch.cuda.is_available(): - return torch.device(get_cuda_device_string()) + return get_cuda_device_string() if has_mps(): - return torch.device("mps") + return "mps" + + return "cpu" + - return cpu +def get_optimal_device(): + return torch.device(get_optimal_device_name()) def get_device_for(task): @@ -79,6 +83,16 @@ cpu = torch.device("cpu") device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None dtype = torch.float16 dtype_vae = torch.float16 +dtype_unet = torch.float16 +unet_needs_upcast = False + + +def cond_cast_unet(input): + return input.to(dtype_unet) if unet_needs_upcast else input + + +def cond_cast_float(input): + return input.float() if unet_needs_upcast else input def randn(seed, shape): @@ -106,6 +120,10 @@ def autocast(disable=False): return torch.autocast("cuda") +def without_autocast(disable=False): + return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext() + + class NansException(Exception): pass @@ -123,7 +141,7 @@ def test_for_nans(x, where): message = "A tensor with all NaNs was produced in Unet." if not shared.cmd_opts.no_half: - message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this." + message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this." elif where == "vae": message = "A tensor with all NaNs was produced in VAE." @@ -133,6 +151,8 @@ def test_for_nans(x, where): else: message = "A tensor with all NaNs was produced." + message += " Use --disable-nan-check commandline argument to disable this check." + raise NansException(message) @@ -187,6 +207,3 @@ if has_mps(): cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0)) torch.cumsum = lambda input, *args, **kwargs: ( cumsum_fix(input, orig_cumsum, *args, **kwargs) ) torch.Tensor.cumsum = lambda self, *args, **kwargs: ( cumsum_fix(self, orig_Tensor_cumsum, *args, **kwargs) ) - orig_narrow = torch.narrow - torch.narrow = lambda *args, **kwargs: ( orig_narrow(*args, **kwargs).clone() ) - |