aboutsummaryrefslogtreecommitdiffstats
path: root/modules/npu_specific.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2024-01-31 19:39:29 +0000
committerGitHub <noreply@github.com>2024-01-31 19:39:29 +0000
commit96b550430a986fa49670249aabdd42cd182fb6c8 (patch)
tree44cbaa680fd3afeea5799fa75af4d52872f646ba /modules/npu_specific.py
parentce168ab5dbc8b54b7245f352a2eaa55a37019b91 (diff)
parentcc3f604310458eed7d26456c1b3934d582283ffe (diff)
downloadstable-diffusion-webui-gfx803-96b550430a986fa49670249aabdd42cd182fb6c8.tar.gz
stable-diffusion-webui-gfx803-96b550430a986fa49670249aabdd42cd182fb6c8.tar.bz2
stable-diffusion-webui-gfx803-96b550430a986fa49670249aabdd42cd182fb6c8.zip
Merge pull request #14801 from wangshuai09/npu_support
Add NPU Support
Diffstat (limited to 'modules/npu_specific.py')
-rw-r--r--modules/npu_specific.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/modules/npu_specific.py b/modules/npu_specific.py
new file mode 100644
index 00000000..94100691
--- /dev/null
+++ b/modules/npu_specific.py
@@ -0,0 +1,31 @@
+import importlib
+import torch
+
+from modules import shared
+
+
+def check_for_npu():
+ if importlib.util.find_spec("torch_npu") is None:
+ return False
+ import torch_npu
+
+ try:
+ # Will raise a RuntimeError if no NPU is found
+ _ = torch_npu.npu.device_count()
+ return torch.npu.is_available()
+ except RuntimeError:
+ return False
+
+
+def get_npu_device_string():
+ if shared.cmd_opts.device_id is not None:
+ return f"npu:{shared.cmd_opts.device_id}"
+ return "npu:0"
+
+
+def torch_npu_gc():
+ with torch.npu.device(get_npu_device_string()):
+ torch.npu.empty_cache()
+
+
+has_npu = check_for_npu()