diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-06-27 06:20:49 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-27 06:20:49 +0000 |
commit | 3cd4fd51ef916aba8b978490569a5da82795a839 (patch) | |
tree | cdae174b0c146f11bbda49f52bbcb389cc2f9893 /modules | |
parent | d4f9250c5aef0fd3afb6ed8a6bc3515fa2fcf635 (diff) | |
parent | 2667f47ffbf7c641a7e77abbdddf5e81bf144199 (diff) | |
download | stable-diffusion-webui-gfx803-3cd4fd51ef916aba8b978490569a5da82795a839.tar.gz stable-diffusion-webui-gfx803-3cd4fd51ef916aba8b978490569a5da82795a839.tar.bz2 stable-diffusion-webui-gfx803-3cd4fd51ef916aba8b978490569a5da82795a839.zip |
Merge pull request #10823 from akx/model-loady
Upscaler model loading cleanup
Diffstat (limited to 'modules')
-rw-r--r-- | modules/esrgan_model.py | 23 | ||||
-rw-r--r-- | modules/gfpgan_model.py | 2 | ||||
-rw-r--r-- | modules/modelloader.py | 31 | ||||
-rw-r--r-- | modules/realesrgan_model.py | 33 |
4 files changed, 53 insertions, 36 deletions
diff --git a/modules/esrgan_model.py b/modules/esrgan_model.py index 2fced999..02a1727d 100644 --- a/modules/esrgan_model.py +++ b/modules/esrgan_model.py @@ -1,15 +1,13 @@ -import os
+import sys
import numpy as np
import torch
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
-from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts
-
+from modules.upscaler import Upscaler, UpscalerData
def mod2normal(state_dict):
@@ -134,7 +132,7 @@ class UpscalerESRGAN(Upscaler): scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
scalers.append(scaler_data)
for file in model_paths:
- if "http" in file:
+ if file.startswith("http"):
name = self.model_name
else:
name = modelloader.friendly_name(file)
@@ -143,26 +141,25 @@ class UpscalerESRGAN(Upscaler): self.scalers.append(scaler_data)
def do_upscale(self, img, selected_model):
- model = self.load_model(selected_model)
- if model is None:
+ try:
+ model = self.load_model(selected_model)
+ except Exception as e:
+ print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
return img
model.to(devices.device_esrgan)
img = esrgan_upscale(model, img)
return img
def load_model(self, path: str):
- if "http" in path:
- filename = load_file_from_url(
+ if path.startswith("http"):
+ # TODO: this doesn't use `path` at all?
+ filename = modelloader.load_file_from_url(
url=self.model_url,
model_dir=self.model_download_path,
file_name=f"{self.model_name}.pth",
- progress=True,
)
else:
filename = path
- if not os.path.exists(filename) or filename is None:
- print(f"Unable to load {self.model_path} from {filename}")
- return None
state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)
diff --git a/modules/gfpgan_model.py b/modules/gfpgan_model.py index 6ecd295c..8e0f13bd 100644 --- a/modules/gfpgan_model.py +++ b/modules/gfpgan_model.py @@ -25,7 +25,7 @@ def gfpgann(): return None
models = modelloader.load_models(model_path, model_url, user_path, ext_filter="GFPGAN")
- if len(models) == 1 and "http" in models[0]:
+ if len(models) == 1 and models[0].startswith("http"):
model_file = models[0]
elif len(models) != 0:
latest_file = max(models, key=os.path.getctime)
diff --git a/modules/modelloader.py b/modules/modelloader.py index 75f01247..098bcb79 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import importlib @@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale from modules.paths import script_path, models_path +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: str | None = None, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + + Returns the path to the downloaded file. + """ + os.makedirs(model_dir, exist_ok=True) + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + from torch.hub import download_url_to_file + download_url_to_file(url, cached_file, progress=progress) + return cached_file + + def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if model_url is not None and len(output) == 0: if download_name is not None: - from basicsr.utils.download_util import load_file_from_url - dl = load_file_from_url(model_url, places[0], True, download_name) - output.append(dl) + output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name)) else: output.append(model_url) @@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None def friendly_name(file: str): - if "http" in file: + if file.startswith("http"): file = urlparse(file).path file = os.path.basename(file) diff --git a/modules/realesrgan_model.py b/modules/realesrgan_model.py index 2d27b321..0700b853 100644 --- a/modules/realesrgan_model.py +++ b/modules/realesrgan_model.py @@ -2,7 +2,6 @@ import os import numpy as np
from PIL import Image
-from basicsr.utils.download_util import load_file_from_url
from realesrgan import RealESRGANer
from modules.upscaler import Upscaler, UpscalerData
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler): if not self.enable:
return img
- info = self.load_model(path)
- if not os.path.exists(info.local_data_path):
- print(f"Unable to load RealESRGAN model: {info.name}")
+ try:
+ info = self.load_model(path)
+ except Exception:
+ errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
return img
upsampler = RealESRGANer(
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler): return image
def load_model(self, path):
- try:
- info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)
-
- if info is None:
- print(f"Unable to find model info: {path}")
- return None
-
- if info.local_data_path.startswith("http"):
- info.local_data_path = load_file_from_url(url=info.data_path, model_dir=self.model_download_path, progress=True)
-
- return info
- except Exception:
- errors.report("Error making Real-ESRGAN models list", exc_info=True)
- return None
+ for scaler in self.scalers:
+ if scaler.data_path == path:
+ if scaler.local_data_path.startswith("http"):
+ scaler.local_data_path = modelloader.load_file_from_url(
+ scaler.data_path,
+ model_dir=self.model_download_path,
+ )
+ if not os.path.exists(scaler.local_data_path):
+ raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
+ return scaler
+ raise ValueError(f"Unable to find model info: {path}")
def load_models(self, _):
return get_realesrgan_models(self)
|