From da464a3fb39ecc6ea7b22fe87271194480d8501c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 12 Jul 2023 23:52:43 +0300 Subject: SDXL support --- modules/sd_hijack_clip.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 3b5a7666..6c17a81d 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -42,6 +42,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): self.hijack: sd_hijack.StableDiffusionModelHijack = hijack self.chunk_length = 75 + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + def empty_chunk(self): """creates an empty PromptChunk and returns it""" @@ -199,8 +203,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): """ Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will - be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, and for SD2 it's 1024. + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. An example shape returned by this function can be: (2, 77, 768). + For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" """ @@ -233,7 +238,10 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) self.hijack.comments.append(f"Used embeddings: {embeddings_list}") - return torch.hstack(zs) + if getattr(self.wrapped, 'return_pooled', False): + return torch.hstack(zs), zs[0].pooled + else: + return torch.hstack(zs) def process_tokens(self, remade_batch_tokens, batch_multipliers): """ @@ -256,9 +264,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise batch_multipliers = torch.asarray(batch_multipliers).to(devices.device) original_mean = z.mean() - z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) + z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape) new_mean = z.mean() - z = z * (original_mean / new_mean) + z *= (original_mean / new_mean) return z -- cgit v1.2.3 From 594c8e7b263d9b37f4b18b56b159aeb6d1bba1b4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 13 Jul 2023 11:35:52 +0300 Subject: fix CLIP doing the unneeded normalization revert SD2.1 back to use the original repo add SDXL's force_zero_embeddings to negative prompt --- modules/processing.py | 2 +- modules/prompt_parser.py | 14 ++++++++++---- modules/sd_hijack.py | 2 +- modules/sd_hijack_clip.py | 15 +++++++++++++++ modules/sd_models_config.py | 1 - modules/sd_models_xl.py | 3 ++- 6 files changed, 29 insertions(+), 8 deletions(-) (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/processing.py b/modules/processing.py index 85d35423..f01a6907 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -344,7 +344,7 @@ class StableDiffusionProcessing: def setup_conds(self): prompts = prompt_parser.SdConditioning(self.prompts, width=self.width, height=self.height) - negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height) + negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True) sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index 33810669..b29d079d 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -116,11 +116,17 @@ class SdConditioning(list): A list with prompts for stable diffusion's conditioner model. Can also specify width and height of created image - SDXL needs it. """ - def __init__(self, prompts, width=None, height=None): + def __init__(self, prompts, is_negative_prompt=False, width=None, height=None, copy_from=None): super().__init__() self.extend(prompts) - self.width = width or getattr(prompts, 'width', None) - self.height = height or getattr(prompts, 'height', None) + + if copy_from is None: + copy_from = prompts + + self.is_negative_prompt = is_negative_prompt or getattr(copy_from, 'is_negative_prompt', False) + self.width = width or getattr(copy_from, 'width', None) + self.height = height or getattr(copy_from, 'height', None) + def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): @@ -153,7 +159,7 @@ def get_learned_conditioning(model, prompts: SdConditioning | list[str], steps): res.append(cached) continue - texts = [x[1] for x in prompt_schedule] + texts = SdConditioning([x[1] for x in prompt_schedule], copy_from=prompts) conds = model.get_learned_conditioning(texts) cond_schedule = [] diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 266811f9..647cdfbe 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -190,7 +190,7 @@ class StableDiffusionModelHijack: if typename == 'FrozenCLIPEmbedder': model_embeddings = m.cond_stage_model.transformer.text_model.embeddings model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) - m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(embedder, self) + m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self) conditioner.embedders[i] = m.cond_stage_model if typename == 'FrozenOpenCLIPEmbedder2': embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self) diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 6c17a81d..b3771909 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -323,3 +323,18 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) return embedded + + +class FrozenCLIPEmbedderForSDXLWithCustomWords(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + def encode_with_transformers(self, tokens): + outputs = self.wrapped.transformer(input_ids=tokens, output_hidden_states=self.wrapped.layer == "hidden") + + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + + return z diff --git a/modules/sd_models_config.py b/modules/sd_models_config.py index 2e92479a..04c09ab0 100644 --- a/modules/sd_models_config.py +++ b/modules/sd_models_config.py @@ -12,7 +12,6 @@ sd_xl_repo_configs_path = os.path.join(paths.paths['Stable Diffusion XL'], "conf config_default = shared.sd_default_config config_sd2 = os.path.join(sd_repo_configs_path, "v2-inference.yaml") config_sd2v = os.path.join(sd_repo_configs_path, "v2-inference-v.yaml") -config_sd2v = os.path.join(sd_xl_repo_configs_path, "sd_2_1_768.yaml") config_sd2_inpainting = os.path.join(sd_repo_configs_path, "v2-inpainting-inference.yaml") config_sdxl = os.path.join(sd_xl_repo_configs_path, "sd_xl_base.yaml") config_depth_model = os.path.join(sd_repo_configs_path, "v2-midas-inference.yaml") diff --git a/modules/sd_models_xl.py b/modules/sd_models_xl.py index 1dd4459f..b799ff46 100644 --- a/modules/sd_models_xl.py +++ b/modules/sd_models_xl.py @@ -22,7 +22,8 @@ def get_learned_conditioning(self: sgm.models.diffusion.DiffusionEngine, batch: "target_size_as_tuple": torch.tensor([height, width]).repeat(len(batch), 1).to(devices.device, devices.dtype), } - c = self.conditioner(sdxl_conds) + force_zero_negative_prompt = getattr(batch, 'is_negative_prompt', False) and all(x == '' for x in batch) + c = self.conditioner(sdxl_conds, force_zero_embeddings=['txt'] if force_zero_negative_prompt else []) return c -- cgit v1.2.3 From 2b1bae0d755c2d5201f6a6aadeadb5588208d43f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 15 Jul 2023 08:41:22 +0300 Subject: add textual inversion hashes to infotext --- modules/processing.py | 7 ++++--- modules/sd_hijack.py | 5 ++++- modules/sd_hijack_clip.py | 15 ++++++++++++--- modules/shared.py | 1 + modules/textual_inversion/textual_inversion.py | 9 ++++++++- style.css | 4 ++++ 6 files changed, 33 insertions(+), 8 deletions(-) (limited to 'modules/sd_hijack_clip.py') diff --git a/modules/processing.py b/modules/processing.py index cd568a20..49441e77 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -732,9 +732,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.setup_conds() - if len(model_hijack.comments) > 0: - for comment in model_hijack.comments: - comments[comment] = 1 + for comment in model_hijack.comments: + comments[comment] = 1 + + p.extra_generation_params.update(model_hijack.extra_generation_params) if p.n_iter > 1: shared.state.job = f"Batch {n+1} out of {p.n_iter}" diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 3b6f95ce..6b5aae4b 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -147,7 +147,6 @@ def undo_weighted_forward(sd_model): class StableDiffusionModelHijack: fixes = None - comments = [] layers = None circular_enabled = False clip = None @@ -156,6 +155,9 @@ class StableDiffusionModelHijack: embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() def __init__(self): + self.extra_generation_params = {} + self.comments = [] + self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) def apply_optimizations(self, option=None): @@ -236,6 +238,7 @@ class StableDiffusionModelHijack: def clear_comments(self): self.comments = [] + self.extra_generation_params = {} def get_prompt_lengths(self, text): if self.clip is None: diff --git a/modules/sd_hijack_clip.py b/modules/sd_hijack_clip.py index 3b5a7666..c1d780a3 100644 --- a/modules/sd_hijack_clip.py +++ b/modules/sd_hijack_clip.py @@ -229,9 +229,18 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): z = self.process_tokens(tokens, multipliers) zs.append(z) - if len(used_embeddings) > 0: - embeddings_list = ", ".join([f'{name} [{embedding.checksum()}]' for name, embedding in used_embeddings.items()]) - self.hijack.comments.append(f"Used embeddings: {embeddings_list}") + if opts.textual_inversion_add_hashes_to_infotext and used_embeddings: + hashes = [] + for name, embedding in used_embeddings.items(): + shorthash = embedding.shorthash + if not shorthash: + continue + + name = name.replace(":", "").replace(",", "") + hashes.append(f"{name}: {shorthash}") + + if hashes: + self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) return torch.hstack(zs) diff --git a/modules/shared.py b/modules/shared.py index 48478a68..a32fd4ed 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -472,6 +472,7 @@ options_templates.update(options_section(('extra_networks', "Extra Networks"), { "extra_networks_card_height": OptionInfo(0, "Card height for Extra Networks").info("in pixels"), "extra_networks_add_text_separator": OptionInfo(" ", "Extra networks separator").info("extra text to add before <...> when adding extra network to prompt"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order").needs_restart(), + "textual_inversion_add_hashes_to_infotext": OptionInfo(True, "Add Textual Inversion hashes to infotext"), "sd_hypernetwork": OptionInfo("None", "Add hypernetwork to prompt", gr.Dropdown, lambda: {"choices": ["None", *hypernetworks]}, refresh=reload_hypernetworks), })) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index cbe975b7..38e072a8 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors +from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -49,6 +49,8 @@ class Embedding: self.sd_checkpoint_name = None self.optimizer_state_dict = None self.filename = None + self.hash = None + self.shorthash = None def save(self, filename): embedding_data = { @@ -82,6 +84,10 @@ class Embedding: self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}' return self.cached_checksum + def set_hash(self, v): + self.hash = v + self.shorthash = self.hash[0:12] + class DirWithTextualInversionEmbeddings: def __init__(self, path): @@ -199,6 +205,7 @@ class EmbeddingDatabase: embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] embedding.filename = path + embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '') if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) diff --git a/style.css b/style.css index a424067f..9e13d7fd 100644 --- a/style.css +++ b/style.css @@ -231,6 +231,10 @@ button.custom-button{ padding-top: 0.5em; } +.html-log .comments:empty{ + padding-top: 0; +} + .html-log .performance { font-size: 0.85em; color: #444; -- cgit v1.2.3