diff options
author | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-11-03 04:10:53 +0000 |
---|---|---|
committer | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-11-19 04:41:41 +0000 |
commit | abc1e79a5da24a1ea0f4bceedcdf225f32010aa8 (patch) | |
tree | d6a840a8d7af4c5fdc86e9bff53cb844be98c9ea | |
parent | 8ab4927452b04dcd30847eaf92ea7a9f3b9c74e1 (diff) | |
download | stable-diffusion-webui-gfx803-abc1e79a5da24a1ea0f4bceedcdf225f32010aa8.tar.gz stable-diffusion-webui-gfx803-abc1e79a5da24a1ea0f4bceedcdf225f32010aa8.tar.bz2 stable-diffusion-webui-gfx803-abc1e79a5da24a1ea0f4bceedcdf225f32010aa8.zip |
Fix base VAE caching was done after loading VAE, also add safeguard
-rw-r--r-- | modules/sd_models.py | 1 | ||||
-rw-r--r-- | modules/sd_vae.py | 19 |
2 files changed, 9 insertions, 11 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..e4dba62c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file
model.sd_checkpoint_info = checkpoint_info
+ sd_vae.clear_loaded_vae()
sd_vae.load_vae(model, vae_file)
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 7a79239f..dd69a5e6 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -15,7 +15,7 @@ vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} -default_vae_dict = {"auto": "auto", "None": "None"} +default_vae_dict = {"auto": "auto", "None": None, None: None} default_vae_list = ["auto", "None"] @@ -39,6 +39,7 @@ def get_base_vae(model): def store_base_vae(model): global base_vae, checkpoint_info if checkpoint_info != model.sd_checkpoint_info: + assert not loaded_vae_file, "Trying to store non-base VAE!" base_vae = model.first_stage_model.state_dict().copy() checkpoint_info = model.sd_checkpoint_info @@ -50,9 +51,11 @@ def delete_base_vae(): def restore_base_vae(model): + global loaded_vae_file if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: print("Restoring base VAE") load_vae_dict(model, base_vae) + loaded_vae_file = None delete_base_vae() @@ -140,10 +143,10 @@ def load_vae(model, vae_file=None): if vae_file: print(f"Loading VAE weights from: {vae_file}") + store_base_vae(model) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} load_vae_dict(model, vae_dict_1) - store_base_vae(model) # If vae used is not in dict, update it # It will be removed on refresh though @@ -157,15 +160,6 @@ def load_vae(model, vae_file=None): loaded_vae_file = vae_file - """ - # Save current VAE to VAE settings, maybe? will it work? - if save_settings: - if vae_file is None: - vae_opt = "None" - - # shared.opts.sd_vae = vae_opt - """ - first_load = False @@ -174,6 +168,9 @@ def load_vae_dict(model, vae_dict_1): model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) +def clear_loaded_vae(): + global loaded_vae_file + loaded_vae_file = None def reload_vae_weights(sd_model=None, vae_file="auto"): from modules import lowvram, devices, sd_hijack |