diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-19 04:59:39 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-19 04:59:39 +0000 |
commit | 0a334b447ff0c41519bb9e280050736913ad9cf8 (patch) | |
tree | e27963f76b7357ff0cb7b2c3fdcb720ab64f0e50 /modules/modelloader.py | |
parent | 6094310704f4b3853bfa5d05d9c1ace58b2deee7 (diff) | |
parent | c2b975485708791b29d44d79ee1a48d3abd838b7 (diff) | |
download | stable-diffusion-webui-gfx803-0a334b447ff0c41519bb9e280050736913ad9cf8.tar.gz stable-diffusion-webui-gfx803-0a334b447ff0c41519bb9e280050736913ad9cf8.tar.bz2 stable-diffusion-webui-gfx803-0a334b447ff0c41519bb9e280050736913ad9cf8.zip |
Merge branch 'dev' into allow-no-venv-install
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r-- | modules/modelloader.py | 31 |
1 files changed, 27 insertions, 4 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py index 75f01247..098bcb79 100644 --- a/modules/modelloader.py +++ b/modules/modelloader.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import os import shutil import importlib @@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale from modules.paths import script_path, models_path +def load_file_from_url( + url: str, + *, + model_dir: str, + progress: bool = True, + file_name: str | None = None, +) -> str: + """Download a file from `url` into `model_dir`, using the file present if possible. + + Returns the path to the downloaded file. + """ + os.makedirs(model_dir, exist_ok=True) + if not file_name: + parts = urlparse(url) + file_name = os.path.basename(parts.path) + cached_file = os.path.abspath(os.path.join(model_dir, file_name)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + from torch.hub import download_url_to_file + download_url_to_file(url, cached_file, progress=progress) + return cached_file + + def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list: """ A one-and done loader to try finding the desired models in specified directories. @@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None if model_url is not None and len(output) == 0: if download_name is not None: - from basicsr.utils.download_util import load_file_from_url - dl = load_file_from_url(model_url, places[0], True, download_name) - output.append(dl) + output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name)) else: output.append(model_url) @@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None def friendly_name(file: str): - if "http" in file: + if file.startswith("http"): file = urlparse(file).path file = os.path.basename(file) |