diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-14 16:56:09 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-14 16:56:09 +0000 |
commit | a5bbcd215304e0c83ab2b9fe7f172f88536d7629 (patch) | |
tree | 5c44a862f9aada5a47b364de06e8ad6d6ddcd310 /modules | |
parent | 69781031e7473e020b3af4461fdceb20130e56ab (diff) | |
download | stable-diffusion-webui-gfx803-a5bbcd215304e0c83ab2b9fe7f172f88536d7629.tar.gz stable-diffusion-webui-gfx803-a5bbcd215304e0c83ab2b9fe7f172f88536d7629.tar.bz2 stable-diffusion-webui-gfx803-a5bbcd215304e0c83ab2b9fe7f172f88536d7629.zip |
fix bug with "Ignore selected VAE for..." option completely disabling VAE election
rework VAE resolving code to be more simple
Diffstat (limited to 'modules')
-rw-r--r-- | modules/sd_models.py | 6 | ||||
-rw-r--r-- | modules/sd_vae.py | 194 | ||||
-rw-r--r-- | modules/shared.py | 4 |
3 files changed, 82 insertions, 122 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index e5a0bc63..6a681cef 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -224,7 +224,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd
-def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
+def load_model_weights(model, checkpoint_info: CheckpointInfo):
sd_model_hash = checkpoint_info.calculate_shorthash()
cache_enabled = shared.opts.sd_checkpoint_cache > 0
@@ -277,8 +277,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): sd_vae.delete_base_vae()
sd_vae.clear_loaded_vae()
- vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file)
- sd_vae.load_vae(model, vae_file)
+ vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
+ sd_vae.load_vae(model, vae_file, vae_source)
def enable_midas_autodownload():
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0a49daa1..6ea92711 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -9,23 +9,9 @@ import glob from copy import deepcopy -model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) -vae_dir = "VAE" -vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) - - +vae_path = os.path.abspath(os.path.join(models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - - -default_vae_dict = {"auto": "auto", "None": None, None: None} -default_vae_list = ["auto", "None"] - - -default_vae_values = [default_vae_dict[x] for x in default_vae_list] -vae_dict = dict(default_vae_dict) -vae_list = list(default_vae_list) -first_load = True +vae_dict = {} base_vae = None @@ -64,100 +50,69 @@ def restore_base_vae(model): def get_filename(filepath): - return os.path.splitext(os.path.basename(filepath))[0] - - -def refresh_vae_list(vae_path=vae_path, model_path=model_path): - global vae_dict, vae_list - res = {} - candidates = [ - *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), + return os.path.basename(filepath) + + +def refresh_vae_list(): + vae_dict.clear() + + paths = [ + os.path.join(sd_models.model_path, '**/*.vae.ckpt'), + os.path.join(sd_models.model_path, '**/*.vae.pt'), + os.path.join(sd_models.model_path, '**/*.vae.safetensors'), + os.path.join(vae_path, '**/*.ckpt'), + os.path.join(vae_path, '**/*.pt'), + os.path.join(vae_path, '**/*.safetensors'), ] - if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): - candidates.append(shared.cmd_opts.vae_path) + + if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir): + paths += [ + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), + ] + + candidates = [] + for path in paths: + candidates += glob.iglob(path, recursive=True) + for filepath in candidates: name = get_filename(filepath) - res[name] = filepath - vae_list.clear() - vae_list.extend(default_vae_list) - vae_list.extend(list(res.keys())) - vae_dict.clear() - vae_dict.update(res) - vae_dict.update(default_vae_dict) - return vae_list - - -def get_vae_from_settings(vae_file="auto"): - # else, we load from settings, if not set to be default - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print(f"Selected VAE doesn't exist: {vae_file}") - return vae_file - - -def resolve_vae(checkpoint_file=None, vae_file="auto"): - global first_load, vae_dict, vae_list - - # if vae_file argument is provided, it takes priority, but not saved - if vae_file and vae_file not in default_vae_list: - if not os.path.isfile(vae_file): - print(f"VAE provided as function argument doesn't exist: {vae_file}") - vae_file = "auto" - # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported - if first_load and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - shared.opts.data['sd_vae'] = get_filename(vae_file) - else: - print(f"VAE provided as command line argument doesn't exist: {vae_file}") - # fallback to selector in settings, if vae selector not set to act as default fallback - if not shared.opts.sd_vae_as_default: - vae_file = get_vae_from_settings(vae_file) - # vae-path cmd arg takes priority for auto - if vae_file == "auto" and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - print(f"Using VAE provided as command line argument: {vae_file}") - # if still not found, try look for ".vae.pt" beside model - model_path = os.path.splitext(checkpoint_file)[0] - if vae_file == "auto": - vae_file_try = model_path + ".vae.pt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.ckpt" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.ckpt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.safetensors" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.safetensors" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # No more fallbacks for auto - if vae_file == "auto": - vae_file = None - # Last check, just because - if vae_file and not os.path.exists(vae_file): - vae_file = None - - return vae_file - - -def load_vae(model, vae_file=None): - global first_load, vae_dict, vae_list, loaded_vae_file + vae_dict[name] = filepath + + +def find_vae_near_checkpoint(checkpoint_file): + checkpoint_path = os.path.splitext(checkpoint_file)[0] + for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: + if os.path.isfile(vae_location): + return vae_location + + return None + + +def resolve_vae(checkpoint_file): + if shared.cmd_opts.vae_path is not None: + return shared.cmd_opts.vae_path, 'from commandline argument' + + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"): + return vae_near_checkpoint, 'found near the checkpoint' + + if shared.opts.sd_vae == "None": + return None, None + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return vae_from_options, 'specified in settings' + + if shared.opts.sd_vae != "Automatic": + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + + return None, None + + +def load_vae(model, vae_file=None, vae_source="from unknown source"): + global vae_dict, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -165,12 +120,12 @@ def load_vae(model, vae_file=None): if vae_file: if cache_enabled and vae_file in checkpoints_loaded: # use vae checkpoint cache - print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") store_base_vae(model) _load_vae_dict(model, checkpoints_loaded[vae_file]) else: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") + assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" + print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) @@ -191,14 +146,12 @@ def load_vae(model, vae_file=None): vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file - vae_list.append(vae_opt) + elif loaded_vae_file: restore_base_vae(model) loaded_vae_file = vae_file - first_load = False - # don't call this from outside def _load_vae_dict(model, vae_dict_1): @@ -211,7 +164,10 @@ def clear_loaded_vae(): loaded_vae_file = None -def reload_vae_weights(sd_model=None, vae_file="auto"): +unspecified = object() + + +def reload_vae_weights(sd_model=None, vae_file=unspecified): from modules import lowvram, devices, sd_hijack if not sd_model: @@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): checkpoint_info = sd_model.sd_checkpoint_info checkpoint_file = checkpoint_info.filename - vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if vae_file == unspecified: + vae_file, vae_source = resolve_vae(checkpoint_file) + else: + vae_source = "from function argument" if loaded_vae_file == vae_file: return @@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): sd_hijack.model_hijack.undo_hijack(sd_model) - load_vae(sd_model, vae_file) + load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) @@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("VAE Weights loaded.") + print("VAE weights loaded.") return sd_model diff --git a/modules/shared.py b/modules/shared.py index e0ec3136..9756adea 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -83,7 +83,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
-parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
+parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
@@ -383,7 +383,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
"sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
"sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
- "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
+ "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
"sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
"sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
"sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
|