From 5ba04f9ec050a66e918571f07e8863f157f05b44 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Dec 2022 13:45:58 +0100 Subject: Attempting to solve slow loads for `safetensors`. Fixes #5893 --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ecdd91c5..cd938656 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -168,7 +168,10 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) + device = map_location or shared.weight_load_location + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.3 From f55ac33d446185680604e872ceda2ae858821d5c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 31 Dec 2022 11:27:02 -0500 Subject: validate textual inversion embeddings --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ecdd91c5..ebd4dff7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -325,6 +325,9 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) print("Model loaded.") + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model + return sd_model -- cgit v1.2.3 From 311354c0bb8930ea939d6aa6b3edd50c69301320 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 00:38:09 +0300 Subject: fix the issue with training on SD2.0 --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ebd4dff7..bff8d6c9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + model.logvar = model.logvar.to(devices.device) # fix for training + sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) -- cgit v1.2.3 From 8f96f9289981a66741ba770d14f3d27ce335a0fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 18:39:14 +0300 Subject: call script callbacks for reloaded model after loading embeddings --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index bff8d6c9..b98b05fc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -324,12 +324,12 @@ def load_model(checkpoint_info=None): sd_model.eval() shared.sd_model = sd_model + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + script_callbacks.model_loaded_callback(sd_model) print("Model loaded.") - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model - return sd_model -- cgit v1.2.3 From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:35:07 +0300 Subject: helpful error message when trying to load 2.0 without config failing to load model weights from settings won't break generation for currently loaded model anymore --- modules/sd_models.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b98b05fc..6846b74a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -278,6 +278,7 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -312,6 +313,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.use_fp16 = False sd_model = instantiate_from_config(sd_config.model) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -336,10 +338,12 @@ def load_model(checkpoint_info=None): def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - + if not sd_model: sd_model = shared.sd_model + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return @@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) - - sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) - - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: - sd_model.to(devices.device) + try: + load_model_weights(sd_model, checkpoint_info) + except Exception as e: + print("Failed to load checkpoint, restoring previous") + load_model_weights(sd_model, current_checkpoint_info) + raise + finally: + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) print("Weights loaded.") + return sd_model -- cgit v1.2.3 From 8d8a05a3bbb50fdfeab51679a919d2487bd97976 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:47:42 +0300 Subject: find configs for models at runtime rather than when starting --- modules/sd_models.py | 31 ++++++++++++++++++------------- 1 file changed, 18 insertions(+), 13 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6846b74a..6dca4ddf 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} checkpoints_loaded = collections.OrderedDict() @@ -48,6 +48,14 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) +def find_checkpoint_config(info): + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return shared.cmd_opts.config + + def list_models(): checkpoints_list.clear() model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) @@ -73,7 +81,7 @@ def list_models(): if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) @@ -81,12 +89,7 @@ def list_models(): h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - basename, _ = os.path.splitext(filename) - config = basename + ".yaml" - if not os.path.exists(config): - config = shared.cmd_opts.config - - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) def get_closet_checkpoint_match(searchString): @@ -282,9 +285,10 @@ def enable_midas_autodownload(): def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + checkpoint_config = find_checkpoint_config(checkpoint_info) - if checkpoint_info.config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_info.config}") + if checkpoint_config != shared.cmd_opts.config: + print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -292,7 +296,7 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_info.config) + sd_config = OmegaConf.load(checkpoint_config) if should_hijack_inpainting(checkpoint_info): # Hardcoded config for now... @@ -302,7 +306,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.finetune_keys = None # Create a "fake" config with a different name so that we know to unload it when switching models. - checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml")) if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None): sd_model = shared.sd_model current_checkpoint_info = sd_model.sd_checkpoint_info + checkpoint_config = find_checkpoint_config(current_checkpoint_info) if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.3 From 0cd6399b8b1699b8b7acad6f0ad2988111fe618e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 14:29:13 +0300 Subject: fix broken inpainting model --- modules/sd_models.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6dca4ddf..a568823d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -305,9 +305,6 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None - # Create a "fake" config with a different name so that we know to unload it when switching models. - checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml")) - if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False -- cgit v1.2.3 From 642142556d8ecdea9beb86d7618b628b1803ab98 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 15:09:53 +0300 Subject: use commandline-supplied cuda device name instead of cuda:0 for safetensors PR that doesn't fix anything --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ee918f24..76a89e88 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -173,7 +173,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None if extension.lower() == ".safetensors": device = map_location or shared.weight_load_location if device is None: - device = "cuda:0" if torch.cuda.is_available() else "cpu" + device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.3 From 552d7b90bf483c160cd20740f7acd7fccbc02e6f Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Mon, 9 Jan 2023 18:34:26 -0500 Subject: allow model load if previous model failed --- modules/sd_models.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 76a89e88..0a6d55ca 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -49,6 +49,9 @@ def checkpoint_tiles(): def find_checkpoint_config(info): + if info is None: + return shared.cmd_opts.config + config = os.path.splitext(info.filename)[0] + ".yaml" if os.path.exists(config): return config @@ -345,14 +348,16 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed + current_checkpoint_info = None + else: + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: + return - current_checkpoint_info = sd_model.sd_checkpoint_info checkpoint_config = find_checkpoint_config(current_checkpoint_info) - if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return - - if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.3 From 0c3feb202c5714abd50d879c1db2cd9a71ce93e3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 14:08:29 +0300 Subject: disable torch weight initialization and CLIP downloading/reading checkpoint to speedup creating sd model from config --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 0a6d55ca..ee241032 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -13,7 +13,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -319,7 +319,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False - sd_model = instantiate_from_config(sd_config.model) + with sd_disable_initialization.DisableInitialization(): + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.3 From ce3f639ec8758ce2bc90483336361d2dc25acd3a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 16:51:04 +0300 Subject: add more stuff to ignore when creating model from config prevent .vae.safetensors files from being listed as stable diffusion models --- modules/sd_models.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ee241032..1bb9088b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,6 +2,7 @@ import collections import os.path import sys import gc +import time from collections import namedtuple import torch import re @@ -61,7 +62,7 @@ def find_checkpoint_config(info): def list_models(): checkpoints_list.clear() - model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) + model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"]) def modeltitle(path, shorthash): abspath = os.path.abspath(path) @@ -288,6 +289,17 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper +class Timer: + def __init__(self): + self.start = time.time() + + def elapsed(self): + end = time.time() + res = end - self.start + self.start = end + return res + + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -319,11 +331,17 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False + timer = Timer() + with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) + elapsed_create = timer.elapsed() + load_model_weights(sd_model, checkpoint_info) + elapsed_load_weights = timer.elapsed() + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: @@ -338,7 +356,9 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) - print("Model loaded.") + elapsed_the_rest = timer.elapsed() + + print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") return sd_model @@ -349,7 +369,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model - if sd_model is None: # previous model load failed + if sd_model is None: # previous model load failed current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info @@ -371,6 +391,8 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) + timer = Timer() + try: load_model_weights(sd_model, checkpoint_info) except Exception as e: @@ -384,6 +406,8 @@ def reload_model_weights(sd_model=None, info=None): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("Weights loaded.") + elapsed = timer.elapsed() + + print(f"Weights loaded in {elapsed:.1f}s.") return sd_model -- cgit v1.2.3 From 0f8603a55988d22616b17140e6c4a7e9d0736af5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 10 Jan 2023 17:46:59 +0300 Subject: add support for transformers==4.25.1 add fallback for when quick model creation fails --- modules/sd_models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 1bb9088b..b5bc12f0 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting @@ -333,7 +333,11 @@ def load_model(checkpoint_info=None): timer = Timer() - with sd_disable_initialization.DisableInitialization(): + try: + with sd_disable_initialization.DisableInitialization(): + sd_model = instantiate_from_config(sd_config.model) + except Exception as e: + print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) elapsed_create = timer.elapsed() -- cgit v1.2.3 From 4fdacd31e48c6a7a35c1c25c559932585e8addde Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 10:24:56 +0300 Subject: possible fix for fallback for fast model creation from config --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b5bc12f0..a0a8a909 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -337,6 +337,9 @@ def load_model(checkpoint_info=None): with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) except Exception as e: + pass + + if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.3 From 1a23dc32ac5e16fac10115cafd0b841abd06e59f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 10:34:36 +0300 Subject: possible fix for fallback for fast model creation from config, attempt 2 --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index a0a8a909..084ba7fa 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -333,6 +333,7 @@ def load_model(checkpoint_info=None): timer = Timer() + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.3 From 4bd490727e156ff53107d53416d6b89be86f2a62 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 11 Jan 2023 18:54:04 +0300 Subject: fix for an error caused by skipping initialization, for realsies this time: TypeError: expected str, bytes or os.PathLike object, not NoneType --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 084ba7fa..c466f273 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -334,6 +334,7 @@ def load_model(checkpoint_info=None): timer = Timer() sd_model = None + try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) -- cgit v1.2.3 From a95f1353089bdeaccd7c266b40cdd79efedfe632 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 09:56:59 +0300 Subject: change hash to sha256 --- modules/sd_models.py | 116 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 74 insertions(+), 42 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index c466f273..7babb9ae 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,17 +14,56 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} +checkpoint_alisases = {} checkpoints_loaded = collections.OrderedDict() + +class CheckpointInfo: + def __init__(self, filename): + self.filename = filename + abspath = os.path.abspath(filename) + + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): + name = abspath.replace(shared.cmd_opts.ckpt_dir, '') + elif abspath.startswith(model_path): + name = abspath.replace(model_path, '') + else: + name = os.path.basename(filename) + + if name.startswith("\\") or name.startswith("/"): + name = name[1:] + + self.title = name + self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] + self.hash = model_hash(filename) + self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + self.shorthash = None + self.sha256 = None + + def register(self): + checkpoints_list[self.title] = self + for id in self.ids: + checkpoint_alisases[id] = self + + def calculate_shorthash(self): + self.sha256 = hashes.sha256(self.filename, self.title) + self.shorthash = self.sha256[0:10] + + if self.shorthash not in self.ids: + self.ids += [self.shorthash, self.sha256] + self.register() + + return self.shorthash + + try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. @@ -43,10 +82,14 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - convert = lambda name: int(name) if name.isdigit() else name.lower() - alphanumeric_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)] - return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) +def checkpoint_tiles(): + def convert(name): + return int(name) if name.isdigit() else name.lower() + + def alphanumeric_key(key): + return [convert(c) for c in re.split('([0-9]+)', key)] + + return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) def find_checkpoint_config(info): @@ -62,48 +105,38 @@ def find_checkpoint_config(info): def list_models(): checkpoints_list.clear() + checkpoint_alisases.clear() model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"]) - def modeltitle(path, shorthash): - abspath = os.path.abspath(path) - - if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): - name = abspath.replace(shared.cmd_opts.ckpt_dir, '') - elif abspath.startswith(model_path): - name = abspath.replace(model_path, '') - else: - name = os.path.basename(path) - - if name.startswith("\\") or name.startswith("/"): - name = name[1:] - - shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] - - return f'{name} [{shorthash}]', shortname - cmd_ckpt = shared.cmd_opts.ckpt if os.path.exists(cmd_ckpt): - h = model_hash(cmd_ckpt) - title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) - shared.opts.data['sd_model_checkpoint'] = title + checkpoint_info = CheckpointInfo(cmd_ckpt) + checkpoint_info.register() + + shared.opts.data['sd_model_checkpoint'] = checkpoint_info.title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) + for filename in model_list: - h = model_hash(filename) - title, short_model_name = modeltitle(filename, h) + checkpoint_info = CheckpointInfo(filename) + checkpoint_info.register() + - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) +def get_closet_checkpoint_match(search_string): + checkpoint_info = checkpoint_alisases.get(search_string, None) + if checkpoint_info is not None: + return + found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) + if found: + return found[0] -def get_closet_checkpoint_match(searchString): - applicable = sorted([info for info in checkpoints_list.values() if searchString in info.title], key = lambda x:len(x.title)) - if len(applicable) > 0: - return applicable[0] return None def model_hash(filename): + """old hash that only looks at a small part of the file and is prone to collisions""" + try: with open(filename, "rb") as file: import hashlib @@ -119,7 +152,7 @@ def model_hash(filename): def select_checkpoint(): model_checkpoint = shared.opts.sd_model_checkpoint - checkpoint_info = checkpoints_list.get(model_checkpoint, None) + checkpoint_info = checkpoint_alisases.get(model_checkpoint, None) if checkpoint_info is not None: return checkpoint_info @@ -189,9 +222,8 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info, vae_file="auto"): - checkpoint_file = checkpoint_info.filename - sd_model_hash = checkpoint_info.hash +def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): + sd_model_hash = checkpoint_info.calculate_shorthash() cache_enabled = shared.opts.sd_checkpoint_cache > 0 @@ -201,9 +233,9 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.load_state_dict(checkpoints_loaded[checkpoint_info]) else: # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") - sd = read_state_dict(checkpoint_file) + sd = read_state_dict(checkpoint_info.filename) model.load_state_dict(sd, strict=False) del sd @@ -235,14 +267,14 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoints_loaded.popitem(last=False) # LRU model.sd_model_hash = sd_model_hash - model.sd_model_checkpoint = checkpoint_file + model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) + vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file) sd_vae.load_vae(model, vae_file) -- cgit v1.2.3 From f9ac3352cb66ce2bc0aa4325130fc7267fb35e4f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 10:25:21 +0300 Subject: change hypernets to use sha256 hashes --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 7babb9ae..8f00191c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -125,7 +125,7 @@ def list_models(): def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_alisases.get(search_string, None) if checkpoint_info is not None: - return + return checkpoint_info found = sorted([info for info in checkpoints_list.values() if search_string in info.title], key=lambda x: len(x.title)) if found: -- cgit v1.2.3 From febd2b722e80959b89a0e5966a159b4eb430c5a5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 13:37:55 +0300 Subject: update key to use with checkpoints' sha256 in cache --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 8f00191c..1fe6d11b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -54,7 +54,7 @@ class CheckpointInfo: checkpoint_alisases[id] = self def calculate_shorthash(self): - self.sha256 = hashes.sha256(self.filename, self.title) + self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title) self.shorthash = self.sha256[0:10] if self.shorthash not in self.ids: -- cgit v1.2.3 From 08c6f009a5ee92dd3218a942c08e8337c26352be Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 15:55:40 +0300 Subject: load hashes from cache for checkpoints that have them add checkpoint hash to footer --- modules/sd_models.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 1fe6d11b..e5a0bc63 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -44,9 +44,11 @@ class CheckpointInfo: self.title = name self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) - self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] - self.shorthash = None - self.sha256 = None + + self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title) + self.shorthash = self.sha256[0:10] if self.sha256 else None + + self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self @@ -269,6 +271,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename model.sd_checkpoint_info = checkpoint_info + shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 model.logvar = model.logvar.to(devices.device) # fix for training -- cgit v1.2.3 From a5bbcd215304e0c83ab2b9fe7f172f88536d7629 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 19:56:09 +0300 Subject: fix bug with "Ignore selected VAE for..." option completely disabling VAE election rework VAE resolving code to be more simple --- modules/sd_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e5a0bc63..6a681cef 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -224,7 +224,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): +def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_model_hash = checkpoint_info.calculate_shorthash() cache_enabled = shared.opts.sd_checkpoint_cache > 0 @@ -277,8 +277,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file) - sd_vae.load_vae(model, vae_file) + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + sd_vae.load_vae(model, vae_file, vae_source) def enable_midas_autodownload(): -- cgit v1.2.3 From c1928cdd6194928af0f53f70c51d59479b7025e2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 18:58:08 +0300 Subject: bring back short hashes to sd checkpoint selection --- modules/sd_models.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6a681cef..12083848 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -41,14 +41,16 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"): name = name[1:] - self.title = name + self.name = name self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] self.hash = model_hash(filename) - self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + self.title) + self.sha256 = hashes.sha256_from_cache(self.filename, "checkpoint/" + name) self.shorthash = self.sha256[0:10] if self.sha256 else None - self.ids = [self.hash, self.model_name, self.title, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256] if self.shorthash else []) + self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + + self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self @@ -56,13 +58,15 @@ class CheckpointInfo: checkpoint_alisases[id] = self def calculate_shorthash(self): - self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.title) + self.sha256 = hashes.sha256(self.filename, "checkpoint/" + self.name) self.shorthash = self.sha256[0:10] if self.shorthash not in self.ids: self.ids += [self.shorthash, self.sha256] self.register() + self.title = f'{self.name} [{self.shorthash}]' + return self.shorthash @@ -225,7 +229,10 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None def load_model_weights(model, checkpoint_info: CheckpointInfo): + title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + if checkpoint_info.title != title: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title cache_enabled = shared.opts.sd_checkpoint_cache > 0 -- cgit v1.2.3 From 84d9ce30cb427759547bc7876ed80ab91787d175 Mon Sep 17 00:00:00 2001 From: brkirch Date: Tue, 24 Jan 2023 23:51:45 -0500 Subject: Add option for float32 sampling with float16 UNet This also handles type casting so that ROCm and MPS torch devices work correctly without --no-half. One cast is required for deepbooru in deepbooru_model.py, some explicit casting is required for img2img and inpainting. depth_model can't be converted to float16 or it won't work correctly on some systems (it's known to have issues on MPS) so in sd_models.py model.depth_model is removed for model.half(). --- modules/sd_models.py | 10 ++++++++++ 1 file changed, 10 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..7c98991a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -257,16 +257,24 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): if not shared.cmd_opts.no_half: vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 if shared.cmd_opts.no_half_vae: model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None model.half() model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) @@ -372,6 +380,8 @@ def load_model(checkpoint_info=None): if shared.cmd_opts.no_half: sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True timer = Timer() -- cgit v1.2.3 From ee0a0da3244123cb6d2ba4097a54a1e9caccb687 Mon Sep 17 00:00:00 2001 From: Kyle Date: Wed, 25 Jan 2023 08:53:23 -0500 Subject: Add instruct-pix2pix hijack Allows loading instruct-pix2pix models via same method as inpainting models in sd_models.py and sd_hijack_ip2p.py Adds ddpm_edit.py necessary for instruct-pix2pix --- modules/sd_models.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 12083848..cddc2343 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting +from modules.sd_hijack_ip2p import should_hijack_ip2p model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -365,6 +366,15 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None + if should_hijack_ip2p(checkpoint_info): + sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" + sd_config.model.params.conditioning_key = "hybrid" + sd_config.model.params.first_stage_key = "edited" + sd_config.model.params.cond_stage_key = "edit" + sd_config.model.params.image_size = 16 + sd_config.model.params.unet_config.params.in_channels = 8 + sd_config.model.params.unet_config.params.out_channels = 4 + if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -429,7 +439,7 @@ def reload_model_weights(sd_model=None, info=None): checkpoint_config = find_checkpoint_config(current_checkpoint_info) - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.3 From d2ac95fa7b2a8d0bcc5361ee16dba9cbb81ff8b2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:28:12 +0300 Subject: remove the need to place configs near models --- modules/sd_models.py | 228 +++++++++++++++++++++++++-------------------------- 1 file changed, 113 insertions(+), 115 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072eb2e..fa208728 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -2,8 +2,6 @@ import collections import os.path import sys import gc -import time -from collections import namedtuple import torch import re import safetensors.torch @@ -14,10 +12,10 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes +from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path -from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting -from modules.sd_hijack_ip2p import should_hijack_ip2p +from modules.sd_hijack_inpainting import do_inpainting_hijack +from modules.timer import Timer model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -99,17 +97,6 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) -def find_checkpoint_config(info): - if info is None: - return shared.cmd_opts.config - - config = os.path.splitext(info.filename)[0] + ".yaml" - if os.path.exists(config): - return config - - return shared.cmd_opts.config - - def list_models(): checkpoints_list.clear() checkpoint_alisases.clear() @@ -215,9 +202,7 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - device = map_location or shared.weight_load_location - if device is None: - device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" + device = map_location or shared.weight_load_location or devices.get_optimal_device_name() pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) @@ -229,60 +214,74 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo): +def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): + sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + + if checkpoint_info in checkpoints_loaded: + # use checkpoint cache + print(f"Loading weights [{sd_model_hash}] from cache") + return checkpoints_loaded[checkpoint_info] + + print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + res = read_state_dict(checkpoint_info.filename) + timer.record("load weights from disk") + + return res + + +def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): title = checkpoint_info.title sd_model_hash = checkpoint_info.calculate_shorthash() + timer.record("calculate hash") + if checkpoint_info.title != title: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title - cache_enabled = shared.opts.sd_checkpoint_cache > 0 + if state_dict is None: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if cache_enabled and checkpoint_info in checkpoints_loaded: - # use checkpoint cache - print(f"Loading weights [{sd_model_hash}] from cache") - model.load_state_dict(checkpoints_loaded[checkpoint_info]) - else: - # load from file - print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") + model.load_state_dict(state_dict, strict=False) + del state_dict + timer.record("apply weights to model") - sd = read_state_dict(checkpoint_info.filename) - model.load_state_dict(sd, strict=False) - del sd - - if cache_enabled: - # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + if shared.opts.sd_checkpoint_cache > 0: + # cache newly loaded model + checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + + if shared.cmd_opts.opt_channelslast: + model.to(memory_format=torch.channels_last) + timer.record("apply channels_last") - if shared.cmd_opts.opt_channelslast: - model.to(memory_format=torch.channels_last) + if not shared.cmd_opts.no_half: + vae = model.first_stage_model + depth_model = getattr(model, 'depth_model', None) - if not shared.cmd_opts.no_half: - vae = model.first_stage_model - depth_model = getattr(model, 'depth_model', None) + # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 + if shared.cmd_opts.no_half_vae: + model.first_stage_model = None + # with --upcast-sampling, don't convert the depth model weights to float16 + if shared.cmd_opts.upcast_sampling and depth_model: + model.depth_model = None - # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16 - if shared.cmd_opts.no_half_vae: - model.first_stage_model = None - # with --upcast-sampling, don't convert the depth model weights to float16 - if shared.cmd_opts.upcast_sampling and depth_model: - model.depth_model = None + model.half() + model.first_stage_model = vae + if depth_model: + model.depth_model = depth_model - model.half() - model.first_stage_model = vae - if depth_model: - model.depth_model = depth_model + timer.record("apply half()") - devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 - devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - devices.dtype_unet = model.model.diffusion_model.dtype - devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 + devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 + devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 + devices.dtype_unet = model.model.diffusion_model.dtype + devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 - model.first_stage_model.to(devices.dtype_vae) + model.first_stage_model.to(devices.dtype_vae) + timer.record("apply dtype to VAE") # clean up cache if limit is reached - if cache_enabled: - while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache + 1: # we need to count the current model - checkpoints_loaded.popitem(last=False) # LRU + while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: + checkpoints_loaded.popitem(last=False) model.sd_model_hash = sd_model_hash model.sd_model_checkpoint = checkpoint_info.filename @@ -295,6 +294,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_vae.clear_loaded_vae() vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) sd_vae.load_vae(model, vae_file, vae_source) + timer.record("load VAE") def enable_midas_autodownload(): @@ -340,24 +340,20 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper -class Timer: - def __init__(self): - self.start = time.time() +def repair_config(sd_config): - def elapsed(self): - end = time.time() - res = end - self.start - self.start = end - return res + if not hasattr(sd_config.model.params, "use_ema"): + sd_config.model.params.use_ema = False + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True -def load_model(checkpoint_info=None): + +def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() - checkpoint_config = find_checkpoint_config(checkpoint_info) - - if checkpoint_config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -365,38 +361,27 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_config) - - if should_hijack_inpainting(checkpoint_info): - # Hardcoded config for now... - sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.unet_config.params.in_channels = 9 - sd_config.model.params.finetune_keys = None - - if should_hijack_ip2p(checkpoint_info): - sd_config.model.target = "modules.models.diffusion.ddpm_edit.LatentDiffusion" - sd_config.model.params.conditioning_key = "hybrid" - sd_config.model.params.first_stage_key = "edited" - sd_config.model.params.cond_stage_key = "edit" - sd_config.model.params.image_size = 16 - sd_config.model.params.unet_config.params.in_channels = 8 - sd_config.model.params.unet_config.params.out_channels = 4 + do_inpainting_hijack() - if not hasattr(sd_config.model.params, "use_ema"): - sd_config.model.params.use_ema = False + timer = Timer() - do_inpainting_hijack() + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - timer = Timer() + timer.record("find config") - sd_model = None + sd_config = OmegaConf.load(checkpoint_config) + repair_config(sd_config) + + timer.record("load config") + + print(f"Creating model from config: {checkpoint_config}") + sd_model = None try: with sd_disable_initialization.DisableInitialization(): sd_model = instantiate_from_config(sd_config.model) @@ -407,29 +392,35 @@ def load_model(checkpoint_info=None): print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) sd_model = instantiate_from_config(sd_config.model) - elapsed_create = timer.elapsed() + sd_model.used_config = checkpoint_config - load_model_weights(sd_model, checkpoint_info) + timer.record("create model") - elapsed_load_weights = timer.elapsed() + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) else: sd_model.to(shared.device) + timer.record("move model to device") + sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + sd_model.eval() shared.sd_model = sd_model sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + timer.record("load textual inversion embeddings") + script_callbacks.model_loaded_callback(sd_model) - elapsed_the_rest = timer.elapsed() + timer.record("scripts callbacks") - print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).") + print(f"Model loaded in {timer.summary()}.") return sd_model @@ -440,6 +431,7 @@ def reload_model_weights(sd_model=None, info=None): if not sd_model: sd_model = shared.sd_model + if sd_model is None: # previous model load failed current_checkpoint_info = None else: @@ -447,14 +439,6 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - checkpoint_config = find_checkpoint_config(current_checkpoint_info) - - if current_checkpoint_info is None or checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info) or should_hijack_ip2p(checkpoint_info) != should_hijack_ip2p(sd_model.sd_checkpoint_info): - del sd_model - checkpoints_loaded.clear() - load_model(checkpoint_info) - return shared.sd_model - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -464,21 +448,35 @@ def reload_model_weights(sd_model=None, info=None): timer = Timer() + state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + + checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) + + timer.record("find config") + + if sd_model is None or checkpoint_config != sd_model.used_config: + del sd_model + checkpoints_loaded.clear() + load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"]) + return shared.sd_model + try: - load_model_weights(sd_model, checkpoint_info) + load_model_weights(sd_model, checkpoint_info, state_dict, timer) except Exception as e: print("Failed to load checkpoint, restoring previous") - load_model_weights(sd_model, current_checkpoint_info) + load_model_weights(sd_model, current_checkpoint_info, None, timer) raise finally: sd_hijack.model_hijack.hijack(sd_model) + timer.record("hijack") + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) + timer.record("move model to device") - elapsed = timer.elapsed() - - print(f"Weights loaded in {elapsed:.1f}s.") + print(f"Weights loaded in {timer.summary()}.") return sd_model -- cgit v1.2.3 From 6f31d2210c189f8db118e6f95add7ba2a64f0238 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 27 Jan 2023 11:54:19 +0300 Subject: support detecting midas model fix broken api for checkpoint list --- modules/sd_models.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index fa208728..37dad18d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -439,12 +439,12 @@ def reload_model_weights(sd_model=None, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(sd_model) + sd_hijack.model_hijack.undo_hijack(sd_model) timer = Timer() -- cgit v1.2.3 From 5eee2ac39863f9e44591b50d0710dd2615416a13 Mon Sep 17 00:00:00 2001 From: Max Audron Date: Wed, 25 Jan 2023 17:15:42 +0100 Subject: add data-dir flag and set all user data directories based on it --- modules/sd_models.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 37dad18d..b2d48a51 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -12,13 +12,13 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config from modules.paths import models_path from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) +model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) checkpoints_list = {} checkpoint_alisases = {} @@ -307,7 +307,7 @@ def enable_midas_autodownload(): location automatically. """ - midas_path = os.path.join(models_path, 'midas') + midas_path = os.path.join(paths.models_path, 'midas') # stable-diffusion-stability-ai hard-codes the midas model path to # a location that differs from where other scripts using this model look. -- cgit v1.2.3