From 0a89cd1a584b1584a0609c0ba27fb35c434b0b68 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 24 Jul 2023 22:08:08 +0300 Subject: Use less RAM when creating models --- modules/sd_models.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index fb31a793..acb1e817 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -460,7 +460,6 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) - def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): - sd_model = instantiate_from_config(sd_config.model) - except Exception: - pass + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + + except Exception as e: + errors.display(e, "creating model quickly", full_traceback=True) if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) - sd_model = instantiate_from_config(sd_config.model) + + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config timer.record("create model") - load_model_weights(sd_model, checkpoint_info, state_dict, timer) + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) -- cgit v1.2.3 From 3bca90b249d749ed5429f76e380d2ffa52fc0d41 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 30 Jul 2023 13:48:27 +0300 Subject: hires fix checkpoint selection --- modules/sd_models.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index acb1e817..cb67e425 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -52,6 +52,7 @@ class CheckpointInfo: self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) @@ -81,6 +82,7 @@ class CheckpointInfo: checkpoints_list.pop(self.title) self.title = f'{self.name} [{self.shorthash}]' + self.short_title = f'{self.name_for_extra} [{self.shorthash}]' self.register() return self.shorthash @@ -101,14 +103,8 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - def convert(name): - return int(name) if name.isdigit() else name.lower() - - def alphanumeric_key(key): - return [convert(c) for c in re.split('([0-9]+)', key)] - - return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) +def checkpoint_tiles(use_short=False): + return [x.short_title if use_short else x.title for x in checkpoints_list.values()] def list_models(): @@ -131,11 +127,14 @@ def list_models(): elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - for filename in sorted(model_list, key=str.lower): + for filename in model_list: checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() +re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") + + def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: @@ -145,6 +144,11 @@ def get_closet_checkpoint_match(search_string): if found: return found[0] + search_string_without_checksum = re.sub(re_strip_checksum, '', search_string) + found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + return None -- cgit v1.2.3 From 4d9b096663288e2aa738723fa63950f3d41f6170 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 31 Jul 2023 10:43:31 +0300 Subject: additional memory improvements when switching between models of different types --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index cb67e425..4855037a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -582,7 +582,10 @@ def reload_model_weights(sd_model=None, info=None): timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: - del sd_model + if sd_model is not None: + sd_model.to(device="meta") + + devices.torch_gc() load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model -- cgit v1.2.3 From b235022c615a7384f73c05fe240d8f4a28d103d4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 00:24:48 +0300 Subject: option to keep multiple models in memory --- modules/sd_models.py | 136 ++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 112 insertions(+), 24 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index acb1e817..77195f2f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,7 +15,6 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl -from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -423,6 +422,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: def __init__(self): self.sd_model = None + self.loaded_sd_models = [] self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -437,6 +437,7 @@ class SdModelData: try: load_model() + except Exception as e: errors.display(e, "loading stable diffusion model", full_traceback=True) print("", file=sys.stderr) @@ -448,11 +449,24 @@ class SdModelData: def set_sd_model(self, v): self.sd_model = v + try: + self.loaded_sd_models.remove(v) + except ValueError: + pass + + if v is not None: + self.loaded_sd_models.insert(0, v) + model_data = SdModelData() def get_empty_cond(sd_model): + from modules import extra_networks, processing + + p = processing.StableDiffusionProcessingTxt2Img() + extra_networks.activate(p, {}) + if hasattr(sd_model, 'conditioner'): d = sd_model.get_learned_conditioning([""]) return d['crossattn'] @@ -460,19 +474,43 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) +def send_model_to_cpu(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + m.to(devices.cpu) + + devices.torch_gc() + + +def send_model_to_device(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) + else: + m.to(shared.device) + + +def send_model_to_trash(m): + m.to(device="meta") + devices.torch_gc() + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): - from modules import lowvram, sd_hijack + from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + timer = Timer() + if model_data.sd_model: - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + send_model_to_trash(model_data.sd_model) model_data.sd_model = None - gc.collect() devices.torch_gc() - do_inpainting_hijack() - - timer = Timer() + timer.record("unload existing model") if already_loaded_state_dict is not None: state_dict = already_loaded_state_dict @@ -512,12 +550,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): load_model_weights(sd_model, checkpoint_info, state_dict, timer) + timer.record("load weights from state dict") - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) - else: - sd_model.to(shared.device) - + send_model_to_device(sd_model) timer.record("move model to device") sd_hijack.model_hijack.hijack(sd_model) @@ -525,7 +560,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("hijack") sd_model.eval() - model_data.sd_model = sd_model + model_data.set_sd_model(sd_model) model_data.was_loaded_at_least_once = True sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model @@ -546,10 +581,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): return sd_model +def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): + """ + Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models. + If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary). + If not, returns the model that can be used to load weights from checkpoint_info's file. + If no such model exists, returns None. + Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit). + """ + + already_loaded = None + for i in reversed(range(len(model_data.loaded_sd_models))): + loaded_model = model_data.loaded_sd_models[i] + if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename: + already_loaded = loaded_model + continue + + if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: + print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") + model_data.loaded_sd_models.pop() + send_model_to_trash(loaded_model) + timer.record("send model to trash") + + if shared.opts.sd_checkpoints_keep_in_cpu: + send_model_to_cpu(sd_model) + timer.record("send model to cpu") + + if already_loaded is not None: + send_model_to_device(already_loaded) + timer.record("send model to device") + + model_data.set_sd_model(already_loaded) + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") + return model_data.sd_model + elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: + print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})") + + model_data.sd_model = None + load_model(checkpoint_info) + return model_data.sd_model + elif len(model_data.loaded_sd_models) > 0: + sd_model = model_data.loaded_sd_models.pop() + model_data.sd_model = sd_model + + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") + return sd_model + else: + return None + + def reload_model_weights(sd_model=None, info=None): - from modules import lowvram, devices, sd_hijack + from modules import devices, sd_hijack checkpoint_info = info or select_checkpoint() + timer = Timer() + if not sd_model: sd_model = model_data.sd_model @@ -558,19 +644,17 @@ def reload_model_weights(sd_model=None, info=None): else: current_checkpoint_info = sd_model.sd_checkpoint_info if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return - - sd_unet.apply_unet("None") + return sd_model - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) + sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) + if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + return sd_model + if sd_model is not None: + sd_unet.apply_unet("None") + send_model_to_cpu(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model) - timer = Timer() - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) @@ -578,7 +662,9 @@ def reload_model_weights(sd_model=None, info=None): timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: - del sd_model + if sd_model is not None: + send_model_to_trash(sd_model) + load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model @@ -601,6 +687,8 @@ def reload_model_weights(sd_model=None, info=None): print(f"Weights loaded in {timer.summary()}.") + model_data.set_sd_model(sd_model) + return sd_model -- cgit v1.2.3 From 4b43480fe8b65a3bd24dc9bc03a7e910c9b0314f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 07:08:11 +0300 Subject: show metadata for SD checkpoints in the extra networks UI --- modules/sd_models.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index acb1e817..1af7fd78 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -33,6 +33,8 @@ class CheckpointInfo: self.filename = filename abspath = os.path.abspath(filename) + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): name = abspath.replace(shared.cmd_opts.ckpt_dir, '') elif abspath.startswith(model_path): @@ -43,6 +45,19 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"): name = name[1:] + def read_metadata(): + metadata = read_metadata_from_safetensors(filename) + self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None) + + return metadata + + self.metadata = {} + if self.is_safetensors: + try: + self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata) + except Exception as e: + errors.display(e, f"reading metadata for {filename}") + self.name = name self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] @@ -55,15 +70,6 @@ class CheckpointInfo: self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) - self.metadata = {} - - _, ext = os.path.splitext(self.filename) - if ext.lower() == ".safetensors": - try: - self.metadata = read_metadata_from_safetensors(filename) - except Exception as e: - errors.display(e, f"reading checkpoint metadata: {filename}") - def register(self): checkpoints_list[self.title] = self for id in self.ids: -- cgit v1.2.3 From 07be13caa357b14f6afa247566d53339522b8e66 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 08:27:54 +0300 Subject: add metadata to checkpoint merger --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 1af7fd78..8f72f21d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -85,7 +85,7 @@ class CheckpointInfo: if self.shorthash not in self.ids: self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] - checkpoints_list.pop(self.title) + checkpoints_list.pop(self.title, None) self.title = f'{self.name} [{self.shorthash}]' self.register() -- cgit v1.2.3 From 390bffa81b747a7eb38ac7a0cd6dfb9fcc388151 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 17:13:15 +0300 Subject: repair merge error --- modules/sd_models.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 40a450df..3c451a4b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,7 +15,6 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache -from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd -- cgit v1.2.3 From 20549a50cb3c41868ce561c6658bfaa0d20ac7ba Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 3 Aug 2023 22:46:57 +0300 Subject: add style editor dialog rework toprow for img2img and txt2img to use a class with fields fix the console error when editing checkpoint user metadata --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 8f72f21d..1d93d893 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -68,7 +68,7 @@ class CheckpointInfo: self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' - self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) + self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self -- cgit v1.2.3 From 24f21583cdba2ae6cc51773b956c6ce068d3dfe4 Mon Sep 17 00:00:00 2001 From: AnyISalIn Date: Fri, 4 Aug 2023 11:43:27 +0800 Subject: fix: prevent cache model.state_dict() after model hijack Signed-off-by: AnyISalIn --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 1d93d893..ba15b451 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -303,12 +303,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_models_xl.extend_sdxl(model) model.load_state_dict(state_dict, strict=False) - del state_dict timer.record("apply weights to model") if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + checkpoints_loaded[checkpoint_info] = state_dict + + del state_dict if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) -- cgit v1.2.3 From f1975b0213f5be400889ec04b3891d1cb571fe20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 17:01:07 +0300 Subject: initial refiner support --- modules/sd_models.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6051604..981aa93d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -289,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): return res +class SkipWritingToConfig: + """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight.""" + + skip = False + previous = None + + def __enter__(self): + self.previous = SkipWritingToConfig.skip + SkipWritingToConfig.skip = True + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + SkipWritingToConfig.skip = self.previous + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) -- cgit v1.2.3 From c96e4750d895a47290dc7f96e030197069c75fa4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 7 Aug 2023 08:07:09 +0300 Subject: SD VAE rework 2 - the setting for preferring opts.sd_vae has been inverted and reworded - resolve_vae function made easier to read and now returns an object rather than a tuple - if the checkbox for overriding per-model preferences is checked, opts.sd_vae overrides checkpoint user metadata - changing VAE in user metadata for currently loaded model immediately applies the selection --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6051604..d65735e3 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -356,7 +356,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename).tuple() sd_vae.load_vae(model, vae_file, vae_source) timer.record("load VAE") -- cgit v1.2.3 From 6e7828e1d271c644840047c3db60e669a232402a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 7 Aug 2023 08:16:20 +0300 Subject: apply unet overrides after switching model --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index d65735e3..53c1df54 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -699,6 +699,7 @@ def reload_model_weights(sd_model=None, info=None): print(f"Weights loaded in {timer.summary()}.") model_data.set_sd_model(sd_model) + sd_unet.apply_unet() return sd_model -- cgit v1.2.3 From 8c200c21564992b7af1d2d692524051158e533db Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Tue, 8 Aug 2023 10:48:03 +0900 Subject: Fix mismatch between shared.sd_model & shared.opts --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 53c1df54..cded27d4 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -623,6 +623,8 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): timer.record("send model to device") model_data.set_sd_model(already_loaded) + shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title + shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") return model_data.sd_model elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: -- cgit v1.2.3 From 386245a26427a64f364f66f6fecd03b3bccfd7f3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 9 Aug 2023 10:25:35 +0300 Subject: split shared.py into multiple files; should resolve all circular reference import errors related to shared.py --- modules/sd_models.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 53c1df54..88a09899 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack from modules.timer import Timer import tomesd @@ -473,7 +473,6 @@ model_data = SdModelData() def get_empty_cond(sd_model): - from modules import extra_networks, processing p = processing.StableDiffusionProcessingTxt2Img() extra_networks.activate(p, {}) @@ -486,8 +485,6 @@ def get_empty_cond(sd_model): def send_model_to_cpu(m): - from modules import lowvram - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.send_everything_to_cpu() else: @@ -497,8 +494,6 @@ def send_model_to_cpu(m): def send_model_to_device(m): - from modules import lowvram - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) else: @@ -642,7 +637,6 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): def reload_model_weights(sd_model=None, info=None): - from modules import devices, sd_hijack checkpoint_info = info or select_checkpoint() timer = Timer() @@ -705,7 +699,6 @@ def reload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None): - from modules import devices, sd_hijack timer = Timer() if model_data.sd_model: -- cgit v1.2.3 From aa10faa591f1ca0bd93ae3d53a0a4c15a3fbaf82 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 9 Aug 2023 14:47:44 +0300 Subject: fix checkpoint name jumping around in the list of checkpoints for no good reason --- modules/sd_models.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b490fa99..7a866a07 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -68,7 +68,9 @@ class CheckpointInfo: self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' - self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) + self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + if self.shorthash: + self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]'] def register(self): checkpoints_list[self.title] = self @@ -80,10 +82,14 @@ class CheckpointInfo: if self.sha256 is None: return - self.shorthash = self.sha256[0:10] + shorthash = self.sha256[0:10] + if self.shorthash == self.sha256[0:10]: + return self.shorthash + + self.shorthash = shorthash if self.shorthash not in self.ids: - self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] + self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]'] checkpoints_list.pop(self.title, None) self.title = f'{self.name} [{self.shorthash}]' -- cgit v1.2.3 From ac8a5d18d3ede6bcb8fa5a3da1c7c28e064cd65d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 10 Aug 2023 17:04:59 +0300 Subject: resolve merge issues --- modules/sd_models.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6cb2f34..a178adca 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -640,8 +640,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): timer.record("send model to device") model_data.set_sd_model(already_loaded) - shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title - shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 + + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title + shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") return model_data.sd_model elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: -- cgit v1.2.3 From 64311faa6848d641cc452115e4e1eb47d2a7b519 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 12 Aug 2023 12:39:59 +0300 Subject: put refiner into main UI, into the new accordions section add VAE from main model into infotext, not from refiner model option to make scripts UI without gr.Group fix inconsistencies with refiner when usings samplers that do more denoising than steps --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index a178adca..f6fbdcd6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,6 +147,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") def get_closet_checkpoint_match(search_string): + if not search_string: + return None + checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: return checkpoint_info -- cgit v1.2.3 From 0815c45bcdec0a2e5c60bdd5b33d95813d799c01 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 10:44:17 +0300 Subject: send weights to target device instead of CPU memory --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd6..b01d44c5 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -579,7 +579,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.device): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") -- cgit v1.2.3 From 57e59c14c8a13a99d6422597d27d92ad10a51ca1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 11:28:00 +0300 Subject: Revert "send weights to target device instead of CPU memory" This reverts commit 0815c45bcdec0a2e5c60bdd5b33d95813d799c01. --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b01d44c5..f6fbdcd6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -579,7 +579,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.device): + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") -- cgit v1.2.3 From eaba3d7349c6f0e151be66ade3fdc848d693a10d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 16 Aug 2023 12:11:01 +0300 Subject: send weights to target device instead of CPU memory --- modules/sd_models.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f6fbdcd6..f912fe16 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -518,6 +518,13 @@ def send_model_to_cpu(m): devices.torch_gc() +def model_target_device(): + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + return devices.cpu + else: + return devices.device + + def send_model_to_device(m): if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) @@ -579,7 +586,15 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("create model") - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + if shared.cmd_opts.no_half: + weight_dtype_conversion = None + else: + weight_dtype_conversion = { + 'first_stage_model': None, + '': torch.float16, + } + + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") -- cgit v1.2.3 From 0dc74545c0b5510911757ed9f2be703aab58f014 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 17 Aug 2023 07:54:07 +0300 Subject: resolve the issue with loading fp16 checkpoints while using --no-half --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index f912fe16..685585b1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -343,7 +343,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.to(memory_format=torch.channels_last) timer.record("apply channels_last") - if not shared.cmd_opts.no_half: + if shared.cmd_opts.no_half: + model.float() + timer.record("apply float()") + else: vae = model.first_stage_model depth_model = getattr(model, 'depth_model', None) -- cgit v1.2.3 From 042e1d5d0b1fc0bfd358e3a90db7d163934bd238 Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sun, 20 Aug 2023 15:00:14 +0900 Subject: Fix SD VAE switch error after model reuse --- modules/sd_models.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 685585b1..2c976561 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -462,6 +462,7 @@ class SdModelData: def __init__(self): self.sd_model = None self.loaded_sd_models = [] + self.loaded_vae_states = {} self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -485,16 +486,27 @@ class SdModelData: return self.sd_model - def set_sd_model(self, v): + def set_sd_model(self, v, already_loaded=False): self.sd_model = v + if already_loaded: + sd_vae_state = self.loaded_vae_states.get(v.sd_model_hash, {}) + sd_vae.base_vae = sd_vae_state.get("base_vae", None) + sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) + sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) try: self.loaded_sd_models.remove(v) + self.loaded_vae_states.pop(v.sd_model_hash, {}).clear() except ValueError: pass if v is not None: self.loaded_sd_models.insert(0, v) + self.loaded_vae_states[v.sd_model_hash] = dict( + base_vae=sd_vae.base_vae, + loaded_vae_file=sd_vae.loaded_vae_file, + checkpoint_info=sd_vae.checkpoint_info, + ) model_data = SdModelData() @@ -649,6 +661,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") model_data.loaded_sd_models.pop() + model_data.loaded_vae_states.pop(loaded_model.sd_model_hash, {}).clear() send_model_to_trash(loaded_model) timer.record("send model to trash") @@ -660,7 +673,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): send_model_to_device(already_loaded) timer.record("send model to device") - model_data.set_sd_model(already_loaded) + model_data.set_sd_model(already_loaded, already_loaded=True) if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title @@ -678,6 +691,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): sd_model = model_data.loaded_sd_models.pop() model_data.sd_model = sd_model + sd_vae_state = model_data.loaded_vae_states.pop(sd_model.sd_model_hash, {}) + sd_vae.base_vae = sd_vae_state.get("base_vae", None) + sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) + sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") return sd_model else: -- cgit v1.2.3 From 5159edbf0e0e1d5a25fbd588e000487746790117 Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sun, 20 Aug 2023 19:44:37 +0900 Subject: Store base_vae and loaded_vae_file in sd_model --- modules/sd_models.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 2c976561..150d550b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -462,7 +462,6 @@ class SdModelData: def __init__(self): self.sd_model = None self.loaded_sd_models = [] - self.loaded_vae_states = {} self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -489,24 +488,19 @@ class SdModelData: def set_sd_model(self, v, already_loaded=False): self.sd_model = v if already_loaded: - sd_vae_state = self.loaded_vae_states.get(v.sd_model_hash, {}) - sd_vae.base_vae = sd_vae_state.get("base_vae", None) - sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) - sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) + sd_vae.base_vae = getattr(v, "base_vae", None) + sd_vae.loaded_vae_file = getattr(v, "loaded_vae_file", None) + sd_vae.checkpoint_info = v.sd_checkpoint_info try: self.loaded_sd_models.remove(v) - self.loaded_vae_states.pop(v.sd_model_hash, {}).clear() except ValueError: pass if v is not None: + setattr(v, "base_vae", sd_vae.base_vae) + setattr(v, "loaded_vae_file", sd_vae.loaded_vae_file) self.loaded_sd_models.insert(0, v) - self.loaded_vae_states[v.sd_model_hash] = dict( - base_vae=sd_vae.base_vae, - loaded_vae_file=sd_vae.loaded_vae_file, - checkpoint_info=sd_vae.checkpoint_info, - ) model_data = SdModelData() @@ -661,7 +655,6 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") model_data.loaded_sd_models.pop() - model_data.loaded_vae_states.pop(loaded_model.sd_model_hash, {}).clear() send_model_to_trash(loaded_model) timer.record("send model to trash") @@ -691,10 +684,9 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): sd_model = model_data.loaded_sd_models.pop() model_data.sd_model = sd_model - sd_vae_state = model_data.loaded_vae_states.pop(sd_model.sd_model_hash, {}) - sd_vae.base_vae = sd_vae_state.get("base_vae", None) - sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None) - sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None) + sd_vae.base_vae = getattr(sd_model, "base_vae", None) + sd_vae.loaded_vae_file = getattr(sd_model, "loaded_vae_file", None) + sd_vae.checkpoint_info = sd_model.sd_checkpoint_info print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") return sd_model -- cgit v1.2.3 From af5d2e8e5fc4440691fb7f1aa3492def1c755722 Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sun, 20 Aug 2023 20:08:22 +0900 Subject: Change to access sd_model attribute with dot --- modules/sd_models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 150d550b..dd749122 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -498,8 +498,8 @@ class SdModelData: pass if v is not None: - setattr(v, "base_vae", sd_vae.base_vae) - setattr(v, "loaded_vae_file", sd_vae.loaded_vae_file) + v.base_vae = sd_vae.base_vae + v.loaded_vae_file = sd_vae.loaded_vae_file self.loaded_sd_models.insert(0, v) -- cgit v1.2.3 From 549b0fc5267e9539f321f0891aa757619b7248cb Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Sun, 20 Aug 2023 23:06:51 +0900 Subject: Change where VAE state are stored in model --- modules/sd_models.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index dd749122..d3775ec6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -498,8 +498,6 @@ class SdModelData: pass if v is not None: - v.base_vae = sd_vae.base_vae - v.loaded_vae_file = sd_vae.loaded_vae_file self.loaded_sd_models.insert(0, v) -- cgit v1.2.3 From be301f224d26ac4363ce3bd8bcb510b00bd6db27 Mon Sep 17 00:00:00 2001 From: Uminosachi <49424133+Uminosachi@users.noreply.github.com> Date: Mon, 21 Aug 2023 11:28:53 +0900 Subject: Fix for consistency with shared.opts.sd_vae of UI --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index d3775ec6..27d15e66 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -671,6 +671,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): shared.opts.data["sd_checkpoint_hash"] = already_loaded.sd_checkpoint_info.sha256 print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") + sd_vae.reload_vae_weights(already_loaded) return model_data.sd_model elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})") -- cgit v1.2.3 From 016554e43740e0b7ded75e89255de81270de9d6c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 22 Aug 2023 18:49:08 +0300 Subject: add --medvram-sdxl --- modules/sd_models.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 27d15e66..4331853a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -517,7 +517,7 @@ def get_empty_cond(sd_model): def send_model_to_cpu(m): - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + if m.lowvram: lowvram.send_everything_to_cpu() else: m.to(devices.cpu) @@ -525,17 +525,17 @@ def send_model_to_cpu(m): devices.torch_gc() -def model_target_device(): - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: +def model_target_device(m): + if lowvram.is_needed(m): return devices.cpu else: return devices.device def send_model_to_device(m): - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) - else: + lowvram.apply(m) + + if not m.lowvram: m.to(shared.device) @@ -601,7 +601,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): '': torch.float16, } - with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion): + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion): load_model_weights(sd_model, checkpoint_info, state_dict, timer) timer.record("load weights from state dict") @@ -743,7 +743,7 @@ def reload_model_weights(sd_model=None, info=None): script_callbacks.model_loaded_callback(sd_model) timer.record("script callbacks") - if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + if not sd_model.lowvram: sd_model.to(devices.device) timer.record("move model to device") -- cgit v1.2.3 From 0232a987bb8cbf337a5b34f38f7718aef92af269 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 23 Aug 2023 07:10:43 +0300 Subject: set devices.dtype_unet correctly --- modules/sd_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 4331853a..547e93c4 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -345,6 +345,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.no_half: model.float() + devices.dtype_unet = torch.float32 timer.record("apply float()") else: vae = model.first_stage_model @@ -362,9 +363,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if depth_model: model.depth_model = depth_model + devices.dtype_unet = torch.float16 timer.record("apply half()") - devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.3