diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-24 08:09:04 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-24 08:09:04 +0000 |
commit | 189229bbf9276fb73e48c783856b02fc57ab5c9b (patch) | |
tree | 728b1ab97fec6d18a1ec687ba552ca83b0dcf109 /modules/sd_vae.py | |
parent | 31f2be3dcedf85c036c5f784c640208d122b62ed (diff) | |
parent | b6c02174050b2c5dd98bf24c797e85ff269516f5 (diff) | |
download | stable-diffusion-webui-gfx803-189229bbf9276fb73e48c783856b02fc57ab5c9b.tar.gz stable-diffusion-webui-gfx803-189229bbf9276fb73e48c783856b02fc57ab5c9b.tar.bz2 stable-diffusion-webui-gfx803-189229bbf9276fb73e48c783856b02fc57ab5c9b.zip |
Merge branch 'dev' into release_candidate
Diffstat (limited to 'modules/sd_vae.py')
-rw-r--r-- | modules/sd_vae.py | 105 |
1 files changed, 87 insertions, 18 deletions
diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e4ff2994..669097da 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,6 +1,9 @@ import os import collections -from modules import paths, shared, devices, script_callbacks, sd_models +from dataclasses import dataclass + +from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks, lowvram, sd_hijack, hashes + import glob from copy import deepcopy @@ -16,6 +19,23 @@ checkpoint_info = None checkpoints_loaded = collections.OrderedDict() + +def get_loaded_vae_name(): + if loaded_vae_file is None: + return None + + return os.path.basename(loaded_vae_file) + + +def get_loaded_vae_hash(): + if loaded_vae_file is None: + return None + + sha256 = hashes.sha256(loaded_vae_file, 'vae') + + return sha256[0:10] if sha256 else None + + def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: return base_vae @@ -83,6 +103,8 @@ def refresh_vae_list(): name = get_filename(filepath) vae_dict[name] = filepath + vae_dict.update(dict(sorted(vae_dict.items(), key=lambda item: shared.natural_sort_key(item[0])))) + def find_vae_near_checkpoint(checkpoint_file): checkpoint_path = os.path.basename(checkpoint_file).rsplit('.', 1)[0] @@ -93,27 +115,74 @@ def find_vae_near_checkpoint(checkpoint_file): return None -def resolve_vae(checkpoint_file): - if shared.cmd_opts.vae_path is not None: - return shared.cmd_opts.vae_path, 'from commandline argument' +@dataclass +class VaeResolution: + vae: str = None + source: str = None + resolved: bool = True - is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + def tuple(self): + return self.vae, self.source - vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) - if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or is_automatic): - return vae_near_checkpoint, 'found near the checkpoint' +def is_automatic(): + return shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config + + +def resolve_vae_from_setting() -> VaeResolution: if shared.opts.sd_vae == "None": - return None, None + return VaeResolution() vae_from_options = vae_dict.get(shared.opts.sd_vae, None) if vae_from_options is not None: - return vae_from_options, 'specified in settings' + return VaeResolution(vae_from_options, 'specified in settings') - if not is_automatic: + if not is_automatic(): print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") - return None, None + return VaeResolution(resolved=False) + + +def resolve_vae_from_user_metadata(checkpoint_file) -> VaeResolution: + metadata = extra_networks.get_user_metadata(checkpoint_file) + vae_metadata = metadata.get("vae", None) + if vae_metadata is not None and vae_metadata != "Automatic": + if vae_metadata == "None": + return VaeResolution() + + vae_from_metadata = vae_dict.get(vae_metadata, None) + if vae_from_metadata is not None: + return VaeResolution(vae_from_metadata, "from user metadata") + + return VaeResolution(resolved=False) + + +def resolve_vae_near_checkpoint(checkpoint_file) -> VaeResolution: + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) + if vae_near_checkpoint is not None and (not shared.opts.sd_vae_overrides_per_model_preferences or is_automatic): + return VaeResolution(vae_near_checkpoint, 'found near the checkpoint') + + return VaeResolution(resolved=False) + + +def resolve_vae(checkpoint_file) -> VaeResolution: + if shared.cmd_opts.vae_path is not None: + return VaeResolution(shared.cmd_opts.vae_path, 'from commandline argument') + + if shared.opts.sd_vae_overrides_per_model_preferences and not is_automatic(): + return resolve_vae_from_setting() + + res = resolve_vae_from_user_metadata(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_near_checkpoint(checkpoint_file) + if res.resolved: + return res + + res = resolve_vae_from_setting() + + return res def load_vae_dict(filename, map_location): @@ -123,7 +192,7 @@ def load_vae_dict(filename, map_location): def load_vae(model, vae_file=None, vae_source="from unknown source"): - global vae_dict, loaded_vae_file + global vae_dict, base_vae, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -161,6 +230,8 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"): restore_base_vae(model) loaded_vae_file = vae_file + model.base_vae = base_vae + model.loaded_vae_file = loaded_vae_file # don't call this from outside @@ -178,8 +249,6 @@ unspecified = object() def reload_vae_weights(sd_model=None, vae_file=unspecified): - from modules import lowvram, devices, sd_hijack - if not sd_model: sd_model = shared.sd_model @@ -187,14 +256,14 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): checkpoint_file = checkpoint_info.filename if vae_file == unspecified: - vae_file, vae_source = resolve_vae(checkpoint_file) + vae_file, vae_source = resolve_vae(checkpoint_file).tuple() else: vae_source = "from function argument" if loaded_vae_file == vae_file: return - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + if sd_model.lowvram: lowvram.send_everything_to_cpu() else: sd_model.to(devices.cpu) @@ -206,7 +275,7 @@ def reload_vae_weights(sd_model=None, vae_file=unspecified): sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + if not sd_model.lowvram: sd_model.to(devices.device) print("VAE weights loaded.") |