diff options
Diffstat (limited to 'scripts/xy_grid.py')
-rw-r--r-- | scripts/xy_grid.py | 46 |
1 files changed, 43 insertions, 3 deletions
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 146663b0..6344e612 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -1,7 +1,9 @@ from collections import namedtuple
from copy import copy
+from itertools import permutations, chain
import random
-
+import csv
+from io import StringIO
from PIL import Image
import numpy as np
@@ -29,6 +31,31 @@ def apply_prompt(p, x, xs): p.negative_prompt = p.negative_prompt.replace(xs[0], x)
+def apply_order(p, x, xs):
+ token_order = []
+
+ # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
+ for token in x:
+ token_order.append((p.prompt.find(token), token))
+
+ token_order.sort(key=lambda t: t[0])
+
+ prompt_parts = []
+
+ # Split the prompt up, taking out the tokens
+ for _, token in token_order:
+ n = p.prompt.find(token)
+ prompt_parts.append(p.prompt[0:n])
+ p.prompt = p.prompt[n + len(token):]
+
+ # Rebuild the prompt with the tokens in the order we want
+ prompt_tmp = ""
+ for idx, part in enumerate(prompt_parts):
+ prompt_tmp += part
+ prompt_tmp += x[idx]
+ p.prompt = prompt_tmp + p.prompt
+
+
samplers_dict = {}
for i, sampler in enumerate(modules.sd_samplers.samplers):
samplers_dict[sampler.name.lower()] = i
@@ -60,16 +87,26 @@ def format_value_add_label(p, opt, x): def format_value(p, opt, x):
if type(x) == float:
x = round(x, 8)
-
return x
+
+def format_value_join_list(p, opt, x):
+ return ", ".join(x)
+
+
def do_nothing(p, x, xs):
pass
+
def format_nothing(p, opt, x):
return ""
+def str_permutations(x):
+ """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
+ return x
+
+
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
@@ -82,6 +119,7 @@ axis_options = [ AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
@@ -159,7 +197,7 @@ class Script(scripts.Script): if opt.label == 'Nothing':
return [0]
- valslist = [x.strip() for x in vals.split(",")]
+ valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
if opt.type == int:
valslist_ext = []
@@ -206,6 +244,8 @@ class Script(scripts.Script): valslist_ext.append(val)
valslist = valslist_ext
+ elif opt.type == str_permutations:
+ valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist]
|