diff options
author | Liam <liamthekerr@gmail.com> | 2022-09-27 19:56:18 +0000 |
---|---|---|
committer | Liam <liamthekerr@gmail.com> | 2022-09-27 19:56:18 +0000 |
commit | 5034f7d7597685aaa4779296983be0f49f4f991f (patch) | |
tree | a4d2fe6104e5034fae22476d7705634a551359ff | |
parent | ca3e5519e8b6dc020c5e7ae508738afb5dc6f3ec (diff) | |
download | stable-diffusion-webui-gfx803-5034f7d7597685aaa4779296983be0f49f4f991f.tar.gz stable-diffusion-webui-gfx803-5034f7d7597685aaa4779296983be0f49f4f991f.tar.bz2 stable-diffusion-webui-gfx803-5034f7d7597685aaa4779296983be0f49f4f991f.zip |
added token counter next to txt2img and img2img prompts
-rw-r--r-- | javascript/helpers.js | 13 | ||||
-rw-r--r-- | javascript/ui.js | 47 | ||||
-rw-r--r-- | modules/sd_hijack.py | 30 | ||||
-rw-r--r-- | modules/ui.py | 7 | ||||
-rw-r--r-- | style.css | 4 |
5 files changed, 92 insertions, 9 deletions
diff --git a/javascript/helpers.js b/javascript/helpers.js new file mode 100644 index 00000000..1b26931f --- /dev/null +++ b/javascript/helpers.js @@ -0,0 +1,13 @@ +// helper functions + +function debounce(func, wait_time) { + let timeout; + return function wrapped(...args) { + let call_function = () => { + clearTimeout(timeout); + func(...args) + } + clearTimeout(timeout); + timeout = setTimeout(call_function, wait_time); + }; +}
\ No newline at end of file diff --git a/javascript/ui.js b/javascript/ui.js index 076e9436..77e0f4c1 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -183,4 +183,51 @@ onUiUpdate(function(){ }); json_elem.parentElement.style.display="none" + + let debounce_time = 800 + if (!txt2img_textarea) { + txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea") + txt2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "txt2img"), debounce_time)) + } + if (!img2img_textarea) { + img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea") + img2img_textarea?.addEventListener("input", debounce(submit_prompt_text.bind(null, "img2img"), debounce_time)) + } }) + + +let txt2img_textarea, img2img_textarea = undefined; +function submit_prompt_text(source, e) { + let prompt_text; + if (source == "txt2img") + prompt_text = txt2img_textarea.value; + else if (source == "img2img") + prompt_text = img2img_textarea.value; + if (!prompt_text) + return; + params = { + method: "POST", + headers: { + "Accept": "application/json", + "Content-type": "application/json" + }, + body: JSON.stringify({data:[prompt_text]}) + } + fetch('http://127.0.0.1:7860/api/tokenize/', params) + .then((response) => response.json()) + .then((data) => { + if (data?.data.length) { + let response_json = data.data[0] + if (elem = gradioApp().getElementById(source+"_token_counter")) { + if (response_json.token_count > response_json.max_length) + elem.classList.add("red"); + else + elem.classList.remove("red"); + elem.innerText = response_json.token_count + "/" + response_json.max_length; + } + } + }) + .catch((error) => { + console.error('Error:', error); + }); +}
\ No newline at end of file diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 7b2030d4..4d799ac0 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -180,6 +180,7 @@ class StableDiffusionModelHijack: dir_mtime = None
layers = None
circular_enabled = False
+ clip = None
def load_textual_inversion_embeddings(self, dirname, model):
mt = os.path.getmtime(dirname)
@@ -242,6 +243,7 @@ class StableDiffusionModelHijack: model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
+ self.clip = m.cond_stage_model
if cmd_opts.opt_split_attention_v1:
ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
@@ -268,6 +270,11 @@ class StableDiffusionModelHijack: for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
layer.padding_mode = 'circular' if enable else 'zeros'
+ def tokenize(self, text):
+ max_length = self.clip.max_length - 2
+ _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
+ return {"tokens": remade_batch_tokens[0], "token_count":token_count, "max_length":max_length}
+
class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
def __init__(self, wrapped, hijack):
@@ -294,14 +301,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): if mult != 1.0:
self.token_mults[ident] = mult
- def forward(self, text):
- self.hijack.fixes = []
- self.hijack.comments = []
- remade_batch_tokens = []
+ def process_text(self, text):
id_start = self.wrapped.tokenizer.bos_token_id
id_end = self.wrapped.tokenizer.eos_token_id
maxlen = self.wrapped.max_length
used_custom_terms = []
+ remade_batch_tokens = []
+ overflowing_words = []
+ hijack_comments = []
+ hijack_fixes = []
+ token_count = 0
cache = {}
batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
@@ -353,9 +362,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): ovf = remade_tokens[maxlen - 2:]
overflowing_words = [vocab.get(int(x), "") for x in ovf]
overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))
-
- self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
-
+ hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+ token_count = len(remade_tokens)
remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
@@ -364,8 +372,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module): multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]
remade_batch_tokens.append(remade_tokens)
- self.hijack.fixes.append(fixes)
+ hijack_fixes.append(fixes)
batch_multipliers.append(multipliers)
+ return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count
+
+ def forward(self, text):
+ batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
+ self.hijack.fixes = hijack_fixes
+ self.hijack.comments = hijack_comments
if len(used_custom_terms) > 0:
self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
diff --git a/modules/ui.py b/modules/ui.py index f7ca5588..3b9c8525 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -22,6 +22,7 @@ from modules.paths import script_path from modules.shared import opts, cmd_opts
import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img
+from modules.sd_hijack import model_hijack
import modules.ldsr_model
import modules.scripts
import modules.gfpgan_model
@@ -337,11 +338,15 @@ def create_toprow(is_img2img): with gr.Row():
with gr.Column(scale=80):
with gr.Row():
- prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2)
+ prompt = gr.Textbox(label="Prompt", elem_id=id_part+"_prompt", show_label=False, placeholder="Prompt", lines=2)
with gr.Column(scale=1, elem_id="roll_col"):
roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
paste = gr.Button(value=paste_symbol, elem_id="paste")
+ token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
+ token_output = gr.JSON(visible=False)
+ if is_img2img: # only define the api function ONCE
+ token_counter.change(fn=model_hijack.tokenize, api_name="tokenize", inputs=[token_counter], outputs=[token_output])
with gr.Column(scale=10, elem_id="style_pos_col"):
prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
@@ -389,3 +389,7 @@ input[type="range"]{ border-radius: 8px;
display: none;
}
+
+.red {
+ color: red;
+}
|