aboutsummaryrefslogtreecommitdiffstats
path: root/modules/modelloader.py
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-05-29 06:34:26 +0000
committerAarni Koskela <akx@iki.fi>2023-06-13 09:38:28 +0000
commit89352a2f52c6be51318192cedd86c8a342966a49 (patch)
treefde621a62c1778d10aa1834597ec367de429642f /modules/modelloader.py
parent59419bd64a1581caccaac04dceb66c1c069a2db1 (diff)
downloadstable-diffusion-webui-gfx803-89352a2f52c6be51318192cedd86c8a342966a49.tar.gz
stable-diffusion-webui-gfx803-89352a2f52c6be51318192cedd86c8a342966a49.tar.bz2
stable-diffusion-webui-gfx803-89352a2f52c6be51318192cedd86c8a342966a49.zip
Move `load_file_from_url` to modelloader
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r--modules/modelloader.py29
1 files changed, 26 insertions, 3 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index be23071a..a69c8a4f 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import shutil
import importlib
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path
+def load_file_from_url(
+ url: str,
+ *,
+ model_dir: str,
+ progress: bool = True,
+ file_name: str | None = None,
+) -> str:
+ """Download a file from `url` into `model_dir`, using the file present if possible.
+
+ Returns the path to the downloaded file.
+ """
+ os.makedirs(model_dir, exist_ok=True)
+ if not file_name:
+ parts = urlparse(url)
+ file_name = os.path.basename(parts.path)
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ from torch.hub import download_url_to_file
+ download_url_to_file(url, cached_file, progress=progress)
+ return cached_file
+
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
- from basicsr.utils.download_util import load_file_from_url
- dl = load_file_from_url(model_url, places[0], True, download_name)
- output.append(dl)
+ output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
else:
output.append(model_url)