diff options
author | Uminosachi <49424133+Uminosachi@users.noreply.github.com> | 2023-08-20 06:00:14 +0000 |
---|---|---|
committer | Uminosachi <49424133+Uminosachi@users.noreply.github.com> | 2023-08-20 06:00:14 +0000 |
commit | 042e1d5d0b1fc0bfd358e3a90db7d163934bd238 (patch) | |
tree | b6f967d9ab5278b76a217adb132756e407371446 /modules/sd_models.py | |
parent | 9d2299ed0bd6c81cae8a7ba4ca22d6a14fb27bef (diff) | |
download | stable-diffusion-webui-gfx803-042e1d5d0b1fc0bfd358e3a90db7d163934bd238.tar.gz stable-diffusion-webui-gfx803-042e1d5d0b1fc0bfd358e3a90db7d163934bd238.tar.bz2 stable-diffusion-webui-gfx803-042e1d5d0b1fc0bfd358e3a90db7d163934bd238.zip |
Fix SD VAE switch error after model reuse
Diffstat (limited to 'modules/sd_models.py')
-rw-r--r-- | modules/sd_models.py | 22 |
1 files changed, 20 insertions, 2 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 685585b1..2c976561 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -462,6 +462,7 @@ class SdModelData: def __init__(self):
self.sd_model = None
self.loaded_sd_models = []
+ self.loaded_vae_states = {}
self.was_loaded_at_least_once = False
self.lock = threading.Lock()
@@ -485,16 +486,27 @@ class SdModelData: return self.sd_model
- def set_sd_model(self, v):
+ def set_sd_model(self, v, already_loaded=False):
self.sd_model = v
+ if already_loaded:
+ sd_vae_state = self.loaded_vae_states.get(v.sd_model_hash, {})
+ sd_vae.base_vae = sd_vae_state.get("base_vae", None)
+ sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None)
+ sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None)
try:
self.loaded_sd_models.remove(v)
+ self.loaded_vae_states.pop(v.sd_model_hash, {}).clear()
except ValueError:
pass
if v is not None:
self.loaded_sd_models.insert(0, v)
+ self.loaded_vae_states[v.sd_model_hash] = dict(
+ base_vae=sd_vae.base_vae,
+ loaded_vae_file=sd_vae.loaded_vae_file,
+ checkpoint_info=sd_vae.checkpoint_info,
+ )
model_data = SdModelData()
@@ -649,6 +661,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
model_data.loaded_sd_models.pop()
+ model_data.loaded_vae_states.pop(loaded_model.sd_model_hash, {}).clear()
send_model_to_trash(loaded_model)
timer.record("send model to trash")
@@ -660,7 +673,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): send_model_to_device(already_loaded)
timer.record("send model to device")
- model_data.set_sd_model(already_loaded)
+ model_data.set_sd_model(already_loaded, already_loaded=True)
if not SkipWritingToConfig.skip:
shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
@@ -678,6 +691,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): sd_model = model_data.loaded_sd_models.pop()
model_data.sd_model = sd_model
+ sd_vae_state = model_data.loaded_vae_states.pop(sd_model.sd_model_hash, {})
+ sd_vae.base_vae = sd_vae_state.get("base_vae", None)
+ sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None)
+ sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None)
+
print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
return sd_model
else:
|