diff options
author | Aarni Koskela <akx@iki.fi> | 2023-12-30 22:04:47 +0000 |
---|---|---|
committer | Aarni Koskela <akx@iki.fi> | 2023-12-30 22:04:47 +0000 |
commit | c0ca6348e8489651df861a101142805c213c66a0 (patch) | |
tree | ccfc8698b72d5cbffe499156d4ab7da67071fadb | |
parent | 3be90740316f8fbb950b31d440458a5e8ed4beb3 (diff) | |
download | stable-diffusion-webui-gfx803-c0ca6348e8489651df861a101142805c213c66a0.tar.gz stable-diffusion-webui-gfx803-c0ca6348e8489651df861a101142805c213c66a0.tar.bz2 stable-diffusion-webui-gfx803-c0ca6348e8489651df861a101142805c213c66a0.zip |
load_spandrel_model: always return a model descriptor
-rw-r--r-- | modules/modelloader.py | 23 |
1 files changed, 13 insertions, 10 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py index 0b89d682..8bcee08c 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,8 +1,9 @@ from __future__ import annotations +import importlib import logging import os -import importlib +from typing import TYPE_CHECKING from urllib.parse import urlparse import torch @@ -10,6 +11,8 @@ import torch from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone +if TYPE_CHECKING: + import spandrel logger = logging.getLogger(__name__) @@ -142,17 +145,17 @@ def load_spandrel_model( half: bool = False, dtype: str | None = None, expected_architecture: str | None = None, -): +) -> spandrel.ModelDescriptor: import spandrel - model = spandrel.ModelLoader(device=device).load_from_file(path) - if expected_architecture and model.architecture != expected_architecture: + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( - f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", + f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) if half: - model = model.model.half() + model_descriptor.model.half() if dtype: - model = model.model.to(dtype=dtype) - model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) - return model + model_descriptor.model.to(dtype=dtype) + model_descriptor.model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + return model_descriptor |