aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-05 04:47:34 +0000
committerGitHub <noreply@github.com>2023-08-05 04:47:34 +0000
commit0ae2767ae6bb775de448b0d8cda1806edb2aef67 (patch)
tree16057eb7ffc9e081652f68d6ffbf350fa61db1de /modules/sd_models.py
parente64263653a3cdce0a46d0578d08dcc962865441f (diff)
parentc09bc2c60856ca1ab2243386176badf909affdbe (diff)
downloadstable-diffusion-webui-gfx803-0ae2767ae6bb775de448b0d8cda1806edb2aef67.tar.gz
stable-diffusion-webui-gfx803-0ae2767ae6bb775de448b0d8cda1806edb2aef67.tar.bz2
stable-diffusion-webui-gfx803-0ae2767ae6bb775de448b0d8cda1806edb2aef67.zip
Merge pull request #12181 from AUTOMATIC1111/hires_checkpoint
Hires fix change checkpoint
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py27
1 files changed, 17 insertions, 10 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index ba15b451..1608b37f 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -67,6 +67,7 @@ class CheckpointInfo:
self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+ self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
@@ -87,6 +88,7 @@ class CheckpointInfo:
checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]'
+ self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
self.register()
return self.shorthash
@@ -107,14 +109,8 @@ def setup_model():
enable_midas_autodownload()
-def checkpoint_tiles():
- def convert(name):
- return int(name) if name.isdigit() else name.lower()
-
- def alphanumeric_key(key):
- return [convert(c) for c in re.split('([0-9]+)', key)]
-
- return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+def checkpoint_tiles(use_short=False):
+ return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
def list_models():
@@ -137,11 +133,14 @@ def list_models():
elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
- for filename in sorted(model_list, key=str.lower):
+ for filename in model_list:
checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register()
+re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
+
+
def get_closet_checkpoint_match(search_string):
checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
@@ -151,6 +150,11 @@ def get_closet_checkpoint_match(search_string):
if found:
return found[0]
+ search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
+ found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
+
return None
@@ -585,7 +589,10 @@ def reload_model_weights(sd_model=None, info=None):
timer.record("find config")
if sd_model is None or checkpoint_config != sd_model.used_config:
- del sd_model
+ if sd_model is not None:
+ sd_model.to(device="meta")
+
+ devices.torch_gc()
load_model(checkpoint_info, already_loaded_state_dict=state_dict)
return model_data.sd_model