From 86867e153f4449167e3489323df35cf04f1fffa0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 9 Sep 2022 23:16:02 +0300 Subject: support for prompt styles fix broken prompt matrix --- modules/styles.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 modules/styles.py (limited to 'modules/styles.py') diff --git a/modules/styles.py b/modules/styles.py new file mode 100644 index 00000000..58fb7d75 --- /dev/null +++ b/modules/styles.py @@ -0,0 +1,41 @@ +import csv +import os.path +from collections import namedtuple + +PromptStyle = namedtuple("PromptStyle", ["name", "text"]) + + +def load_styles(filename): + res = {"None": PromptStyle("None", "")} + + if os.path.exists(filename): + with open(filename, "r", encoding="utf8", newline='') as file: + reader = csv.DictReader(file) + + for row in reader: + res[row["name"]] = PromptStyle(row["name"], row["text"]) + + return res + + +def apply_style_text(style_text, prompt): + return prompt + ", " + style_text if prompt else style_text + + +def apply_style(p, style): + if type(p.prompt) == list: + p.prompt = [apply_style_text(style.text, x) for x in p.prompt] + else: + p.prompt = apply_style_text(style.text, p.prompt) + + +def save_style(filename, style): + with open(filename, "a", encoding="utf8", newline='') as file: + atstart = file.tell() == 0 + + writer = csv.DictWriter(file, fieldnames=["name", "text"]) + + if atstart: + writer.writeheader() + + writer.writerow({"name": style.name, "text": style.text}) -- cgit v1.2.3