aboutsummaryrefslogtreecommitdiffstats
path: root/modules/modelloader.py
diff options
context:
space:
mode:
authorBeinsezii <39478211+Beinsezii@users.noreply.github.com>2023-06-27 22:29:47 +0000
committerGitHub <noreply@github.com>2023-06-27 22:29:47 +0000
commit9d8af4bd6aaf09b8a94dc10dd5e99c82e23dec38 (patch)
treed86d470db2289a31e589a7bd777a38784605313c /modules/modelloader.py
parent1d7c51fb9f757b5dcdc506f8fc003e6047151567 (diff)
parentfab73f2e7d388ca99cdb3d5de7f36c0b9a1a3b1c (diff)
downloadstable-diffusion-webui-gfx803-9d8af4bd6aaf09b8a94dc10dd5e99c82e23dec38.tar.gz
stable-diffusion-webui-gfx803-9d8af4bd6aaf09b8a94dc10dd5e99c82e23dec38.tar.bz2
stable-diffusion-webui-gfx803-9d8af4bd6aaf09b8a94dc10dd5e99c82e23dec38.zip
Merge branch 'AUTOMATIC1111:dev' into dev
Diffstat (limited to 'modules/modelloader.py')
-rw-r--r--modules/modelloader.py34
1 files changed, 28 insertions, 6 deletions
diff --git a/modules/modelloader.py b/modules/modelloader.py
index be23071a..098bcb79 100644
--- a/modules/modelloader.py
+++ b/modules/modelloader.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
import os
import shutil
import importlib
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path
+def load_file_from_url(
+ url: str,
+ *,
+ model_dir: str,
+ progress: bool = True,
+ file_name: str | None = None,
+) -> str:
+ """Download a file from `url` into `model_dir`, using the file present if possible.
+
+ Returns the path to the downloaded file.
+ """
+ os.makedirs(model_dir, exist_ok=True)
+ if not file_name:
+ parts = urlparse(url)
+ file_name = os.path.basename(parts.path)
+ cached_file = os.path.abspath(os.path.join(model_dir, file_name))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ from torch.hub import download_url_to_file
+ download_url_to_file(url, cached_file, progress=progress)
+ return cached_file
+
+
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
"""
A one-and done loader to try finding the desired models in specified directories.
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
if model_url is not None and len(output) == 0:
if download_name is not None:
- from basicsr.utils.download_util import load_file_from_url
- dl = load_file_from_url(model_url, places[0], True, download_name)
- output.append(dl)
+ output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
else:
output.append(model_url)
@@ -59,7 +82,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
def friendly_name(file: str):
- if "http" in file:
+ if file.startswith("http"):
file = urlparse(file).path
file = os.path.basename(file)
@@ -95,8 +118,7 @@ def cleanup_models():
def move_files(src_path: str, dest_path: str, ext_filter: str = None):
try:
- if not os.path.exists(dest_path):
- os.makedirs(dest_path)
+ os.makedirs(dest_path, exist_ok=True)
if os.path.exists(src_path):
for file in os.listdir(src_path):
fullpath = os.path.join(src_path, file)