aboutsummaryrefslogtreecommitdiffstats
path: root/modules/sd_models.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-02-19 09:30:58 +0000
committerGitHub <noreply@github.com>2023-02-19 09:30:58 +0000
commitcfc9849f3f64977936769b6479d6b2231ecbfc5b (patch)
tree51b29ea3b255fe43b0fe7560f9f7fdd23a475427 /modules/sd_models.py
parent5afd9e82c3829348c58803cd85b02c87308fffae (diff)
parentd99bd04b3f8c7753e31aa6dea6109785c4bb92c9 (diff)
downloadstable-diffusion-webui-gfx803-cfc9849f3f64977936769b6479d6b2231ecbfc5b.tar.gz
stable-diffusion-webui-gfx803-cfc9849f3f64977936769b6479d6b2231ecbfc5b.tar.bz2
stable-diffusion-webui-gfx803-cfc9849f3f64977936769b6479d6b2231ecbfc5b.zip
Merge branch 'master' into 6866-fix-hires-prompt-matrix
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r--modules/sd_models.py18
1 files changed, 13 insertions, 5 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 300387a9..127e9663 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -59,13 +59,17 @@ class CheckpointInfo:
def calculate_shorthash(self):
self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name)
+ if self.sha256 is None:
+ return
+
self.shorthash = self.sha256[0:10]
if self.shorthash not in self.ids:
- self.ids += [self.shorthash, self.sha256]
- self.register()
+ self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']
+ checkpoints_list.pop(self.title)
self.title = f'{self.name} [{self.shorthash}]'
+ self.register()
return self.shorthash
@@ -101,7 +105,7 @@ def checkpoint_tiles():
def list_models():
checkpoints_list.clear()
checkpoint_alisases.clear()
- model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])
+ model_list = modelloader.load_models(model_path=model_path, model_url="https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.safetensors", command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], download_name="v1-5-pruned-emaonly.safetensors", ext_blacklist=[".vae.ckpt", ".vae.safetensors"])
cmd_ckpt = shared.cmd_opts.ckpt
if os.path.exists(cmd_ckpt):
@@ -158,7 +162,7 @@ def select_checkpoint():
print(f" - directory {model_path}", file=sys.stderr)
if shared.cmd_opts.ckpt_dir is not None:
print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr)
- print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr)
+ print("Can't run without a checkpoint. Find and place a .ckpt or .safetensors file into any of those locations. The program will exit.", file=sys.stderr)
exit(1)
checkpoint_info = next(iter(checkpoints_list.values()))
@@ -350,6 +354,9 @@ def repair_config(sd_config):
sd_config.model.params.unet_config.params.use_fp16 = True
+sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
+sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'
+
def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
from modules import lowvram, sd_hijack
checkpoint_info = checkpoint_info or select_checkpoint()
@@ -370,6 +377,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
+ clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict
timer.record("find config")
@@ -382,7 +390,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
sd_model = None
try:
- with sd_disable_initialization.DisableInitialization():
+ with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
sd_model = instantiate_from_config(sd_config.model)
except Exception as e:
pass