aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_torch_utils.py
diff options
context:
space:
mode:
authorwangshuai09 <391746016@qq.com>2024-01-30 11:15:41 +0000
committerGitHub <noreply@github.com>2024-01-30 11:15:41 +0000
commit74ff85a1a1ee4cce432b1c7d33c1eda831f68d48 (patch)
tree99b70e0fef8422c8f603bf7faa1a393091cb2a8b /test/test_torch_utils.py
parentec124607f47371a6cfd61a795f86a7f1cbd44651 (diff)
parentce168ab5dbc8b54b7245f352a2eaa55a37019b91 (diff)
downloadstable-diffusion-webui-gfx803-74ff85a1a1ee4cce432b1c7d33c1eda831f68d48.tar.gz
stable-diffusion-webui-gfx803-74ff85a1a1ee4cce432b1c7d33c1eda831f68d48.tar.bz2
stable-diffusion-webui-gfx803-74ff85a1a1ee4cce432b1c7d33c1eda831f68d48.zip
Merge branch 'dev' into npu_support
Diffstat (limited to 'test/test_torch_utils.py')
-rw-r--r--test/test_torch_utils.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/test/test_torch_utils.py b/test/test_torch_utils.py
new file mode 100644
index 00000000..23ccb93a
--- /dev/null
+++ b/test/test_torch_utils.py
@@ -0,0 +1,19 @@
+import types
+
+import pytest
+import torch
+
+from modules import torch_utils
+
+
+@pytest.mark.parametrize("wrapped", [True, False])
+def test_get_param(wrapped):
+ mod = torch.nn.Linear(1, 1)
+ cpu = torch.device("cpu")
+ mod.to(dtype=torch.float16, device=cpu)
+ if wrapped:
+ # more or less how spandrel wraps a thing
+ mod = types.SimpleNamespace(model=mod)
+ p = torch_utils.get_param(mod)
+ assert p.dtype == torch.float16
+ assert p.device == cpu