From cb31abcf58ea1f64266e6d821937eed058c35f4d Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 21:54:31 +0700 Subject: Settings to select VAE --- modules/sd_vae.py | 121 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 modules/sd_vae.py (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py new file mode 100644 index 00000000..82764e55 --- /dev/null +++ b/modules/sd_vae.py @@ -0,0 +1,121 @@ +import torch +import os +from collections import namedtuple +from modules import shared, devices +from modules.paths import models_path +import glob + +model_dir = "Stable-diffusion" +model_path = os.path.abspath(os.path.join(models_path, model_dir)) +vae_dir = "VAE" +vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) + +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} +default_vae_dict = {"auto": "auto", "None": "None"} +default_vae_list = ["auto", "None"] +default_vae_values = [default_vae_dict[x] for x in default_vae_list] +vae_dict = dict(default_vae_dict) +vae_list = list(default_vae_list) +first_load = True + +def get_filename(filepath): + return os.path.splitext(os.path.basename(filepath))[0] + +def refresh_vae_list(vae_path=vae_path, model_path=model_path): + global vae_dict, vae_list + res = {} + candidates = [ + *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), + *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True) + ] + if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): + candidates.append(shared.cmd_opts.vae_path) + for filepath in candidates: + name = get_filename(filepath) + res[name] = filepath + vae_list.clear() + vae_list.extend(default_vae_list) + vae_list.extend(list(res.keys())) + vae_dict.clear() + vae_dict.update(default_vae_dict) + vae_dict.update(res) + return vae_list + +def load_vae(model, checkpoint_file, vae_file="auto"): + global first_load, vae_dict, vae_list + # save_settings = False + + # if vae_file argument is provided, it takes priority + if vae_file and vae_file not in default_vae_list: + if not os.path.isfile(vae_file): + vae_file = "auto" + # save_settings = True + print("VAE provided as function argument doesn't exist") + # for the first load, if vae-path is provided, it takes priority and failure is reported + if first_load and shared.cmd_opts.vae_path is not None: + if os.path.isfile(shared.cmd_opts.vae_path): + vae_file = shared.cmd_opts.vae_path + # save_settings = True + # print("Using VAE provided as command line argument") + else: + print("VAE provided as command line argument doesn't exist") + # else, we load from settings + if vae_file == "auto" and shared.opts.sd_vae is not None: + # if saved VAE settings isn't recognized, fallback to auto + vae_file = vae_dict.get(shared.opts.sd_vae, "auto") + # if VAE selected but not found, fallback to auto + if vae_file not in default_vae_values and not os.path.isfile(vae_file): + vae_file = "auto" + print("Selected VAE doesn't exist") + # vae-path cmd arg takes priority for auto + if vae_file == "auto" and shared.cmd_opts.vae_path is not None: + if os.path.isfile(shared.cmd_opts.vae_path): + vae_file = shared.cmd_opts.vae_path + print("Using VAE provided as command line argument") + # if still not found, try look for ".vae.pt" beside model + model_path = os.path.splitext(checkpoint_file)[0] + if vae_file == "auto": + vae_file_try = model_path + ".vae.pt" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print("Using VAE found beside selected model") + # if still not found, try look for ".vae.ckpt" beside model + if vae_file == "auto": + vae_file_try = model_path + ".vae.ckpt" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print("Using VAE found beside selected model") + # No more fallbacks for auto + if vae_file == "auto": + vae_file = None + # Last check, just because + if vae_file and not os.path.exists(vae_file): + vae_file = None + + if vae_file: + print(f"Loading VAE weights from: {vae_file}") + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + model.first_stage_model.load_state_dict(vae_dict_1) + + # If vae used is not in dict, update it + # It will be removed on refresh though + if vae_file is not None: + vae_opt = get_filename(vae_file) + if vae_opt not in vae_dict: + vae_dict[vae_opt] = vae_file + vae_list.append(vae_opt) + + """ + # Save current VAE to VAE settings, maybe? will it work? + if save_settings: + if vae_file is None: + vae_opt = "None" + + # shared.opts.sd_vae = vae_opt + """ + + first_load = False + model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.3 From e1b2ea6e0012ecc988385fc523d8fb50ea5d6be5 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 22:11:45 +0700 Subject: Change VAE search order and thus priority --- modules/sd_vae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 82764e55..0767b925 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -25,10 +25,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): global vae_dict, vae_list res = {} candidates = [ - *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True) + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), ] if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): candidates.append(shared.cmd_opts.vae_path) -- cgit v1.2.3 From b96d0c4e9ecec3c856b9b4ec795dbd0d34fcac51 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 14:42:28 +0700 Subject: Fix typo from previous commit --- modules/sd_vae.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0767b925..2ce44d5f 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -27,8 +27,8 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): candidates = [ *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True) - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True) ] if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): candidates.append(shared.cmd_opts.vae_path) -- cgit v1.2.3 From 726769da35970f4c100fa7edf11850f9dc059c41 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 15:19:34 +0700 Subject: Checkpoint cache by combination key of checkpoint and vae --- modules/sd_models.py | 27 ++++++++++++++++----------- modules/sd_vae.py | 8 +++++++- 2 files changed, 23 insertions(+), 12 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 91ad4b5e..850f7b7b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -160,11 +160,15 @@ def get_state_dict_from_checkpoint(pl_sd): vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} -def load_model_weights(model, checkpoint_info, force=False): +def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if force or checkpoint_info not in checkpoints_loaded: + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + + checkpoint_key = (checkpoint_info, vae_file) + + if checkpoint_key not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -185,24 +189,25 @@ def load_model_weights(model, checkpoint_info, force=False): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - sd_vae.load_vae(model, checkpoint_file) + sd_vae.load_vae(model, vae_file) model.first_stage_model.to(devices.dtype_vae) if shared.opts.sd_checkpoint_cache > 0: - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + checkpoints_loaded[checkpoint_key] = model.state_dict().copy() while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU else: - print(f"Loading weights [{sd_model_hash}] from cache") - checkpoints_loaded.move_to_end(checkpoint_info) - model.load_state_dict(checkpoints_loaded[checkpoint_info]) + vae_name = sd_vae.get_filename(vae_file) + print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") + checkpoints_loaded.move_to_end(checkpoint_key) + model.load_state_dict(checkpoints_loaded[checkpoint_key]) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info -def load_model(checkpoint_info=None, force=False): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -223,7 +228,7 @@ def load_model(checkpoint_info=None, force=False): do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) - load_model_weights(sd_model, checkpoint_info, force=force) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) @@ -250,7 +255,7 @@ def reload_model_weights(sd_model, info=None, force=False): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - load_model(checkpoint_info, force=force) + load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -260,7 +265,7 @@ def reload_model_weights(sd_model, info=None, force=False): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info, force=force) + load_model_weights(sd_model, checkpoint_info) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 2ce44d5f..e9239326 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -43,7 +43,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): vae_dict.update(res) return vae_list -def load_vae(model, checkpoint_file, vae_file="auto"): +def resolve_vae(checkpoint_file, vae_file="auto"): global first_load, vae_dict, vae_list # save_settings = False @@ -94,6 +94,12 @@ def load_vae(model, checkpoint_file, vae_file="auto"): if vae_file and not os.path.exists(vae_file): vae_file = None + return vae_file + +def load_vae(model, vae_file): + global first_load, vae_dict, vae_list + # save_settings = False + if vae_file: print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) -- cgit v1.2.3 From 056f06d3738c267b1014e6e8e1ef5bd97af1fb45 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Wed, 2 Nov 2022 12:51:46 +0700 Subject: Reload VAE without reloading sd checkpoint --- modules/sd_models.py | 15 ++++---- modules/sd_vae.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++---- webui.py | 4 +-- 3 files changed, 98 insertions(+), 18 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ab85b65..883639d1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -159,15 +159,13 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd -vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - checkpoint_key = (checkpoint_info, vae_file) + checkpoint_key = checkpoint_info if checkpoint_key not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -190,13 +188,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - sd_vae.load_vae(model, vae_file) - model.first_stage_model.to(devices.dtype_vae) - if shared.opts.sd_checkpoint_cache > 0: + # if PR #4035 were to get merged, restore base VAE first before caching checkpoints_loaded[checkpoint_key] = model.state_dict().copy() while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU + else: vae_name = sd_vae.get_filename(vae_file) print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") @@ -207,6 +204,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.load_vae(model, vae_file) + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack @@ -254,14 +253,14 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model=None, info=None, force=False): +def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() if not sd_model: sd_model = shared.sd_model - if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e9239326..78e14e8a 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,26 +1,65 @@ import torch import os from collections import namedtuple -from modules import shared, devices +from modules import shared, devices, script_callbacks from modules.paths import models_path import glob + model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) vae_dir = "VAE" vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) + vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + + default_vae_dict = {"auto": "auto", "None": "None"} default_vae_list = ["auto", "None"] + + default_vae_values = [default_vae_dict[x] for x in default_vae_list] vae_dict = dict(default_vae_dict) vae_list = list(default_vae_list) first_load = True + +base_vae = None +loaded_vae_file = None +checkpoint_info = None + + +def get_base_vae(model): + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: + return base_vae + return None + + +def store_base_vae(model): + global base_vae, checkpoint_info + if checkpoint_info != model.sd_checkpoint_info: + base_vae = model.first_stage_model.state_dict().copy() + checkpoint_info = model.sd_checkpoint_info + + +def delete_base_vae(): + global base_vae, checkpoint_info + base_vae = None + checkpoint_info = None + + +def restore_base_vae(model): + global base_vae, checkpoint_info + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: + load_vae_dict(model, base_vae) + delete_base_vae() + + def get_filename(filepath): return os.path.splitext(os.path.basename(filepath))[0] + def refresh_vae_list(vae_path=vae_path, model_path=model_path): global vae_dict, vae_list res = {} @@ -43,6 +82,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): vae_dict.update(res) return vae_list + def resolve_vae(checkpoint_file, vae_file="auto"): global first_load, vae_dict, vae_list # save_settings = False @@ -96,24 +136,26 @@ def resolve_vae(checkpoint_file, vae_file="auto"): return vae_file -def load_vae(model, vae_file): - global first_load, vae_dict, vae_list + +def load_vae(model, vae_file=None): + global first_load, vae_dict, vae_list, loaded_vae_file # save_settings = False if vae_file: print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - model.first_stage_model.load_state_dict(vae_dict_1) + load_vae_dict(model, vae_dict_1) - # If vae used is not in dict, update it - # It will be removed on refresh though - if vae_file is not None: + # If vae used is not in dict, update it + # It will be removed on refresh though vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file vae_list.append(vae_opt) + loaded_vae_file = vae_file + """ # Save current VAE to VAE settings, maybe? will it work? if save_settings: @@ -124,4 +166,45 @@ def load_vae(model, vae_file): """ first_load = False + + +# don't call this from outside +def load_vae_dict(model, vae_dict_1=None): + if vae_dict_1: + store_base_vae(model) + model.first_stage_model.load_state_dict(vae_dict_1) + else: + restore_base_vae() model.first_stage_model.to(devices.dtype_vae) + + +def reload_vae_weights(sd_model=None, vae_file="auto"): + from modules import lowvram, devices, sd_hijack + + if not sd_model: + sd_model = shared.sd_model + + checkpoint_info = sd_model.sd_checkpoint_info + checkpoint_file = checkpoint_info.filename + vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if loaded_vae_file == vae_file: + return + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) + + sd_hijack.model_hijack.undo_hijack(sd_model) + + load_vae(sd_model, vae_file) + + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) + + print(f"VAE Weights loaded.") + return sd_model diff --git a/webui.py b/webui.py index 7cb4691b..034777a2 100644 --- a/webui.py +++ b/webui.py @@ -81,9 +81,7 @@ def initialize(): modules.sd_vae.refresh_vae_list() modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) - # I don't know what needs to be done to only reload VAE, with all those hijacks callbacks, and lowvram, - # so for now this reloads the whole model too - shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(force=True)), call=False) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) -- cgit v1.2.3 From a5409a6e4bc3eaa9757a7505d4564ad8e0d899ea Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Wed, 2 Nov 2022 14:37:22 +0700 Subject: Save VAE provided by cmd_opts.vae_path --- modules/sd_vae.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 78e14e8a..71e7a6e6 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -78,27 +78,24 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): vae_list.extend(default_vae_list) vae_list.extend(list(res.keys())) vae_dict.clear() - vae_dict.update(default_vae_dict) vae_dict.update(res) + vae_dict.update(default_vae_dict) return vae_list def resolve_vae(checkpoint_file, vae_file="auto"): global first_load, vae_dict, vae_list - # save_settings = False - # if vae_file argument is provided, it takes priority + # if vae_file argument is provided, it takes priority, but not saved if vae_file and vae_file not in default_vae_list: if not os.path.isfile(vae_file): vae_file = "auto" - # save_settings = True print("VAE provided as function argument doesn't exist") - # for the first load, if vae-path is provided, it takes priority and failure is reported + # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported if first_load and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): vae_file = shared.cmd_opts.vae_path - # save_settings = True - # print("Using VAE provided as command line argument") + shared.opts.data['sd_vae'] = get_filename(vae_file) else: print("VAE provided as command line argument doesn't exist") # else, we load from settings -- cgit v1.2.3 From 8ab4927452b04dcd30847eaf92ea7a9f3b9c74e1 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Wed, 2 Nov 2022 22:54:09 +0700 Subject: Fix model wasn't restored even when choosing "None" --- modules/sd_vae.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 71e7a6e6..7a79239f 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -50,8 +50,8 @@ def delete_base_vae(): def restore_base_vae(model): - global base_vae, checkpoint_info if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: + print("Restoring base VAE") load_vae_dict(model, base_vae) delete_base_vae() @@ -143,6 +143,7 @@ def load_vae(model, vae_file=None): vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} load_vae_dict(model, vae_dict_1) + store_base_vae(model) # If vae used is not in dict, update it # It will be removed on refresh though @@ -150,6 +151,9 @@ def load_vae(model, vae_file=None): if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file vae_list.append(vae_opt) + # shared.opts.data['sd_vae'] = vae_opt + else: + restore_base_vae(model) loaded_vae_file = vae_file @@ -166,12 +170,8 @@ def load_vae(model, vae_file=None): # don't call this from outside -def load_vae_dict(model, vae_dict_1=None): - if vae_dict_1: - store_base_vae(model) - model.first_stage_model.load_state_dict(vae_dict_1) - else: - restore_base_vae() +def load_vae_dict(model, vae_dict_1): + model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.3 From abc1e79a5da24a1ea0f4bceedcdf225f32010aa8 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Thu, 3 Nov 2022 11:10:53 +0700 Subject: Fix base VAE caching was done after loading VAE, also add safeguard --- modules/sd_models.py | 1 + modules/sd_vae.py | 19 ++++++++----------- 2 files changed, 9 insertions(+), 11 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..e4dba62c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.clear_loaded_vae() sd_vae.load_vae(model, vae_file) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 7a79239f..dd69a5e6 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -15,7 +15,7 @@ vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} -default_vae_dict = {"auto": "auto", "None": "None"} +default_vae_dict = {"auto": "auto", "None": None, None: None} default_vae_list = ["auto", "None"] @@ -39,6 +39,7 @@ def get_base_vae(model): def store_base_vae(model): global base_vae, checkpoint_info if checkpoint_info != model.sd_checkpoint_info: + assert not loaded_vae_file, "Trying to store non-base VAE!" base_vae = model.first_stage_model.state_dict().copy() checkpoint_info = model.sd_checkpoint_info @@ -50,9 +51,11 @@ def delete_base_vae(): def restore_base_vae(model): + global loaded_vae_file if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: print("Restoring base VAE") load_vae_dict(model, base_vae) + loaded_vae_file = None delete_base_vae() @@ -140,10 +143,10 @@ def load_vae(model, vae_file=None): if vae_file: print(f"Loading VAE weights from: {vae_file}") + store_base_vae(model) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} load_vae_dict(model, vae_dict_1) - store_base_vae(model) # If vae used is not in dict, update it # It will be removed on refresh though @@ -157,15 +160,6 @@ def load_vae(model, vae_file=None): loaded_vae_file = vae_file - """ - # Save current VAE to VAE settings, maybe? will it work? - if save_settings: - if vae_file is None: - vae_opt = "None" - - # shared.opts.sd_vae = vae_opt - """ - first_load = False @@ -174,6 +168,9 @@ def load_vae_dict(model, vae_dict_1): model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) +def clear_loaded_vae(): + global loaded_vae_file + loaded_vae_file = None def reload_vae_weights(sd_model=None, vae_file="auto"): from modules import lowvram, devices, sd_hijack -- cgit v1.2.3 From c7be83bf0240498d9382e2afeaa3f0677d26c7f6 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 11:11:14 +0700 Subject: Misc Misc --- modules/sd_models.py | 1 + modules/sd_vae.py | 3 +-- modules/shared.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e4dba62c..cd7fe37a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() sd_vae.load_vae(model, vae_file) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index dd69a5e6..13bf3d31 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -154,8 +154,7 @@ def load_vae(model, vae_file=None): if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file vae_list.append(vae_opt) - # shared.opts.data['sd_vae'] = vae_opt - else: + elif loaded_vae_file: restore_base_vae(model) loaded_vae_file = vae_file diff --git a/modules/shared.py b/modules/shared.py index 17132e42..a9daf800 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -335,7 +335,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), -- cgit v1.2.3 From 9fdc343dcaee70f1a0ff15c0cc668dbd487abc61 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Thu, 17 Nov 2022 18:04:10 +0700 Subject: Fix model caching requiring deepcopy --- modules/sd_vae.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 13bf3d31..5b4709b5 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -4,6 +4,7 @@ from collections import namedtuple from modules import shared, devices, script_callbacks from modules.paths import models_path import glob +from copy import deepcopy model_dir = "Stable-diffusion" @@ -40,7 +41,7 @@ def store_base_vae(model): global base_vae, checkpoint_info if checkpoint_info != model.sd_checkpoint_info: assert not loaded_vae_file, "Trying to store non-base VAE!" - base_vae = model.first_stage_model.state_dict().copy() + base_vae = deepcopy(model.first_stage_model.state_dict()) checkpoint_info = model.sd_checkpoint_info -- cgit v1.2.3 From 028b67b6357b5a00ccbd6ea72d2f244a6664162b Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sat, 19 Nov 2022 01:27:54 +0700 Subject: Use underscore naming for "private" functions in sd_vae --- modules/sd_vae.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 5b4709b5..d82a7bad 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -55,7 +55,7 @@ def restore_base_vae(model): global loaded_vae_file if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: print("Restoring base VAE") - load_vae_dict(model, base_vae) + _load_vae_dict(model, base_vae) loaded_vae_file = None delete_base_vae() @@ -147,7 +147,7 @@ def load_vae(model, vae_file=None): store_base_vae(model) vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - load_vae_dict(model, vae_dict_1) + _load_vae_dict(model, vae_dict_1) # If vae used is not in dict, update it # It will be removed on refresh though @@ -164,7 +164,7 @@ def load_vae(model, vae_file=None): # don't call this from outside -def load_vae_dict(model, vae_dict_1): +def _load_vae_dict(model, vae_dict_1): model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.3 From 0663706d4405b4f76ce653097f4f8989ee8b8684 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Thu, 3 Nov 2022 13:47:03 +0700 Subject: Option to use selected VAE as default fallback instead of primary option --- modules/sd_vae.py | 25 ++++++++++++++++--------- modules/shared.py | 1 + webui.py | 1 + 3 files changed, 18 insertions(+), 9 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 71e7a6e6..0b5f0213 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -83,7 +83,19 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): return vae_list -def resolve_vae(checkpoint_file, vae_file="auto"): +def get_vae_from_settings(vae_file="auto"): + # else, we load from settings, if not set to be default + if vae_file == "auto" and shared.opts.sd_vae is not None: + # if saved VAE settings isn't recognized, fallback to auto + vae_file = vae_dict.get(shared.opts.sd_vae, "auto") + # if VAE selected but not found, fallback to auto + if vae_file not in default_vae_values and not os.path.isfile(vae_file): + vae_file = "auto" + print("Selected VAE doesn't exist") + return vae_file + + +def resolve_vae(checkpoint_file=None, vae_file="auto"): global first_load, vae_dict, vae_list # if vae_file argument is provided, it takes priority, but not saved @@ -98,14 +110,9 @@ def resolve_vae(checkpoint_file, vae_file="auto"): shared.opts.data['sd_vae'] = get_filename(vae_file) else: print("VAE provided as command line argument doesn't exist") - # else, we load from settings - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print("Selected VAE doesn't exist") + # fallback to selector in settings, if vae selector not set to act as default fallback + if not shared.opts.sd_vae_as_default: + vae_file = get_vae_from_settings(vae_file) # vae-path cmd arg takes priority for auto if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): diff --git a/modules/shared.py b/modules/shared.py index 17132e42..b84767f0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -336,6 +336,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": modules.sd_models.checkpoint_tiles()}, refresh=sd_models.list_models), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": list(sd_vae.vae_list)}, refresh=sd_vae.refresh_vae_list), + "sd_vae_as_default": OptionInfo(False, "Use selected VAE as default fallback instead"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), diff --git a/webui.py b/webui.py index f4f1d74d..2cd3bae9 100644 --- a/webui.py +++ b/webui.py @@ -82,6 +82,7 @@ def initialize(): modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) -- cgit v1.2.3 From 2c5ca706a7e624d268545ba3318ba230b7b33477 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 10:55:47 +0700 Subject: Remove no longer necessary parts and add vae_file safeguard --- modules/sd_models.py | 10 ++-------- modules/sd_vae.py | 1 + 2 files changed, 3 insertions(+), 8 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..c59151e0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): cache_enabled = shared.opts.sd_checkpoint_cache > 0 - if cache_enabled: - sd_vae.restore_base_vae(model) - - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - if cache_enabled and checkpoint_info in checkpoints_loaded: # use checkpoint cache - vae_name = sd_vae.get_filename(vae_file) if vae_file else None - vae_message = f" with {vae_name} VAE" if vae_name else "" - print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") + print(f"Loading weights [{sd_model_hash}] from cache") model.load_state_dict(checkpoints_loaded[checkpoint_info]) else: # load from file @@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) sd_vae.load_vae(model, vae_file) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 71e7a6e6..8bdb2c17 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -139,6 +139,7 @@ def load_vae(model, vae_file=None): # save_settings = False if vae_file: + assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} -- cgit v1.2.3 From 271fd2d700a59e80d9dc9f23ad3ef08c988e8b24 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 10:58:15 +0700 Subject: More verbose messages --- modules/sd_vae.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 8bdb2c17..fa8de905 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -89,15 +89,15 @@ def resolve_vae(checkpoint_file, vae_file="auto"): # if vae_file argument is provided, it takes priority, but not saved if vae_file and vae_file not in default_vae_list: if not os.path.isfile(vae_file): + print(f"VAE provided as function argument doesn't exist: {vae_file}") vae_file = "auto" - print("VAE provided as function argument doesn't exist") # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported if first_load and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): vae_file = shared.cmd_opts.vae_path shared.opts.data['sd_vae'] = get_filename(vae_file) else: - print("VAE provided as command line argument doesn't exist") + print(f"VAE provided as command line argument doesn't exist: {vae_file}") # else, we load from settings if vae_file == "auto" and shared.opts.sd_vae is not None: # if saved VAE settings isn't recognized, fallback to auto @@ -105,25 +105,25 @@ def resolve_vae(checkpoint_file, vae_file="auto"): # if VAE selected but not found, fallback to auto if vae_file not in default_vae_values and not os.path.isfile(vae_file): vae_file = "auto" - print("Selected VAE doesn't exist") + print(f"Selected VAE doesn't exist: {vae_file}") # vae-path cmd arg takes priority for auto if vae_file == "auto" and shared.cmd_opts.vae_path is not None: if os.path.isfile(shared.cmd_opts.vae_path): vae_file = shared.cmd_opts.vae_path - print("Using VAE provided as command line argument") + print(f"Using VAE provided as command line argument: {vae_file}") # if still not found, try look for ".vae.pt" beside model model_path = os.path.splitext(checkpoint_file)[0] if vae_file == "auto": vae_file_try = model_path + ".vae.pt" if os.path.isfile(vae_file_try): vae_file = vae_file_try - print("Using VAE found beside selected model") + print(f"Using VAE found similar to selected model: {vae_file}") # if still not found, try look for ".vae.ckpt" beside model if vae_file == "auto": vae_file_try = model_path + ".vae.ckpt" if os.path.isfile(vae_file_try): vae_file = vae_file_try - print("Using VAE found beside selected model") + print(f"Using VAE found similar to selected model: {vae_file}") # No more fallbacks for auto if vae_file == "auto": vae_file = None -- cgit v1.2.3 From 3bf5591efe9a9f219c6088be322a87adc4f48f95 Mon Sep 17 00:00:00 2001 From: Yuval Aboulafia Date: Sat, 24 Dec 2022 21:35:29 +0200 Subject: fix F541 f-string without any placeholders --- extensions-builtin/LDSR/ldsr_model_arch.py | 2 +- modules/codeformer/vqgan_arch.py | 4 ++-- modules/hypernetworks/hypernetwork.py | 4 ++-- modules/images.py | 2 +- modules/interrogate.py | 2 +- modules/safe.py | 8 ++++---- modules/sd_models.py | 8 ++++---- modules/sd_vae.py | 2 +- modules/textual_inversion/textual_inversion.py | 2 +- scripts/prompts_from_file.py | 2 +- 10 files changed, 18 insertions(+), 18 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/extensions-builtin/LDSR/ldsr_model_arch.py b/extensions-builtin/LDSR/ldsr_model_arch.py index f5bd8ae4..0ad49f4e 100644 --- a/extensions-builtin/LDSR/ldsr_model_arch.py +++ b/extensions-builtin/LDSR/ldsr_model_arch.py @@ -26,7 +26,7 @@ class LDSR: global cached_ldsr_model if shared.opts.ldsr_cached and cached_ldsr_model is not None: - print(f"Loading model from cache") + print("Loading model from cache") model: torch.nn.Module = cached_ldsr_model else: print(f"Loading model from {self.modelPath}") diff --git a/modules/codeformer/vqgan_arch.py b/modules/codeformer/vqgan_arch.py index c06c590c..e7293683 100644 --- a/modules/codeformer/vqgan_arch.py +++ b/modules/codeformer/vqgan_arch.py @@ -382,7 +382,7 @@ class VQAutoEncoder(nn.Module): self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) logger.info(f'vqgan is loaded from: {model_path} [params]') else: - raise ValueError(f'Wrong params!') + raise ValueError('Wrong params!') def forward(self, x): @@ -431,7 +431,7 @@ class VQGANDiscriminator(nn.Module): elif 'params' in chkpt: self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) else: - raise ValueError(f'Wrong params!') + raise ValueError('Wrong params!') def forward(self, x): return self.main(x) \ No newline at end of file diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c406ffb3..9d3034ae 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -277,7 +277,7 @@ def load_hypernetwork(filename): print(traceback.format_exc(), file=sys.stderr) else: if shared.loaded_hypernetwork is not None: - print(f"Unloading hypernetwork") + print("Unloading hypernetwork") shared.loaded_hypernetwork = None @@ -417,7 +417,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, initial_step = hypernetwork.step or 0 if initial_step >= steps: - shared.state.textinfo = f"Model has already been trained beyond specified max steps" + shared.state.textinfo = "Model has already been trained beyond specified max steps" return hypernetwork, filename scheduler = LearnRateScheduler(learn_rate, steps, initial_step) diff --git a/modules/images.py b/modules/images.py index 809ad9f7..31d4528d 100644 --- a/modules/images.py +++ b/modules/images.py @@ -599,7 +599,7 @@ def read_info_from_image(image): Negative prompt: {json_info["uc"]} Steps: {json_info["steps"]}, Sampler: {sampler}, CFG scale: {json_info["scale"]}, Seed: {json_info["seed"]}, Size: {image.width}x{image.height}, Clip skip: 2, ENSD: 31337""" except Exception: - print(f"Error parsing NovelAI image generation parameters:", file=sys.stderr) + print("Error parsing NovelAI image generation parameters:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return geninfo, items diff --git a/modules/interrogate.py b/modules/interrogate.py index 0068b81c..46935210 100644 --- a/modules/interrogate.py +++ b/modules/interrogate.py @@ -172,7 +172,7 @@ class InterrogateModels: res += ", " + match except Exception: - print(f"Error interrogating", file=sys.stderr) + print("Error interrogating", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) res += "" diff --git a/modules/safe.py b/modules/safe.py index 479c8b86..1d4c20b9 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -137,15 +137,15 @@ def load_with_extra(filename, extra_handler=None, *args, **kwargs): except pickle.UnpicklingError: print(f"Error verifying pickled file from {filename}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - print(f"-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) - print(f"You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) + print("-----> !!!! The file is most likely corrupted !!!! <-----", file=sys.stderr) + print("You can skip this check with --disable-safe-unpickle commandline argument, but that is not going to help you.\n\n", file=sys.stderr) return None except Exception: print(f"Error verifying pickled file from {filename}:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - print(f"\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) - print(f"You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) + print("\nThe file may be malicious, so the program is not going to read it.", file=sys.stderr) + print("You can skip this check with --disable-safe-unpickle commandline argument.\n\n", file=sys.stderr) return None return unsafe_torch_load(filename, *args, **kwargs) diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ca06211..ecdd91c5 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -117,13 +117,13 @@ def select_checkpoint(): return checkpoint_info if len(checkpoints_list) == 0: - print(f"No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) + print("No checkpoints found. When searching for checkpoints, looked at:", file=sys.stderr) if shared.cmd_opts.ckpt is not None: print(f" - file {os.path.abspath(shared.cmd_opts.ckpt)}", file=sys.stderr) print(f" - directory {model_path}", file=sys.stderr) if shared.cmd_opts.ckpt_dir is not None: print(f" - directory {os.path.abspath(shared.cmd_opts.ckpt_dir)}", file=sys.stderr) - print(f"Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) + print("Can't run without a checkpoint. Find and place a .ckpt file into any of those locations. The program will exit.", file=sys.stderr) exit(1) checkpoint_info = next(iter(checkpoints_list.values())) @@ -324,7 +324,7 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) - print(f"Model loaded.") + print("Model loaded.") return sd_model @@ -359,5 +359,5 @@ def reload_model_weights(sd_model=None, info=None): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print(f"Weights loaded.") + print("Weights loaded.") return sd_model diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 25638a83..3856418e 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -208,5 +208,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print(f"VAE Weights loaded.") + print("VAE Weights loaded.") return sd_model diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index daf3997b..f6112578 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -263,7 +263,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ initial_step = embedding.step or 0 if initial_step >= steps: - shared.state.textinfo = f"Model has already been trained beyond specified max steps" + shared.state.textinfo = "Model has already been trained beyond specified max steps" return embedding, filename scheduler = LearnRateScheduler(learn_rate, steps, initial_step) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 6e118ddb..e8386ed2 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -140,7 +140,7 @@ class Script(scripts.Script): try: args = cmdargs(line) except Exception: - print(f"Error parsing line [line] as commandline:", file=sys.stderr) + print(f"Error parsing line {line} as commandline:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) args = {"prompt": line} else: -- cgit v1.2.3 From 893933e05ad267778111b4fad6d1ecb80937afdf Mon Sep 17 00:00:00 2001 From: hitomi Date: Sun, 25 Dec 2022 20:49:25 +0800 Subject: Add memory cache for VAE weights --- modules/sd_vae.py | 31 +++++++++++++++++++++++++------ modules/shared.py | 1 + 2 files changed, 26 insertions(+), 6 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 3856418e..ac71d62d 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,5 +1,6 @@ import torch import os +import collections from collections import namedtuple from modules import shared, devices, script_callbacks from modules.paths import models_path @@ -30,6 +31,7 @@ base_vae = None loaded_vae_file = None checkpoint_info = None +checkpoints_loaded = collections.OrderedDict() def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: @@ -149,13 +151,30 @@ def load_vae(model, vae_file=None): global first_load, vae_dict, vae_list, loaded_vae_file # save_settings = False + cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 + if vae_file: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") - store_base_vae(model) - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - _load_vae_dict(model, vae_dict_1) + if cache_enabled and vae_file in checkpoints_loaded: + # use vae checkpoint cache + print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + store_base_vae(model) + _load_vae_dict(model, checkpoints_loaded[vae_file]) + else: + assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" + print(f"Loading VAE weights from: {vae_file}") + store_base_vae(model) + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + _load_vae_dict(model, vae_dict_1) + + if cache_enabled: + # cache newly loaded vae + checkpoints_loaded[vae_file] = vae_dict_1.copy() + + # clean up cache if limit is reached + if cache_enabled: + while len(checkpoints_loaded) > shared.opts.sd_vae_checkpoint_cache + 1: # we need to count the current model + checkpoints_loaded.popitem(last=False) # LRU # If vae used is not in dict, update it # It will be removed on refresh though diff --git a/modules/shared.py b/modules/shared.py index d4ddeea0..671d30e1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -356,6 +356,7 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), -- cgit v1.2.3 From cb255faec6e5f6b47b7632e6b7d450b9e2f6678b Mon Sep 17 00:00:00 2001 From: Lee Bousfield Date: Sun, 8 Jan 2023 10:17:50 -0700 Subject: Add support for loading VAEs from safetensor files --- modules/sd_vae.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index ac71d62d..9fcfd9db 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,4 +1,5 @@ import torch +import safetensors.torch import os import collections from collections import namedtuple @@ -72,8 +73,10 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): candidates = [ *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), + *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True) + *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), + *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), ] if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): candidates.append(shared.cmd_opts.vae_path) @@ -137,6 +140,12 @@ def resolve_vae(checkpoint_file=None, vae_file="auto"): if os.path.isfile(vae_file_try): vae_file = vae_file_try print(f"Using VAE found similar to selected model: {vae_file}") + # if still not found, try look for ".vae.safetensors" beside model + if vae_file == "auto": + vae_file_try = model_path + ".vae.safetensors" + if os.path.isfile(vae_file_try): + vae_file = vae_file_try + print(f"Using VAE found similar to selected model: {vae_file}") # No more fallbacks for auto if vae_file == "auto": vae_file = None @@ -163,8 +172,14 @@ def load_vae(model, vae_file=None): assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") store_base_vae(model) - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} + _, extension = os.path.splitext(vae_file) + if extension.lower() == ".safetensors": + vae_ckpt = safetensors.torch.load_file(vae_file, device=shared.weight_load_location) + else: + vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) + if "state_dict" in vae_ckpt: + vae_ckpt = vae_ckpt["state_dict"] + vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, vae_dict_1) if cache_enabled: -- cgit v1.2.3 From 49c4509ce2302350210ff650fd26373518c46a79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 19:58:35 +0300 Subject: use existing function for loading VAE weights from file --- modules/sd_vae.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) (limited to 'modules/sd_vae.py') diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 9fcfd9db..0a49daa1 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -3,7 +3,7 @@ import safetensors.torch import os import collections from collections import namedtuple -from modules import shared, devices, script_callbacks +from modules import shared, devices, script_callbacks, sd_models from modules.paths import models_path import glob from copy import deepcopy @@ -172,13 +172,8 @@ def load_vae(model, vae_file=None): assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") store_base_vae(model) - _, extension = os.path.splitext(vae_file) - if extension.lower() == ".safetensors": - vae_ckpt = safetensors.torch.load_file(vae_file, device=shared.weight_load_location) - else: - vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - if "state_dict" in vae_ckpt: - vae_ckpt = vae_ckpt["state_dict"] + + vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys} _load_vae_dict(model, vae_dict_1) @@ -210,10 +205,12 @@ def _load_vae_dict(model, vae_dict_1): model.first_stage_model.load_state_dict(vae_dict_1) model.first_stage_model.to(devices.dtype_vae) + def clear_loaded_vae(): global loaded_vae_file loaded_vae_file = None + def reload_vae_weights(sd_model=None, vae_file="auto"): from modules import lowvram, devices, sd_hijack -- cgit v1.2.3