From f98f4f73aa4898c754681f411608df5f248619f6 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 4 Jun 2023 10:56:48 +0300 Subject: infer styles from prompts, and an option to control the behavior --- modules/styles.py | 67 +++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 2 deletions(-) (limited to 'modules/styles.py') diff --git a/modules/styles.py b/modules/styles.py index 34e1b5e1..ec0e1bc5 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -1,6 +1,7 @@ import csv import os import os.path +import re import typing import shutil @@ -28,6 +29,44 @@ def apply_styles_to_prompt(prompt, styles): return prompt +re_spaces = re.compile(" +") + + +def extract_style_text_from_prompt(style_text, prompt): + stripped_prompt = re.sub(re_spaces, " ", prompt.strip()) + stripped_style_text = re.sub(re_spaces, " ", style_text.strip()) + if "{prompt}" in stripped_style_text: + left, right = stripped_style_text.split("{prompt}", 2) + if stripped_prompt.startswith(left) and stripped_prompt.endswith(right): + prompt = stripped_prompt[len(left):len(stripped_prompt)-len(right)] + return True, prompt + else: + if stripped_prompt.endswith(stripped_style_text): + prompt = stripped_prompt[:len(stripped_prompt)-len(stripped_style_text)] + + if prompt.endswith(', '): + prompt = prompt[:-2] + + return True, prompt + + return False, prompt + + +def extract_style_from_prompts(style: PromptStyle, prompt, negative_prompt): + if not style.prompt and not style.negative_prompt: + return False, prompt, negative_prompt + + match_positive, extracted_positive = extract_style_text_from_prompt(style.prompt, prompt) + if not match_positive: + return False, prompt, negative_prompt + + match_negative, extracted_negative = extract_style_text_from_prompt(style.negative_prompt, negative_prompt) + if not match_negative: + return False, prompt, negative_prompt + + return True, extracted_positive, extracted_negative + + class StyleDatabase: def __init__(self, path: str): self.no_style = PromptStyle("None", "", "") @@ -67,10 +106,34 @@ class StyleDatabase: if os.path.exists(path): shutil.copy(path, f"{path}.bak") - fd = os.open(path, os.O_RDWR|os.O_CREAT) + fd = os.open(path, os.O_RDWR | os.O_CREAT) with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file: # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) writer.writeheader() - writer.writerows(style._asdict() for k, style in self.styles.items()) + writer.writerows(style._asdict() for k, style in self.styles.items()) + + def extract_styles_from_prompt(self, prompt, negative_prompt): + extracted = [] + + applicable_styles = list(self.styles.values()) + + while True: + found_style = None + + for style in applicable_styles: + is_match, new_prompt, new_neg_prompt = extract_style_from_prompts(style, prompt, negative_prompt) + if is_match: + found_style = style + prompt = new_prompt + negative_prompt = new_neg_prompt + break + + if not found_style: + break + + applicable_styles.remove(found_style) + extracted.append(found_style.name) + + return list(reversed(extracted)), prompt, negative_prompt -- cgit v1.2.3