diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-09-12 13:34:13 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-09-12 13:34:13 +0000 |
commit | b70b51cc7248f57dcf16add86701139432c21e5b (patch) | |
tree | 8dbe1cbdd064872383fecbd8230601e8ccab484c | |
parent | 11e648f6c75de2fb22460d34a618dbb3aa6df0bc (diff) | |
download | stable-diffusion-webui-gfx803-b70b51cc7248f57dcf16add86701139432c21e5b.tar.gz stable-diffusion-webui-gfx803-b70b51cc7248f57dcf16add86701139432c21e5b.tar.bz2 stable-diffusion-webui-gfx803-b70b51cc7248f57dcf16add86701139432c21e5b.zip |
Allow TF32 in CUDA for increased performance #279
-rw-r--r-- | modules/devices.py | 11 | ||||
-rw-r--r-- | modules/errors.py | 10 |
2 files changed, 21 insertions, 0 deletions
diff --git a/modules/devices.py b/modules/devices.py index f88e807e..a93a245b 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -1,6 +1,8 @@ import torch # has_mps is only available in nightly pytorch (for now), `getattr` for compatibility +from modules import errors + has_mps = getattr(torch, 'has_mps', False) cpu = torch.device("cpu") @@ -20,3 +22,12 @@ def torch_gc(): if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() + + +def enable_tf32(): + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + + +errors.run(enable_tf32, "Enabling TF32") diff --git a/modules/errors.py b/modules/errors.py new file mode 100644 index 00000000..372dc51a --- /dev/null +++ b/modules/errors.py @@ -0,0 +1,10 @@ +import sys
+import traceback
+
+
+def run(code, task):
+ try:
+ code()
+ except Exception as e:
+ print(f"{task}: {type(e).__name__}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
|