diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-21 14:35:51 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-21 14:35:51 +0000 |
commit | ac0aa2b18efeeb9220a5994c8dd54c7cdda7cc40 (patch) | |
tree | c6898c0d0966c6170c5aa2211860a3ff28a07bfb /modules | |
parent | 3d898044e5e55dca1698e9b5b7d3558b5b78675a (diff) | |
download | stable-diffusion-webui-gfx803-ac0aa2b18efeeb9220a5994c8dd54c7cdda7cc40.tar.gz stable-diffusion-webui-gfx803-ac0aa2b18efeeb9220a5994c8dd54c7cdda7cc40.tar.bz2 stable-diffusion-webui-gfx803-ac0aa2b18efeeb9220a5994c8dd54c7cdda7cc40.zip |
loading SD VAE, see PR #3303
Diffstat (limited to 'modules')
-rw-r--r-- | modules/sd_models.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index b1c91b0d..d99dbce8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -155,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd
+vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
+
+
def load_model_weights(model, checkpoint_info):
checkpoint_file = checkpoint_info.filename
sd_model_hash = checkpoint_info.hash
@@ -186,7 +189,7 @@ def load_model_weights(model, checkpoint_info): if os.path.exists(vae_file):
print(f"Loading VAE weights from: {vae_file}")
vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
- vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
+ vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
model.first_stage_model.load_state_dict(vae_dict)
model.first_stage_model.to(devices.dtype_vae)
|