From 9f267af3f7404d8d8a9123e8e1c07a6557eba54d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 14 Sep 2022 17:56:21 +0300 Subject: added a second style field added the ability to use {prompt} in styles added a button to apply style to textbox rearranged top row for UI --- modules/styles.py | 96 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 57 insertions(+), 39 deletions(-) (limited to 'modules/styles.py') diff --git a/modules/styles.py b/modules/styles.py index bc7f070f..eeedcd08 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -20,49 +20,67 @@ class PromptStyle(typing.NamedTuple): negative_prompt: str -def load_styles(path: str) -> dict[str, PromptStyle]: - styles = {"None": PromptStyle("None", "", "")} +def merge_prompts(style_prompt: str, prompt: str) -> str: + if "{prompt}" in style_prompt: + res = style_prompt.replace("{prompt}", prompt) + else: + parts = filter(None, (prompt.strip(), style_prompt.strip())) + res = ", ".join(parts) - if os.path.exists(path): - with open(path, "r", encoding="utf8", newline='') as file: - reader = csv.DictReader(file) - for row in reader: - # Support loading old CSV format with "name, text"-columns - prompt = row["prompt"] if "prompt" in row else row["text"] - negative_prompt = row.get("negative_prompt", "") - styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) + return res - return styles +def apply_styles_to_prompt(prompt, styles): + for style in styles: + prompt = merge_prompts(style, prompt) -def merge_prompts(style_prompt: str, prompt: str) -> str: - parts = filter(None, (prompt.strip(), style_prompt.strip())) - return ", ".join(parts) + return prompt -def apply_style(processing: StableDiffusionProcessing, style: PromptStyle) -> None: - if isinstance(processing.prompt, list): - processing.prompt = [merge_prompts(style.prompt, p) for p in processing.prompt] - else: - processing.prompt = merge_prompts(style.prompt, processing.prompt) +class StyleDatabase: + def __init__(self, path: str): + self.no_style = PromptStyle("None", "", "") + self.styles = {"None": self.no_style} - if isinstance(processing.negative_prompt, list): - processing.negative_prompt = [merge_prompts(style.negative_prompt, p) for p in processing.negative_prompt] - else: - processing.negative_prompt = merge_prompts(style.negative_prompt, processing.negative_prompt) - - -def save_styles(path: str, styles: abc.Iterable[PromptStyle]) -> None: - # Write to temporary file first, so we don't nuke the file if something goes wrong - fd, temp_path = tempfile.mkstemp(".csv") - with os.fdopen(fd, "w", encoding="utf8", newline='') as file: - # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, - # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() - writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) - writer.writeheader() - writer.writerows(style._asdict() for style in styles) - - # Always keep a backup file around - if os.path.exists(path): - shutil.move(path, path + ".bak") - shutil.move(temp_path, path) + if not os.path.exists(path): + return + + with open(path, "r", encoding="utf8", newline='') as file: + reader = csv.DictReader(file) + for row in reader: + # Support loading old CSV format with "name, text"-columns + prompt = row["prompt"] if "prompt" in row else row["text"] + negative_prompt = row.get("negative_prompt", "") + self.styles[row["name"]] = PromptStyle(row["name"], prompt, negative_prompt) + + def apply_styles_to_prompt(self, prompt, styles): + return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).prompt for x in styles]) + + def apply_negative_styles_to_prompt(self, prompt, styles): + return apply_styles_to_prompt(prompt, [self.styles.get(x, self.no_style).negative_prompt for x in styles]) + + def apply_styles(self, p: StableDiffusionProcessing) -> None: + if isinstance(p.prompt, list): + p.prompt = [self.apply_styles_to_prompt(prompt, p.styles) for prompt in p.prompt] + else: + p.prompt = self.apply_styles_to_prompt(p.prompt, p.styles) + + if isinstance(p.negative_prompt, list): + p.negative_prompt = [self.apply_negative_styles_to_prompt(prompt, p.styles) for prompt in p.negative_prompt] + else: + p.negative_prompt = self.apply_negative_styles_to_prompt(p.negative_prompt, p.styles) + + def save_styles(self, path: str) -> None: + # Write to temporary file first, so we don't nuke the file if something goes wrong + fd, temp_path = tempfile.mkstemp(".csv") + with os.fdopen(fd, "w", encoding="utf8", newline='') as file: + # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple, + # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict() + writer = csv.DictWriter(file, fieldnames=PromptStyle._fields) + writer.writeheader() + writer.writerows(style._asdict() for k, style in self.styles.items()) + + # Always keep a backup file around + if os.path.exists(path): + shutil.move(path, path + ".bak") + shutil.move(temp_path, path) -- cgit v1.2.3