aboutsummaryrefslogtreecommitdiffstats
path: root/modules/devices.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2023-01-27 15:19:43 +0000
committerbrkirch <brkirch@users.noreply.github.com>2023-01-28 09:16:25 +0000
commitada17dbd7c4c68a4e559848d2e6f2a7799722806 (patch)
treeced66b899aba64a4e5d7b66a3bc8cdb796e0cf16 /modules/devices.py
parentc4b9b07db6272768428fa8efeb7d7a9f22eca0b1 (diff)
downloadstable-diffusion-webui-gfx803-ada17dbd7c4c68a4e559848d2e6f2a7799722806.tar.gz
stable-diffusion-webui-gfx803-ada17dbd7c4c68a4e559848d2e6f2a7799722806.tar.bz2
stable-diffusion-webui-gfx803-ada17dbd7c4c68a4e559848d2e6f2a7799722806.zip
Refactor conditional casting, fix upscalers
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 6b36622c..0100e4af 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -83,6 +83,14 @@ dtype_unet = torch.float16
unet_needs_upcast = False
+def cond_cast_unet(input):
+ return input.to(dtype_unet) if unet_needs_upcast else input
+
+
+def cond_cast_float(input):
+ return input.float() if unet_needs_upcast else input
+
+
def randn(seed, shape):
torch.manual_seed(seed)
if device.type == 'mps':