diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-21 04:10:19 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-21 04:10:19 +0000 |
commit | 5a3fe7a8d1c554982fb3226bda3ee6c4b10bb56c (patch) | |
tree | 01aae591c92fa1e4f0b57a488cc30fc09f2b63a8 | |
parent | 42b72fe2463bc06a97935bc7a7770a9d562269d8 (diff) | |
parent | be301f224d26ac4363ce3bd8bcb510b00bd6db27 (diff) | |
download | stable-diffusion-webui-gfx803-5a3fe7a8d1c554982fb3226bda3ee6c4b10bb56c.tar.gz stable-diffusion-webui-gfx803-5a3fe7a8d1c554982fb3226bda3ee6c4b10bb56c.tar.bz2 stable-diffusion-webui-gfx803-5a3fe7a8d1c554982fb3226bda3ee6c4b10bb56c.zip |
Merge pull request #12685 from Uminosachi/fix-vae-mismatch
Fix SD VAE switch error after model reuse
-rw-r--r-- | modules/sd_models.py | 13 | ||||
-rw-r--r-- | modules/sd_vae.py | 4 |
2 files changed, 14 insertions, 3 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 685585b1..27d15e66 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -485,8 +485,12 @@ class SdModelData: return self.sd_model
- def set_sd_model(self, v):
+ def set_sd_model(self, v, already_loaded=False):
self.sd_model = v
+ if already_loaded:
+ sd_vae.base_vae = getattr(v, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = v.sd_checkpoint_info
try:
self.loaded_sd_models.remove(v)
@@ -660,13 +664,14 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): send_model_to_device(already_loaded)
timer.record("send model to device")
- model_data.set_sd_model(already_loaded)
+ model_data.set_sd_model(already_loaded, already_loaded=True)
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256
print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
+ sd_vae.reload_vae_weights(already_loaded)
return model_data.sd_model
elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")
@@ -678,6 +683,10 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
+ sd_vae.base_vae = getattr(sd_model, "base_vae", None)
+ sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None)
+ sd_vae.checkpoint_info = sd_model.sd_checkpoint_info
+
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index dbade067..ee118656 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -192,7 +192,7 @@ def load_vae_dict(filename, map_location): def load_vae(model, vae_file=None, vae_source="from unknown source"): - global vae_dict, loaded_vae_file + global vae_dict, base_vae, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -230,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): restore_base_vae(model) loaded_vae_file = vae_file + model.base_vae = base_vae + model.loaded_vae_file = loaded_vae_file # don't call this from outside |