aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_models.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py56
1 files changed, 17 insertions, 39 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 4b9000a4..caa85d5e 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -12,10 +12,10 @@ from modules import shared, modelloader
from modules.paths import models_path
model_dir = "Stable-diffusion"
-model_path = os.path.join(models_path, model_dir)
+model_path = os.path.abspath(os.path.join(models_path, model_dir))
model_name = "sd-v1-4.ckpt"
model_url = "https://drive.yerf.org/wl/?id=EBfTrmcCCUAGaQBXVIj5lJmEhjoP1tgl&mode=grid&download=1"
-user_dir = None
+user_dir: (str | None) = None
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
@@ -30,26 +30,8 @@ except Exception:
pass
-def modeltitle(path, h):
- abspath = os.path.abspath(path)
-
- if abspath.startswith(model_dir):
- name = abspath.replace(model_dir, '')
- else:
- name = os.path.basename(path)
-
- if name.startswith("\\") or name.startswith("/"):
- name = name[1:]
-
- return f'{name} [{h}]'
-
-
def setup_model(dirname):
- global model_path
- global model_name
- global model_url
global user_dir
- global model_list
user_dir = dirname
if not os.path.exists(model_path):
os.makedirs(model_path)
@@ -62,21 +44,16 @@ def checkpoint_tiles():
def list_models():
- global model_path
- global model_url
- global model_name
- global user_dir
checkpoints_list.clear()
- model_list = modelloader.load_models(model_path=model_path,model_url=model_url,command_path= user_dir,
- ext_filter=[".ckpt"], download_name=model_name)
- print(f"Model list: {model_list}")
- model_dir = os.path.abspath(model_path)
+ model_list = modelloader.load_models(model_path=model_path, model_url=model_url, command_path=user_dir, ext_filter=[".ckpt"], download_name=model_name)
- def modeltitle(path, h):
+ def modeltitle(path, shorthash):
abspath = os.path.abspath(path)
- if abspath.startswith(model_dir):
- name = abspath.replace(model_dir, '')
+ if user_dir is not None and abspath.startswith(user_dir):
+ name = abspath.replace(user_dir, '')
+ elif abspath.startswith(model_path):
+ name = abspath.replace(model_path, '')
else:
name = os.path.basename(path)
@@ -85,29 +62,30 @@ def list_models():
shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
- return f'{name} [{h}]', shortname
+ return f'{name} [{shorthash}]', shortname
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
- title, model_name = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
+ title, short_model_name = modeltitle(cmd_ckpt, h)
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
for filename in model_list:
h = model_hash(filename)
- title, model_name = modeltitle(filename, h)
- checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
+ title, short_model_name = modeltitle(filename, h)
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)
+
def get_closet_checkpoint_match(searchString):
applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title))
- if len(applicable)>0:
+ if len(applicable) > 0:
return applicable[0]
return None
+
def model_hash(filename):
try:
- print(f"Opening: {filename}")
with open(filename, "rb") as file:
import hashlib
m = hashlib.sha256()
@@ -128,7 +106,7 @@ def select_checkpoint():
if len(checkpoints_list) == 0:
print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr)
print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr)
- print(f" - directory {os.path.abspath(shared.cmd_opts.stablediffusion_models_path)}", file=sys.stderr)
+ print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)