From 36966e3200943dbf890b5338cfa939df552d3c47 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 15:38:58 +0700 Subject: Fix #4035 --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f86dc3ed..a29c8c1a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -201,7 +201,7 @@ def load_model_weights(model, checkpoint_info): if shared.opts.sd_checkpoint_cache > 0: checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") -- cgit v1.2.3 From bf7a699845675eefdabb9cfa40c55398976274ae Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Mon, 31 Oct 2022 16:27:27 +0700 Subject: Fix #4035 for real now --- modules/sd_models.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index a29c8c1a..b2dd005a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,6 +165,9 @@ def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash + if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): + checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() + if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -198,16 +201,14 @@ def load_model_weights(model, checkpoint_info): model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.to(devices.dtype_vae) - - if shared.opts.sd_checkpoint_cache > 0: - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: - checkpoints_loaded.popitem(last=False) # LRU else: print(f"Loading weights [{sd_model_hash}] from cache") - checkpoints_loaded.move_to_end(checkpoint_info) model.load_state_dict(checkpoints_loaded[checkpoint_info]) + if shared.opts.sd_checkpoint_cache > 0: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) # LRU + model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info -- cgit v1.2.3 From 99043f33606d3057f83ea52a403e10cd29d1f7e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 11:20:42 +0300 Subject: fix one of previous merges breaking the program --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 63e07a12..34c57bfa 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -167,6 +167,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): sd_vae.restore_base_vae(model) checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + if checkpoint_info not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") -- cgit v1.2.3 From 3b51d239ac9201228c6032fc109111e347e8e6b0 Mon Sep 17 00:00:00 2001 From: cluder <1590330+cluder@users.noreply.github.com> Date: Wed, 9 Nov 2022 04:54:21 +0100 Subject: - do not use ckpt cache, if disabled - cache model after is has been loaded from file --- modules/sd_models.py | 27 +++++++++++++++++---------- 1 file changed, 17 insertions(+), 10 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 34c57bfa..720c2a96 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -163,13 +163,21 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash - if shared.opts.sd_checkpoint_cache > 0 and hasattr(model, "sd_checkpoint_info"): + cache_enabled = shared.opts.sd_checkpoint_cache > 0 + + if cache_enabled: sd_vae.restore_base_vae(model) - checkpoints_loaded[model.sd_checkpoint_info] = model.state_dict().copy() vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - if checkpoint_info not in checkpoints_loaded: + if cache_enabled and checkpoint_info in checkpoints_loaded: + # use checkpoint cache + vae_name = sd_vae.get_filename(vae_file) if vae_file else None + vae_message = f" with {vae_name} VAE" if vae_name else "" + print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") + model.load_state_dict(checkpoints_loaded[checkpoint_info]) + else: + # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) @@ -180,6 +188,10 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): del pl_sd model.load_state_dict(sd, strict=False) del sd + + if cache_enabled: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) @@ -199,13 +211,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.first_stage_model.to(devices.dtype_vae) - else: - vae_name = sd_vae.get_filename(vae_file) if vae_file else None - vae_message = f" with {vae_name} VAE" if vae_name else "" - print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - - if shared.opts.sd_checkpoint_cache > 0: + # clean up cache if limit is reached + if cache_enabled: while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU -- cgit v1.2.3 From eebf49592ad2c0933e58b06a098b92e48d47e4fe Mon Sep 17 00:00:00 2001 From: cluder <1590330+cluder@users.noreply.github.com> Date: Wed, 9 Nov 2022 07:17:09 +0100 Subject: restore #4035 behavior - if checkpoint cache is set to 1, keep 2 models in cache (current +1 more) --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 720c2a96..80addf03 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -213,7 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # clean up cache if limit is reached if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model checkpoints_loaded.popitem(last=False) # LRU model.sd_model_hash = sd_model_hash -- cgit v1.2.3 From 2c5ca706a7e624d268545ba3318ba230b7b33477 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 13 Nov 2022 10:55:47 +0700 Subject: Remove no longer necessary parts and add vae_file safeguard --- modules/sd_models.py | 10 ++-------- modules/sd_vae.py | 1 + 2 files changed, 3 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..c59151e0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -165,16 +165,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): cache_enabled = shared.opts.sd_checkpoint_cache > 0 - if cache_enabled: - sd_vae.restore_base_vae(model) - - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - if cache_enabled and checkpoint_info in checkpoints_loaded: # use checkpoint cache - vae_name = sd_vae.get_filename(vae_file) if vae_file else None - vae_message = f" with {vae_name} VAE" if vae_name else "" - print(f"Loading weights [{sd_model_hash}]{vae_message} from cache") + print(f"Loading weights [{sd_model_hash}] from cache") model.load_state_dict(checkpoints_loaded[checkpoint_info]) else: # load from file @@ -220,6 +213,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) sd_vae.load_vae(model, vae_file) diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 71e7a6e6..8bdb2c17 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -139,6 +139,7 @@ def load_vae(model, vae_file=None): # save_settings = False if vae_file: + assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} -- cgit v1.2.3 From 0efffbb407a9d07eae6850374099775385ce176c Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Mon, 21 Nov 2022 14:04:25 +0100 Subject: Supporting `*.safetensors` format. If a model file exists with extension `.safetensors` then we can load it more safely than with PyTorch weights. --- modules/sd_models.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 80addf03..0164cc1b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -45,7 +45,7 @@ def checkpoint_tiles(): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -180,7 +180,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if checkpoint_file.endswith(".safetensors"): + try: + from safetensors.torch import load_file + except ImportError as e: + raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") + pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + else: + pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") -- cgit v1.2.3 From 1e506657e1cb732a5f0e567ba2585fba2bbb1327 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Sat, 26 Nov 2022 13:28:44 -0500 Subject: no-half support for SD 2.0 --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index c59151e0..0e0bd79e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -244,6 +244,9 @@ def load_model(checkpoint_info=None): do_inpainting_hijack() + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.3 From 6074175faa751dde933aa8e15cd687ca4e4b4a23 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 14:46:40 +0300 Subject: add safetensors to requirements --- modules/sd_models.py | 11 +++++------ requirements.txt | 1 + requirements_versions.txt | 1 + 3 files changed, 7 insertions(+), 6 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ae36841a..77236480 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -5,6 +5,7 @@ import gc from collections import namedtuple import torch import re +import safetensors.torch from omegaconf import OmegaConf from ldm.util import instantiate_from_config @@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - if checkpoint_file.endswith(".safetensors"): - try: - from safetensors.torch import load_file - except ImportError as e: - raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}") - pl_sd = load_file(checkpoint_file, device=shared.weight_load_location) + _, extension = os.path.splitext(checkpoint_file) + if extension.lower() == ".safetensors": + pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location) else: pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) + if "global_step" in pl_sd: print(f"Global Step: {pl_sd['global_step']}") diff --git a/requirements.txt b/requirements.txt index e4e5ec64..5f3d9623 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,3 +29,4 @@ lark inflection GitPython torchsde +safetensors diff --git a/requirements_versions.txt b/requirements_versions.txt index 8d557fe3..035fa82f 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -26,3 +26,4 @@ lark==1.1.2 inflection==0.5.1 GitPython==3.1.27 torchsde==0.2.5 +safetensors==0.2.5 -- cgit v1.2.3 From dac9b6f15de5e675053d9490a20e0457dcd1a23e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 27 Nov 2022 15:51:29 +0300 Subject: add safetensors support for model merging #4869 --- modules/extras.py | 26 ++++++++++++++------------ modules/sd_models.py | 26 +++++++++++++++----------- modules/ui.py | 7 ++++++- 3 files changed, 35 insertions(+), 24 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/extras.py b/modules/extras.py index 71b93a06..3d65d90a 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -20,6 +20,7 @@ import modules.codeformer_model import piexif import piexif.helper import gradio as gr +import safetensors.torch class LruCache(OrderedDict): @@ -249,7 +250,7 @@ def run_pnginfo(image): return '', geninfo, info -def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name): +def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format): def weighted_sum(theta0, theta1, alpha): return ((1 - alpha) * theta0) + (alpha * theta1) @@ -264,19 +265,15 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None) print(f"Loading {primary_model_info.filename}...") - primary_model = torch.load(primary_model_info.filename, map_location='cpu') - theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model) + theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') print(f"Loading {secondary_model_info.filename}...") - secondary_model = torch.load(secondary_model_info.filename, map_location='cpu') - theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model) + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') if teritary_model_info is not None: print(f"Loading {teritary_model_info.filename}...") - teritary_model = torch.load(teritary_model_info.filename, map_location='cpu') - theta_2 = sd_models.get_state_dict_from_checkpoint(teritary_model) + theta_2 = sd_models.read_state_dict(teritary_model_info.filename, map_location='cpu') else: - teritary_model = None theta_2 = None theta_funcs = { @@ -295,7 +292,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam theta_1[key] = theta_func1(theta_1[key], t2) else: theta_1[key] = torch.zeros_like(theta_1[key]) - del theta_2, teritary_model + del theta_2 for key in tqdm.tqdm(theta_0.keys()): if 'model' in key and key in theta_1: @@ -314,12 +311,17 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.ckpt' - filename = filename if custom_name == '' else (custom_name + '.ckpt') + filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.' + checkpoint_format + filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format) output_modelname = os.path.join(ckpt_dir, filename) print(f"Saving to {output_modelname}...") - torch.save(primary_model, output_modelname) + + _, extension = os.path.splitext(output_modelname) + if extension.lower() == ".safetensors": + safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"}) + else: + torch.save(theta_0, output_modelname) sd_models.list_models() diff --git a/modules/sd_models.py b/modules/sd_models.py index 77236480..a1ea5611 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -160,6 +160,20 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): + _, extension = os.path.splitext(checkpoint_file) + if extension.lower() == ".safetensors": + pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) + else: + pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) + + if print_global_state and "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + + sd = get_state_dict_from_checkpoint(pl_sd) + return sd + + def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -174,17 +188,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): # load from file print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") - _, extension = os.path.splitext(checkpoint_file) - if extension.lower() == ".safetensors": - pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location) - else: - pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location) - - if "global_step" in pl_sd: - print(f"Global Step: {pl_sd['global_step']}") - - sd = get_state_dict_from_checkpoint(pl_sd) - del pl_sd + sd = read_state_dict(checkpoint_file) model.load_state_dict(sd, strict=False) del sd diff --git a/modules/ui.py b/modules/ui.py index de2b5544..aa13978d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1164,7 +1164,11 @@ def create_ui(wrap_gradio_gpu_call): custom_name = gr.Textbox(label="Custom Name (Optional)") interp_amount = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label='Multiplier (M) - set to 0 to get model A', value=0.3) interp_method = gr.Radio(choices=["Weighted sum", "Add difference"], value="Weighted sum", label="Interpolation Method") - save_as_half = gr.Checkbox(value=False, label="Save as float16") + + with gr.Row(): + checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format") + save_as_half = gr.Checkbox(value=False, label="Save as float16") + modelmerger_merge = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary') with gr.Column(variant='panel'): @@ -1692,6 +1696,7 @@ def create_ui(wrap_gradio_gpu_call): interp_amount, save_as_half, custom_name, + checkpoint_format, ], outputs=[ submit_result, -- cgit v1.2.3 From 0376da180c81a11880a2587903d69d85541051e7 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 28 Nov 2022 08:39:59 +0300 Subject: make it possible to save nai model using safetensors --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index a1ea5611..283cf1cd 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -144,8 +144,8 @@ def transform_checkpoint_dict_key(k): def get_state_dict_from_checkpoint(pl_sd): - if "state_dict" in pl_sd: - pl_sd = pl_sd["state_dict"] + pl_sd = pl_sd.pop("state_dict", pl_sd) + pl_sd.pop("state_dict", None) sd = {} for k, v in pl_sd.items(): -- cgit v1.2.3