diff options
author | Aarni Koskela <akx@iki.fi> | 2023-12-30 14:37:03 +0000 |
---|---|---|
committer | Aarni Koskela <akx@iki.fi> | 2023-12-30 14:37:03 +0000 |
commit | 4ad0c0c0a805da4bac03cff86ea17c25a1291546 (patch) | |
tree | 9821621545c6989205074d7bd23137eacbbad0e2 /modules/modelloader.py | |
parent | c756133541da478a35a74cda416d114a8973cf8e (diff) | |
download | stable-diffusion-webui-gfx803-4ad0c0c0a805da4bac03cff86ea17c25a1291546.tar.gz stable-diffusion-webui-gfx803-4ad0c0c0a805da4bac03cff86ea17c25a1291546.tar.bz2 stable-diffusion-webui-gfx803-4ad0c0c0a805da4bac03cff86ea17c25a1291546.zip |
Verify architecture for loaded Spandrel models
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r-- | modules/modelloader.py | 13 |
1 files changed, 12 insertions, 1 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py index 30116932..f4182559 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,6 +6,8 @@ import shutil import importlib from urllib.parse import urlparse +import torch + from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path @@ -183,9 +185,18 @@ def load_upscalers(): ) -def load_spandrel_model(path, *, device, half: bool = False, dtype=None): +def load_spandrel_model( + path: str, + *, + device: str | torch.device | None, + half: bool = False, + dtype: str | None = None, + expected_architecture: str | None = None, +): import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model.architecture != expected_architecture: + raise TypeError(f"Model {path} is not a {expected_architecture} model") if half: model = model.model.half() if dtype: |