From af081211ee93622473ee575de30fed2fd8263c09 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 11 Jul 2023 21:16:43 +0300 Subject: getting SD2.1 to run on SDXL repo --- modules/sd_models.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 060e0007..8d639583 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -289,6 +289,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) + if hasattr(model, 'conditioner'): + sd_models_xl.extend_sdxl(model) + model.load_state_dict(state_dict, strict=False) del state_dict timer.record("apply weights to model") @@ -334,7 +337,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer model.sd_checkpoint_info = checkpoint_info shared.opts.data["sd_checkpoint_hash"] = checkpoint_info.sha256 - model.logvar = model.logvar.to(devices.device) # fix for training + if hasattr(model, 'logvar'): + model.logvar = model.logvar.to(devices.device) # fix for training sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() -- cgit v1.2.3 From da464a3fb39ecc6ea7b22fe87271194480d8501c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 12 Jul 2023 23:52:43 +0300 Subject: SDXL support --- modules/launch_utils.py | 17 +++++++++++++ modules/lowvram.py | 51 +++++++++++++++++++++++++++----------- modules/paths.py | 9 ++++++- modules/processing.py | 7 ++++-- modules/prompt_parser.py | 23 ++++++++++++++--- modules/sd_hijack.py | 23 ++++++++++++++++- modules/sd_hijack_clip.py | 16 +++++++++--- modules/sd_hijack_open_clip.py | 38 +++++++++++++++++++++++++--- modules/sd_hijack_optimizations.py | 51 ++++++++++++++++++++++++++++++++------ modules/sd_models.py | 14 +++++++++-- modules/sd_models_config.py | 5 +++- modules/sd_models_xl.py | 27 +++++++++++++++++--- modules/sd_samplers_kdiffusion.py | 2 +- modules/shared.py | 2 ++ requirements.txt | 1 + requirements_versions.txt | 1 + 16 files changed, 242 insertions(+), 45 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 3b740dbd..aa9d1880 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -224,6 +224,20 @@ def run_extensions_installers(settings_file): run_extension_installer(os.path.join(extensions_dir, dirname_extension)) +def mute_sdxl_imports(): + """create fake modules that SDXL wants to import but doesn't actually use for our purposes""" + + import importlib + + module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('taming.modules.losses.lpips', None)) + module.LPIPS = None + sys.modules['taming.modules.losses.lpips'] = module + + module = importlib.util.module_from_spec(importlib.machinery.ModuleSpec('sgm.data', None)) + module.StableDataModuleFromConfig = None + sys.modules['sgm.data'] = module + + def prepare_environment(): torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu118") torch_command = os.environ.get('TORCH_COMMAND', f"pip install torch==2.0.1 torchvision==0.15.2 --extra-index-url {torch_index_url}") @@ -319,11 +333,14 @@ def prepare_environment(): if args.update_all_extensions: git_pull_recursive(extensions_dir) + mute_sdxl_imports() + if "--exit" in sys.argv: print("Exiting because of --exit argument") exit(0) + def configure_for_tests(): if "--api" not in sys.argv: sys.argv.append("--api") diff --git a/modules/lowvram.py b/modules/lowvram.py index d95bcfbf..da4f33a8 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -53,19 +53,46 @@ def setup_for_low_vram(sd_model, use_medvram): send_me_to_gpu(first_stage_model, None) return first_stage_model_decode(z) - # for SD1, cond_stage_model is CLIP and its NN is in the tranformer frield, but for SD2, it's open clip, and it's in model field - if hasattr(sd_model.cond_stage_model, 'model'): - sd_model.cond_stage_model.transformer = sd_model.cond_stage_model.model - - # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model and then - # send the model to GPU. Then put modules back. the modules will be in CPU. - stored = sd_model.cond_stage_model.transformer, sd_model.first_stage_model, getattr(sd_model, 'depth_model', None), getattr(sd_model, 'embedder', None), sd_model.model - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = None, None, None, None, None + to_remain_in_cpu = [ + (sd_model, 'first_stage_model'), + (sd_model, 'depth_model'), + (sd_model, 'embedder'), + (sd_model, 'model'), + (sd_model, 'embedder'), + ] + + is_sdxl = hasattr(sd_model, 'conditioner') + is_sd2 = not is_sdxl and hasattr(sd_model.cond_stage_model, 'model') + + if is_sdxl: + to_remain_in_cpu.append((sd_model, 'conditioner')) + elif is_sd2: + to_remain_in_cpu.append((sd_model.cond_stage_model, 'model')) + else: + to_remain_in_cpu.append((sd_model.cond_stage_model, 'transformer')) + + # remove several big modules: cond, first_stage, depth/embedder (if applicable), and unet from the model + stored = [] + for obj, field in to_remain_in_cpu: + module = getattr(obj, field, None) + stored.append(module) + setattr(obj, field, None) + + # send the model to GPU. sd_model.to(devices.device) - sd_model.cond_stage_model.transformer, sd_model.first_stage_model, sd_model.depth_model, sd_model.embedder, sd_model.model = stored + + # put modules back. the modules will be in CPU. + for (obj, field), module in zip(to_remain_in_cpu, stored): + setattr(obj, field, module) # register hooks for those the first three models - sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) + if is_sdxl: + sd_model.conditioner.register_forward_pre_hook(send_me_to_gpu) + elif is_sd2: + sd_model.cond_stage_model.model.register_forward_pre_hook(send_me_to_gpu) + else: + sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu) + sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu) sd_model.first_stage_model.encode = first_stage_model_encode_wrap sd_model.first_stage_model.decode = first_stage_model_decode_wrap @@ -75,10 +102,6 @@ def setup_for_low_vram(sd_model, use_medvram): sd_model.embedder.register_forward_pre_hook(send_me_to_gpu) parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model - if hasattr(sd_model.cond_stage_model, 'model'): - sd_model.cond_stage_model.model = sd_model.cond_stage_model.transformer - del sd_model.cond_stage_model.transformer - if use_medvram: sd_model.model.register_forward_pre_hook(send_me_to_gpu) else: diff --git a/modules/paths.py b/modules/paths.py index f509a85f..1100a8dc 100644 --- a/modules/paths.py +++ b/modules/paths.py @@ -20,7 +20,7 @@ assert sd_path is not None, f"Couldn't find Stable Diffusion in any of: {possibl path_dirs = [ (sd_path, 'ldm', 'Stable Diffusion', []), - (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', []), + (os.path.join(sd_path, '../generative-models'), 'sgm', 'Stable Diffusion XL', ["sgm"]), (os.path.join(sd_path, '../CodeFormer'), 'inference_codeformer.py', 'CodeFormer', []), (os.path.join(sd_path, '../BLIP'), 'models/blip.py', 'BLIP', []), (os.path.join(sd_path, '../k-diffusion'), 'k_diffusion/sampling.py', 'k_diffusion', ["atstart"]), @@ -36,6 +36,13 @@ for d, must_exist, what, options in path_dirs: d = os.path.abspath(d) if "atstart" in options: sys.path.insert(0, d) + elif "sgm" in options: + # Stable Diffusion XL repo has scripts dir with __init__.py in it which ruins every extension's scripts dir, so we + # import sgm and remove it from sys.path so that when a script imports scripts.something, it doesbn't use sgm's scripts dir. + + sys.path.insert(0, d) + import sgm + sys.path.pop(0) else: sys.path.append(d) paths[what] = d diff --git a/modules/processing.py b/modules/processing.py index cd568a20..85d35423 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -343,10 +343,13 @@ class StableDiffusionProcessing: return cache[1] def setup_conds(self): + prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) + negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height) + sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index d7f9e9a9..33810669 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from collections import namedtuple from typing import List @@ -109,7 +111,19 @@ def get_learned_conditioning_prompt_schedules(prompts, steps): ScheduledPromptConditioning = namedtuple("ScheduledPromptConditioning", ["end_at_step", "cond"]) -def get_learned_conditioning(model, prompts, steps): +class SdConditioning(list): + """ + A list with prompts for stable diffusion's conditioner model. + Can also specify width and height of created image - SDXL needs it. + """ + def __init__(self, prompts, width=None, height=None): + super().__init__() + self.extend(prompts) + self.width = width or getattr(prompts, 'width', None) + self.height = height or getattr(prompts, 'height', None) + + +def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): """converts a list of prompts into a list of prompt schedules - each schedule is a list of ScheduledPromptConditioning, specifying the comdition (cond), and the sampling step at which this condition is to be replaced by the next one. @@ -160,11 +174,13 @@ def get_learned_conditioning(model, prompts, steps): re_AND = re.compile(r"\bAND\b") re_weight = re.compile(r"^(.*?)(?:\s*:\s*([-+]?(?:\d+\.?|\d*\.\d+)))?\s*$") -def get_multicond_prompt_list(prompts): + +def get_multicond_prompt_list(prompts: SdConditioning | list[str]): res_indexes = [] - prompt_flat_list = [] prompt_indexes = {} + prompt_flat_list = SdConditioning(prompts) + prompt_flat_list.clear() for prompt in prompts: subprompts = re_AND.split(prompt) @@ -201,6 +217,7 @@ class MulticondLearnedConditioning: self.shape: tuple = shape # the shape field is needed to send this object to DDIM/PLMS self.batch: List[List[ComposableScheduledPromptConditioning]] = batch + def get_multicond_learned_conditioning(model, prompts, steps) -> MulticondLearnedConditioning: """same as get_learned_conditioning, but returns a list of ScheduledPromptConditioning along with the weight objects for each prompt. For each prompt, the list is obtained by splitting the prompt using the AND separator. diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index c4b9211f..266811f9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -15,6 +15,11 @@ import ldm.models.diffusion.ddim import ldm.models.diffusion.plms import ldm.modules.encoders.modules +import sgm.modules.attention +import sgm.modules.diffusionmodules.model +import sgm.modules.diffusionmodules.openaimodel +import sgm.modules.encoders.modules + attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward @@ -56,6 +61,9 @@ def apply_optimizations(option=None): ldm.modules.diffusionmodules.model.nonlinearity = silu ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + sgm.modules.diffusionmodules.model.nonlinearity = silu + sgm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th + if current_optimizer is not None: current_optimizer.undo() current_optimizer = None @@ -89,6 +97,10 @@ def undo_optimizations(): ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + sgm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + def fix_checkpoint(): """checkpoints are now added and removed in embedding/hypernet code, since torch doesn't want @@ -170,10 +182,19 @@ class StableDiffusionModelHijack: if conditioner: for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] - if type(embedder).__name__ == 'FrozenOpenCLIPEmbedder': + typename = type(embedder).__name__ + if typename == 'FrozenOpenCLIPEmbedder': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) conditioner.embedders[i] = m.cond_stage_model + if typename == 'FrozenCLIPEmbedder': + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self) + conditioner.embedders[i] = m.cond_stage_model + if typename == 'FrozenOpenCLIPEmbedder2': + embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 3b5a7666..6c17a81d 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.chunk_length = 75 + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + def empty_chunk(self): """creates an empty PromptChunk and returns it""" @@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): """ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will - be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. An example shape returned by this function can be: (2, 77, 768). + For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" """ @@ -233,7 +238,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) self.hijack.comments.append(f"Used embeddings: {embeddings_list}") - return torch.hstack(zs) + if getattr(self.wrapped, 'return_pooled', False): + return torch.hstack(zs), zs[0].pooled + else: + return torch.hstack(zs) def process_tokens(self, remade_batch_tokens, batch_multipliers): """ @@ -256,9 +264,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() - z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() - z = z * (original_mean / new_mean) + z *= (original_mean / new_mean) return z diff --git a/modules/sd_hijack_open_clip.py b/modules/sd_hijack_open_clip.py index 6ac5bda6..fcf5ad07 100644 --- a/modules/sd_hijack_open_clip.py +++ b/modules/sd_hijack_open_clip.py @@ -16,10 +16,6 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit self.id_end = tokenizer.encoder[""] self.id_pad = 0 - self.is_trainable = getattr(wrapped, 'is_trainable', False) - self.input_key = getattr(wrapped, 'input_key', 'txt') - self.legacy_ucg_val = None - def tokenize(self, texts): assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' @@ -39,3 +35,37 @@ class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWit embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) return embedded + + +class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] + self.id_start = tokenizer.encoder[""] + self.id_end = tokenizer.encoder[""] + self.id_pad = 0 + + def tokenize(self, texts): + assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' + + tokenized = [tokenizer.encode(text) for text in texts] + + return tokenized + + def encode_with_transformers(self, tokens): + d = self.wrapped.encode_with_transformer(tokens) + z = d[self.wrapped.layer] + + pooled = d.get("pooled") + if pooled is not None: + z.pooled = pooled + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + ids = tokenizer.encode(init_text) + ids = torch.asarray([ids], device=devices.device, dtype=torch.int) + embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) + + return embedded diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index 53e27ade..e99c9ba5 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -14,7 +14,11 @@ from modules.hypernetworks import hypernetwork import ldm.modules.attention import ldm.modules.diffusionmodules.model +import sgm.modules.attention +import sgm.modules.diffusionmodules.model + diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward +sgm_diffusionmodules_model_AttnBlock_forward = sgm.modules.diffusionmodules.model.AttnBlock.forward class SdOptimization: @@ -39,6 +43,9 @@ class SdOptimization: ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward + sgm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sgm_diffusionmodules_model_AttnBlock_forward + class SdOptimizationXformers(SdOptimization): name = "xformers" @@ -51,6 +58,8 @@ class SdOptimizationXformers(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = xformers_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward + sgm.modules.attention.CrossAttention.forward = xformers_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward class SdOptimizationSdpNoMem(SdOptimization): @@ -65,6 +74,8 @@ class SdOptimizationSdpNoMem(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward class SdOptimizationSdp(SdOptimizationSdpNoMem): @@ -76,6 +87,8 @@ class SdOptimizationSdp(SdOptimizationSdpNoMem): def apply(self): ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward + sgm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward class SdOptimizationSubQuad(SdOptimization): @@ -86,6 +99,8 @@ class SdOptimizationSubQuad(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward + sgm.modules.attention.CrossAttention.forward = sub_quad_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward class SdOptimizationV1(SdOptimization): @@ -94,9 +109,9 @@ class SdOptimizationV1(SdOptimization): cmd_opt = "opt_split_attention_v1" priority = 10 - def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 class SdOptimizationInvokeAI(SdOptimization): @@ -109,6 +124,7 @@ class SdOptimizationInvokeAI(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI class SdOptimizationDoggettx(SdOptimization): @@ -119,6 +135,8 @@ class SdOptimizationDoggettx(SdOptimization): def apply(self): ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward + sgm.modules.attention.CrossAttention.forward = split_cross_attention_forward + sgm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward def list_optimizers(res): @@ -155,7 +173,7 @@ def get_available_vram(): # see https://github.com/basujindal/stable-diffusion/pull/117 for discussion -def split_cross_attention_forward_v1(self, x, context=None, mask=None): +def split_cross_attention_forward_v1(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) @@ -196,7 +214,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): # taken from https://github.com/Doggettx/stable-diffusion and modified -def split_cross_attention_forward(self, x, context=None, mask=None): +def split_cross_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) @@ -262,11 +280,13 @@ def split_cross_attention_forward(self, x, context=None, mask=None): # -- Taken from https://github.com/invoke-ai/InvokeAI and modified -- mem_total_gb = psutil.virtual_memory().total // (1 << 30) + def einsum_op_compvis(q, k, v): s = einsum('b i d, b j d -> b i j', q, k) s = s.softmax(dim=-1, dtype=s.dtype) return einsum('b i j, b j d -> b i d', s, v) + def einsum_op_slice_0(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[0], slice_size): @@ -274,6 +294,7 @@ def einsum_op_slice_0(q, k, v, slice_size): r[i:end] = einsum_op_compvis(q[i:end], k[i:end], v[i:end]) return r + def einsum_op_slice_1(q, k, v, slice_size): r = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) for i in range(0, q.shape[1], slice_size): @@ -281,6 +302,7 @@ def einsum_op_slice_1(q, k, v, slice_size): r[:, i:end] = einsum_op_compvis(q[:, i:end], k, v) return r + def einsum_op_mps_v1(q, k, v): if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096 return einsum_op_compvis(q, k, v) @@ -290,12 +312,14 @@ def einsum_op_mps_v1(q, k, v): slice_size -= 1 return einsum_op_slice_1(q, k, v, slice_size) + def einsum_op_mps_v2(q, k, v): if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16: return einsum_op_compvis(q, k, v) else: return einsum_op_slice_0(q, k, v, 1) + def einsum_op_tensor_mem(q, k, v, max_tensor_mb): size_mb = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() // (1 << 20) if size_mb <= max_tensor_mb: @@ -305,6 +329,7 @@ def einsum_op_tensor_mem(q, k, v, max_tensor_mb): return einsum_op_slice_0(q, k, v, q.shape[0] // div) return einsum_op_slice_1(q, k, v, max(q.shape[1] // div, 1)) + def einsum_op_cuda(q, k, v): stats = torch.cuda.memory_stats(q.device) mem_active = stats['active_bytes.all.current'] @@ -315,6 +340,7 @@ def einsum_op_cuda(q, k, v): # Divide factor of safety as there's copying and fragmentation return einsum_op_tensor_mem(q, k, v, mem_free_total / 3.3 / (1 << 20)) + def einsum_op(q, k, v): if q.device.type == 'cuda': return einsum_op_cuda(q, k, v) @@ -328,7 +354,8 @@ def einsum_op(q, k, v): # Tested on i7 with 8MB L3 cache. return einsum_op_tensor_mem(q, k, v, 32) -def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): + +def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q = self.to_q(x) @@ -356,7 +383,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): # Based on Birch-san's modified implementation of sub-quadratic attention from https://github.com/Birch-san/diffusers/pull/1 # The sub_quad_attention_forward function is under the MIT License listed under Memory Efficient Attention in the Licenses section of the web UI interface -def sub_quad_attention_forward(self, x, context=None, mask=None): +def sub_quad_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): assert mask is None, "attention-mask not currently implemented for SubQuadraticCrossAttnProcessor." h = self.heads @@ -392,6 +419,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): return x + def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_size_min=None, chunk_threshold=None, use_checkpoint=True): bytes_per_token = torch.finfo(q.dtype).bits//8 batch_x_heads, q_tokens, _ = q.shape @@ -442,7 +470,7 @@ def get_xformers_flash_attention_op(q, k, v): return None -def xformers_attention_forward(self, x, context=None, mask=None): +def xformers_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): h = self.heads q_in = self.to_q(x) context = default(context, x) @@ -465,9 +493,10 @@ def xformers_attention_forward(self, x, context=None, mask=None): out = rearrange(out, 'b n h d -> b n (h d)', h=h) return self.to_out(out) + # Based on Diffusers usage of scaled dot product attention from https://github.com/huggingface/diffusers/blob/c7da8fd23359a22d0df2741688b5b4f33c26df21/src/diffusers/models/cross_attention.py # The scaled_dot_product_attention_forward function contains parts of code under Apache-2.0 license listed under Scaled Dot Product Attention in the Licenses section of the web UI interface -def scaled_dot_product_attention_forward(self, x, context=None, mask=None): +def scaled_dot_product_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): batch_size, sequence_length, inner_dim = x.shape if mask is not None: @@ -507,10 +536,12 @@ def scaled_dot_product_attention_forward(self, x, context=None, mask=None): hidden_states = self.to_out[1](hidden_states) return hidden_states -def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None): + +def scaled_dot_product_no_mem_attention_forward(self, x, context=None, mask=None, additional_tokens=None, n_times_crossframe_attn_in_self=0): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return scaled_dot_product_attention_forward(self, x, context, mask) + def cross_attention_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -569,6 +600,7 @@ def cross_attention_attnblock_forward(self, x): return h3 + def xformers_attnblock_forward(self, x): try: h_ = x @@ -592,6 +624,7 @@ def xformers_attnblock_forward(self, x): except NotImplementedError: return cross_attention_attnblock_forward(self, x) + def sdp_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) @@ -612,10 +645,12 @@ def sdp_attnblock_forward(self, x): out = self.proj_out(out) return x + out + def sdp_no_mem_attnblock_forward(self, x): with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=False): return sdp_attnblock_forward(self, x) + def sub_quad_attnblock_forward(self, x): h_ = x h_ = self.norm(h_) diff --git a/modules/sd_models.py b/modules/sd_models.py index 8d639583..e4aae597 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -411,6 +411,7 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' +sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' class SdModelData: @@ -445,6 +446,15 @@ class SdModelData: model_data = SdModelData() +def get_empty_cond(sd_model): + if hasattr(sd_model, 'conditioner'): + d = sd_model.get_learned_conditioning([""]) + return d['crossattn'] + else: + return sd_model.cond_stage_model([""]) + + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -465,7 +475,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict + clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict timer.record("find config") @@ -517,7 +527,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("scripts callbacks") with devices.autocast(), torch.no_grad(): - sd_model.cond_stage_model_empty_prompt = sd_model.cond_stage_model([""]) + sd_model.cond_stage_model_empty_prompt = get_empty_cond(sd_model) timer.record("calculate empty prompt") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 96501569..2e92479a 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") +config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -70,7 +71,9 @@ def guess_model_config_from_state_dict(sd, filename): diffusion_model_input = sd.get('model.diffusion_model.input_blocks.0.0.weight', None) sd2_variations_weight = sd.get('embedder.model.ln_final.weight', None) - if sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: + if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: + return config_sdxl + elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: return config_unclip diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index d43b8868..e8e270c3 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -1,18 +1,30 @@ from __future__ import annotations +import sys + import torch import sgm.models.diffusion import sgm.modules.diffusionmodules.denoiser_scaling import sgm.modules.diffusionmodules.discretizer -from modules import devices +from modules import devices, shared, prompt_parser -def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: list[str]): +def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: prompt_parser.SdConditioning | list[str]): for embedder in self.conditioner.embedders: embedder.ucg_rate = 0.0 - c = self.conditioner({'txt': batch}) + width = getattr(self, 'target_width', 1024) + height = getattr(self, 'target_height', 1024) + + sdxl_conds = { + "txt": batch, + "original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + } + + c = self.conditioner(sdxl_conds) return c @@ -26,7 +38,7 @@ def extend_sdxl(model): model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' - model.cond_stage_model = [x for x in model.conditioner.embedders if type(x).__name__ == 'FrozenOpenCLIPEmbedder'][0] + model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0] model.cond_stage_key = model.cond_stage_model.input_key model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" @@ -34,7 +46,14 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.is_xl = True + sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.modules.attention.print = lambda *args: None +sgm.modules.diffusionmodules.model.print = lambda *args: None +sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None +sgm.modules.encoders.modules.print = lambda *args: None + diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 73289ce4..5552a8dc 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -186,7 +186,7 @@ class CFGDenoiser(torch.nn.Module): for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(cond_in[a:b], image_cond_in[a:b])) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b])) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size diff --git a/modules/shared.py b/modules/shared.py index b7518de6..71afd94f 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,6 +428,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), + "sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"), + "sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"), })) options_templates.update(options_section(('optimizations', "Optimizations"), { diff --git a/requirements.txt b/requirements.txt index 3142085e..b3f8a7f4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,6 +14,7 @@ kornia lark numpy omegaconf +open-clip-torch piexif psutil diff --git a/requirements_versions.txt b/requirements_versions.txt index f71b9d6c..b826bf43 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -15,6 +15,7 @@ kornia==0.6.7 lark==1.1.2 numpy==1.23.5 omegaconf==2.2.3 +open-clip-torch==2.20.0 piexif==1.1.3 psutil~=5.9.5 pytorch_lightning==1.9.4 -- cgit v1.2.3 From e16ebc917dfc902f041963df0d4e99e8141cf82f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 17:32:35 +0300 Subject: repair --no-half for SDXL --- modules/sd_models.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e4aae597..9e8cb3cf 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -395,10 +395,11 @@ def repair_config(sd_config): if not hasattr(sd_config.model.params, "use_ema"): sd_config.model.params.use_ema = False - if shared.cmd_opts.no_half: - sd_config.model.params.unet_config.params.use_fp16 = False - elif shared.cmd_opts.upcast_sampling: - sd_config.model.params.unet_config.params.use_fp16 = True + if hasattr(sd_config.model.params, 'unet_config'): + if shared.cmd_opts.no_half: + sd_config.model.params.unet_config.params.use_fp16 = False + elif shared.cmd_opts.upcast_sampling: + sd_config.model.params.unet_config.params.use_fp16 = True if getattr(sd_config.model.params.first_stage_config.params.ddconfig, "attn_type", None) == "vanilla-xformers" and not shared.xformers_available: sd_config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla" -- cgit v1.2.3 From 6c5f83b19b331d51bde28c5033d13d0d64c11e54 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 21:17:50 +0300 Subject: add support for SDXL loras with te1/te2 modules --- extensions-builtin/Lora/lora.py | 41 +++++++++++++++++++++++++++++++---------- modules/sd_models.py | 3 ++- modules/sd_models_xl.py | 1 - 3 files changed, 33 insertions(+), 12 deletions(-) (limited to 'modules/sd_models.py') diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py index 03f1ef85..4b5da7b5 100644 --- a/extensions-builtin/Lora/lora.py +++ b/extensions-builtin/Lora/lora.py @@ -68,6 +68,14 @@ def convert_diffusers_name_to_compvis(key, is_sd2): return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): + if 'mlp_fc1' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" + elif 'mlp_fc2' in m[1]: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" + else: + return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" + return key @@ -142,10 +150,20 @@ class LoraUpDownModule: def assign_lora_names_to_compvis_modules(sd_model): lora_layer_mapping = {} - for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): - lora_name = name.replace(".", "_") - lora_layer_mapping[lora_name] = module - module.lora_layer_name = lora_name + if shared.sd_model.is_sdxl: + for i, embedder in enumerate(shared.sd_model.conditioner.embedders): + if not hasattr(embedder, 'wrapped'): + continue + + for name, module in embedder.wrapped.named_modules(): + lora_name = f'{i}_{name.replace(".", "_")}' + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + else: + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name for name, module in shared.sd_model.model.named_modules(): lora_name = name.replace(".", "_") @@ -168,10 +186,10 @@ def load_lora(name, lora_on_disk): keys_failed_to_match = {} is_sd2 = 'model_transformer_resblocks' in shared.sd_model.lora_layer_mapping - for key_diffusers, weight in sd.items(): - key_diffusers_without_lora_parts, lora_key = key_diffusers.split(".", 1) - key = convert_diffusers_name_to_compvis(key_diffusers_without_lora_parts, is_sd2) + for key_lora, weight in sd.items(): + key_lora_without_lora_parts, lora_key = key_lora.split(".", 1) + key = convert_diffusers_name_to_compvis(key_lora_without_lora_parts, is_sd2) sd_module = shared.sd_model.lora_layer_mapping.get(key, None) if sd_module is None: @@ -180,12 +198,15 @@ def load_lora(name, lora_on_disk): sd_module = shared.sd_model.lora_layer_mapping.get(m.group(1), None) # SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" - if sd_module is None and "lora_unet" in key_diffusers_without_lora_parts: - key = key_diffusers_without_lora_parts.replace("lora_unet", "diffusion_model") + if sd_module is None and "lora_unet" in key_lora_without_lora_parts: + key = key_lora_without_lora_parts.replace("lora_unet", "diffusion_model") + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + elif sd_module is None and "lora_te1_text_model" in key_lora_without_lora_parts: + key = key_lora_without_lora_parts.replace("lora_te1_text_model", "0_transformer_text_model") sd_module = shared.sd_model.lora_layer_mapping.get(key, None) if sd_module is None: - keys_failed_to_match[key_diffusers] = key + keys_failed_to_match[key_lora] = key continue lora_module = lora.modules.get(key, None) diff --git a/modules/sd_models.py b/modules/sd_models.py index 9e8cb3cf..07702175 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -289,7 +289,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) - if hasattr(model, 'conditioner'): + model.is_sdxl = hasattr(model, 'conditioner') + if model.is_sdxl: sd_models_xl.extend_sdxl(model) model.load_state_dict(state_dict, strict=False) diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index af445a61..a7240dc0 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -48,7 +48,6 @@ def extend_sdxl(model): discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) - model.is_sdxl = True sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning -- cgit v1.2.3 From 6d8dcdefa07d5f8f7e528046b0facdcc51185e60 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:16:01 +0300 Subject: initial SDXL refiner support --- modules/sd_hijack.py | 18 ++++++++++---- modules/sd_models.py | 3 ++- modules/sd_models_config.py | 3 +++ modules/sd_models_xl.py | 57 ++++++++++++++++++++++++++++++++++++--------- modules/shared.py | 9 +++++-- 5 files changed, 71 insertions(+), 19 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 647cdfbe..2b274c18 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,21 +180,29 @@ class StableDiffusionModelHijack: def hijack(self, m): conditioner = getattr(m, 'conditioner', None) if conditioner: + text_cond_models = [] + for i in range(len(conditioner.embedders)): embedder = conditioner.embedders[i] typename = type(embedder).__name__ if typename == 'FrozenOpenCLIPEmbedder': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) - m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenCLIPEmbedder': - model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + model_embeddings = embedder.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) - conditioner.embedders[i] = m.cond_stage_model + conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) if typename == 'FrozenOpenCLIPEmbedder2': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self) + text_cond_models.append(conditioner.embedders[i]) + + if len(text_cond_models) == 1: + m.cond_stage_model = text_cond_models[0] + else: + m.cond_stage_model = conditioner if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation: model_embeddings = m.cond_stage_model.roberta.embeddings diff --git a/modules/sd_models.py b/modules/sd_models.py index 07702175..267f4d8e 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -414,6 +414,7 @@ def repair_config(sd_config): sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight' sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight' sdxl_clip_weight = 'conditioner.embedders.1.model.ln_final.weight' +sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: @@ -477,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict or sdxl_clip_weight in state_dict + clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict]) timer.record("find config") diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 04c09ab0..8266fa39 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -14,6 +14,7 @@ config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") +config_sdxl_refiner = os.path.join(sd_xl_repo_configs_path, "sd_xl_refiner.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") config_unclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-l-inference.yaml") config_unopenclip = os.path.join(sd_repo_configs_path, "v2-1-stable-unclip-h-inference.yaml") @@ -72,6 +73,8 @@ def guess_model_config_from_state_dict(sd, filename): if sd.get('conditioner.embedders.1.model.ln_final.weight', None) is not None: return config_sdxl + if sd.get('conditioner.embedders.0.model.ln_final.weight', None) is not None: + return config_sdxl_refiner elif sd.get('depth_model.model.pretrained.act_postprocess3.0.project.0.bias', None) is not None: return config_depth_model elif sd2_variations_weight is not None and sd2_variations_weight.shape[0] == 768: diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index a7240dc0..01320c7a 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -14,15 +14,20 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: width = getattr(self, 'target_width', 1024) height = getattr(self, 'target_height', 1024) + is_negative_prompt = getattr(batch, 'is_negative_prompt', False) + aesthetic_score = shared.opts.sdxl_refiner_low_aesthetic_score if is_negative_prompt else shared.opts.sdxl_refiner_high_aesthetic_score + + devices_args = dict(device=devices.device, dtype=devices.dtype) sdxl_conds = { "txt": batch, - "original_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), - "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left]).repeat(len(batch), 1).to(devices.device, devices.dtype), - "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), + "original_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "crop_coords_top_left": torch.tensor([shared.opts.sdxl_crop_top, shared.opts.sdxl_crop_left], **devices_args).repeat(len(batch), 1), + "target_size_as_tuple": torch.tensor([height, width], **devices_args).repeat(len(batch), 1), + "aesthetic_score": torch.tensor([aesthetic_score], **devices_args).repeat(len(batch), 1), } - force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch) + force_zero_negative_prompt = is_negative_prompt and all(x == '' for x in batch) c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) return c @@ -35,25 +40,55 @@ def apply_model(self: sgm.models.diffusion.DiffusionEngine, x, t, cond): def get_first_stage_encoding(self, x): # SDXL's encode_first_stage does everything so get_first_stage_encoding is just there for compatibility return x + +sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning +sgm.models.diffusion.DiffusionEngine.apply_model = apply_model +sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding + + +def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text, nvpt): + res = [] + + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'encode_embedding_init_text')]: + encoded = embedder.encode_embedding_init_text(init_text, nvpt) + res.append(encoded) + + return torch.cat(res, dim=1) + + +def process_texts(self, texts): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]: + return embedder.process_texts(texts) + + +def get_target_prompt_token_count(self, token_count): + for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'get_target_prompt_token_count')]: + return embedder.get_target_prompt_token_count(token_count) + + +# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist +sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text +sgm.modules.GeneralConditioner.process_texts = process_texts +sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count + + def extend_sdxl(model): + """this adds a bunch of parameters to make SDXL model look a bit more like SD1.5 to the rest of the codebase.""" + dtype = next(model.model.diffusion_model.parameters()).dtype model.model.diffusion_model.dtype = dtype model.model.conditioning_key = 'crossattn' - - model.cond_stage_model = [x for x in model.conditioner.embedders if 'CLIPEmbedder' in type(x).__name__][0] - model.cond_stage_key = model.cond_stage_model.input_key + model.cond_stage_key = 'txt' + # model.cond_stage_model will be set in sd_hijack model.parameterization = "v" if isinstance(model.denoiser.scaling, sgm.modules.diffusionmodules.denoiser_scaling.VScaling) else "eps" discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization() model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype) + model.conditioner.wrapped = torch.nn.Module() -sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning -sgm.models.diffusion.DiffusionEngine.apply_model = apply_model -sgm.models.diffusion.DiffusionEngine.get_first_stage_encoding = get_first_stage_encoding - sgm.modules.attention.print = lambda *args: None sgm.modules.diffusionmodules.model.print = lambda *args: None sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None diff --git a/modules/shared.py b/modules/shared.py index 71afd94f..234ede0d 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -428,8 +428,13 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors"), - "sdxl_crop_top": OptionInfo(0, "SDXL top coordinate of the crop"), - "sdxl_crop_left": OptionInfo(0, "SDXL left coordinate of the crop"), +})) + +options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { + "sdxl_crop_top": OptionInfo(0, "crop top coordinate"), + "sdxl_crop_left": OptionInfo(0, "crop left coordinate"), + "sdxl_refiner_low_aesthetic_score": OptionInfo(2.5, "SDXL low aesthetic score", gr.Number).info("used for refiner model negative prompt"), + "sdxl_refiner_high_aesthetic_score": OptionInfo(6.0, "SDXL high aesthetic score", gr.Number).info("used for refiner model prompt"), })) options_templates.update(options_section(('optimizations', "Optimizations"), { -- cgit v1.2.3 From b7dbeda0d9e475aafa9db0cfe015bf724502ec20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 14 Jul 2023 09:19:08 +0300 Subject: linter --- modules/sd_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 267f4d8e..729f03d7 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -478,7 +478,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) - clip_is_included_into_sd = any([x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict]) + clip_is_included_into_sd = any(x for x in [sd1_clip_weight, sd2_clip_weight, sdxl_clip_weight, sdxl_refiner_clip_weight] if x in state_dict) timer.record("find config") -- cgit v1.2.3 From 699108bfbb05c2a7d2ee4a2c7abcfaa0a244d8ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 18:56:14 +0300 Subject: hide cards for networks of incompatible stable diffusion version in Lora extra networks interface --- extensions-builtin/Lora/network.py | 20 +++++++++++++ extensions-builtin/Lora/scripts/lora_script.py | 2 ++ extensions-builtin/Lora/ui_edit_user_metadata.py | 20 +++++++++---- extensions-builtin/Lora/ui_extra_networks_lora.py | 34 +++++++++++++++++++---- html/extra-networks-card.html | 2 +- javascript/extraNetworks.js | 2 +- modules/sd_models.py | 3 ++ modules/ui_extra_networks.py | 3 +- modules/ui_extra_networks_user_metadata.py | 7 ++++- style.css | 6 +++- 10 files changed, 84 insertions(+), 15 deletions(-) (limited to 'modules/sd_models.py') diff --git a/extensions-builtin/Lora/network.py b/extensions-builtin/Lora/network.py index fe42dbdd..8ecfa29a 100644 --- a/extensions-builtin/Lora/network.py +++ b/extensions-builtin/Lora/network.py @@ -1,5 +1,6 @@ import os from collections import namedtuple +import enum from modules import sd_models, cache, errors, hashes, shared @@ -8,6 +9,13 @@ NetworkWeights = namedtuple('NetworkWeights', ['network_key', 'sd_key', 'w', 'sd metadata_tags_order = {"ss_sd_model_name": 1, "ss_resolution": 2, "ss_clip_skip": 3, "ss_num_train_images": 10, "ss_tag_frequency": 20} +class SdVersion(enum.Enum): + Unknown = 1 + SD1 = 2 + SD2 = 3 + SDXL = 4 + + class NetworkOnDisk: def __init__(self, name, filename): self.name = name @@ -44,6 +52,18 @@ class NetworkOnDisk: '' ) + self.sd_version = self.detect_version() + + def detect_version(self): + if str(self.metadata.get('ss_base_model_version', "")).startswith("sdxl_"): + return SdVersion.SDXL + elif str(self.metadata.get('ss_v2', "")) == "True": + return SdVersion.SD2 + elif len(self.metadata): + return SdVersion.SD1 + + return SdVersion.Unknown + def set_hash(self, v): self.hash = v self.shorthash = self.hash[0:12] diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py index f478f718..cd28afc9 100644 --- a/extensions-builtin/Lora/scripts/lora_script.py +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -63,6 +63,8 @@ shared.options_templates.update(shared.options_section(('extra_networks', "Extra "sd_lora": shared.OptionInfo("None", "Add network to prompt", gr.Dropdown, lambda: {"choices": ["None", *networks.available_networks]}, refresh=networks.list_available_networks), "lora_preferred_name": shared.OptionInfo("Alias from file", "When adding to prompt, refer to Lora by", gr.Radio, {"choices": ["Alias from file", "Filename"]}), "lora_add_hashes_to_infotext": shared.OptionInfo(True, "Add Lora hashes to infotext"), + "lora_show_all": shared.OptionInfo(False, "Always show all networks on the Lora page").info("otherwise, those detected as for incompatible version of Stable Diffusion will be hidden"), + "lora_hide_unknown_for_versions": shared.OptionInfo([], "Hide networks of unknown versions for model versions", gr.CheckboxGroup, {"choices": ["SD1", "SD2", "SDXL"]}), })) diff --git a/extensions-builtin/Lora/ui_edit_user_metadata.py b/extensions-builtin/Lora/ui_edit_user_metadata.py index 354a1d68..c8730443 100644 --- a/extensions-builtin/Lora/ui_edit_user_metadata.py +++ b/extensions-builtin/Lora/ui_edit_user_metadata.py @@ -46,14 +46,17 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) def __init__(self, ui, tabname, page): super().__init__(ui, tabname, page) + self.select_sd_version = None + self.taginfo = None self.edit_activation_text = None self.slider_preferred_weight = None self.edit_notes = None - def save_lora_user_metadata(self, name, desc, activation_text, preferred_weight, notes): + def save_lora_user_metadata(self, name, desc, sd_version, activation_text, preferred_weight, notes): user_metadata = self.get_user_metadata(name) user_metadata["description"] = desc + user_metadata["sd version"] = sd_version user_metadata["activation text"] = activation_text user_metadata["preferred weight"] = preferred_weight user_metadata["notes"] = notes @@ -112,11 +115,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) gradio_tags = [(tag, str(count)) for tag, count in tags[0:24]] return [ - *values[0:4], + *values[0:5], + item.get("sd_version", "Unknown"), gr.HighlightedText.update(value=gradio_tags, visible=True if tags else False), user_metadata.get('activation text', ''), float(user_metadata.get('preferred weight', 0.0)), - user_metadata.get('notes', ''), gr.update(visible=True if tags else False), gr.update(value=self.generate_random_prompt_from_tags(tags), visible=True if tags else False), ] @@ -141,10 +144,15 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) return ", ".join(sorted(res)) + def create_extra_default_items_in_left_column(self): + + # this would be a lot better as gr.Radio but I can't make it work + self.select_sd_version = gr.Dropdown(['SD1', 'SD2', 'SDXL', 'Unknown'], value='Unknown', label='Stable Diffusion version', interactive=True) + def create_editor(self): self.create_default_editor_elems() - self.taginfo = gr.HighlightedText(label="Tags") + self.taginfo = gr.HighlightedText(label="Training dataset tags") self.edit_activation_text = gr.Text(label='Activation text', info="Will be added to prompt along with Lora") self.slider_preferred_weight = gr.Slider(label='Preferred weight', info="Set to 0 to disable", minimum=0.0, maximum=2.0, step=0.01) @@ -178,10 +186,11 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) self.edit_description, self.html_filedata, self.html_preview, + self.edit_notes, + self.select_sd_version, self.taginfo, self.edit_activation_text, self.slider_preferred_weight, - self.edit_notes, row_random_prompt, random_prompt, ] @@ -192,6 +201,7 @@ class LoraUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor) edited_components = [ self.edit_description, + self.select_sd_version, self.edit_activation_text, self.slider_preferred_weight, self.edit_notes, diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py index b6171a26..4b32098b 100644 --- a/extensions-builtin/Lora/ui_extra_networks_lora.py +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -1,7 +1,9 @@ import os + +import network import networks -from modules import shared, ui_extra_networks +from modules import shared, ui_extra_networks, paths from modules.ui_extra_networks import quote_js from ui_edit_user_metadata import LoraUserMetadataEditor @@ -13,14 +15,13 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): def refresh(self): networks.list_available_networks() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): lora_on_disk = networks.available_networks.get(name) path, ext = os.path.splitext(lora_on_disk.filename) alias = lora_on_disk.get_alias() - # in 1.5 filename changes to be full filename instead of path without extension, and metadata is dict instead of json string item = { "name": name, "filename": lora_on_disk.filename, @@ -30,6 +31,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): "local_preview": f"{path}.{shared.opts.samples_format}", "metadata": lora_on_disk.metadata, "sort_keys": {'default': index, **self.get_sort_keys(lora_on_disk.filename)}, + "sd_version": lora_on_disk.sd_version.name, } self.read_user_metadata(item) @@ -40,15 +42,37 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): if activation_text: item["prompt"] += " + " + quote_js(" " + activation_text) + sd_version = item["user_metadata"].get("sd version") + if sd_version in network.SdVersion.__members__: + item["sd_version"] = sd_version + sd_version = network.SdVersion[sd_version] + else: + sd_version = lora_on_disk.sd_version + + if shared.opts.lora_show_all or not enable_filter: + pass + elif sd_version == network.SdVersion.Unknown: + model_version = network.SdVersion.SDXL if shared.sd_model.is_sdxl else network.SdVersion.SD2 if shared.sd_model.is_sd2 else network.SdVersion.SD1 + if model_version.name in shared.opts.lora_hide_unknown_for_versions: + return None + elif shared.sd_model.is_sdxl and sd_version != network.SdVersion.SDXL: + return None + elif shared.sd_model.is_sd2 and sd_version != network.SdVersion.SD2: + return None + elif shared.sd_model.is_sd1 and sd_version != network.SdVersion.SD1: + return None + return item def list_items(self): for index, name in enumerate(networks.available_networks): item = self.create_item(name, index) - yield item + + if item is not None: + yield item def allowed_directories_for_previews(self): - return [shared.cmd_opts.lora_dir] + return [shared.cmd_opts.lora_dir, os.path.join(paths.models_path, "LyCORIS")] def create_user_metadata_editor(self, ui, tabname): return LoraUserMetadataEditor(ui, tabname, self) diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html index eb8b1a67..39674666 100644 --- a/html/extra-networks-card.html +++ b/html/extra-networks-card.html @@ -1,8 +1,8 @@
{background_image}
- {edit_button} {metadata_button} + {edit_button}
diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index e453094a..5582a6e5 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -213,7 +213,7 @@ function popup(contents) { globalPopupInner.classList.add('global-popup-inner'); globalPopup.appendChild(globalPopupInner); - gradioApp().appendChild(globalPopup); + gradioApp().querySelector('.main').appendChild(globalPopup); } globalPopupInner.innerHTML = ''; diff --git a/modules/sd_models.py b/modules/sd_models.py index 729f03d7..4d9382dd 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -290,6 +290,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer state_dict = get_checkpoint_state_dict(checkpoint_info, timer) model.is_sdxl = hasattr(model, 'conditioner') + model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model') + model.is_sd1 = not model.is_sdxl and not model.is_sd2 + if model.is_sdxl: sd_models_xl.extend_sdxl(model) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 6c73998f..49612298 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -62,7 +62,8 @@ def get_single_card(page: str = "", tabname: str = "", name: str = ""): page = next(iter([x for x in extra_pages if x.name == page]), None) try: - item = page.create_item(name) + item = page.create_item(name, enable_filter=False) + page.items[name] = item except Exception as e: errors.display(e, "creating item for extra network") item = page.items.get(name) diff --git a/modules/ui_extra_networks_user_metadata.py b/modules/ui_extra_networks_user_metadata.py index 01ff4e4b..63d4b503 100644 --- a/modules/ui_extra_networks_user_metadata.py +++ b/modules/ui_extra_networks_user_metadata.py @@ -42,6 +42,9 @@ class UserMetadataEditor: return user_metadata + def create_extra_default_items_in_left_column(self): + pass + def create_default_editor_elems(self): with gr.Row(): with gr.Column(scale=2): @@ -49,6 +52,8 @@ class UserMetadataEditor: self.edit_description = gr.Textbox(label="Description", lines=4) self.html_filedata = gr.HTML() + self.create_extra_default_items_in_left_column() + with gr.Column(scale=1, min_width=0): self.html_preview = gr.HTML() @@ -111,7 +116,7 @@ class UserMetadataEditor: table = '' + "".join(f"" for name, value in params) + '' - return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', ''), + return html.escape(name), user_metadata.get('description', ''), table, self.get_card_html(name), user_metadata.get('notes', '') def write_user_metadata(self, name, metadata): item = self.page.items.get(name, {}) diff --git a/style.css b/style.css index 8a66c3d2..e249cfd3 100644 --- a/style.css +++ b/style.css @@ -841,7 +841,7 @@ footer { .extra-network-cards .card .card-button { text-shadow: 2px 2px 3px black; - padding: 0.25em; + padding: 0.25em 0.1em; font-size: 200%; width: 1.5em; } @@ -957,6 +957,10 @@ div.block.gradio-box.edit-user-metadata { text-align: left; } +.edit-user-metadata .file-metadata th, .edit-user-metadata .file-metadata td{ + padding: 0.3em 1em; +} + .edit-user-metadata .wrap.translucent{ background: var(--body-background-fill); } -- cgit v1.2.3 From f0e2098f1a533c88396536282c1d6cd7d847a51c Mon Sep 17 00:00:00 2001 From: brkirch Date: Mon, 17 Jul 2023 23:39:38 -0400 Subject: Add support for `--upcast-sampling` with SD XL --- modules/sd_hijack_unet.py | 8 +++++++- modules/sd_models.py | 2 +- 2 files changed, 8 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_hijack_unet.py b/modules/sd_hijack_unet.py index ca1daf45..2101f1a0 100644 --- a/modules/sd_hijack_unet.py +++ b/modules/sd_hijack_unet.py @@ -39,7 +39,10 @@ def apply_model(orig_func, self, x_noisy, t, cond, **kwargs): if isinstance(cond, dict): for y in cond.keys(): - cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + if isinstance(cond[y], list): + cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]] + else: + cond[y] = cond[y].to(devices.dtype_unet) if isinstance(cond[y], torch.Tensor) else cond[y] with devices.autocast(): return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float() @@ -77,3 +80,6 @@ first_stage_sub = lambda orig_func, self, x, **kwargs: orig_func(self, x.to(devi CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.decode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.encode_first_stage', first_stage_sub, first_stage_cond) CondFunc('ldm.models.diffusion.ddpm.LatentDiffusion.get_first_stage_encoding', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).float(), first_stage_cond) + +CondFunc('sgm.modules.diffusionmodules.wrappers.OpenAIWrapper.forward', apply_model, unet_needs_upcast) +CondFunc('sgm.modules.diffusionmodules.openaimodel.timestep_embedding', lambda orig_func, timesteps, *args, **kwargs: orig_func(timesteps, *args, **kwargs).to(torch.float32 if timesteps.dtype == torch.int64 else devices.dtype_unet), unet_needs_upcast) diff --git a/modules/sd_models.py b/modules/sd_models.py index 4d9382dd..5813b550 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -326,7 +326,7 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer timer.record("apply half()") - devices.dtype_unet = model.model.diffusion_model.dtype + devices.dtype_unet = torch.float16 if model.is_sdxl and not shared.cmd_opts.no_half else model.model.diffusion_model.dtype devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.3 From b270ded268c92950a35a7a326da54496ef4151c8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 18 Jul 2023 18:10:04 +0300 Subject: fix the issue with /sdapi/v1/options failing (this time for sure!) fix automated tests downloading CLIP model --- .github/workflows/run_tests.yaml | 1 + modules/api/models.py | 6 ++---- modules/cmd_args.py | 1 + modules/sd_models.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) (limited to 'modules/sd_models.py') diff --git a/.github/workflows/run_tests.yaml b/.github/workflows/run_tests.yaml index e9370cc0..3dafaf8d 100644 --- a/.github/workflows/run_tests.yaml +++ b/.github/workflows/run_tests.yaml @@ -41,6 +41,7 @@ jobs: --skip-prepare-environment --skip-torch-cuda-test --test-server + --do-not-download-clip --no-half --disable-opt-split-attention --use-cpu all diff --git a/modules/api/models.py b/modules/api/models.py index 96cfe920..4cd20a92 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -209,11 +209,9 @@ class PreprocessResponse(BaseModel): fields = {} for key, metadata in opts.data_labels.items(): value = opts.data.get(key) - if key == 'sd_model_checkpoint': - value = None - optType = opts.typemap.get(type(metadata.default), type(value)) + optType = opts.typemap.get(type(metadata.default), type(metadata.default)) - if isinstance(optType, types.NoneType): + if metadata.default is None: pass elif metadata is not None: fields.update({key: (Optional[optType], Field(default=metadata.default, description=metadata.label))}) diff --git a/modules/cmd_args.py b/modules/cmd_args.py index ae78f469..e401f641 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -15,6 +15,7 @@ parser.add_argument("--update-check", action='store_true', help="launch.py argum parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") +parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint") parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored") parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) diff --git a/modules/sd_models.py b/modules/sd_models.py index 5813b550..fb31a793 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -494,7 +494,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = None try: - with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd): + with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): sd_model = instantiate_from_config(sd_config.model) except Exception: pass -- cgit v1.2.3 From 0a89cd1a584b1584a0609c0ba27fb35c434b0b68 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 24 Jul 2023 22:08:08 +0300 Subject: Use less RAM when creating models --- modules/cmd_args.py | 1 + modules/sd_disable_initialization.py | 106 +++++++++++++++++++++++++++++++++-- modules/sd_models.py | 16 ++++-- webui.py | 4 +- 4 files changed, 114 insertions(+), 13 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index dd5fadc4..cb4ec5f7 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -67,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 9fc89dc6..695c5736 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -3,8 +3,31 @@ import open_clip import torch import transformers.utils.hub +from modules import shared -class DisableInitialization: + +class ReplaceHelper: + def __init__(self): + self.replaced = [] + + def replace(self, obj, field, func): + original = getattr(obj, field, None) + if original is None: + return None + + self.replaced.append((obj, field, original)) + setattr(obj, field, func) + + return original + + def restore(self): + for obj, field, original in self.replaced: + setattr(obj, field, original) + + self.replaced.clear() + + +class DisableInitialization(ReplaceHelper): """ When an object of this class enters a `with` block, it starts: - preventing torch's layer initialization functions from working @@ -21,7 +44,7 @@ class DisableInitialization: """ def __init__(self, disable_clip=True): - self.replaced = [] + super().__init__() self.disable_clip = disable_clip def replace(self, obj, field, func): @@ -86,8 +109,81 @@ class DisableInitialization: self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): - for obj, field, original in self.replaced: - setattr(obj, field, original) + self.restore() - self.replaced.clear() +class InitializeOnMeta(ReplaceHelper): + """ + Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device, + which results in those parameters having no values and taking no memory. model.to() will be broken and + will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict. + + Usage: + ``` + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + ``` + """ + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + def set_device(x): + x["device"] = "meta" + return x + + linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs))) + conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs))) + mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs))) + self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() + + +class LoadStateDictOnMeta(ReplaceHelper): + """ + Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device. + As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory. + Meant to be used together with InitializeOnMeta above. + + Usage: + ``` + with sd_disable_initialization.LoadStateDictOnMeta(state_dict): + model.load_state_dict(state_dict, strict=False) + ``` + """ + + def __init__(self, state_dict, device): + super().__init__() + self.state_dict = state_dict + self.device = device + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + sd = self.state_dict + device = self.device + + def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): + params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta] + + for name, param in params: + if param.is_meta: + self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad) + + original(self, state_dict, prefix, *args, **kwargs) + + for name, _ in params: + key = prefix + name + if key in sd: + del sd[key] + + linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) + conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) + mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() diff --git a/modules/sd_models.py b/modules/sd_models.py index fb31a793..acb1e817 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -460,7 +460,6 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) - def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): - sd_model = instantiate_from_config(sd_config.model) - except Exception: - pass + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + + except Exception as e: + errors.display(e, "creating model quickly", full_traceback=True) if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) - sd_model = instantiate_from_config(sd_config.model) + + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config timer.record("create model") - load_model_weights(sd_model, checkpoint_info, state_dict, timer) + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) diff --git a/webui.py b/webui.py index 2314735f..51248c39 100644 --- a/webui.py +++ b/webui.py @@ -320,9 +320,9 @@ def initialize_rest(*, reload_script_modules=False): if modules.sd_hijack.current_optimizer is None: modules.sd_hijack.apply_optimizations() - Thread(target=load_model).start() + devices.first_time_calculation() - Thread(target=devices.first_time_calculation).start() + Thread(target=load_model).start() shared.reload_hypernetworks() startup_timer.record("reload hypernetworks") -- cgit v1.2.3 From 3bca90b249d749ed5429f76e380d2ffa52fc0d41 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 30 Jul 2023 13:48:27 +0300 Subject: hires fix checkpoint selection --- modules/generation_parameters_copypaste.py | 3 ++ modules/processing.py | 47 +++++++++++++++++++----------- modules/sd_models.py | 22 ++++++++------ modules/shared.py | 19 ++++++++---- modules/txt2img.py | 3 +- modules/ui.py | 8 ++++- 6 files changed, 68 insertions(+), 34 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a3448be9..4e286558 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -280,6 +280,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Hires sampler" not in res: res["Hires sampler"] = "Use same sampler" + if "Hires checkpoint" not in res: + res["Hires checkpoint"] = "Use same checkpoint" + if "Hires prompt" not in res: res["Hires prompt"] = "" diff --git a/modules/processing.py b/modules/processing.py index b0992ee1..7026487a 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -935,7 +935,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): cached_hr_uc = [None, None] cached_hr_c = [None, None] - def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength @@ -946,11 +946,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_resize_y = hr_resize_y self.hr_upscale_to_x = hr_resize_x self.hr_upscale_to_y = hr_resize_y + self.hr_checkpoint_name = hr_checkpoint_name + self.hr_checkpoint_info = None self.hr_sampler_name = hr_sampler_name self.hr_prompt = hr_prompt self.hr_negative_prompt = hr_negative_prompt self.all_hr_prompts = None self.all_hr_negative_prompts = None + self.latent_scale_mode = None if firstphase_width != 0 or firstphase_height != 0: self.hr_upscale_to_x = self.width @@ -973,6 +976,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: + if self.hr_checkpoint_name: + self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name) + + if self.hr_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {self.hr_checkpoint_name}') + + self.extra_generation_params["Hires checkpoint"] = self.hr_checkpoint_info.short_title + if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name: self.extra_generation_params["Hires sampler"] = self.hr_sampler_name @@ -982,6 +993,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt): self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt + self.latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") + if self.enable_hr and self.latent_scale_mode is None: + if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): + raise Exception(f"could not find upscaler named {self.hr_upscaler}") + if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): self.hr_resize_x = self.width self.hr_resize_y = self.height @@ -1020,14 +1036,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f - # special case: the user has chosen to do nothing - if self.hr_upscale_to_x == self.width and self.hr_upscale_to_y == self.height: - self.enable_hr = False - self.denoising_strength = None - self.extra_generation_params.pop("Hires upscale", None) - self.extra_generation_params.pop("Hires resize", None) - return - if not state.processing_has_refined_job_count: if state.job_count == -1: state.job_count = self.n_iter @@ -1045,17 +1053,22 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest") - if self.enable_hr and latent_scale_mode is None: - if not any(x.name == self.hr_upscaler for x in shared.sd_upscalers): - raise Exception(f"could not find upscaler named {self.hr_upscaler}") - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) if not self.enable_hr: return samples + current = shared.sd_model.sd_checkpoint_info + try: + if self.hr_checkpoint_info is not None: + sd_models.reload_model_weights(info=self.hr_checkpoint_info) + + return self.sample_hr_pass(samples, seeds, subseeds, subseed_strength, prompts) + finally: + sd_models.reload_model_weights(info=current) + + def sample_hr_pass(self, samples, seeds, subseeds, subseed_strength, prompts): self.is_hr_pass = True target_width = self.hr_upscale_to_x @@ -1073,11 +1086,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): info = create_infotext(self, self.all_prompts, self.all_seeds, self.all_subseeds, [], iteration=self.iteration, position_in_batch=index) images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, info=info, p=self, suffix="-before-highres-fix") - if latent_scale_mode is not None: + if self.latent_scale_mode is not None: for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode["mode"], antialias=latent_scale_mode["antialias"]) + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=self.latent_scale_mode["mode"], antialias=self.latent_scale_mode["antialias"]) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -1193,7 +1206,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_uc = None self.hr_c = None - if self.enable_hr: + if self.enable_hr and self.hr_checkpoint_info is None: if shared.opts.hires_fix_use_firstpass_conds: self.calculate_hr_conds() diff --git a/modules/sd_models.py b/modules/sd_models.py index acb1e817..cb67e425 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -52,6 +52,7 @@ class CheckpointInfo: self.shorthash = self.sha256[0:10] if self.sha256 else None self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' + self.short_title = self.name_for_extra if self.shorthash is None else f'{self.name_for_extra} [{self.shorthash}]' self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) @@ -81,6 +82,7 @@ class CheckpointInfo: checkpoints_list.pop(self.title) self.title = f'{self.name} [{self.shorthash}]' + self.short_title = f'{self.name_for_extra} [{self.shorthash}]' self.register() return self.shorthash @@ -101,14 +103,8 @@ def setup_model(): enable_midas_autodownload() -def checkpoint_tiles(): - def convert(name): - return int(name) if name.isdigit() else name.lower() - - def alphanumeric_key(key): - return [convert(c) for c in re.split('([0-9]+)', key)] - - return sorted([x.title for x in checkpoints_list.values()], key=alphanumeric_key) +def checkpoint_tiles(use_short=False): + return [x.short_title if use_short else x.title for x in checkpoints_list.values()] def list_models(): @@ -131,11 +127,14 @@ def list_models(): elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file: print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr) - for filename in sorted(model_list, key=str.lower): + for filename in model_list: checkpoint_info = CheckpointInfo(filename) checkpoint_info.register() +re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") + + def get_closet_checkpoint_match(search_string): checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: @@ -145,6 +144,11 @@ def get_closet_checkpoint_match(search_string): if found: return found[0] + search_string_without_checksum = re.sub(re_strip_checksum, '', search_string) + found = sorted([info for info in checkpoints_list.values() if search_string_without_checksum in info.title], key=lambda x: len(x.title)) + if found: + return found[0] + return None diff --git a/modules/shared.py b/modules/shared.py index aa72c9c8..807fb9e3 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -220,12 +220,19 @@ class State: return import modules.sd_samplers - if opts.show_progress_grid: - self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) - else: - self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) - self.current_image_sampling_step = self.sampling_step + try: + if opts.show_progress_grid: + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) + else: + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) + + self.current_image_sampling_step = self.sampling_step + + except Exception: + # when switching models during genration, VAE would be on CPU, so creating an image will fail. + # we silently ignore this error + errors.record_exception() def assign_current_image(self, image): self.current_image = image @@ -512,7 +519,7 @@ options_templates.update(options_section(('ui', "User interface"), { "ui_tab_order": OptionInfo([], "UI tab order", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "hidden_tabs": OptionInfo([], "Hidden UI tabs", ui_components.DropdownMulti, lambda: {"choices": list(tab_names)}).needs_restart(), "ui_reorder_list": OptionInfo([], "txt2img/img2img UI item order", ui_components.DropdownMulti, lambda: {"choices": list(shared_items.ui_reorder_categories())}).info("selected items appear first").needs_restart(), - "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires sampler selection").needs_restart(), + "hires_fix_show_sampler": OptionInfo(False, "Hires fix: show hires checkpoint and sampler selection").needs_restart(), "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_restart(), "disable_token_counters": OptionInfo(False, "Disable prompt token counters").needs_restart(), })) diff --git a/modules/txt2img.py b/modules/txt2img.py index 29d94e8c..935ed418 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -41,6 +41,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_second_pass_steps=hr_second_pass_steps, hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, + hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, diff --git a/modules/ui.py b/modules/ui.py index 07ecee7b..6d8265f2 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -476,6 +476,10 @@ def create_ui(): hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: + checkpoint_choices = lambda: ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True) + hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=checkpoint_choices(), value="Use same checkpoint") + create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": checkpoint_choices()}, "hr_checkpoint_refresh") + hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index") with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: @@ -553,6 +557,7 @@ def create_ui(): hr_second_pass_steps, hr_resize_x, hr_resize_y, + hr_checkpoint_name, hr_sampler_index, hr_prompt, hr_negative_prompt, @@ -630,8 +635,9 @@ def create_ui(): (hr_second_pass_steps, "Hires steps"), (hr_resize_x, "Hires resize-1"), (hr_resize_y, "Hires resize-2"), + (hr_checkpoint_name, "Hires checkpoint"), (hr_sampler_index, "Hires sampler"), - (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" else gr.update()), + (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), -- cgit v1.2.3 From 4d9b096663288e2aa738723fa63950f3d41f6170 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 31 Jul 2023 10:43:31 +0300 Subject: additional memory improvements when switching between models of different types --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index cb67e425..4855037a 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -582,7 +582,10 @@ def reload_model_weights(sd_model=None, info=None): timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: - del sd_model + if sd_model is not None: + sd_model.to(device="meta") + + devices.torch_gc() load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model -- cgit v1.2.3 From b235022c615a7384f73c05fe240d8f4a28d103d4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 00:24:48 +0300 Subject: option to keep multiple models in memory --- modules/lowvram.py | 3 + modules/sd_hijack.py | 6 +- modules/sd_hijack_inpainting.py | 5 +- modules/sd_models.py | 136 +++++++++++++++++++++++++++++++++------- modules/sd_models_xl.py | 8 +-- modules/shared.py | 12 +++- 6 files changed, 135 insertions(+), 35 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/lowvram.py b/modules/lowvram.py index 3f830664..96f52b7b 100644 --- a/modules/lowvram.py +++ b/modules/lowvram.py @@ -15,6 +15,9 @@ def send_everything_to_cpu(): def setup_for_low_vram(sd_model, use_medvram): + if getattr(sd_model, 'lowvram', False): + return + sd_model.lowvram = True parents = {} diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cfa5f0eb..7d692e3c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -30,8 +30,10 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention # silence new console spam from SD2 -ldm.modules.attention.print = lambda *args: None -ldm.modules.diffusionmodules.model.print = lambda *args: None +ldm.modules.attention.print = shared.ldm_print +ldm.modules.diffusionmodules.model.print = shared.ldm_print +ldm.util.print = shared.ldm_print +ldm.models.diffusion.ddpm.print = shared.ldm_print optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index c1977b19..97350f4f 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -91,7 +91,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0, e_t -def do_inpainting_hijack(): - # p_sample_plms is needed because PLMS can't work with dicts as conditionings - - ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms +ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_models.py b/modules/sd_models.py index acb1e817..77195f2f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,7 +15,6 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl -from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -423,6 +422,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight' class SdModelData: def __init__(self): self.sd_model = None + self.loaded_sd_models = [] self.was_loaded_at_least_once = False self.lock = threading.Lock() @@ -437,6 +437,7 @@ class SdModelData: try: load_model() + except Exception as e: errors.display(e, "loading stable diffusion model", full_traceback=True) print("", file=sys.stderr) @@ -448,11 +449,24 @@ class SdModelData: def set_sd_model(self, v): self.sd_model = v + try: + self.loaded_sd_models.remove(v) + except ValueError: + pass + + if v is not None: + self.loaded_sd_models.insert(0, v) + model_data = SdModelData() def get_empty_cond(sd_model): + from modules import extra_networks, processing + + p = processing.StableDiffusionProcessingTxt2Img() + extra_networks.activate(p, {}) + if hasattr(sd_model, 'conditioner'): d = sd_model.get_learned_conditioning([""]) return d['crossattn'] @@ -460,19 +474,43 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) +def send_model_to_cpu(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + m.to(devices.cpu) + + devices.torch_gc() + + +def send_model_to_device(m): + from modules import lowvram + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram) + else: + m.to(shared.device) + + +def send_model_to_trash(m): + m.to(device="meta") + devices.torch_gc() + + def load_model(checkpoint_info=None, already_loaded_state_dict=None): - from modules import lowvram, sd_hijack + from modules import sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() + timer = Timer() + if model_data.sd_model: - sd_hijack.model_hijack.undo_hijack(model_data.sd_model) + send_model_to_trash(model_data.sd_model) model_data.sd_model = None - gc.collect() devices.torch_gc() - do_inpainting_hijack() - - timer = Timer() + timer.record("unload existing model") if already_loaded_state_dict is not None: state_dict = already_loaded_state_dict @@ -512,12 +550,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): load_model_weights(sd_model, checkpoint_info, state_dict, timer) + timer.record("load weights from state dict") - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) - else: - sd_model.to(shared.device) - + send_model_to_device(sd_model) timer.record("move model to device") sd_hijack.model_hijack.hijack(sd_model) @@ -525,7 +560,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): timer.record("hijack") sd_model.eval() - model_data.sd_model = sd_model + model_data.set_sd_model(sd_model) model_data.was_loaded_at_least_once = True sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) # Reload embeddings after model load as they may or may not fit the model @@ -546,10 +581,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): return sd_model +def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): + """ + Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models. + If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary). + If not, returns the model that can be used to load weights from checkpoint_info's file. + If no such model exists, returns None. + Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit). + """ + + already_loaded = None + for i in reversed(range(len(model_data.loaded_sd_models))): + loaded_model = model_data.loaded_sd_models[i] + if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename: + already_loaded = loaded_model + continue + + if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0: + print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}") + model_data.loaded_sd_models.pop() + send_model_to_trash(loaded_model) + timer.record("send model to trash") + + if shared.opts.sd_checkpoints_keep_in_cpu: + send_model_to_cpu(sd_model) + timer.record("send model to cpu") + + if already_loaded is not None: + send_model_to_device(already_loaded) + timer.record("send model to device") + + model_data.set_sd_model(already_loaded) + print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}") + return model_data.sd_model + elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit: + print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})") + + model_data.sd_model = None + load_model(checkpoint_info) + return model_data.sd_model + elif len(model_data.loaded_sd_models) > 0: + sd_model = model_data.loaded_sd_models.pop() + model_data.sd_model = sd_model + + print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}") + return sd_model + else: + return None + + def reload_model_weights(sd_model=None, info=None): - from modules import lowvram, devices, sd_hijack + from modules import devices, sd_hijack checkpoint_info = info or select_checkpoint() + timer = Timer() + if not sd_model: sd_model = model_data.sd_model @@ -558,19 +644,17 @@ def reload_model_weights(sd_model=None, info=None): else: current_checkpoint_info = sd_model.sd_checkpoint_info if sd_model.sd_model_checkpoint == checkpoint_info.filename: - return - - sd_unet.apply_unet("None") + return sd_model - if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: - lowvram.send_everything_to_cpu() - else: - sd_model.to(devices.cpu) + sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) + if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + return sd_model + if sd_model is not None: + sd_unet.apply_unet("None") + send_model_to_cpu(sd_model) sd_hijack.model_hijack.undo_hijack(sd_model) - timer = Timer() - state_dict = get_checkpoint_state_dict(checkpoint_info, timer) checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info) @@ -578,7 +662,9 @@ def reload_model_weights(sd_model=None, info=None): timer.record("find config") if sd_model is None or checkpoint_config != sd_model.used_config: - del sd_model + if sd_model is not None: + send_model_to_trash(sd_model) + load_model(checkpoint_info, already_loaded_state_dict=state_dict) return model_data.sd_model @@ -601,6 +687,8 @@ def reload_model_weights(sd_model=None, info=None): print(f"Weights loaded in {timer.summary()}.") + model_data.set_sd_model(sd_model) + return sd_model diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index bc219508..01123321 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -98,10 +98,10 @@ def extend_sdxl(model): model.conditioner.wrapped = torch.nn.Module() -sgm.modules.attention.print = lambda *args: None -sgm.modules.diffusionmodules.model.print = lambda *args: None -sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None -sgm.modules.encoders.modules.print = lambda *args: None +sgm.modules.attention.print = shared.ldm_print +sgm.modules.diffusionmodules.model.print = shared.ldm_print +sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print +sgm.modules.encoders.modules.print = shared.ldm_print # this gets the code to load the vanilla attention that we override sgm.modules.attention.SDP_IS_AVAILABLE = True diff --git a/modules/shared.py b/modules/shared.py index aa72c9c8..0184fcd0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -392,6 +392,7 @@ options_templates.update(options_section(('system', "System"), { "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."), "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""), "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), + "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), })) options_templates.update(options_section(('training', "Training"), { @@ -411,7 +412,9 @@ options_templates.update(options_section(('training', "Training"), { options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), - "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), + "sd_checkpoints_limit": OptionInfo(1, "Maximum number of checkpoints loaded at the same time", gr.Slider, {"minimum": 1, "maximum": 10, "step": 1}), + "sd_checkpoints_keep_in_cpu": OptionInfo(True, "Only keep one model on device").info("will keep models other than the currently used one in RAM rather than VRAM"), + "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}).info("obsolete; set to 0 and use the two settings above instead"), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": shared_items.sd_vae_items()}, refresh=shared_items.refresh_vae_list).info("choose VAE model: Automatic = use one with same filename as checkpoint; None = use VAE from checkpoint"), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), @@ -889,3 +892,10 @@ def walk_files(path, allowed_extensions=None): continue yield os.path.join(root, filename) + + +def ldm_print(*args, **kwargs): + if opts.hide_ldm_prints: + return + + print(*args, **kwargs) -- cgit v1.2.3 From 4b43480fe8b65a3bd24dc9bc03a7e910c9b0314f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 07:08:11 +0300 Subject: show metadata for SD checkpoints in the extra networks UI --- modules/sd_models.py | 26 ++++++++++++++++---------- modules/ui_extra_networks_checkpoints.py | 1 + 2 files changed, 17 insertions(+), 10 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index acb1e817..1af7fd78 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -14,7 +14,7 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config -from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl +from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd @@ -33,6 +33,8 @@ class CheckpointInfo: self.filename = filename abspath = os.path.abspath(filename) + self.is_safetensors = os.path.splitext(filename)[1].lower() == ".safetensors" + if shared.cmd_opts.ckpt_dir is not None and abspath.startswith(shared.cmd_opts.ckpt_dir): name = abspath.replace(shared.cmd_opts.ckpt_dir, '') elif abspath.startswith(model_path): @@ -43,6 +45,19 @@ class CheckpointInfo: if name.startswith("\\") or name.startswith("/"): name = name[1:] + def read_metadata(): + metadata = read_metadata_from_safetensors(filename) + self.modelspec_thumbnail = metadata.pop('modelspec.thumbnail', None) + + return metadata + + self.metadata = {} + if self.is_safetensors: + try: + self.metadata = cache.cached_data_for_file('safetensors-metadata', "checkpoint/" + name, filename, read_metadata) + except Exception as e: + errors.display(e, f"reading metadata for {filename}") + self.name = name self.name_for_extra = os.path.splitext(os.path.basename(filename))[0] self.model_name = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0] @@ -55,15 +70,6 @@ class CheckpointInfo: self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) - self.metadata = {} - - _, ext = os.path.splitext(self.filename) - if ext.lower() == ".safetensors": - try: - self.metadata = read_metadata_from_safetensors(filename) - except Exception as e: - errors.display(e, f"reading checkpoint metadata: {filename}") - def register(self): checkpoints_list[self.title] = self for id in self.ids: diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 76780cfd..2bb0a222 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -23,6 +23,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): "search_term": self.search_terms_from_path(checkpoint.filename) + " " + (checkpoint.sha256 or ""), "onclick": '"' + html.escape(f"""return selectCheckpoint({quote_js(name)})""") + '"', "local_preview": f"{path}.{shared.opts.samples_format}", + "metadata": checkpoint.metadata, "sort_keys": {'default': index, **self.get_sort_keys(checkpoint.filename)}, } -- cgit v1.2.3 From 07be13caa357b14f6afa247566d53339522b8e66 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 08:27:54 +0300 Subject: add metadata to checkpoint merger --- modules/extras.py | 39 +++++++++++++++++++++++++++++++++------ modules/sd_models.py | 2 +- modules/ui_checkpoint_merger.py | 20 ++++++++++++++++++-- 3 files changed, 52 insertions(+), 9 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/extras.py b/modules/extras.py index e9c0263e..2a310ae3 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ import json import torch import tqdm -from modules import shared, images, sd_models, sd_vae, sd_models_config +from modules import shared, images, sd_models, sd_vae, sd_models_config, errors from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch @@ -72,7 +72,20 @@ def to_half(tensor, enable): return tensor -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): +def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name): + metadata = {} + + for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]: + checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None) + if checkpoint_info is None: + continue + + metadata.update(checkpoint_info.metadata) + + return json.dumps(metadata, indent=4, ensure_ascii=False) + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json): shared.state.begin(job="model-merge") def fail(message): @@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = "Saving" print(f"Saving to {output_modelname}...") - metadata = None + metadata = {} + + if save_metadata and copy_metadata_fields: + if primary_model_info: + metadata.update(primary_model_info.metadata) + if secondary_model_info: + metadata.update(secondary_model_info.metadata) + if tertiary_model_info: + metadata.update(tertiary_model_info.metadata) if save_metadata: - metadata = {"format": "pt"} + try: + metadata.update(json.loads(metadata_json)) + except Exception as e: + errors.display(e, "readin metadata from json") + + metadata["format"] = "pt" + if save_metadata and add_merge_recipe: merge_recipe = { "type": "webui", # indicate this model was merged with webui's built-in merger "primary_model_hash": primary_model_info.sha256, @@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ "is_inpainting": result_is_inpainting_model, "is_instruct_pix2pix": result_is_instruct_pix2pix_model } - metadata["sd_merge_recipe"] = json.dumps(merge_recipe) sd_merge_models = {} @@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if tertiary_model_info: add_model_metadata(tertiary_model_info) + metadata["sd_merge_recipe"] = json.dumps(merge_recipe) metadata["sd_merge_models"] = json.dumps(sd_merge_models) _, extension = os.path.splitext(output_modelname) if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata) + safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None) else: torch.save(theta_0, output_modelname) diff --git a/modules/sd_models.py b/modules/sd_models.py index 1af7fd78..8f72f21d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -85,7 +85,7 @@ class CheckpointInfo: if self.shorthash not in self.ids: self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] - checkpoints_list.pop(self.title) + checkpoints_list.pop(self.title, None) self.title = f'{self.name} [{self.shorthash}]' self.register() diff --git a/modules/ui_checkpoint_merger.py b/modules/ui_checkpoint_merger.py index 8e72258a..4863d861 100644 --- a/modules/ui_checkpoint_merger.py +++ b/modules/ui_checkpoint_merger.py @@ -51,7 +51,6 @@ class UiCheckpointMerger: with FormRow(): self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format") self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half") - self.save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata") with FormRow(): with gr.Column(): @@ -65,16 +64,30 @@ class UiCheckpointMerger: with FormRow(): self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights") - with gr.Row(): + with gr.Accordion("Metadata", open=False) as metadata_editor: + with FormRow(): + self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata") + self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe") + self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata") + + self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format") + self.read_metadata = gr.Button("Read metadata from selected checkpoints") + + with FormRow(): self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary') with gr.Column(variant='compact', elem_id="modelmerger_results_container"): with gr.Group(elem_id="modelmerger_results_panel"): self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False) + self.metadata_editor = metadata_editor self.blocks = modelmerger_interface def setup_ui(self, dummy_component, sd_model_checkpoint_component): + self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False) + + self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json]) + self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result]) self.modelmerger_merge.click( fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]), @@ -93,6 +106,9 @@ class UiCheckpointMerger: self.bake_in_vae, self.discard_weights, self.save_metadata, + self.add_merge_recipe, + self.copy_metadata_fields, + self.metadata_json, ], outputs=[ self.primary_model_name, -- cgit v1.2.3 From 390bffa81b747a7eb38ac7a0cd6dfb9fcc388151 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 17:13:15 +0300 Subject: repair merge error --- modules/sd_models.py | 1 - 1 file changed, 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 40a450df..3c451a4b 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -15,7 +15,6 @@ import ldm.modules.midas as midas from ldm.util import instantiate_from_config from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache -from modules.sd_hijack_inpainting import do_inpainting_hijack from modules.timer import Timer import tomesd -- cgit v1.2.3 From 20549a50cb3c41868ce561c6658bfaa0d20ac7ba Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 3 Aug 2023 22:46:57 +0300 Subject: add style editor dialog rework toprow for img2img and txt2img to use a class with fields fix the console error when editing checkpoint user metadata --- modules/sd_models.py | 2 +- modules/styles.py | 5 +- modules/ui.py | 230 ++++++++++--------------- modules/ui_common.py | 32 +++- modules/ui_extra_networks_checkpoints.py | 2 +- modules/ui_extra_networks_hypernets.py | 2 +- modules/ui_extra_networks_textual_inversion.py | 2 +- modules/ui_prompt_styles.py | 110 ++++++++++++ style.css | 13 ++ 9 files changed, 248 insertions(+), 150 deletions(-) create mode 100644 modules/ui_prompt_styles.py (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 8f72f21d..1d93d893 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -68,7 +68,7 @@ class CheckpointInfo: self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]' - self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) + self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else []) def register(self): checkpoints_list[self.title] = self diff --git a/modules/styles.py b/modules/styles.py index ec0e1bc5..0740fe1b 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -106,10 +106,7 @@ class StyleDatabase: if os.path.exists(path): shutil.copy(path, f"{path}.bak") - fd = os.open(path, os.O_RDWR | os.O_CREAT) - with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: - # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, - # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() + with open(path, "w", encoding="utf-8-sig", newline='') as file: writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer.writeheader() writer.writerows(style._asdict() for k, style in self.styles.items()) diff --git a/modules/ui.py b/modules/ui.py index ac2787eb..c059dcec 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,7 +12,7 @@ import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path from modules.ui_common import create_refresh_button @@ -92,19 +92,6 @@ def send_gradio_gallery_to_image(x): return image_from_url_text(x[0]) -def add_style(name: str, prompt: str, negative_prompt: str): - if name is None: - return [gr_show() for x in range(4)] - - style = modules.styles.PromptStyle(name, prompt, negative_prompt) - shared.prompt_styles.styles[style.name] = style - # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we - # reserialize all styles every time we save them - shared.prompt_styles.save_styles(shared.styles_filename) - - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] - - def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): from modules import processing, devices @@ -129,13 +116,6 @@ def resize_from_to_html(width, height, scale_by): return f"resize: from {width}x{height} to {target_width}x{target_height}" -def apply_styles(prompt, prompt_neg, styles): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) - - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] - - def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles): if mode in {0, 1, 3, 4}: return [interrogation_function(ii_singles[mode]), None] @@ -267,71 +247,67 @@ def update_token_counter(text, steps): return f"{token_count}/{max_length}" -def create_toprow(is_img2img): - id_part = "img2img" if is_img2img else "txt2img" +class Toprow: + def __init__(self, is_img2img): + id_part = "img2img" if is_img2img else "txt2img" + self.id_part = id_part - with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): - with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - - with gr.Row(): - with gr.Column(scale=80): - with gr.Row(): - negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - - button_interrogate = None - button_deepbooru = None - if is_img2img: - with gr.Column(scale=1, elem_classes="interrogate-col"): - button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") - button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") - - with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): - with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): - interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") - skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") - submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') - - skip.click( - fn=lambda: shared.state.skip(), - inputs=[], - outputs=[], - ) + with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"): + with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6): + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) - interrupt.click( - fn=lambda: shared.state.interrupt(), - inputs=[], - outputs=[], - ) + with gr.Row(): + with gr.Column(scale=80): + with gr.Row(): + self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"]) + + self.button_interrogate = None + self.button_deepbooru = None + if is_img2img: + with gr.Column(scale=1, elem_classes="interrogate-col"): + self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate") + self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") + + with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"): + with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"): + self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt") + self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip") + self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') + + self.skip.click( + fn=lambda: shared.state.skip(), + inputs=[], + outputs=[], + ) - with gr.Row(elem_id=f"{id_part}_tools"): - paste = ToolButton(value=paste_symbol, elem_id="paste") - clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") - extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") - prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply") - save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create") - restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) - - token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) - token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") - negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) - negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") - - clear_prompt_button.click( - fn=lambda *x: x, - _js="confirm_clear_prompt", - inputs=[prompt, negative_prompt], - outputs=[prompt, negative_prompt], - ) + self.interrupt.click( + fn=lambda: shared.state.interrupt(), + inputs=[], + outputs=[], + ) - with gr.Row(elem_id=f"{id_part}_styles_row"): - prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) - create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") + with gr.Row(elem_id=f"{id_part}_tools"): + self.paste = ToolButton(value=paste_symbol, elem_id="paste") + self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") + self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False) + + self.token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"]) + self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + self.negative_token_counter = gr.HTML(value="0/75", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"]) + self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button") + + self.clear_prompt_button.click( + fn=lambda *x: x, + _js="confirm_clear_prompt", + inputs=[self.prompt, self.negative_prompt], + outputs=[self.prompt, self.negative_prompt], + ) - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button + self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt) def setup_progressbar(*args, **kwargs): @@ -419,14 +395,14 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False) + toprow = txt2img_toprow = Toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: from modules import ui_extra_networks - extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img') + extra_networks_ui = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'txt2img') with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): @@ -532,9 +508,9 @@ def create_ui(): _js="submit", inputs=[ dummy_component, - txt2img_prompt, - txt2img_negative_prompt, - txt2img_prompt_styles, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, steps, sampler_index, restore_faces, @@ -569,12 +545,12 @@ def create_ui(): show_progress=False, ) - txt2img_prompt.submit(**txt2img_args) - submit.click(**txt2img_args) + toprow.prompt.submit(**txt2img_args) + toprow.submit.click(**txt2img_args) res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False) - restore_progress_button.click( + toprow.restore_progress_button.click( fn=progress.restore_progress, _js="restoreProgressTxt2img", inputs=[dummy_component], @@ -593,7 +569,7 @@ def create_ui(): txt_prompt_img ], outputs=[ - txt2img_prompt, + toprow.prompt, txt_prompt_img ], show_progress=False, @@ -607,8 +583,8 @@ def create_ui(): ) txt2img_paste_fields = [ - (txt2img_prompt, "Prompt"), - (txt2img_negative_prompt, "Negative prompt"), + (toprow.prompt, "Prompt"), + (toprow.negative_prompt, "Negative prompt"), (steps, "Steps"), (sampler_index, "Sampler"), (restore_faces, "Face restoration"), @@ -621,7 +597,7 @@ def create_ui(): (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), - (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), @@ -639,12 +615,12 @@ def create_ui(): ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None, + paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None, )) txt2img_preview_params = [ - txt2img_prompt, - txt2img_negative_prompt, + toprow.prompt, + toprow.negative_prompt, steps, sampler_index, cfg_scale, @@ -653,8 +629,8 @@ def create_ui(): height, ] - token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) - negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) @@ -662,13 +638,13 @@ def create_ui(): modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True) + toprow = img2img_toprow = Toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: from modules import ui_extra_networks - extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img') + extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'img2img') with FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): @@ -889,7 +865,7 @@ def create_ui(): img2img_prompt_img ], outputs=[ - img2img_prompt, + toprow.prompt, img2img_prompt_img ], show_progress=False, @@ -901,9 +877,9 @@ def create_ui(): inputs=[ dummy_component, dummy_component, - img2img_prompt, - img2img_negative_prompt, - img2img_prompt_styles, + toprow.prompt, + toprow.negative_prompt, + toprow.ui_styles.dropdown, init_img, sketch, init_img_with_mask, @@ -962,11 +938,11 @@ def create_ui(): inpaint_color_sketch, init_img_inpaint, ], - outputs=[img2img_prompt, dummy_component], + outputs=[toprow.prompt, dummy_component], ) - img2img_prompt.submit(**img2img_args) - submit.click(**img2img_args) + toprow.prompt.submit(**img2img_args) + toprow.submit.click(**img2img_args) res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False) @@ -978,7 +954,7 @@ def create_ui(): show_progress=False, ) - restore_progress_button.click( + toprow.restore_progress_button.click( fn=progress.restore_progress, _js="restoreProgressImg2img", inputs=[dummy_component], @@ -991,46 +967,24 @@ def create_ui(): show_progress=False, ) - img2img_interrogate.click( + toprow.button_interrogate.click( fn=lambda *args: process_interrogate(interrogate, *args), **interrogate_args, ) - img2img_deepbooru.click( + toprow.button_deepbooru.click( fn=lambda *args: process_interrogate(interrogate_deepbooru, *args), **interrogate_args, ) - prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles] - style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] - - for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): - button.click( - fn=add_style, - _js="ask_for_style_name", - # Have to pass empty dummy component here, because the JavaScript and Python function have to accept - # the same number of parameters, but we only know the style-name after the JavaScript prompt - inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_styles, img2img_prompt_styles], - ) - - for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): - button.click( - fn=apply_styles, - _js=js_func, - inputs=[prompt, negative_prompt, styles], - outputs=[prompt, negative_prompt, styles], - ) - - token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) - negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter]) + toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter]) + toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter]) ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) img2img_paste_fields = [ - (img2img_prompt, "Prompt"), - (img2img_negative_prompt, "Negative prompt"), + (toprow.prompt, "Prompt"), + (toprow.negative_prompt, "Negative prompt"), (steps, "Steps"), (sampler_index, "Sampler"), (restore_faces, "Face restoration"), @@ -1044,7 +998,7 @@ def create_ui(): (subseed_strength, "Variation seed strength"), (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), - (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), + (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), *modules.scripts.scripts_img2img.infotext_fields @@ -1052,7 +1006,7 @@ def create_ui(): parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( - paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None, + paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, )) modules.scripts.scripts_current = None diff --git a/modules/ui_common.py b/modules/ui_common.py index 11eb2a4b..ba75fa73 100644 --- a/modules/ui_common.py +++ b/modules/ui_common.py @@ -223,20 +223,44 @@ Requested path was: {f} def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component] + + label = None + for comp in refresh_components: + label = getattr(comp, 'label', None) + if label is not None: + break + def refresh(): refresh_method() args = refreshed_args() if callable(refreshed_args) else refreshed_args for k, v in args.items(): - setattr(refresh_component, k, v) + for comp in refresh_components: + setattr(comp, k, v) - return gr.update(**(args or {})) + return [gr.update(**(args or {})) for _ in refresh_components] - refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id) + refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh") refresh_button.click( fn=refresh, inputs=[], - outputs=[refresh_component] + outputs=[*refresh_components] ) return refresh_button + +def setup_dialog(button_show, dialog, *, button_close=None): + """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window.""" + + dialog.visible = False + + button_show.click( + fn=lambda: gr.update(visible=True), + inputs=[], + outputs=[dialog], + ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }") + + if button_close: + button_close.click(fn=None, _js="closePopup") + diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 2bb0a222..891d8f2c 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -12,7 +12,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.refresh_checkpoints() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name) path, ext = os.path.splitext(checkpoint.filename) return { diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py index e53ccb42..514a4562 100644 --- a/modules/ui_extra_networks_hypernets.py +++ b/modules/ui_extra_networks_hypernets.py @@ -11,7 +11,7 @@ class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): def refresh(self): shared.reload_hypernetworks() - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): full_path = shared.hypernetworks[name] path, ext = os.path.splitext(full_path) diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py index d1794e50..73134698 100644 --- a/modules/ui_extra_networks_textual_inversion.py +++ b/modules/ui_extra_networks_textual_inversion.py @@ -12,7 +12,7 @@ class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): def refresh(self): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) - def create_item(self, name, index=None): + def create_item(self, name, index=None, enable_filter=True): embedding = sd_hijack.model_hijack.embedding_db.word_embeddings.get(name) path, ext = os.path.splitext(embedding.filename) diff --git a/modules/ui_prompt_styles.py b/modules/ui_prompt_styles.py new file mode 100644 index 00000000..85eb3a64 --- /dev/null +++ b/modules/ui_prompt_styles.py @@ -0,0 +1,110 @@ +import gradio as gr + +from modules import shared, ui_common, ui_components, styles + +styles_edit_symbol = '\U0001f58c\uFE0F' # 🖌️ +styles_materialize_symbol = '\U0001f4cb' # 📋 + + +def select_style(name): + style = shared.prompt_styles.styles.get(name) + existing = style is not None + empty = not name + + prompt = style.prompt if style else gr.update() + negative_prompt = style.negative_prompt if style else gr.update() + + return prompt, negative_prompt, gr.update(visible=existing), gr.update(visible=not empty) + + +def save_style(name, prompt, negative_prompt): + if not name: + return gr.update(visible=False) + + style = styles.PromptStyle(name, prompt, negative_prompt) + shared.prompt_styles.styles[style.name] = style + shared.prompt_styles.save_styles(shared.styles_filename) + + return gr.update(visible=True) + + +def delete_style(name): + if name == "": + return + + shared.prompt_styles.styles.pop(name, None) + shared.prompt_styles.save_styles(shared.styles_filename) + + return '', '', '' + + +def materialize_styles(prompt, negative_prompt, styles): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + negative_prompt = shared.prompt_styles.apply_negative_styles_to_prompt(negative_prompt, styles) + + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=negative_prompt), gr.Dropdown.update(value=[])] + + +def refresh_styles(): + return gr.update(choices=list(shared.prompt_styles.styles)), gr.update(choices=list(shared.prompt_styles.styles)) + + +class UiPromptStyles: + def __init__(self, tabname, main_ui_prompt, main_ui_negative_prompt): + self.tabname = tabname + + with gr.Row(elem_id=f"{tabname}_styles_row"): + self.dropdown = gr.Dropdown(label="Styles", show_label=False, elem_id=f"{tabname}_styles", choices=list(shared.prompt_styles.styles), value=[], multiselect=True, tooltip="Styles") + edit_button = ui_components.ToolButton(value=styles_edit_symbol, elem_id=f"{tabname}_styles_edit_button", tooltip="Edit styles") + + with gr.Box(elem_id=f"{tabname}_styles_dialog", elem_classes="popup-dialog") as styles_dialog: + with gr.Row(): + self.selection = gr.Dropdown(label="Styles", elem_id=f"{tabname}_styles_edit_select", choices=list(shared.prompt_styles.styles), value=[], allow_custom_value=True, info="Styles allow you to add custom text to prompt. Use the {prompt} token in style text, and it will be replaced with user's prompt when applying style. Otherwise, style's text will be added to the end of the prompt.") + ui_common.create_refresh_button([self.dropdown, self.selection], shared.prompt_styles.reload, lambda: {"choices": list(shared.prompt_styles.styles)}, f"refresh_{tabname}_styles") + self.materialize = ui_components.ToolButton(value=styles_materialize_symbol, elem_id=f"{tabname}_style_apply", tooltip="Apply all selected styles from the style selction dropdown in main UI to the prompt.") + + with gr.Row(): + self.prompt = gr.Textbox(label="Prompt", show_label=True, elem_id=f"{tabname}_edit_style_prompt", lines=3) + + with gr.Row(): + self.neg_prompt = gr.Textbox(label="Negative prompt", show_label=True, elem_id=f"{tabname}_edit_style_neg_prompt", lines=3) + + with gr.Row(): + self.save = gr.Button('Save', variant='primary', elem_id=f'{tabname}_edit_style_save', visible=False) + self.delete = gr.Button('Delete', variant='primary', elem_id=f'{tabname}_edit_style_delete', visible=False) + self.close = gr.Button('Close', variant='secondary', elem_id=f'{tabname}_edit_style_close') + + self.selection.change( + fn=select_style, + inputs=[self.selection], + outputs=[self.prompt, self.neg_prompt, self.delete, self.save], + show_progress=False, + ) + + self.save.click( + fn=save_style, + inputs=[self.selection, self.prompt, self.neg_prompt], + outputs=[self.delete], + show_progress=False, + ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) + + self.delete.click( + fn=delete_style, + _js='function(name){ if(name == "") return ""; return confirm("Delete style " + name + "?") ? name : ""; }', + inputs=[self.selection], + outputs=[self.selection, self.prompt, self.neg_prompt], + show_progress=False, + ).then(refresh_styles, outputs=[self.dropdown, self.selection], show_progress=False) + + self.materialize.click( + fn=materialize_styles, + inputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + outputs=[main_ui_prompt, main_ui_negative_prompt, self.dropdown], + show_progress=False, + ).then(fn=None, _js="function(){update_"+tabname+"_tokens(); closePopup();}", show_progress=False) + + ui_common.setup_dialog(button_show=edit_button, dialog=styles_dialog, button_close=self.close) + + + + diff --git a/style.css b/style.css index 6c92d6e7..cf8470e4 100644 --- a/style.css +++ b/style.css @@ -972,3 +972,16 @@ div.block.gradio-box.edit-user-metadata { .edit-user-metadata-buttons{ margin-top: 1.5em; } + + + + +div.block.gradio-box.popup-dialog, .popup-dialog { + width: 56em; + background: var(--body-background-fill); + padding: 2em !important; +} + +div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-child{ + margin-top: 1em; +} -- cgit v1.2.3 From 24f21583cdba2ae6cc51773b956c6ce068d3dfe4 Mon Sep 17 00:00:00 2001 From: AnyISalIn Date: Fri, 4 Aug 2023 11:43:27 +0800 Subject: fix: prevent cache model.state_dict() after model hijack Signed-off-by: AnyISalIn --- modules/sd_models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 1d93d893..ba15b451 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -303,12 +303,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer sd_models_xl.extend_sdxl(model) model.load_state_dict(state_dict, strict=False) - del state_dict timer.record("apply weights to model") if shared.opts.sd_checkpoint_cache > 0: # cache newly loaded model - checkpoints_loaded[checkpoint_info] = model.state_dict().copy() + checkpoints_loaded[checkpoint_info] = state_dict + + del state_dict if shared.cmd_opts.opt_channelslast: model.to(memory_format=torch.channels_last) -- cgit v1.2.3