From 5ba04f9ec050a66e918571f07e8863f157f05b44 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Wed, 21 Dec 2022 13:45:58 +0100 Subject: Attempting to solve slow loads for `safetensors`. Fixes #5893 --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ecdd91c5..cd938656 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -168,7 +168,10 @@ def get_state_dict_from_checkpoint(pl_sd): def read_state_dict(checkpoint_file, print_global_state=False, map_location=None): _, extension = os.path.splitext(checkpoint_file) if extension.lower() == ".safetensors": - pl_sd = safetensors.torch.load_file(checkpoint_file, device=map_location or shared.weight_load_location) + device = map_location or shared.weight_load_location + if device is None: + device = "cuda:0" if torch.cuda.is_available() else "cpu" + pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.3 From f55ac33d446185680604e872ceda2ae858821d5c Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Sat, 31 Dec 2022 11:27:02 -0500 Subject: validate textual inversion embeddings --- modules/sd_models.py | 3 ++ modules/textual_inversion/textual_inversion.py | 43 +++++++++++++++++++++++--- modules/ui.py | 2 -- 3 files changed, 41 insertions(+), 7 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ecdd91c5..ebd4dff7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -325,6 +325,9 @@ def load_model(checkpoint_info=None): script_callbacks.model_loaded_callback(sd_model) print("Model loaded.") + + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model + return sd_model diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index f6112578..103ace60 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -23,6 +23,8 @@ class Embedding: self.vec = vec self.name = name self.step = step + self.shape = None + self.vectors = 0 self.cached_checksum = None self.sd_checkpoint = None self.sd_checkpoint_name = None @@ -57,8 +59,10 @@ class EmbeddingDatabase: def __init__(self, embeddings_dir): self.ids_lookup = {} self.word_embeddings = {} + self.skipped_embeddings = [] self.dir_mtime = None self.embeddings_dir = embeddings_dir + self.expected_shape = -1 def register_embedding(self, embedding, model): @@ -75,14 +79,35 @@ class EmbeddingDatabase: return embedding - def load_textual_inversion_embeddings(self): + def get_expected_shape(self): + expected_shape = -1 # initialize with unknown + idx = torch.tensor(0).to(shared.device) + if expected_shape == -1: + try: # matches sd15 signature + first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx) + expected_shape = first_embedding.shape[0] + except: + pass + if expected_shape == -1: + try: # matches sd20 signature + first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx) + expected_shape = first_embedding.shape[0] + except: + pass + if expected_shape == -1: + print('Could not determine expected embeddings shape from model') + return expected_shape + + def load_textual_inversion_embeddings(self, force_reload = False): mt = os.path.getmtime(self.embeddings_dir) - if self.dir_mtime is not None and mt <= self.dir_mtime: + if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime: return self.dir_mtime = mt self.ids_lookup.clear() self.word_embeddings.clear() + self.skipped_embeddings = [] + self.expected_shape = self.get_expected_shape() def process_file(path, filename): name = os.path.splitext(filename)[0] @@ -122,7 +147,14 @@ class EmbeddingDatabase: embedding.step = data.get('step', None) embedding.sd_checkpoint = data.get('sd_checkpoint', None) embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) - self.register_embedding(embedding, shared.sd_model) + embedding.vectors = vec.shape[0] + embedding.shape = vec.shape[-1] + + if (self.expected_shape == -1) or (self.expected_shape == embedding.shape): + self.register_embedding(embedding, shared.sd_model) + else: + self.skipped_embeddings.append(name) + # print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape)) for fn in os.listdir(self.embeddings_dir): try: @@ -137,8 +169,9 @@ class EmbeddingDatabase: print(traceback.format_exc(), file=sys.stderr) continue - print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") - print("Embeddings:", ', '.join(self.word_embeddings.keys())) + print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys()))) + if (len(self.skipped_embeddings) > 0): + print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings))) def find_embedding_at_position(self, tokens, offset): token = tokens[offset] diff --git a/modules/ui.py b/modules/ui.py index 57ee0465..397dd804 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1157,8 +1157,6 @@ def create_ui(): with gr.Column(variant='panel'): submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False) - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Row().style(equal_height=False): gr.HTML(value="

See wiki for detailed explanation.

") -- cgit v1.2.3 From 311354c0bb8930ea939d6aa6b3edd50c69301320 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 00:38:09 +0300 Subject: fix the issue with training on SD2.0 --- modules/sd_models.py | 2 ++ modules/textual_inversion/textual_inversion.py | 3 +-- 2 files changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ebd4dff7..bff8d6c9 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -228,6 +228,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + model.logvar = model.logvar.to(devices.device) # fix for training + sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 66f40367..1e5722e7 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -282,7 +282,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ return embedding, filename scheduler = LearnRateScheduler(learn_rate, steps, initial_step) - # dataset loading may take a while, so input validations and early returns should be done before this + # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." old_parallel_processing_allowed = shared.parallel_processing_allowed @@ -310,7 +310,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ loss_step = 0 _loss_step = 0 #internal - last_saved_file = "" last_saved_image = "" forced_filename = "" -- cgit v1.2.3 From 8f96f9289981a66741ba770d14f3d27ce335a0fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 3 Jan 2023 18:39:14 +0300 Subject: call script callbacks for reloaded model after loading embeddings --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index bff8d6c9..b98b05fc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -324,12 +324,12 @@ def load_model(checkpoint_info=None): sd_model.eval() shared.sd_model = sd_model + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model + script_callbacks.model_loaded_callback(sd_model) print("Model loaded.") - sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload = True) # Reload embeddings after model load as they may or may not fit the model - return sd_model -- cgit v1.2.3 From 02d7abf5141431b9a3a8a189bb3136c71abd5e79 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:35:07 +0300 Subject: helpful error message when trying to load 2.0 without config failing to load model weights from settings won't break generation for currently loaded model anymore --- modules/errors.py | 25 +++++++++++++++++++++++-- modules/sd_models.py | 26 ++++++++++++++++++-------- modules/shared.py | 9 +++++++-- webui.py | 12 ++++++++++-- 4 files changed, 58 insertions(+), 14 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/errors.py b/modules/errors.py index 372dc51a..a668c014 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -2,9 +2,30 @@ import sys import traceback +def print_error_explanation(message): + lines = message.strip().split("\n") + max_len = max([len(x) for x in lines]) + + print('=' * max_len, file=sys.stderr) + for line in lines: + print(line, file=sys.stderr) + print('=' * max_len, file=sys.stderr) + + +def display(e: Exception, task): + print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + message = str(e) + if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message: + print_error_explanation(""" +The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file. +See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this. + """) + + def run(code, task): try: code() except Exception as e: - print(f"{task}: {type(e).__name__}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) + display(task, e) diff --git a/modules/sd_models.py b/modules/sd_models.py index b98b05fc..6846b74a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -278,6 +278,7 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -312,6 +313,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.use_fp16 = False sd_model = instantiate_from_config(sd_config.model) + load_model_weights(sd_model, checkpoint_info) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: @@ -336,10 +338,12 @@ def load_model(checkpoint_info=None): def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() - + if not sd_model: sd_model = shared.sd_model + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return @@ -356,13 +360,19 @@ def reload_model_weights(sd_model=None, info=None): sd_hijack.model_hijack.undo_hijack(sd_model) - load_model_weights(sd_model, checkpoint_info) - - sd_hijack.model_hijack.hijack(sd_model) - script_callbacks.model_loaded_callback(sd_model) - - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: - sd_model.to(devices.device) + try: + load_model_weights(sd_model, checkpoint_info) + except Exception as e: + print("Failed to load checkpoint, restoring previous") + load_model_weights(sd_model, current_checkpoint_info) + raise + finally: + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) print("Weights loaded.") + return sd_model diff --git a/modules/shared.py b/modules/shared.py index 23657a93..7588c47b 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -14,7 +14,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, sd_vae, extensions, script_loading +from modules import localization, sd_vae, extensions, script_loading, errors from modules.paths import models_path, script_path, sd_path @@ -494,7 +494,12 @@ class Options: return False if self.data_labels[key].onchange is not None: - self.data_labels[key].onchange() + try: + self.data_labels[key].onchange() + except Exception as e: + errors.display(e, f"changing setting {key} to {value}") + setattr(self, key, oldval) + return False return True diff --git a/webui.py b/webui.py index c7d55a97..13375e71 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook +from modules import import_hook, errors from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path @@ -61,7 +61,15 @@ def initialize(): modelloader.load_upscalers() modules.sd_vae.refresh_vae_list() - modules.sd_models.load_model() + + try: + modules.sd_models.load_model() + except Exception as e: + errors.display(e, "loading stable diffusion model") + print("", file=sys.stderr) + print("Stable diffusion model failed to load, exiting", file=sys.stderr) + exit(1) + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) -- cgit v1.2.3 From 8d8a05a3bbb50fdfeab51679a919d2487bd97976 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 12:47:42 +0300 Subject: find configs for models at runtime rather than when starting --- modules/sd_hijack_inpainting.py | 5 ++++- modules/sd_models.py | 31 ++++++++++++++++++------------- 2 files changed, 22 insertions(+), 14 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index 3c214a35..31d2c898 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F def should_hijack_inpainting(checkpoint_info): + from modules import sd_models + ckpt_basename = os.path.basename(checkpoint_info.filename).lower() - cfg_basename = os.path.basename(checkpoint_info.config).lower() + cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower() + return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename diff --git a/modules/sd_models.py b/modules/sd_models.py index 6846b74a..6dca4ddf 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) -CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config']) +CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name']) checkpoints_list = {} checkpoints_loaded = collections.OrderedDict() @@ -48,6 +48,14 @@ def checkpoint_tiles(): return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key) +def find_checkpoint_config(info): + config = os.path.splitext(info.filename)[0] + ".yaml" + if os.path.exists(config): + return config + + return shared.cmd_opts.config + + def list_models(): checkpoints_list.clear() model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"]) @@ -73,7 +81,7 @@ def list_models(): if os.path.exists(cmd_ckpt): h = model_hash(cmd_ckpt) title, short_model_name = modeltitle(cmd_ckpt, h) - checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config) + checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name) shared.opts.data['sd_model_checkpoint'] = title elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) @@ -81,12 +89,7 @@ def list_models(): h = model_hash(filename) title, short_model_name = modeltitle(filename, h) - basename, _ = os.path.splitext(filename) - config = basename + ".yaml" - if not os.path.exists(config): - config = shared.cmd_opts.config - - checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config) + checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name) def get_closet_checkpoint_match(searchString): @@ -282,9 +285,10 @@ def enable_midas_autodownload(): def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + checkpoint_config = find_checkpoint_config(checkpoint_info) - if checkpoint_info.config != shared.cmd_opts.config: - print(f"Loading config from: {checkpoint_info.config}") + if checkpoint_config != shared.cmd_opts.config: + print(f"Loading config from: {checkpoint_config}") if shared.sd_model: sd_hijack.model_hijack.undo_hijack(shared.sd_model) @@ -292,7 +296,7 @@ def load_model(checkpoint_info=None): gc.collect() devices.torch_gc() - sd_config = OmegaConf.load(checkpoint_info.config) + sd_config = OmegaConf.load(checkpoint_config) if should_hijack_inpainting(checkpoint_info): # Hardcoded config for now... @@ -302,7 +306,7 @@ def load_model(checkpoint_info=None): sd_config.model.params.finetune_keys = None # Create a "fake" config with a different name so that we know to unload it when switching models. - checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml")) if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False @@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None): sd_model = shared.sd_model current_checkpoint_info = sd_model.sd_checkpoint_info + checkpoint_config = find_checkpoint_config(current_checkpoint_info) if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): + if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): del sd_model checkpoints_loaded.clear() load_model(checkpoint_info) -- cgit v1.2.3 From 0cd6399b8b1699b8b7acad6f0ad2988111fe618e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 14:29:13 +0300 Subject: fix broken inpainting model --- modules/sd_models.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6dca4ddf..a568823d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -305,9 +305,6 @@ def load_model(checkpoint_info=None): sd_config.model.params.unet_config.params.in_channels = 9 sd_config.model.params.finetune_keys = None - # Create a "fake" config with a different name so that we know to unload it when switching models. - checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml")) - if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False -- cgit v1.2.3 From 642142556d8ecdea9beb86d7618b628b1803ab98 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 15:09:53 +0300 Subject: use commandline-supplied cuda device name instead of cuda:0 for safetensors PR that doesn't fix anything --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index ee918f24..76a89e88 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -173,7 +173,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None if extension.lower() == ".safetensors": device = map_location or shared.weight_load_location if device is None: - device = "cuda:0" if torch.cuda.is_available() else "cpu" + device = devices.get_cuda_device_string() if torch.cuda.is_available() else "cpu" pl_sd = safetensors.torch.load_file(checkpoint_file, device=device) else: pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location) -- cgit v1.2.3