diff options
Diffstat (limited to 'modules/styles.py')
-rw-r--r-- | modules/styles.py | 54 |
1 files changed, 10 insertions, 44 deletions
diff --git a/modules/styles.py b/modules/styles.py index 4d218cd7..81d9800d 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -2,7 +2,6 @@ import csv import fnmatch
import os
import os.path
-import re
import typing
import shutil
@@ -14,22 +13,6 @@ class PromptStyle(typing.NamedTuple): path: str = None
-def clean_text(text: str) -> str:
- """
- Iterating through a list of regular expressions and replacement strings, we
- clean up the prompt and style text to make it easier to match against each
- other.
- """
- re_list = [
- ("multiple commas", re.compile("(,+\s+)+,?"), ", "),
- ("multiple spaces", re.compile("\s{2,}"), " "),
- ]
- for _, regex, replace in re_list:
- text = regex.sub(replace, text)
-
- return text.strip(", ")
-
-
def merge_prompts(style_prompt: str, prompt: str) -> str:
if "{prompt}" in style_prompt:
res = style_prompt.replace("{prompt}", prompt)
@@ -44,7 +27,7 @@ def apply_styles_to_prompt(prompt, styles): for style in styles:
prompt = merge_prompts(style, prompt)
- return clean_text(prompt)
+ return prompt
def unwrap_style_text_from_prompt(style_text, prompt):
@@ -56,8 +39,8 @@ def unwrap_style_text_from_prompt(style_text, prompt): Note that the "cleaned" version of the style text is only used for matching
purposes here. It isn't returned; the original style text is not modified.
"""
- stripped_prompt = clean_text(prompt)
- stripped_style_text = clean_text(style_text)
+ stripped_prompt = prompt
+ stripped_style_text = style_text
if "{prompt}" in stripped_style_text:
# Work out whether the prompt is wrapped in the style text. If so, we
# return True and the "inner" prompt text that isn't part of the style.
@@ -115,10 +98,8 @@ class StyleDatabase: self.path = path
folder, file = os.path.split(self.path)
- self.default_file = file.split("*")[0] + ".csv"
- if self.default_file == ".csv":
- self.default_file = "styles.csv"
- self.default_path = os.path.join(folder, self.default_file)
+ filename, _, ext = file.partition('*')
+ self.default_path = os.path.join(folder, filename + ext)
self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]
@@ -172,10 +153,8 @@ class StyleDatabase: row["name"], prompt, negative_prompt, path
)
- def get_style_paths(self) -> list():
- """
- Returns a list of all distinct paths, including the default path, of
- files that styles are loaded from."""
+ def get_style_paths(self) -> set:
+ """Returns a set of all distinct paths of files that styles are loaded from."""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
@@ -189,9 +168,9 @@ class StyleDatabase: style_paths.add(style.path)
# Remove any paths for styles that are just list dividers
- style_paths.remove("do_not_save")
+ style_paths.discard("do_not_save")
- return list(style_paths)
+ return style_paths
def get_style_prompts(self, styles):
return [self.styles.get(x, self.no_style).prompt for x in styles]
@@ -213,20 +192,7 @@ class StyleDatabase: # The path argument is deprecated, but kept for backwards compatibility
_ = path
- # Update any styles without a path to the default path
- for style in list(self.styles.values()):
- if not style.path:
- self.styles[style.name] = style._replace(path=self.default_path)
-
- # Create a list of all distinct paths, including the default path
- style_paths = set()
- style_paths.add(self.default_path)
- for _, style in self.styles.items():
- if style.path:
- style_paths.add(style.path)
-
- # Remove any paths for styles that are just list dividers
- style_paths.remove("do_not_save")
+ style_paths = self.get_style_paths()
csv_names = [os.path.split(path)[1].lower() for path in style_paths]
|