From f741a98baccae100fcfb40c017b5c35c5cba1b0c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 08:43:42 +0300 Subject: imports cleanup for ruff --- modules/esrgan_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index f4369257..85aa6934 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -6,7 +6,7 @@ from PIL import Image from basicsr.utils.download_util import load_file_from_url import modules.esrgan_model_arch as arch -from modules import shared, modelloader, images, devices +from modules import modelloader, images, devices from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts -- cgit v1.2.3 From a5121e7a0623db328a9462d340d389ed6737374a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 10 May 2023 11:37:18 +0300 Subject: fixes for B007 --- modules/esrgan_model.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 85aa6934..a009eb42 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -16,9 +16,7 @@ def mod2normal(state_dict): # this code is copied from https://github.com/victorca25/iNNfer if 'conv_first.weight' in state_dict: crt_net = {} - items = [] - for k, v in state_dict.items(): - items.append(k) + items = list(state_dict) crt_net['model.0.weight'] = state_dict['conv_first.weight'] crt_net['model.0.bias'] = state_dict['conv_first.bias'] @@ -52,9 +50,7 @@ def resrgan2normal(state_dict, nb=23): if "conv_first.weight" in state_dict and "body.0.rdb1.conv1.weight" in state_dict: re8x = 0 crt_net = {} - items = [] - for k, v in state_dict.items(): - items.append(k) + items = list(state_dict) crt_net['model.0.weight'] = state_dict['conv_first.weight'] crt_net['model.0.bias'] = state_dict['conv_first.bias'] -- cgit v1.2.3 From df6fffb054f8d3444baa887151a4874506a68be1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 19 May 2023 09:09:00 +0300 Subject: change upscalers to download models into user-specified directory (from commandline args) rather than the default models/<...> --- modules/esrgan_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a009eb42..2fced999 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -154,7 +154,7 @@ class UpscalerESRGAN(Upscaler): if "http" in path: filename = load_file_from_url( url=self.model_url, - model_dir=self.model_path, + model_dir=self.model_download_path, file_name=f"{self.model_name}.pth", progress=True, ) -- cgit v1.2.3 From 89352a2f52c6be51318192cedd86c8a342966a49 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:34:26 +0300 Subject: Move `load_file_from_url` to modelloader --- modules/esrgan_model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 2fced999..f1a98c07 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -3,7 +3,6 @@ import os import numpy as np import torch from PIL import Image -from basicsr.utils.download_util import load_file_from_url import modules.esrgan_model_arch as arch from modules import modelloader, images, devices @@ -152,11 +151,10 @@ class UpscalerESRGAN(Upscaler): def load_model(self, path: str): if "http" in path: - filename = load_file_from_url( + filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, file_name=f"{self.model_name}.pth", - progress=True, ) else: filename = path -- cgit v1.2.3 From 0afbc0c2355ead3a0ce7149a6d678f1f2e2fbfee Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:41:36 +0300 Subject: Fix up `if "http" in ...:` to be more sensible startswiths --- modules/esrgan_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index f1a98c07..0666a2c2 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -133,7 +133,7 @@ class UpscalerESRGAN(Upscaler): scaler_data = UpscalerData(self.model_name, self.model_url, self, 4) scalers.append(scaler_data) for file in model_paths: - if "http" in file: + if file.startswith("http"): name = self.model_name else: name = modelloader.friendly_name(file) @@ -150,7 +150,7 @@ class UpscalerESRGAN(Upscaler): return img def load_model(self, path: str): - if "http" in path: + if path.startswith("http"): filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, -- cgit v1.2.3 From e3a973a68df3cfe13039dae33d19cf2c02a741e0 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:45:07 +0300 Subject: Add TODO comments to sus model loads --- modules/esrgan_model.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 0666a2c2..a20e8d91 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -151,6 +151,7 @@ class UpscalerESRGAN(Upscaler): def load_model(self, path: str): if path.startswith("http"): + # TODO: this doesn't use `path` at all? filename = modelloader.load_file_from_url( url=self.model_url, model_dir=self.model_download_path, -- cgit v1.2.3 From bf67a5dcf44c3dbd88d1913478d4e02477915f33 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 10:38:51 +0300 Subject: Upscaler.load_model: don't return None, just use exceptions --- modules/esrgan_model.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) (limited to 'modules/esrgan_model.py') diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index a20e8d91..02a1727d 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,4 +1,4 @@ -import os +import sys import numpy as np import torch @@ -6,9 +6,8 @@ from PIL import Image import modules.esrgan_model_arch as arch from modules import modelloader, images, devices -from modules.upscaler import Upscaler, UpscalerData from modules.shared import opts - +from modules.upscaler import Upscaler, UpscalerData def mod2normal(state_dict): @@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler): self.scalers.append(scaler_data) def do_upscale(self, img, selected_model): - model = self.load_model(selected_model) - if model is None: + try: + model = self.load_model(selected_model) + except Exception as e: + print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr) return img model.to(devices.device_esrgan) img = esrgan_upscale(model, img) @@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler): ) else: filename = path - if not os.path.exists(filename) or filename is None: - print(f"Unable to load {self.model_path} from {filename}") - return None state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None) -- cgit v1.2.3