From ada17dbd7c4c68a4e559848d2e6f2a7799722806 Mon Sep 17 00:00:00 2001 From: brkirch Date: Fri, 27 Jan 2023 10:19:43 -0500 Subject: Refactor conditional casting, fix upscalers --- modules/devices.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index 6b36622c..0100e4af 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -83,6 +83,14 @@ dtype_unet = torch.float16 unet_needs_upcast = False +def cond_cast_unet(input): + return input.to(dtype_unet) if unet_needs_upcast else input + + +def cond_cast_float(input): + return input.float() if unet_needs_upcast else input + + def randn(seed, shape): torch.manual_seed(seed) if device.type == 'mps': -- cgit v1.2.3