diff options
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r-- | modules/modelloader.py | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py index 098bcb79..f4182559 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,15 +1,21 @@ from __future__ import annotations +import logging import os import shutil import importlib from urllib.parse import urlparse +import torch + from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path +logger = logging.getLogger(__name__) + + def load_file_from_url( url: str, *, @@ -177,3 +183,24 @@ def load_upscalers(): # Special case for UpscalerNone keeps it at the beginning of the list. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) + + +def load_spandrel_model( + path: str, + *, + device: str | torch.device | None, + half: bool = False, + dtype: str | None = None, + expected_architecture: str | None = None, +): + import spandrel + model = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model.architecture != expected_architecture: + raise TypeError(f"Model {path} is not a {expected_architecture} model") + if half: + model = model.model.half() + if dtype: + model = model.model.to(dtype=dtype) + model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) + return model |