diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-01 13:55:55 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-01 13:55:55 +0000 |
commit | 0c9b1e7969b2adb39c38d17b481a689483059459 (patch) | |
tree | 721ae5bc4143f21c8034819755977dc4d293640d /modules/sd_models.py | |
parent | 151b8ed3a62714793e2a212ac609a03dda0b1e26 (diff) | |
parent | 6a0d498c8ec5287a75e2a3bc8a4fa79e82e64c18 (diff) | |
download | stable-diffusion-webui-gfx803-0c9b1e7969b2adb39c38d17b481a689483059459.tar.gz stable-diffusion-webui-gfx803-0c9b1e7969b2adb39c38d17b481a689483059459.tar.bz2 stable-diffusion-webui-gfx803-0c9b1e7969b2adb39c38d17b481a689483059459.zip |
Merge branch 'dev' into multiple_loaded_models
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 77195f2f..40a450df 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,8 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config
-from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
+from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
+from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -32,6 +33,8 @@ class CheckpointInfo: self.filename = filename
abspath = os.path.abspath(filename)
+ self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors"
+
if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir):
name = abspath.replace(shared.cmd_opts.ckpt_dir, '')
elif abspath.startswith(model_path):
@@ -42,6 +45,19 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"):
name = name[1:]
+ def read_metadata():
+ metadata = read_metadata_from_safetensors(filename)
+ self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None)
+
+ return metadata
+
+ self.metadata = {}
+ if self.is_safetensors:
+ try:
+ self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata)
+ except Exception as e:
+ errors.display(e, f"reading metadata for {filename}")
+
self.name = name
self.name_for_extra = os.path.splitext(os.path.basename(filename))[0]
self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
@@ -54,15 +70,6 @@ class CheckpointInfo: self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
- self.metadata = {}
-
- _, ext = os.path.splitext(self.filename)
- if ext.lower() == ".safetensors":
- try:
- self.metadata = read_metadata_from_safetensors(filename)
- except Exception as e:
- errors.display(e, f"reading checkpoint metadata: {filename}")
-
def register(self):
checkpoints_list[self.title] = self
for id in self.ids:
@@ -78,7 +85,7 @@ class CheckpointInfo: if self.shorthash not in self.ids:
self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
- checkpoints_list.pop(self.title)
+ checkpoints_list.pop(self.title, None)
self.title = f'{self.name} [{self.shorthash}]'
self.register()
|