diff options
author | Aarni Koskela <akx@iki.fi> | 2023-12-30 14:37:03 +0000 |
---|---|---|
committer | Aarni Koskela <akx@iki.fi> | 2023-12-30 14:37:03 +0000 |
commit | 4ad0c0c0a805da4bac03cff86ea17c25a1291546 (patch) | |
tree | 9821621545c6989205074d7bd23137eacbbad0e2 | |
parent | c756133541da478a35a74cda416d114a8973cf8e (diff) | |
download | stable-diffusion-webui-gfx803-4ad0c0c0a805da4bac03cff86ea17c25a1291546.tar.gz stable-diffusion-webui-gfx803-4ad0c0c0a805da4bac03cff86ea17c25a1291546.tar.bz2 stable-diffusion-webui-gfx803-4ad0c0c0a805da4bac03cff86ea17c25a1291546.zip |
Verify architecture for loaded Spandrel models
-rw-r--r-- | extensions-builtin/ScuNET/scripts/scunet_model.py | 2 | ||||
-rw-r--r-- | extensions-builtin/SwinIR/scripts/swinir_model.py | 1 | ||||
-rw-r--r-- | modules/codeformer_model.py | 1 | ||||
-rw-r--r-- | modules/esrgan_model.py | 1 | ||||
-rw-r--r-- | modules/gfpgan_model.py | 1 | ||||
-rw-r--r-- | modules/hat_model.py | 1 | ||||
-rw-r--r-- | modules/modelloader.py | 13 | ||||
-rw-r--r-- | modules/realesrgan_model.py | 7 |
8 files changed, 22 insertions, 5 deletions
diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 18cf8e1a..5f3dd08b 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -121,7 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): filename = modelloader.load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - return modelloader.load_spandrel_model(filename, device=device) + return modelloader.load_spandrel_model(filename, device=device, expected_architecture='SCUNet') def on_ui_settings(): diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 85c18b9e..aae159af 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -75,6 +75,7 @@ class UpscalerSwinIR(Upscaler): filename, device=self._get_device(), dtype=devices.dtype, + expected_architecture="SwinIR", ) if getattr(opts, 'SWIN_torch_compile', False): try: diff --git a/modules/codeformer_model.py b/modules/codeformer_model.py index ceda4bab..44b84618 100644 --- a/modules/codeformer_model.py +++ b/modules/codeformer_model.py @@ -37,6 +37,7 @@ class FaceRestorerCodeFormer(face_restoration_utils.CommonFaceRestoration): return modelloader.load_spandrel_model(
model_path,
device=devices.device_codeformer,
+ expected_architecture='CodeFormer',
).model
raise ValueError("No codeformer model found")
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a7c7c9e3..70041ab0 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -49,6 +49,7 @@ class UpscalerESRGAN(Upscaler): return modelloader.load_spandrel_model(
filename,
device=('cpu' if devices.device_esrgan.type == 'mps' else None),
+ expected_architecture='ESRGAN',
)
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index a356b56f..48f8ad5e 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -37,6 +37,7 @@ class FaceRestorerGFPGAN(face_restoration_utils.CommonFaceRestoration): net = modelloader.load_spandrel_model(
model_path,
device=self.get_device(),
+ expected_architecture='GFPGAN',
).model
net.different_w = True # see https://github.com/chaiNNer-org/spandrel/pull/81
return net
diff --git a/modules/hat_model.py b/modules/hat_model.py index 553e1941..7f2abb41 100644 --- a/modules/hat_model.py +++ b/modules/hat_model.py @@ -39,4 +39,5 @@ class UpscalerHAT(Upscaler): return modelloader.load_spandrel_model(
path,
device=devices.device_esrgan, # TODO: should probably be device_hat
+ expected_architecture='HAT',
)
diff --git a/modules/modelloader.py b/modules/modelloader.py index 30116932..f4182559 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -6,6 +6,8 @@ import shutil import importlib from urllib.parse import urlparse +import torch + from modules import shared from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone from modules.paths import script_path, models_path @@ -183,9 +185,18 @@ def load_upscalers(): ) -def load_spandrel_model(path, *, device, half: bool = False, dtype=None): +def load_spandrel_model( + path: str, + *, + device: str | torch.device | None, + half: bool = False, + dtype: str | None = None, + expected_architecture: str | None = None, +): import spandrel model = spandrel.ModelLoader(device=device).load_from_file(path) + if expected_architecture and model.architecture != expected_architecture: + raise TypeError(f"Model {path} is not a {expected_architecture} model") if half: model = model.model.half() if dtype: diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 332d8f4b..2a2be5ad 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -1,9 +1,9 @@ import os
-from modules.upscaler_utils import upscale_with_model
-from modules.upscaler import Upscaler, UpscalerData
-from modules.shared import cmd_opts, opts
from modules import modelloader, errors
+from modules.shared import cmd_opts, opts
+from modules.upscaler import Upscaler, UpscalerData
+from modules.upscaler_utils import upscale_with_model
class UpscalerRealESRGAN(Upscaler):
@@ -40,6 +40,7 @@ class UpscalerRealESRGAN(Upscaler): info.local_data_path,
device=self.device,
half=(not cmd_opts.no_half and not cmd_opts.upcast_sampling),
+ expected_architecture="RealESRGAN",
)
return upscale_with_model(
mod,
|