From 89352a2f52c6be51318192cedd86c8a342966a49 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:34:26 +0300 Subject: Move `load_file_from_url` to modelloader --- modules/esrgan_model.py | 4 +--- modules/modelloader.py | 29 ++++++++++++++++++++++++++--- modules/realesrgan_model.py | 4 ++-- 3 files changed, 29 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 2fced999..f1a98c07 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -3,7 +3,6 @@ import os import numpy as np import torch from PIL import Image -from basicsr.utils.download_util import load_file_from_url import modules.esrgan_model_arch as arch from modules import modelloader, images, devices @@ -152,11 +151,10 @@ class UpscalerESRGAN(Upscaler): def load_model(self, path: str): if "http" in path: - filename = load_file_from_url( + filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, file_name=f"{self.model_name}.pth", - progress=True, ) else: filename = path diff --git a/modules/modelloader.py b/modules/modelloader.py index be23071a..a69c8a4f 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import importlib @@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale from modules.paths import script_path, models_path +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: str | None = None, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + + Returns the path to the downloaded file. + """ + os.makedirs(model_dir, exist_ok=True) + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + from torch.hub import download_url_to_file + download_url_to_file(url, cached_file, progress=progress) + return cached_file + + def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if model_url is not None and len(output) == 0: if download_name is not None: - from basicsr.utils.download_util import load_file_from_url - dl = load_file_from_url(model_url, places[0], True, download_name) - output.append(dl) + output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name)) else: output.append(model_url) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 2d27b321..0d9c2e48 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -2,7 +2,6 @@ import os import numpy as np from PIL import Image -from basicsr.utils.download_util import load_file_from_url from realesrgan import RealESRGANer from modules.upscaler import Upscaler, UpscalerData @@ -10,6 +9,7 @@ from modules.shared import cmd_opts, opts from modules import modelloader, errors + class UpscalerRealESRGAN(Upscaler): def __init__(self, path): self.name = "RealESRGAN" @@ -71,7 +71,7 @@ class UpscalerRealESRGAN(Upscaler): return None if info.local_data_path.startswith("http"): - info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True) + info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path) return info except Exception: -- cgit v1.2.3 From 0afbc0c2355ead3a0ce7149a6d678f1f2e2fbfee Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:41:36 +0300 Subject: Fix up `if "http" in ...:` to be more sensible startswiths --- extensions-builtin/ScuNET/scripts/scunet_model.py | 4 ++-- extensions-builtin/SwinIR/scripts/swinir_model.py | 4 ++-- modules/esrgan_model.py | 4 ++-- modules/gfpgan_model.py | 2 +- modules/modelloader.py | 2 +- 5 files changed, 8 insertions(+), 8 deletions(-) (limited to 'modules') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 2785b551..64f50829 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -27,7 +27,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): scalers = [] add_model2 = True for file in model_paths: - if "http" in file: + if file.startswith("http"): name = self.model_name else: name = modelloader.friendly_name(file) @@ -118,7 +118,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def load_model(self, path: str): device = devices.get_device_for('scunet') - if "http" in path: + if path.startswith("http"): filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index a5b0e2eb..4551761d 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -27,7 +27,7 @@ class UpscalerSwinIR(Upscaler): scalers = [] model_files = self.find_models(ext_filter=[".pt", ".pth"]) for model in model_files: - if "http" in model: + if model.startswith("http"): name = self.model_name else: name = modelloader.friendly_name(model) @@ -48,7 +48,7 @@ class UpscalerSwinIR(Upscaler): return img def load_model(self, path, scale=4): - if "http" in path: + if path.startswith("http"): filename = modelloader.load_file_from_url( url=path, model_dir=self.model_download_path, diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index f1a98c07..0666a2c2 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -133,7 +133,7 @@ class UpscalerESRGAN(Upscaler): scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) scalers.append(scaler_data) for file in model_paths: - if "http" in file: + if file.startswith("http"): name = self.model_name else: name = modelloader.friendly_name(file) @@ -150,7 +150,7 @@ class UpscalerESRGAN(Upscaler): return img def load_model(self, path: str): - if "http" in path: + if path.startswith("http"): filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index e239a09d..804fb53d 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -25,7 +25,7 @@ def gfpgann(): return None models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN") - if len(models) == 1 and "http" in models[0]: + if len(models) == 1 and models[0].startswith("http"): model_file = models[0] elif len(models) != 0: latest_file = max(models, key=os.path.getctime) diff --git a/modules/modelloader.py b/modules/modelloader.py index a69c8a4f..b2f0bb71 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -82,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None def friendly_name(file: str): - if "http" in file: + if file.startswith("http"): file = urlparse(file).path file = os.path.basename(file) -- cgit v1.2.3 From e3a973a68df3cfe13039dae33d19cf2c02a741e0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:45:07 +0300 Subject: Add TODO comments to sus model loads --- extensions-builtin/ScuNET/scripts/scunet_model.py | 1 + modules/esrgan_model.py | 1 + 2 files changed, 2 insertions(+) (limited to 'modules') diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index 64f50829..da74a829 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -119,6 +119,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): def load_model(self, path: str): device = devices.get_device_for('scunet') if path.startswith("http"): + # TODO: this doesn't use `path` at all? filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 0666a2c2..a20e8d91 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -151,6 +151,7 @@ class UpscalerESRGAN(Upscaler): def load_model(self, path: str): if path.startswith("http"): + # TODO: this doesn't use `path` at all? filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, -- cgit v1.2.3 From bf67a5dcf44c3dbd88d1913478d4e02477915f33 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 10:38:51 +0300 Subject: Upscaler.load_model: don't return None, just use exceptions --- extensions-builtin/LDSR/scripts/ldsr_model.py | 13 +++----- extensions-builtin/ScuNET/scripts/scunet_model.py | 16 ++++----- extensions-builtin/SwinIR/scripts/swinir_model.py | 40 +++++++++++------------ modules/esrgan_model.py | 14 ++++---- modules/realesrgan_model.py | 33 +++++++++---------- 5 files changed, 52 insertions(+), 64 deletions(-) (limited to 'modules') diff --git a/extensions-builtin/LDSR/scripts/ldsr_model.py b/extensions-builtin/LDSR/scripts/ldsr_model.py index bf9b6de2..bd78dece 100644 --- a/extensions-builtin/LDSR/scripts/ldsr_model.py +++ b/extensions-builtin/LDSR/scripts/ldsr_model.py @@ -46,16 +46,13 @@ class UpscalerLDSR(Upscaler): yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml") - try: - return LDSR(model, yaml) - except Exception: - errors.report("Error importing LDSR", exc_info=True) - return None + return LDSR(model, yaml) def do_upscale(self, img, path): - ldsr = self.load_model(path) - if ldsr is None: - print("NO LDSR!") + try: + ldsr = self.load_model(path) + except Exception: + errors.report(f"Failed loading LDSR model {path}", exc_info=True) return img ddim_steps = shared.opts.ldsr_steps return ldsr.super_resolution(img, ddim_steps, self.scale) diff --git a/extensions-builtin/ScuNET/scripts/scunet_model.py b/extensions-builtin/ScuNET/scripts/scunet_model.py index da74a829..ffef26b2 100644 --- a/extensions-builtin/ScuNET/scripts/scunet_model.py +++ b/extensions-builtin/ScuNET/scripts/scunet_model.py @@ -1,4 +1,3 @@ -import os.path import sys import PIL.Image @@ -8,7 +7,7 @@ from tqdm import tqdm import modules.upscaler from modules import devices, modelloader, script_callbacks, errors -from scunet_model_arch import SCUNet as net +from scunet_model_arch import SCUNet from modules.modelloader import load_file_from_url from modules.shared import opts @@ -88,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler): torch.cuda.empty_cache() - model = self.load_model(selected_file) - if model is None: - print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr) + try: + model = self.load_model(selected_file) + except Exception as e: + print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr) return img device = devices.get_device_for('scunet') @@ -123,11 +123,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler): filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth") else: filename = path - if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None: - print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr) - return None - - model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) + model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64) model.load_state_dict(torch.load(filename), strict=True) model.eval() for _, v in model.named_parameters(): diff --git a/extensions-builtin/SwinIR/scripts/swinir_model.py b/extensions-builtin/SwinIR/scripts/swinir_model.py index 4551761d..3ce622d9 100644 --- a/extensions-builtin/SwinIR/scripts/swinir_model.py +++ b/extensions-builtin/SwinIR/scripts/swinir_model.py @@ -1,4 +1,4 @@ -import os +import sys import numpy as np import torch @@ -7,8 +7,8 @@ from tqdm import tqdm from modules import modelloader, devices, script_callbacks, shared from modules.shared import opts, state -from swinir_model_arch import SwinIR as net -from swinir_model_arch_v2 import Swin2SR as net2 +from swinir_model_arch import SwinIR +from swinir_model_arch_v2 import Swin2SR from modules.upscaler import Upscaler, UpscalerData @@ -36,8 +36,10 @@ class UpscalerSwinIR(Upscaler): self.scalers = scalers def do_upscale(self, img, model_file): - model = self.load_model(model_file) - if model is None: + try: + model = self.load_model(model_file) + except Exception as e: + print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr) return img model = model.to(device_swinir, dtype=devices.dtype) img = upscale(img, model) @@ -56,25 +58,23 @@ class UpscalerSwinIR(Upscaler): ) else: filename = path - if filename is None or not os.path.exists(filename): - return None if filename.endswith(".v2.pth"): - model = net2( - upscale=scale, - in_chans=3, - img_size=64, - window_size=8, - img_range=1.0, - depths=[6, 6, 6, 6, 6, 6], - embed_dim=180, - num_heads=[6, 6, 6, 6, 6, 6], - mlp_ratio=2, - upsampler="nearest+conv", - resi_connection="1conv", + model = Swin2SR( + upscale=scale, + in_chans=3, + img_size=64, + window_size=8, + img_range=1.0, + depths=[6, 6, 6, 6, 6, 6], + embed_dim=180, + num_heads=[6, 6, 6, 6, 6, 6], + mlp_ratio=2, + upsampler="nearest+conv", + resi_connection="1conv", ) params = None else: - model = net( + model = SwinIR( upscale=scale, in_chans=3, img_size=64, diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a20e8d91..02a1727d 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,4 +1,4 @@ -import os +import sys import numpy as np import torch @@ -6,9 +6,8 @@ from PIL import Image import modules.esrgan_model_arch as arch from modules import modelloader, images, devices -from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts - +from modules.upscaler import Upscaler, UpscalerData def mod2normal(state_dict): @@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler): self.scalers.append(scaler_data) def do_upscale(self, img, selected_model): - model = self.load_model(selected_model) - if model is None: + try: + model = self.load_model(selected_model) + except Exception as e: + print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) return img model.to(devices.device_esrgan) img = esrgan_upscale(model, img) @@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler): ) else: filename = path - if not os.path.exists(filename) or filename is None: - print(f"Unable to load {self.model_path} from {filename}") - return None state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 0d9c2e48..0700b853 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -9,7 +9,6 @@ from modules.shared import cmd_opts, opts from modules import modelloader, errors - class UpscalerRealESRGAN(Upscaler): def __init__(self, path): self.name = "RealESRGAN" @@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler): if not self.enable: return img - info = self.load_model(path) - if not os.path.exists(info.local_data_path): - print(f"Unable to load RealESRGAN model: {info.name}") + try: + info = self.load_model(path) + except Exception: + errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True) return img upsampler = RealESRGANer( @@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler): return image def load_model(self, path): - try: - info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None) - - if info is None: - print(f"Unable to find model info: {path}") - return None - - if info.local_data_path.startswith("http"): - info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path) - - return info - except Exception: - errors.report("Error making Real-ESRGAN models list", exc_info=True) - return None + for scaler in self.scalers: + if scaler.data_path == path: + if scaler.local_data_path.startswith("http"): + scaler.local_data_path = modelloader.load_file_from_url( + scaler.data_path, + model_dir=self.model_download_path, + ) + if not os.path.exists(scaler.local_data_path): + raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}") + return scaler + raise ValueError(f"Unable to find model info: {path}") def load_models(self, _): return get_realesrgan_models(self) -- cgit v1.2.3