aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorBernard Maltais <bmaltais@gmail.com>2022-09-28 01:08:07 +0000
committerBernard Maltais <bmaltais@gmail.com>2022-09-28 01:08:07 +0000
commit591c138e32d8a5789053b3ab6f5881aaf8f002bf (patch)
treef1ac7fd1fb6416bf5511550fec53e4c337298536 /modules/sd_models.py
parente258f89080b8ff38f040dc786290da9144860d38 (diff)
downloadstable-diffusion-webui-gfx803-591c138e32d8a5789053b3ab6f5881aaf8f002bf.tar.gz
stable-diffusion-webui-gfx803-591c138e32d8a5789053b3ab6f5881aaf8f002bf.tar.bz2
stable-diffusion-webui-gfx803-591c138e32d8a5789053b3ab6f5881aaf8f002bf.zip
-Add gradio dropdown list to select checkpoints to merge
-Update the name of the model feilds -Update the associated variable names
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index dc81b0dc..9decc911 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -10,7 +10,7 @@ from ldm.util import instantiate_from_config
from modules import shared
-CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash'])
+CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
try:
@@ -45,7 +45,8 @@ def list_models():
if os.path.exists(cmd_ckpt):
h = model_hash(cmd_ckpt)
title = modeltitle(cmd_ckpt, h)
- checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h)
+ model_name = title.rsplit(".",1)[0] # remove extension if present
+ checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, model_name)
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found: {cmd_ckpt}", file=sys.stderr)
@@ -53,7 +54,8 @@ def list_models():
for filename in glob.glob(model_dir + '/**/*.ckpt', recursive=True):
h = model_hash(filename)
title = modeltitle(filename, h)
- checkpoints_list[title] = CheckpointInfo(filename, title, h)
+ model_name = title.rsplit(".",1)[0] # remove extension if present
+ checkpoints_list[title] = CheckpointInfo(filename, title, h, model_name)
def model_hash(filename):