diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-08-23 21:02:43 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-08-23 21:02:43 +0000 |
commit | 7fd0f3166111c552a2ed4ee1d221583ff5cc1124 (patch) | |
tree | 3fa5fae323b04d954080ba5fcad4a20f20eca458 /webui.py | |
parent | e996f3c1189c774481458cba100dc308a98dd805 (diff) | |
download | stable-diffusion-webui-gfx803-7fd0f3166111c552a2ed4ee1d221583ff5cc1124.tar.gz stable-diffusion-webui-gfx803-7fd0f3166111c552a2ed4ee1d221583ff5cc1124.tar.bz2 stable-diffusion-webui-gfx803-7fd0f3166111c552a2ed4ee1d221583ff5cc1124.zip |
added prompt verification: if it's too long, a warning is returned in the text field along with the part of prompt that has been truncated
Diffstat (limited to 'webui.py')
-rw-r--r-- | webui.py | 34 |
1 files changed, 34 insertions, 0 deletions
@@ -45,6 +45,7 @@ parser.add_argument("--config", type=str, default="configs/stable-diffusion/v1-i parser.add_argument("--ckpt", type=str, default="models/ldm/stable-diffusion-v1/model.ckpt", help="path to checkpoint of model",)
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast")
parser.add_argument("--gfpgan-dir", type=str, help="GFPGAN directory", default='./GFPGAN')
+parser.add_argument("--no-verify-input", action='store_true', help="do not verify input to check if it's too long")
opt = parser.parse_args()
GFPGAN_dir = opt.gfpgan_dir
@@ -231,6 +232,25 @@ def draw_prompt_matrix(im, width, height, all_prompts): return result
+def check_prompt_length(prompt, comments):
+ """this function tests if prompt is too long, and if so, adds a message to comments"""
+
+ tokenizer = model.cond_stage_model.tokenizer
+ max_length = model.cond_stage_model.max_length
+
+ info = model.cond_stage_model.tokenizer([prompt], truncation=True, max_length=max_length, return_overflowing_tokens=True, padding="max_length", return_tensors="pt")
+ ovf = info['overflowing_tokens'][0]
+ overflowing_count = ovf.shape[0]
+ if overflowing_count == 0:
+ return
+
+ vocab = {v: k for k, v in tokenizer.get_vocab().items()}
+ overflowing_words = [vocab.get(int(x), "") for x in ovf]
+ overflowing_text = tokenizer.convert_tokens_to_string(''.join(overflowing_words))
+
+ comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
+
+
def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, batch_size, n_iter, steps, cfg_scale, width, height, prompt_matrix, use_GFPGAN):
"""this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch"""
@@ -248,6 +268,8 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, base_count = len(os.listdir(sample_path))
grid_count = len(os.listdir(outpath)) - 1
+ comments = []
+
prompt_matrix_parts = []
if prompt_matrix:
all_prompts = []
@@ -267,6 +289,15 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, print(f"Prompt matrix will create {len(all_prompts)} images using a total of {n_iter} batches.")
else:
+
+ if not opt.no_verify_input:
+ try:
+ check_prompt_length(prompt, comments)
+ except:
+ import traceback
+ print("Error verifying input:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
all_prompts = batch_size * n_iter * [prompt]
all_seeds = [seed + x for x in range(len(all_prompts))]
@@ -333,6 +364,9 @@ def process_images(outpath, func_init, func_sample, prompt, seed, sampler_name, Steps: {steps}, Sampler: {sampler_name}, CFG scale: {cfg_scale}, Seed: {seed}{', GFPGAN' if use_GFPGAN and GFPGAN is not None else ''}
""".strip()
+ for comment in comments:
+ info += "\n\n" + comment
+
return output_images, seed, info
|