diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-05 04:52:29 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-05 04:52:29 +0000 |
commit | 22ecb78b51f7e6f0234cbc0efbde4ee9a2dc466f (patch) | |
tree | 69e3bf4d53f4113f192116c252a2e410bb5b1f90 /modules/sd_models.py | |
parent | 390bffa81b747a7eb38ac7a0cd6dfb9fcc388151 (diff) | |
parent | 0ae2767ae6bb775de448b0d8cda1806edb2aef67 (diff) | |
download | stable-diffusion-webui-gfx803-22ecb78b51f7e6f0234cbc0efbde4ee9a2dc466f.tar.gz stable-diffusion-webui-gfx803-22ecb78b51f7e6f0234cbc0efbde4ee9a2dc466f.tar.bz2 stable-diffusion-webui-gfx803-22ecb78b51f7e6f0234cbc0efbde4ee9a2dc466f.zip |
Merge branch 'dev' into multiple_loaded_models
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 29 |
1 files changed, 17 insertions, 12 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 3c451a4b..f6051604 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -66,8 +66,9 @@ class CheckpointInfo: self.shorthash = self.sha256[0:10] if self.sha256 else None
self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
+ self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]'
- self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
+ self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
def register(self):
checkpoints_list[self.title] = self
@@ -86,6 +87,7 @@ class CheckpointInfo: checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]'
+ self.short_title = f'{self.name_for_extra} [{self.shorthash}]'
self.register()
return self.shorthash
@@ -106,14 +108,8 @@ def setup_model(): enable_midas_autodownload()
-def checkpoint_tiles():
- def convert(name):
- return int(name) if name.isdigit() else name.lower()
-
- def alphanumeric_key(key):
- return [convert(c) for c in re.split('([0-9]+)', key)]
-
- return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key)
+def checkpoint_tiles(use_short=False):
+ return [x.short_title if use_short else x.title for x in checkpoints_list.values()]
def list_models():
@@ -136,11 +132,14 @@ def list_models(): elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
- for filename in sorted(model_list, key=str.lower):
+ for filename in model_list:
checkpoint_info = CheckpointInfo(filename)
checkpoint_info.register()
+re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$")
+
+
def get_closet_checkpoint_match(search_string):
checkpoint_info = checkpoint_aliases.get(search_string, None)
if checkpoint_info is not None:
@@ -150,6 +149,11 @@ def get_closet_checkpoint_match(search_string): if found:
return found[0]
+ search_string_without_checksum = re.sub(re_strip_checksum, '', search_string)
+ found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title))
+ if found:
+ return found[0]
+
return None
@@ -302,12 +306,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_models_xl.extend_sdxl(model)
model.load_state_dict(state_dict, strict=False)
- del state_dict
timer.record("apply weights to model")
if shared.opts.sd_checkpoint_cache > 0:
# cache newly loaded model
- checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
+ checkpoints_loaded[checkpoint_info] = state_dict
+
+ del state_dict
if shared.cmd_opts.opt_channelslast:
model.to(memory_format=torch.channels_last)
|