diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-19 07:39:51 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-19 07:39:51 +0000 |
commit | 0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135 (patch) | |
tree | 0e81a16c42f716c704d6aa63458f7c3c1894c56e /modules/sd_vae.py | |
parent | c7e50425f63c07242068f8dcccce70a4ef28a17f (diff) | |
download | stable-diffusion-webui-gfx803-0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135.tar.gz stable-diffusion-webui-gfx803-0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135.tar.bz2 stable-diffusion-webui-gfx803-0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135.zip |
allow baking in VAE in checkpoint merger tab
do not save config if it's the default for checkpoint merger tab
change file naming scheme for checkpoint merger tab
allow just saving A without any merging for checkpoint merger tab
some stylistic changes for UI in checkpoint merger tab
Diffstat (limited to 'modules/sd_vae.py')
-rw-r--r-- | modules/sd_vae.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index da1bf15c..4ce238b8 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file): return None, None +def load_vae_dict(filename, map_location): + vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location) + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + return vae_dict_1 + + def load_vae(model, vae_file=None, vae_source="from unknown source"): global vae_dict, loaded_vae_file # save_settings = False @@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) - vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} + vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location) _load_vae_dict(model, vae_dict_1) if cache_enabled: |