From 5034f7d7597685aaa4779296983be0f49f4f991f Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 27 Sep 2022 15:56:18 -0400 Subject: added token counter next to txt2img and img2img prompts --- modules/ui.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py index f7ca5588..3b9c8525 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -22,6 +22,7 @@ from modules.paths import script_path from modules.shared import opts, cmd_opts import modules.shared as shared from modules.sd_samplers import samplers, samplers_for_img2img +from modules.sd_hijack import model_hijack import modules.ldsr_model import modules.scripts import modules.gfpgan_model @@ -337,11 +338,15 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2) + prompt = gr.Textbox(label="Prompt", elem_id=id_part+"_prompt", show_label=False, placeholder="Prompt", lines=2) with gr.Column(scale=1, elem_id="roll_col"): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") + token_output = gr.JSON(visible=False) + if is_img2img: # only define the api function ONCE + token_counter.change(fn=model_hijack.tokenize, api_name="tokenize", inputs=[token_counter], outputs=[token_output]) with gr.Column(scale=10, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) -- cgit v1.2.3 From e5707b66d6db2c019bfccf66f9ba53e3daaea40b Mon Sep 17 00:00:00 2001 From: Liam Date: Tue, 27 Sep 2022 19:29:53 -0400 Subject: switched the token counter to use hidden buttons instead of api call --- modules/ui.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py index 9a3d69c8..15bfd697 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,6 +23,7 @@ from modules.shared import opts, cmd_opts import modules.shared as shared from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_hijack import model_hijack +from modules.helpers import debounce import modules.ldsr_model import modules.scripts import modules.gfpgan_model @@ -330,6 +331,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: outputs=[seed, dummy_component] ) +def update_token_counter(text): + tokens, token_count, max_length = model_hijack.tokenize(text) + style_class = ' class="red"' if (token_count > max_length) else "" + return f"{token_count}/{max_length}" def create_toprow(is_img2img): id_part = "img2img" if is_img2img else "txt2img" @@ -339,15 +344,15 @@ def create_toprow(is_img2img): with gr.Row(): with gr.Column(scale=80): with gr.Row(): - prompt = gr.Textbox(label="Prompt", elem_id=id_part+"_prompt", show_label=False, placeholder="Prompt", lines=2) + prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2) + prompt.change(fn=lambda *args: [], _js=f"{id_part}_token_counter", inputs=[prompt], outputs=[], preprocess=False) with gr.Column(scale=1, elem_id="roll_col"): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) paste = gr.Button(value=paste_symbol, elem_id="paste") token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") - token_output = gr.JSON(visible=False) - if is_img2img: # only define the api function ONCE - token_counter.change(fn=model_hijack.tokenize, api_name="tokenize", inputs=[token_counter], outputs=[token_output]) + hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") + hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter]) with gr.Column(scale=10, elem_id="style_pos_col"): prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1) -- cgit v1.2.3 From 7ca9858c4c05b67089b095142ff792e07b5962a9 Mon Sep 17 00:00:00 2001 From: Liam Date: Wed, 28 Sep 2022 09:43:54 -0400 Subject: removed unused import; now using javascript to watch prompt textarea --- modules/ui.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py index 15bfd697..4e24eb55 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -23,7 +23,6 @@ from modules.shared import opts, cmd_opts import modules.shared as shared from modules.sd_samplers import samplers, samplers_for_img2img from modules.sd_hijack import model_hijack -from modules.helpers import debounce import modules.ldsr_model import modules.scripts import modules.gfpgan_model @@ -345,7 +344,6 @@ def create_toprow(is_img2img): with gr.Column(scale=80): with gr.Row(): prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2) - prompt.change(fn=lambda *args: [], _js=f"{id_part}_token_counter", inputs=[prompt], outputs=[], preprocess=False) with gr.Column(scale=1, elem_id="roll_col"): roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0) -- cgit v1.2.3