From 5034f7d7597685aaa4779296983be0f49f4f991f Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 27 Sep 2022 15:56:18 -0400 Subject: added token counter next to txt2img and img2img prompts --- modules/sd_hijack.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7b2030d4..4d799ac0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,6 +180,7 @@ class StableDiffusionModelHijack: dir_mtime = None layers = None circular_enabled = False + clip = None def load_textual_inversion_embeddings(self, dirname, model): mt = os.path.getmtime(dirname) @@ -242,6 +243,7 @@ class StableDiffusionModelHijack: model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + self.clip = m.cond_stage_model if cmd_opts.opt_split_attention_v1: ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1 @@ -268,6 +270,11 @@ class StableDiffusionModelHijack: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]: layer.padding_mode = 'circular' if enable else 'zeros' + def tokenize(self, text): + max_length = self.clip.max_length - 2 + _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) + return {"tokens": remade_batch_tokens[0], "token_count":token_count, "max_length":max_length} + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): @@ -294,14 +301,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def forward(self, text): - self.hijack.fixes = [] - self.hijack.comments = [] - remade_batch_tokens = [] + def process_text(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id maxlen = self.wrapped.max_length used_custom_terms = [] + remade_batch_tokens = [] + overflowing_words = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 cache = {} batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"] @@ -353,9 +362,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ovf = remade_tokens[maxlen - 2:] overflowing_words = [vocab.get(int(x), "") for x in ovf] overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) - - self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") - + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + token_count = len(remade_tokens) remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end] cache[tuple_tokens] = (remade_tokens, fixes, multipliers) @@ -364,8 +372,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] remade_batch_tokens.append(remade_tokens) - self.hijack.fixes.append(fixes) + hijack_fixes.append(fixes) batch_multipliers.append(multipliers) + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + def forward(self, text): + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + self.hijack.fixes = hijack_fixes + self.hijack.comments = hijack_comments if len(used_custom_terms) > 0: self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) -- cgit v1.2.3 From e5707b66d6db2c019bfccf66f9ba53e3daaea40b Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 27 Sep 2022 19:29:53 -0400 Subject: switched the token counter to use hidden buttons instead of api call --- javascript/helpers.js | 13 ----------- javascript/ui.js | 60 +++++++++++++-------------------------------------- modules/sd_hijack.py | 3 +-- modules/ui.py | 13 +++++++---- 4 files changed, 25 insertions(+), 64 deletions(-) delete mode 100644 javascript/helpers.js (limited to 'modules/sd_hijack.py') diff --git a/javascript/helpers.js b/javascript/helpers.js deleted file mode 100644 index 1b26931f..00000000 --- a/javascript/helpers.js +++ /dev/null @@ -1,13 +0,0 @@ -// helper functions - -function debounce(func, wait_time) { - let timeout; - return function wrapped(...args) { - let call_function = () => { - clearTimeout(timeout); - func(...args) - } - clearTimeout(timeout); - timeout = setTimeout(call_function, wait_time); - }; -} \ No newline at end of file diff --git a/javascript/ui.js b/javascript/ui.js index fbe5a11d..6cfa5c08 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -182,51 +182,21 @@ onUiUpdate(function(){ }); json_elem.parentElement.style.display="none" - - let debounce_time = 800 - if (!txt2img_textarea) { - txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea") - txt2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "txt2img"), debounce_time)) - } - if (!img2img_textarea) { - img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea") - img2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "img2img"), debounce_time)) - } }) +let wait_time = 800 +let token_timeout; +function txt2img_token_counter(text) { + return update_token_counter("txt2img_token_button", text); +} + +function img2img_token_counter(text) { + return update_token_counter("img2img_token_button", text); +} -let txt2img_textarea, img2img_textarea = undefined; -function submit_prompt_text(source, e) { - let prompt_text; - if (source == "txt2img") - prompt_text = txt2img_textarea.value; - else if (source == "img2img") - prompt_text = img2img_textarea.value; - if (!prompt_text) - return; - params = { - method: "POST", - headers: { - "Accept": "application/json", - "Content-type": "application/json" - }, - body: JSON.stringify({data:[prompt_text]}) - } - fetch('http://127.0.0.1:7860/api/tokenize/', params) - .then((response) => response.json()) - .then((data) => { - if (data?.data.length) { - let response_json = data.data[0] - if (elem = gradioApp().getElementById(source+"_token_counter")) { - if (response_json.token_count > response_json.max_length) - elem.classList.add("red"); - else - elem.classList.remove("red"); - elem.innerText = response_json.token_count + "/" + response_json.max_length; - } - } - }) - .catch((error) => { - console.error('Error:', error); - }); -} \ No newline at end of file +function update_token_counter(button_id, text) { + if (token_timeout) + clearTimeout(token_timeout); + token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time); + return []; +} diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 4d799ac0..bfbd07f9 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -273,8 +273,7 @@ class StableDiffusionModelHijack: def tokenize(self, text): max_length = self.clip.max_length - 2 _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) - return {"tokens": remade_batch_tokens[0], "token_count":token_count, "max_length":max_length} - + return remade_batch_tokens[0], token_count, max_length class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): diff --git a/modules/ui.py b/modules/ui.py index 9a3d69c8..15bfd697 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,6 +23,7 @@ from modules.shared import opts, cmd_opts import modules.shared as shared from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_hijack import model_hijack +from modules.helpers import debounce import modules.ldsr_model import modules.scripts import modules.gfpgan_model @@ -330,6 +331,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: outputs=[seed, dummy_component] ) +def update_token_counter(text): + tokens, token_count, max_length = model_hijack.tokenize(text) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" def create_toprow(is_img2img): id_part = "img2img" if is_img2img else "txt2img" @@ -339,15 +344,15 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=id_part+"_prompt", show_label=False, placeholder="Prompt", lines=2) + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2) + prompt.change(fn=lambda *args: [], _js=f"{id_part}_token_counter", inputs=[prompt], outputs=[], preprocess=False) with gr.Column(scale=1, elem_id="roll_col"): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_output = gr.JSON(visible=False) - if is_img2img: # only define the api function ONCE - token_counter.change(fn=model_hijack.tokenize, api_name="tokenize", inputs=[token_counter], outputs=[token_output]) + hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter]) with gr.Column(scale=10, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) -- cgit v1.2.3 From c1c27dad3ba371a5ae344b267c760aa51e77f193 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 29 Sep 2022 11:31:48 +0300 Subject: new implementation for attention/emphasis --- modules/prompt_parser.py | 87 +++++++++++++++++++++++++++++++++++++++- modules/sd_hijack.py | 102 +++++++++++++++++++++++++++++++++++++++++++++-- modules/shared.py | 3 +- 3 files changed, 186 insertions(+), 6 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/prompt_parser.py b/modules/prompt_parser.py index a6a25b28..f3c50adf 100644 --- a/modules/prompt_parser.py +++ b/modules/prompt_parser.py @@ -126,5 +126,90 @@ def reconstruct_cond_batch(c: ScheduledPromptBatch, current_step): return res +re_attention = re.compile(r""" +\\\(| +\\\)| +\\\[| +\\]| +\\\\| +\\| +\(| +\[| +:([+-]?[.\d]+)\)| +\)| +]| +[^\\()\[\]:]+| +: +""", re.X) + + +def parse_prompt_attention(text): + """ + Parses a string with attention tokens and returns a list of pairs: text and its assoicated weight. + Accepted tokens are: + (abc) - increases attention to abc by a multiplier of 1.1 + (abc:3.12) - increases attention to abc by a multiplier of 3.12 + [abc] - decreases attention to abc by a multiplier of 1.1 + \( - literal character '(' + \[ - literal character '[' + \) - literal character ')' + \] - literal character ']' + \\ - literal character '\' + anything else - just text + + Example: + + 'a (((house:1.3)) [on] a (hill:0.5), sun, (((sky))).' + + produces: + + [ + ['a ', 1.0], + ['house', 1.5730000000000004], + [' ', 1.1], + ['on', 1.0], + [' a ', 1.1], + ['hill', 0.55], + [', sun, ', 1.1], + ['sky', 1.4641000000000006], + ['.', 1.1] + ] + """ -#get_learned_conditioning_prompt_schedules(["fantasy landscape with a [mountain:lake:0.25] and [an oak:a christmas tree:0.75][ in foreground::0.6][ in background:0.25] [shoddy:masterful:0.5]"], 100) + res = [] + round_brackets = [] + square_brackets = [] + + round_bracket_multiplier = 1.1 + square_bracket_multiplier = 1 / 1.1 + + def multiply_range(start_position, multiplier): + for p in range(start_position, len(res)): + res[p][1] *= multiplier + + for m in re_attention.finditer(text): + text = m.group(0) + weight = m.group(1) + + if text.startswith('\\'): + res.append([text[1:], 1.0]) + elif text == '(': + round_brackets.append(len(res)) + elif text == '[': + square_brackets.append(len(res)) + elif weight is not None and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), float(weight)) + elif text == ')' and len(round_brackets) > 0: + multiply_range(round_brackets.pop(), round_bracket_multiplier) + elif text == ']' and len(square_brackets) > 0: + multiply_range(square_brackets.pop(), square_bracket_multiplier) + else: + res.append([text, 1.0]) + + for pos in round_brackets: + multiply_range(pos, round_bracket_multiplier) + + for pos in square_brackets: + multiply_range(pos, square_bracket_multiplier) + + return res diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index bfbd07f9..2848a251 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -6,6 +6,7 @@ import torch import numpy as np from torch import einsum +from modules import prompt_parser from modules.shared import opts, device, cmd_opts from ldm.util import default @@ -211,6 +212,7 @@ class StableDiffusionModelHijack: param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11 assert len(param_dict) == 1, 'embedding file has multiple terms in it' emb = next(iter(param_dict.items()))[1] + # diffuser concepts elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: assert len(data.keys()) == 1, 'embedding file has multiple terms in it' @@ -236,7 +238,7 @@ class StableDiffusionModelHijack: print(traceback.format_exc(), file=sys.stderr) continue - print(f"Loaded a total of {len(self.word_embeddings)} text inversion embeddings.") + print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.") def hijack(self, m): model_embeddings = m.cond_stage_model.transformer.text_model.embeddings @@ -275,6 +277,7 @@ class StableDiffusionModelHijack: _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text]) return remade_batch_tokens[0], token_count, max_length + class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): def __init__(self, wrapped, hijack): super().__init__() @@ -300,7 +303,92 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0: self.token_mults[ident] = mult - def process_text(self, text): + + def tokenize_line(self, line, used_custom_terms, hijack_comments): + id_start = self.wrapped.tokenizer.bos_token_id + id_end = self.wrapped.tokenizer.eos_token_id + maxlen = self.wrapped.max_length + + if opts.enable_emphasis: + parsed = prompt_parser.parse_prompt_attention(line) + else: + parsed = [[line, 1.0]] + + tokenized = self.wrapped.tokenizer([text for text, _ in parsed], truncation=False, add_special_tokens=False)["input_ids"] + + fixes = [] + remade_tokens = [] + multipliers = [] + + for tokens, (text, weight) in zip(tokenized, parsed): + i = 0 + while i < len(tokens): + token = tokens[i] + + possible_matches = self.hijack.ids_lookup.get(token, None) + + if possible_matches is None: + remade_tokens.append(token) + multipliers.append(weight) + else: + found = False + for ids, word in possible_matches: + if tokens[i:i + len(ids)] == ids: + emb_len = int(self.hijack.word_embeddings[word].shape[0]) + fixes.append((len(remade_tokens), word)) + remade_tokens += [0] * emb_len + multipliers += [weight] * emb_len + i += len(ids) - 1 + found = True + used_custom_terms.append((word, self.hijack.word_embeddings_checksums[word])) + break + + if not found: + remade_tokens.append(token) + multipliers.append(weight) + i += 1 + + if len(remade_tokens) > maxlen - 2: + vocab = {v: k for k, v in self.wrapped.tokenizer.get_vocab().items()} + ovf = remade_tokens[maxlen - 2:] + overflowing_words = [vocab.get(int(x), "") for x in ovf] + overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words)) + hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n") + + token_count = len(remade_tokens) + remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens)) + remade_tokens = [id_start] + remade_tokens[0:maxlen - 2] + [id_end] + + multipliers = multipliers + [1.0] * (maxlen - 2 - len(multipliers)) + multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0] + + return remade_tokens, fixes, multipliers, token_count + + def process_text(self, texts): + used_custom_terms = [] + remade_batch_tokens = [] + hijack_comments = [] + hijack_fixes = [] + token_count = 0 + + cache = {} + batch_multipliers = [] + for line in texts: + if line in cache: + remade_tokens, fixes, multipliers = cache[line] + else: + remade_tokens, fixes, multipliers, token_count = self.tokenize_line(line, used_custom_terms, hijack_comments) + + cache[line] = (remade_tokens, fixes, multipliers) + + remade_batch_tokens.append(remade_tokens) + hijack_fixes.append(fixes) + batch_multipliers.append(multipliers) + + return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count + + + def process_text_old(self, text): id_start = self.wrapped.tokenizer.bos_token_id id_end = self.wrapped.tokenizer.eos_token_id maxlen = self.wrapped.max_length @@ -376,12 +464,18 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count def forward(self, text): - batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + + if opts.use_old_emphasis_implementation: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text) + else: + batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text) + + self.hijack.fixes = hijack_fixes self.hijack.comments = hijack_comments if len(used_custom_terms) > 0: - self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) + self.hijack.comments.append("Used embeddings: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms])) tokens = torch.asarray(remade_batch_tokens).to(device) outputs = self.wrapped.transformer(input_ids=tokens) diff --git a/modules/shared.py b/modules/shared.py index ec1e569b..f88c2b02 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -195,7 +195,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "save_images_before_color_correction": OptionInfo(False, "Save a copy of image before applying color correction to img2img results"), "img2img_fix_steps": OptionInfo(False, "With img2img, do exactly the amount of steps the slider specifies (normally you'd do less with less denoising)."), "enable_quantization": OptionInfo(False, "Enable quantization in K samplers for sharper and cleaner results. This may change existing seeds. Requires restart to apply."), - "enable_emphasis": OptionInfo(True, "Use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "enable_emphasis": OptionInfo(True, "Eemphasis: use (text) to make model pay more attention to text and [text] to make it pay less attention"), + "use_old_emphasis_implementation": OptionInfo(False, "Use old emphasis implementation. Can be useful to reproduce old seeds."), "enable_batch_seeds": OptionInfo(True, "Make K-diffusion samplers produce same images in a batch as when making a single image"), "filter_nsfw": OptionInfo(False, "Filter NSFW content"), "random_artist_categories": OptionInfo([], "Allowed categories for random artists selection when using the Roll button", gr.CheckboxGroup, {"choices": artist_db.categories()}), -- cgit v1.2.3 From c715ef04d1edb1a112a602639ed3bb292fdeb0e2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 29 Sep 2022 15:40:28 +0300 Subject: fix for incorrect model weight loading for #814 --- modules/sd_hijack.py | 9 +++++++++ modules/sd_models.py | 6 +++++- 2 files changed, 14 insertions(+), 1 deletion(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 2848a251..5945b7c2 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -245,6 +245,7 @@ class StableDiffusionModelHijack: model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self) m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self) + self.clip = m.cond_stage_model if cmd_opts.opt_split_attention_v1: @@ -263,6 +264,14 @@ class StableDiffusionModelHijack: self.layers = flatten(m) + def undo_hijack(self, m): + if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords: + m.cond_stage_model = m.cond_stage_model.wrapped + + model_embeddings = m.cond_stage_model.transformer.text_model.embeddings + if type(model_embeddings.token_embedding) == EmbeddingsWithFixes: + model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped + def apply_circular(self, enable): if self.circular_enabled == enable: return diff --git a/modules/sd_models.py b/modules/sd_models.py index 7a5edced..eb21e498 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -137,7 +137,7 @@ def load_model(): def reload_model_weights(sd_model, info=None): - from modules import lowvram, devices + from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() if sd_model.sd_model_checkpint == checkpoint_info.filename: @@ -148,8 +148,12 @@ def reload_model_weights(sd_model, info=None): else: sd_model.to(devices.cpu) + sd_hijack.model_hijack.undo_hijack(sd_model) + load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash) + sd_hijack.model_hijack.hijack(sd_model) + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) -- cgit v1.2.3