From 165ab44f03cc17dc3e4c35b3e5976f3d646c7ac7 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 10:18:15 +0300 Subject: Use os.makedirs(..., exist_ok=True) --- modules/modelloader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/modelloader.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index be23071a..75f01247 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -95,8 +95,7 @@ def cleanup_models(): def move_files(src_path: str, dest_path: str, ext_filter: str = None): try: - if not os.path.exists(dest_path): - os.makedirs(dest_path) + os.makedirs(dest_path, exist_ok=True) if os.path.exists(src_path): for file in os.listdir(src_path): fullpath = os.path.join(src_path, file) -- cgit v1.2.3 From 89352a2f52c6be51318192cedd86c8a342966a49 Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:34:26 +0300 Subject: Move `load_file_from_url` to modelloader --- modules/modelloader.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) (limited to 'modules/modelloader.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index be23071a..a69c8a4f 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import importlib @@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale from modules.paths import script_path, models_path +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: str | None = None, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + + Returns the path to the downloaded file. + """ + os.makedirs(model_dir, exist_ok=True) + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + from torch.hub import download_url_to_file + download_url_to_file(url, cached_file, progress=progress) + return cached_file + + def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if model_url is not None and len(output) == 0: if download_name is not None: - from basicsr.utils.download_util import load_file_from_url - dl = load_file_from_url(model_url, places[0], True, download_name) - output.append(dl) + output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name)) else: output.append(model_url) -- cgit v1.2.3 From 0afbc0c2355ead3a0ce7149a6d678f1f2e2fbfee Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 29 May 2023 09:41:36 +0300 Subject: Fix up `if "http" in ...:` to be more sensible startswiths --- modules/modelloader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/modelloader.py') diff --git a/modules/modelloader.py b/modules/modelloader.py index a69c8a4f..b2f0bb71 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -82,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None def friendly_name(file: str): - if "http" in file: + if file.startswith("http"): file = urlparse(file).path file = os.path.basename(file) -- cgit v1.2.3