aboutsummaryrefslogtreecommitdiffstats
path: root/scripts/xy_grid.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-10-06 17:30:29 +0000
committerGitHub <noreply@github.com>2022-10-06 17:30:29 +0000
commitab4ddbf333eef170804ef8de67001f77c8fdd64c (patch)
tree21cb1109f8eae463aa4066eec0926cd71ab81740 /scripts/xy_grid.py
parent2a7f48cdb8dcf9acb02610cccae0d1ee5d260bc2 (diff)
parentcf7c784fcc0c84a8a4edd8d3aca4dda4c7025c43 (diff)
downloadstable-diffusion-webui-gfx803-ab4ddbf333eef170804ef8de67001f77c8fdd64c.tar.gz
stable-diffusion-webui-gfx803-ab4ddbf333eef170804ef8de67001f77c8fdd64c.tar.bz2
stable-diffusion-webui-gfx803-ab4ddbf333eef170804ef8de67001f77c8fdd64c.zip
Merge branch 'master' into gallery-styling
Diffstat (limited to 'scripts/xy_grid.py')
-rw-r--r--scripts/xy_grid.py46
1 files changed, 43 insertions, 3 deletions
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index 146663b0..6344e612 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -1,7 +1,9 @@
from collections import namedtuple
from copy import copy
+from itertools import permutations, chain
import random
-
+import csv
+from io import StringIO
from PIL import Image
import numpy as np
@@ -29,6 +31,31 @@ def apply_prompt(p, x, xs):
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
+def apply_order(p, x, xs):
+ token_order = []
+
+ # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen
+ for token in x:
+ token_order.append((p.prompt.find(token), token))
+
+ token_order.sort(key=lambda t: t[0])
+
+ prompt_parts = []
+
+ # Split the prompt up, taking out the tokens
+ for _, token in token_order:
+ n = p.prompt.find(token)
+ prompt_parts.append(p.prompt[0:n])
+ p.prompt = p.prompt[n + len(token):]
+
+ # Rebuild the prompt with the tokens in the order we want
+ prompt_tmp = ""
+ for idx, part in enumerate(prompt_parts):
+ prompt_tmp += part
+ prompt_tmp += x[idx]
+ p.prompt = prompt_tmp + p.prompt
+
+
samplers_dict = {}
for i, sampler in enumerate(modules.sd_samplers.samplers):
samplers_dict[sampler.name.lower()] = i
@@ -60,16 +87,26 @@ def format_value_add_label(p, opt, x):
def format_value(p, opt, x):
if type(x) == float:
x = round(x, 8)
-
return x
+
+def format_value_join_list(p, opt, x):
+ return ", ".join(x)
+
+
def do_nothing(p, x, xs):
pass
+
def format_nothing(p, opt, x):
return ""
+def str_permutations(x):
+ """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
+ return x
+
+
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
@@ -82,6 +119,7 @@ axis_options = [
AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
AxisOption("Sampler", str, apply_sampler, format_value),
AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
@@ -159,7 +197,7 @@ class Script(scripts.Script):
if opt.label == 'Nothing':
return [0]
- valslist = [x.strip() for x in vals.split(",")]
+ valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))]
if opt.type == int:
valslist_ext = []
@@ -206,6 +244,8 @@ class Script(scripts.Script):
valslist_ext.append(val)
valslist = valslist_ext
+ elif opt.type == str_permutations:
+ valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist]