diff options
author | Philpax <me@philpax.me> | 2023-01-05 04:00:58 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-05 04:00:58 +0000 |
commit | 83ca8dd0c96e3cc8dd444e8d980c07718dc647ee (patch) | |
tree | cbadd09c9305099db7763532b7aaf909d366d326 /modules/sd_vae.py | |
parent | fa931733f6acc94e058a1d3d4655846e33ae34be (diff) | |
parent | 5f4fa942b8ec3ed3b15a352903489d6f9e6eb46e (diff) | |
download | stable-diffusion-webui-gfx803-83ca8dd0c96e3cc8dd444e8d980c07718dc647ee.tar.gz stable-diffusion-webui-gfx803-83ca8dd0c96e3cc8dd444e8d980c07718dc647ee.tar.bz2 stable-diffusion-webui-gfx803-83ca8dd0c96e3cc8dd444e8d980c07718dc647ee.zip |
Merge branch 'AUTOMATIC1111:master' into fix-sd-arch-switch-in-override-settings
Diffstat (limited to 'modules/sd_vae.py')
-rw-r--r-- | modules/sd_vae.py | 31 |
1 files changed, 25 insertions, 6 deletions
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 3856418e..ac71d62d 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,5 +1,6 @@ import torch import os +import collections from collections import namedtuple from modules import shared, devices, script_callbacks from modules.paths import models_path @@ -30,6 +31,7 @@ base_vae = None loaded_vae_file = None checkpoint_info = None +checkpoints_loaded = collections.OrderedDict() def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: @@ -149,13 +151,30 @@ def load_vae(model, vae_file=None): global first_load, vae_dict, vae_list, loaded_vae_file # save_settings = False + cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 + if vae_file: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") - store_base_vae(model) - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - _load_vae_dict(model, vae_dict_1) + if cache_enabled and vae_file in checkpoints_loaded: + # use vae checkpoint cache + print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + store_base_vae(model) + _load_vae_dict(model, checkpoints_loaded[vae_file]) + else: + assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" + print(f"Loading VAE weights from: {vae_file}") + store_base_vae(model) + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + _load_vae_dict(model, vae_dict_1) + + if cache_enabled: + # cache newly loaded vae + checkpoints_loaded[vae_file] = vae_dict_1.copy() + + # clean up cache if limit is reached + if cache_enabled: + while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model + checkpoints_loaded.popitem(last=False) # LRU # If vae used is not in dict, update it # It will be removed on refresh though |