diff options
author | brkirch <brkirch@users.noreply.github.com> | 2023-01-27 15:19:43 +0000 |
---|---|---|
committer | brkirch <brkirch@users.noreply.github.com> | 2023-01-28 09:16:25 +0000 |
commit | ada17dbd7c4c68a4e559848d2e6f2a7799722806 (patch) | |
tree | ced66b899aba64a4e5d7b66a3bc8cdb796e0cf16 /modules/devices.py | |
parent | c4b9b07db6272768428fa8efeb7d7a9f22eca0b1 (diff) | |
download | stable-diffusion-webui-gfx803-ada17dbd7c4c68a4e559848d2e6f2a7799722806.tar.gz stable-diffusion-webui-gfx803-ada17dbd7c4c68a4e559848d2e6f2a7799722806.tar.bz2 stable-diffusion-webui-gfx803-ada17dbd7c4c68a4e559848d2e6f2a7799722806.zip |
Refactor conditional casting, fix upscalers
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..0100e4af 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -83,6 +83,14 @@ dtype_unet = torch.float16 unet_needs_upcast = False +def cond_cast_unet(input): + return input.to(dtype_unet) if unet_needs_upcast else input + + +def cond_cast_float(input): + return input.float() if unet_needs_upcast else input + + def randn(seed, shape): torch.manual_seed(seed) if device.type == 'mps': |