aboutsummaryrefslogtreecommitdiffstats
path: root/modules/devices.py
diff options
context:
space:
mode:
authorbrkirch <brkirch@users.noreply.github.com>2022-10-04 11:42:53 +0000
committerGitHub <noreply@github.com>2022-10-04 11:42:53 +0000
commite9e2a7ec9ac704f133f586eb34176e388c93c87c (patch)
tree969e0e595bd36987ae9de9ae302085ef555bba15 /modules/devices.py
parentdc9c5a97742e3a34d37da7108642d8adc0dc5858 (diff)
parentd5bba20a58f43a9f984bb67b4e17f48661f6b818 (diff)
downloadstable-diffusion-webui-gfx803-e9e2a7ec9ac704f133f586eb34176e388c93c87c.tar.gz
stable-diffusion-webui-gfx803-e9e2a7ec9ac704f133f586eb34176e388c93c87c.tar.bz2
stable-diffusion-webui-gfx803-e9e2a7ec9ac704f133f586eb34176e388c93c87c.zip
Merge branch 'master' into cpu-cmdline-opt
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/modules/devices.py b/modules/devices.py
index b7899632..0158b11f 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -1,3 +1,5 @@
+import contextlib
+
import torch
from modules import errors
@@ -56,3 +58,11 @@ def randn_without_seed(shape):
return torch.randn(shape, device=device)
+
+def autocast():
+ from modules import shared
+
+ if dtype == torch.float32 or shared.cmd_opts.precision == "full":
+ return contextlib.nullcontext()
+
+ return torch.autocast("cuda")