From f874b1bcad05d7ea4c3cc28df82904ac7c12e64f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 30 Aug 2023 08:54:31 +0300 Subject: keep order in list of checkpoints when loading model that doesn't have a checksum --- modules/sd_models.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 547e93c4..930d0bee 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -27,6 +27,24 @@ checkpoint_alisases = checkpoint_aliases # for compatibility with old name checkpoints_loaded = collections.OrderedDict() +def replace_key(d, key, new_key, value): + keys = list(d.keys()) + + d[new_key] = value + + if key not in keys: + return d + + index = keys.index(key) + keys[index] = new_key + + new_d = {k: d[k] for k in keys} + + d.clear() + d.update(new_d) + return d + + class CheckpointInfo: def __init__(self, filename): self.filename = filename @@ -91,9 +109,11 @@ class CheckpointInfo: if self.shorthash not in self.ids: self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]', f'{self.name_for_extra} [{self.shorthash}]'] - checkpoints_list.pop(self.title, None) + old_title = self.title self.title = f'{self.name} [{self.shorthash}]' self.short_title = f'{self.name_for_extra} [{self.shorthash}]' + + replace_key(checkpoints_list, old_title, self.title, self) self.register() return self.shorthash -- cgit v1.2.3 From e4726cccf960257e1b456db84a59f28cea019c8f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 8 Sep 2023 09:46:34 +0900 Subject: parsing string to path --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 930d0bee..9b0923de 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -49,11 +49,12 @@ class CheckpointInfo: def __init__(self, filename): self.filename = filename abspath = os.path.abspath(filename) + abs_ckpt_dir = os.path.abspath(shared.cmd_opts.ckpt_dir) if shared.cmd_opts.ckpt_dir is not None else None self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" - if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): - name = abspath.replace(shared.cmd_opts.ckpt_dir, '') + if abs_ckpt_dir and abspath.startswith(abs_ckpt_dir): + name = abspath.replace(abs_ckpt_dir, '') elif abspath.startswith(model_path): name = abspath.replace(model_path, '') else: -- cgit v1.2.3 From 813535d38bbcdd8ccc51d0618a7d9fd353677bb9 Mon Sep 17 00:00:00 2001 From: "qiuwen.wang" Date: Fri, 15 Sep 2023 18:23:23 +0800 Subject: use dict[key]=model; did not update orderdict order, should use move to end --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 930d0bee..6d17dd3c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -309,6 +309,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") + checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") -- cgit v1.2.3 From d9d94141dcfc1a84e98370bc137ffd888509b65e Mon Sep 17 00:00:00 2001 From: woweenie <145132974+woweenie@users.noreply.github.com> Date: Fri, 15 Sep 2023 18:59:44 +0200 Subject: patch DDPM.register_betas so that users can put given_betas in model yaml --- modules/sd_models.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 930d0bee..8e4983a4 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -7,7 +7,7 @@ import threading import torch import re import safetensors.torch -from omegaconf import OmegaConf +from omegaconf import OmegaConf, ListConfig from os import mkdir from urllib import request import ldm.modules.midas as midas @@ -17,6 +17,7 @@ from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack from modules.timer import Timer import tomesd +import numpy as np model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(paths.models_path, model_dir)) @@ -132,6 +133,7 @@ def setup_model(): os.makedirs(model_path, exist_ok=True) enable_midas_autodownload() + patch_given_betas() def checkpoint_tiles(use_short=False): @@ -453,6 +455,17 @@ def enable_midas_autodownload(): midas.api.load_model = load_model_wrapper +def patch_given_betas(): + original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule + def patched_register_schedule(*args, **kwargs): + if args[1] is not None and isinstance(args[1], ListConfig): + modified_args = list(args) # Convert args tuple to a list + modified_args[1] = np.array(args[1]) # Modify the desired element + args = tuple(modified_args) # Convert the list back to a tuple + original_register_schedule(*args, **kwargs) + ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule + + def repair_config(sd_config): if not hasattr(sd_config.model.params, "use_ema"): -- cgit v1.2.3 From 8e355fbd7552f1a7f5124c4685d6fa36f3d0ede1 Mon Sep 17 00:00:00 2001 From: 王秋文/qwwang Date: Mon, 18 Sep 2023 16:45:42 +0800 Subject: fix --- modules/sd_models.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6d17dd3c..eedb38c6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -309,6 +309,7 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") + # move to end as latest checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] -- cgit v1.2.3 From 87b50397a6da273fe0160016a209e4eb0cbf4a89 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 30 Sep 2023 09:11:31 +0300 Subject: add missing import, simplify code, use patches module for #13276 --- modules/sd_models.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e3755253..5ef7aa13 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches from modules.timer import Timer import tomesd import numpy as np @@ -130,6 +130,8 @@ except Exception: def setup_model(): + """called once at startup to do various one-time tasks related to SD models""" + os.makedirs(model_path, exist_ok=True) enable_midas_autodownload() @@ -458,14 +460,17 @@ def enable_midas_autodownload(): def patch_given_betas(): - original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule + import ldm.models.diffusion.ddpm + def patched_register_schedule(*args, **kwargs): - if args[1] is not None and isinstance(args[1], ListConfig): - modified_args = list(args) # Convert args tuple to a list - modified_args[1] = np.array(args[1]) # Modify the desired element - args = tuple(modified_args) # Convert the list back to a tuple + """a modified version of register_schedule function that converts plain list from Omegaconf into numpy""" + + if isinstance(args[1], ListConfig): + args = (args[0], np.array(args[1]), *args[2:]) + original_register_schedule(*args, **kwargs) - ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule + + original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule) def repair_config(sd_config): -- cgit v1.2.3 From 76010a51ef1f3805a7487723599035bc2356c3fb Mon Sep 17 00:00:00 2001 From: wangqiuwen Date: Sat, 7 Oct 2023 15:36:01 +0800 Subject: up --- modules/sd_models.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index eedb38c6..3a060ab6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,4 +1,5 @@ import collections +import copy import os.path import sys import gc @@ -309,8 +310,6 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") - # move to end as latest - checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") @@ -352,12 +351,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if model.is_sdxl: sd_models_xl.extend_sdxl(model) - model.load_state_dict(state_dict, strict=False) - timer.record("apply weights to model") - if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = state_dict + checkpoints_loaded[checkpoint_info] = copy.deepcopy(state_dict) + + model.load_state_dict(state_dict, strict=False) + timer.record("apply weights to model") del state_dict -- cgit v1.2.3 From 770ee23f18d12fb3b5627c636aa420f481e292ee Mon Sep 17 00:00:00 2001 From: wangqiuwen Date: Sat, 7 Oct 2023 15:38:50 +0800 Subject: reverst --- modules/sd_models.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 3a060ab6..8d63e7f1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -310,6 +310,8 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): if checkpoint_info in checkpoints_loaded: # use checkpoint cache print(f"Loading weights [{sd_model_hash}] from cache") + # move to end as latest + checkpoints_loaded.move_to_end(checkpoint_info) return checkpoints_loaded[checkpoint_info] print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}") -- cgit v1.2.3 From 0619df9835833079f8ba5cb5a510b55c4433acaf Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 14 Oct 2023 08:01:04 +0300 Subject: use shallow copy for #13535 --- modules/sd_models.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 2b43868e..c8efeedc 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,5 +1,4 @@ import collections -import copy import os.path import sys import gc @@ -360,7 +359,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = copy.deepcopy(state_dict) + checkpoints_loaded[checkpoint_info] = state_dict.copy() model.load_state_dict(state_dict, strict=False) timer.record("apply weights to model") -- cgit v1.2.3 From 282903bb6798f49af66f6935ee4dc0015895cf7c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 15 Oct 2023 09:41:02 +0300 Subject: repair unload sd checkpoint button --- modules/api/api.py | 11 +++++------ modules/sd_models.py | 13 +------------ modules/ui_settings.py | 24 +++++++++++++++++------- 3 files changed, 23 insertions(+), 25 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/api/api.py b/modules/api/api.py index efedafa4..09083874 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -17,15 +17,14 @@ from fastapi.encoders import jsonable_encoder from secrets import compare_digest import modules.shared as shared -from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste +from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models from modules.api import models from modules.shared import opts from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images from modules.textual_inversion.textual_inversion import create_embedding, train_embedding from modules.textual_inversion.preprocess import preprocess from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork -from PIL import PngImagePlugin,Image -from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases +from PIL import PngImagePlugin, Image from modules.sd_models_config import find_checkpoint_config_near_filename from modules.realesrgan_model import get_realesrgan_models from modules import devices @@ -541,12 +540,12 @@ class Api: return {} def unloadapi(self): - unload_model_weights() + sd_models.unload_model_weights() return {} def reloadapi(self): - reload_model_weights() + sd_models.send_model_to_device(shared.sd_model) return {} @@ -566,7 +565,7 @@ class Api: def set_config(self, req: dict[str, Any]): checkpoint_name = req.get("sd_model_checkpoint", None) - if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases: + if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases: raise RuntimeError(f"model {checkpoint_name!r} not found") for k, v in req.items(): diff --git a/modules/sd_models.py b/modules/sd_models.py index c8efeedc..3b6cdea1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -1,7 +1,6 @@ import collections import os.path import sys -import gc import threading import torch @@ -798,17 +797,7 @@ def reload_model_weights(sd_model=None, info=None): def unload_model_weights(sd_model=None, info=None): - timer = Timer() - - if model_data.sd_model: - model_data.sd_model.to(devices.cpu) - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) - model_data.sd_model = None - sd_model = None - gc.collect() - devices.torch_gc() - - print(f"Unloaded weights {timer.summary()}.") + send_model_to_cpu(sd_model or shared.sd_model) return sd_model diff --git a/modules/ui_settings.py b/modules/ui_settings.py index 74a3aef3..e054d00a 100644 --- a/modules/ui_settings.py +++ b/modules/ui_settings.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo +from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer from modules.call_queue import wrap_gradio_call from modules.shared import opts from modules.ui_components import FormRow @@ -177,8 +177,8 @@ class UiSettings: download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") with gr.Row(): - unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model") - reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model") + unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model") + reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model") with gr.Row(): calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash") calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1) @@ -194,16 +194,26 @@ class UiSettings: self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) + def call_func_and_return_text(func, text): + def handler(): + t = timer.Timer() + func() + t.record(text) + + return f'{text} in {t.total:.1f}s' + + return handler + unload_sd_model.click( - fn=sd_models.unload_model_weights, + fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'), inputs=[], - outputs=[] + outputs=[self.result] ) reload_sd_model.click( - fn=sd_models.reload_model_weights, + fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'), inputs=[], - outputs=[] + outputs=[self.result] ) request_notifications.click( -- cgit v1.2.3 From ff1609f91ea0e9a90ba7b6ecc6d794c39c1f8c8f Mon Sep 17 00:00:00 2001 From: Ritesh Gangnani Date: Sun, 5 Nov 2023 19:13:49 +0530 Subject: Add SSD-1B as a supported model --- modules/sd_hijack.py | 11 +++++++++++ modules/sd_models.py | 8 ++++++-- modules/sd_models_types.py | 5 ++++- 3 files changed, 21 insertions(+), 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 592f0055..d19f853e 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,6 +180,17 @@ class StableDiffusionModelHijack: except Exception as e: errors.display(e, "applying cross attention optimization") undo_optimizations() + + def conv_ssd(self, m): + delattr(m.model.diffusion_model.middle_block, '1') + delattr(m.model.diffusion_model.middle_block, '2') + for i in ['9','8','7','6','5','4']: + delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks,i) + delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks,i) + delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks,i) + delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i) + delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1') + delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1') def hijack(self, m): conditioner = getattr(m, 'conditioner', None) diff --git a/modules/sd_models.py b/modules/sd_models.py index 930d0bee..ef96d29d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -346,10 +346,14 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sdxl = hasattr(model, 'conditioner') model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 - + model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() + if model.is_sdxl: sd_models_xl.extend_sdxl(model) - + + if model.is_ssd: + sd_hijack.model_hijack.conv_ssd(model) + model.load_state_dict(state_dict, strict=False) timer.record("apply weights to model") diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index 5ffd2f4f..1f28942a 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -22,7 +22,10 @@ class WebuiSdModel(LatentDiffusion): """structure with additional information about the file with model's weights""" is_sdxl: bool - """True if the model's architecture is SDXL""" + """True if the model's architecture is SDXL or SSD""" + + is_ssd: bool + """True if the model is SSD""" is_sd2: bool """True if the model's architecture is SD 2.x""" -- cgit v1.2.3 From 44c5097375ae4cf40300c09473bb46cf6c5d6cb7 Mon Sep 17 00:00:00 2001 From: Ritesh Gangnani Date: Sun, 5 Nov 2023 20:31:57 +0530 Subject: Use devices.torch_gc() instead of empty_cache() --- modules/sd_hijack.py | 5 +---- modules/sd_models.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 059ffe8f..0a7e5135 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -1,5 +1,3 @@ -import gc - import torch from torch.nn.functional import silu from types import MethodType @@ -193,8 +191,7 @@ class StableDiffusionModelHijack: delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i) delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1') delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1') - torch.cuda.empty_cache() - gc.collect() + devices.torch_gc() def hijack(self, m): conditioner = getattr(m, 'conditioner', None) diff --git a/modules/sd_models.py b/modules/sd_models.py index ef96d29d..2242c363 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -347,7 +347,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') model.is_sd1 = not model.is_sdxl and not model.is_sd2 model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() - if model.is_sdxl: sd_models_xl.extend_sdxl(model) -- cgit v1.2.3 From 80d639a440929e9effe4620ce74333de283e7efc Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 5 Nov 2023 19:32:21 +0300 Subject: linter --- modules/sd_hijack.py | 2 +- modules/sd_models.py | 4 ++-- modules/sd_models_types.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4fff418d..c6d17764 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -183,7 +183,7 @@ class StableDiffusionModelHijack: except Exception as e: errors.display(e, "applying cross attention optimization") undo_optimizations() - + def conv_ssd(self, m): delattr(m.model.diffusion_model.middle_block, '1') delattr(m.model.diffusion_model.middle_block, '2') diff --git a/modules/sd_models.py b/modules/sd_models.py index d76dc580..1036a3b1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -355,10 +355,10 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys() if model.is_sdxl: sd_models_xl.extend_sdxl(model) - + if model.is_ssd: sd_hijack.model_hijack.conv_ssd(model) - + if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model checkpoints_loaded[checkpoint_info] = state_dict.copy() diff --git a/modules/sd_models_types.py b/modules/sd_models_types.py index 1f28942a..f911fbb6 100644 --- a/modules/sd_models_types.py +++ b/modules/sd_models_types.py @@ -23,7 +23,7 @@ class WebuiSdModel(LatentDiffusion): is_sdxl: bool """True if the model's architecture is SDXL or SSD""" - + is_ssd: bool """True if the model is SSD""" -- cgit v1.2.3 From 6ad666e4794a57dd65790dd6a259d5d4330d45ed Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 5 Nov 2023 19:46:20 +0300 Subject: more changes for #13865: fix formatting, rename the function, add comment and add a readme entry --- README.md | 1 + modules/sd_hijack.py | 24 +++++++++++++----------- modules/sd_models.py | 2 +- 3 files changed, 15 insertions(+), 12 deletions(-) (limited to 'modules/sd_models.py') diff --git a/README.md b/README.md index c7a4e363..25ba070e 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ A browser interface based on Gradio library for Stable Diffusion. - Eased resolution restriction: generated image's dimensions must be a multiple of 8 rather than 64 - Now with a license! - Reorder elements in the UI from settings screen +- [Segmind Stable Diffusion](https://huggingface.co/segmind/SSD-1B) support ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for: diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c6d17764..fba23c38 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -184,17 +184,19 @@ class StableDiffusionModelHijack: errors.display(e, "applying cross attention optimization") undo_optimizations() - def conv_ssd(self, m): - delattr(m.model.diffusion_model.middle_block, '1') - delattr(m.model.diffusion_model.middle_block, '2') - for i in ['9','8','7','6','5','4']: - delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks,i) - delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks,i) - delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks,i) - delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i) - delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1') - delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1') - devices.torch_gc() + def convert_sdxl_to_ssd(self, m): + """Converts an SDXL model to a Segmind Stable Diffusion model (see https://huggingface.co/segmind/SSD-1B)""" + + delattr(m.model.diffusion_model.middle_block, '1') + delattr(m.model.diffusion_model.middle_block, '2') + for i in ['9', '8', '7', '6', '5', '4']: + delattr(m.model.diffusion_model.input_blocks[7][1].transformer_blocks, i) + delattr(m.model.diffusion_model.input_blocks[8][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[0][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks, i) + delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks, '1') + delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks, '1') + devices.torch_gc() def hijack(self, m): conditioner = getattr(m, 'conditioner', None) diff --git a/modules/sd_models.py b/modules/sd_models.py index 1036a3b1..841402e8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -357,7 +357,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_models_xl.extend_sdxl(model) if model.is_ssd: - sd_hijack.model_hijack.conv_ssd(model) + sd_hijack.model_hijack.convert_sdxl_to_ssd(model) if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model -- cgit v1.2.3 From 6080045b2a0964e63bdcd33dd26015f8a51411f6 Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 1 Dec 2023 22:58:05 -0500 Subject: Add support for SD 2.1 Turbo, by converting the state dict from SGM to LDM on load --- modules/sd_models.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 841402e8..9355f1e1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -230,15 +230,19 @@ def select_checkpoint(): return checkpoint_info -checkpoint_dict_replacements = { +checkpoint_dict_replacements_sd1 = { 'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.', 'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.', 'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.', } +checkpoint_dict_replacements_sd2_turbo = { # Converts SD 2.1 Turbo from SGM to LDM format. + 'conditioner.embedders.0.': 'cond_stage_model.', +} + -def transform_checkpoint_dict_key(k): - for text, replacement in checkpoint_dict_replacements.items(): +def transform_checkpoint_dict_key(k, replacements): + for text, replacement in replacements.items(): if k.startswith(text): k = replacement + k[len(text):] @@ -249,9 +253,14 @@ def get_state_dict_from_checkpoint(pl_sd): pl_sd = pl_sd.pop("state_dict", pl_sd) pl_sd.pop("state_dict", None) + is_sd2_turbo = 'conditioner.embedders.0.model.ln_final.weight' in pl_sd and pl_sd['conditioner.embedders.0.model.ln_final.weight'].size()[0] == 1024 + sd = {} for k, v in pl_sd.items(): - new_key = transform_checkpoint_dict_key(k) + if is_sd2_turbo: + new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd2_turbo) + else: + new_key = transform_checkpoint_dict_key(k, checkpoint_dict_replacements_sd1) if new_key is not None: sd[new_key] = v -- cgit v1.2.3