aboutsummaryrefslogtreecommitdiffstats
path: root/modules/npu_specific.py
diff options
context:
space:
mode:
authorwangshuai09 <391746016@qq.com>2024-01-27 09:21:32 +0000
committerwangshuai09 <391746016@qq.com>2024-01-29 11:25:06 +0000
commitec124607f47371a6cfd61a795f86a7f1cbd44651 (patch)
treed60205d2f58c80a0cc0bb8a079b9f33e7bc93f53 /modules/npu_specific.py
parentcf2772fab0af5573da775e7437e6acdca424f26e (diff)
downloadstable-diffusion-webui-gfx803-ec124607f47371a6cfd61a795f86a7f1cbd44651.tar.gz
stable-diffusion-webui-gfx803-ec124607f47371a6cfd61a795f86a7f1cbd44651.tar.bz2
stable-diffusion-webui-gfx803-ec124607f47371a6cfd61a795f86a7f1cbd44651.zip
Add NPU Support
Diffstat (limited to 'modules/npu_specific.py')
-rw-r--r--modules/npu_specific.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/modules/npu_specific.py b/modules/npu_specific.py
new file mode 100644
index 00000000..d8aebf9c
--- /dev/null
+++ b/modules/npu_specific.py
@@ -0,0 +1,34 @@
+import importlib
+import torch
+
+from modules import shared
+
+
+def check_for_npu():
+ if importlib.util.find_spec("torch_npu") is None:
+ return False
+ import torch_npu
+ torch_npu.npu.set_device(0)
+
+ try:
+ # Will raise a RuntimeError if no NPU is found
+ _ = torch.npu.device_count()
+ return torch.npu.is_available()
+ except RuntimeError:
+ return False
+
+
+def get_npu_device_string():
+ if shared.cmd_opts.device_id is not None:
+ return f"npu:{shared.cmd_opts.device_id}"
+ return "npu:0"
+
+
+def torch_npu_gc():
+ # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue
+ torch.npu.set_device(0)
+ with torch.npu.device(get_npu_device_string()):
+ torch.npu.empty_cache()
+
+
+has_npu = check_for_npu()