From b0f59342346b1c8b405f97c0e0bb01c6ae05c601 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 25 Dec 2023 14:43:51 +0200 Subject: Use Spandrel for upscaling and face restoration architectures (aside from GFPGAN and LDSR) --- modules/modelloader.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'modules/modelloader.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index 098bcb79..30116932 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os import shutil import importlib @@ -10,6 +11,9 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale from modules.paths import script_path, models_path +logger = logging.getLogger(__name__) + + def load_file_from_url( url: str, *, @@ -177,3 +181,15 @@ def load_upscalers(): # Special case for UpscalerNone keeps it at the beginning of the list. key=lambda x: x.name.lower() if not isinstance(x.scaler, (UpscalerNone, UpscalerLanczos, UpscalerNearest)) else "" ) + + +def load_spandrel_model(path, *, device, half: bool = False, dtype=None): + import spandrel + model = spandrel.ModelLoader(device=device).load_from_file(path) + if half: + model = model.model.half() + if dtype: + model = model.model.to(dtype=dtype) + model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) + return model -- cgit v1.2.3 From 4ad0c0c0a805da4bac03cff86ea17c25a1291546 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 16:37:03 +0200 Subject: Verify architecture for loaded Spandrel models --- modules/modelloader.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) (limited to 'modules/modelloader.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index 30116932..f4182559 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,6 +6,8 @@ import shutil import importlib from urllib.parse import urlparse +import torch + from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path @@ -183,9 +185,18 @@ def load_upscalers(): ) -def load_spandrel_model(path, *, device, half: bool = False, dtype=None): +def load_spandrel_model( + path: str, + *, + device: str | torch.device | None, + half: bool = False, + dtype: str | None = None, + expected_architecture: str | None = None, +): import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model.architecture != expected_architecture: + raise TypeError(f"Model {path} is not a {expected_architecture} model") if half: model = model.model.half() if dtype: -- cgit v1.2.3 From 5fbb13e0da8eb2e26bd2c45ec8ffbb2de669ef47 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 20:46:44 +0200 Subject: Remove `cleanup_models` code --- modules/modelloader.py | 50 -------------------------------------------------- 1 file changed, 50 deletions(-) (limited to 'modules/modelloader.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index f4182559..5f7aec3e 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -2,7 +2,6 @@ from __future__ import annotations import logging import os -import shutil import importlib from urllib.parse import urlparse @@ -10,7 +9,6 @@ import torch from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone -from modules.paths import script_path, models_path logger = logging.getLogger(__name__) @@ -96,54 +94,6 @@ def friendly_name(file: str): return model_name -def cleanup_models(): - # This code could probably be more efficient if we used a tuple list or something to store the src/destinations - # and then enumerate that, but this works for now. In the future, it'd be nice to just have every "model" scaler - # somehow auto-register and just do these things... - root_path = script_path - src_path = models_path - dest_path = os.path.join(models_path, "Stable-diffusion") - move_files(src_path, dest_path, ".ckpt") - move_files(src_path, dest_path, ".safetensors") - src_path = os.path.join(root_path, "ESRGAN") - dest_path = os.path.join(models_path, "ESRGAN") - move_files(src_path, dest_path) - src_path = os.path.join(models_path, "BSRGAN") - dest_path = os.path.join(models_path, "ESRGAN") - move_files(src_path, dest_path, ".pth") - src_path = os.path.join(root_path, "gfpgan") - dest_path = os.path.join(models_path, "GFPGAN") - move_files(src_path, dest_path) - src_path = os.path.join(root_path, "SwinIR") - dest_path = os.path.join(models_path, "SwinIR") - move_files(src_path, dest_path) - src_path = os.path.join(root_path, "repositories/latent-diffusion/experiments/pretrained_models/") - dest_path = os.path.join(models_path, "LDSR") - move_files(src_path, dest_path) - - -def move_files(src_path: str, dest_path: str, ext_filter: str = None): - try: - os.makedirs(dest_path, exist_ok=True) - if os.path.exists(src_path): - for file in os.listdir(src_path): - fullpath = os.path.join(src_path, file) - if os.path.isfile(fullpath): - if ext_filter is not None: - if ext_filter not in file: - continue - print(f"Moving {file} from {src_path} to {dest_path}.") - try: - shutil.move(fullpath, dest_path) - except Exception: - pass - if len(os.listdir(src_path)) == 0: - print(f"Removing empty folder: {src_path}") - shutil.rmtree(src_path, True) - except Exception: - pass - - def load_upscalers(): # We can only do this 'magic' method to dynamically load upscalers if they are referenced, # so we'll try to import any _model.py files before looking in __subclasses__ -- cgit v1.2.3 From af050dcaa75ef40b6b1c3da3361f32fe52786aeb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sat, 30 Dec 2023 21:05:59 +0200 Subject: Soften Spandrel model-architecture check to just a warning --- modules/modelloader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/modelloader.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index f4182559..6b7d697f 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -196,7 +196,9 @@ def load_spandrel_model( import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) if expected_architecture and model.architecture != expected_architecture: - raise TypeError(f"Model {path} is not a {expected_architecture} model") + logger.warning( + f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", + ) if half: model = model.model.half() if dtype: -- cgit v1.2.3 From c0ca6348e8489651df861a101142805c213c66a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:04:47 +0200 Subject: load_spandrel_model: always return a model descriptor --- modules/modelloader.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) (limited to 'modules/modelloader.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index 0b89d682..8bcee08c 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,8 +1,9 @@ from __future__ import annotations +import importlib import logging import os -import importlib +from typing import TYPE_CHECKING from urllib.parse import urlparse import torch @@ -10,6 +11,8 @@ import torch from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone +if TYPE_CHECKING: + import spandrel logger = logging.getLogger(__name__) @@ -142,17 +145,17 @@ def load_spandrel_model( half: bool = False, dtype: str | None = None, expected_architecture: str | None = None, -): +) -> spandrel.ModelDescriptor: import spandrel - model = spandrel.ModelLoader(device=device).load_from_file(path) - if expected_architecture and model.architecture != expected_architecture: + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( - f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})", + f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) if half: - model = model.model.half() + model_descriptor.model.half() if dtype: - model = model.model.to(dtype=dtype) - model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype) - return model + model_descriptor.model.to(dtype=dtype) + model_descriptor.model.eval() + logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + return model_descriptor -- cgit v1.2.3 From 777af661a21821994993df3ef566b01df2bb61a0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 00:09:51 +0200 Subject: Be more clear about Spandrel model nomenclature --- modules/modelloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/modelloader.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index 8bcee08c..a7194137 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -143,7 +143,7 @@ def load_spandrel_model( *, device: str | torch.device | None, half: bool = False, - dtype: str | None = None, + dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel -- cgit v1.2.3 From 2cacbc124c49f45da5b66b79d9b0a3ab943472eb Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Sun, 31 Dec 2023 19:52:32 +0200 Subject: load_spandrel_model: make `half` `prefer_half` As discussed with the Spandrel folks, it's good to heed Spandrel's "supports half precision" flag to avoid e.g. black blotches and what-not. --- modules/modelloader.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) (limited to 'modules/modelloader.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index a7194137..e100bb24 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -139,23 +139,31 @@ def load_upscalers(): def load_spandrel_model( - path: str, + path: str | os.PathLike, *, device: str | torch.device | None, - half: bool = False, + prefer_half: bool = False, dtype: str | torch.dtype | None = None, expected_architecture: str | None = None, ) -> spandrel.ModelDescriptor: import spandrel - model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path) + model_descriptor = spandrel.ModelLoader(device=device).load_from_file(str(path)) if expected_architecture and model_descriptor.architecture != expected_architecture: logger.warning( f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})", ) - if half: - model_descriptor.model.half() + half = False + if prefer_half: + if model_descriptor.supports_half: + model_descriptor.model.half() + half = True + else: + logger.info("Model %s does not support half precision, ignoring --half", path) if dtype: model_descriptor.model.to(dtype=dtype) model_descriptor.model.eval() - logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype) + logger.debug( + "Loaded %s from %s (device=%s, half=%s, dtype=%s)", + model_descriptor, path, device, half, dtype, + ) return model_descriptor -- cgit v1.2.3