aboutsummaryrefslogtreecommitdiffstats
path: root/scripts/xy_grid.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/xy_grid.py')
-rw-r--r--scripts/xy_grid.py39
1 files changed, 30 insertions, 9 deletions
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index dd6db81c..eccfda87 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -78,7 +78,7 @@ axis_options = [
]
-def draw_xy_grid(xs, ys, x_label, y_label, cell):
+def draw_xy_grid(p, xs, ys, x_label, y_label, cell, draw_legend):
res = []
ver_texts = [[images.GridAnnotation(y_label(y))] for y in ys]
@@ -86,7 +86,7 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
first_pocessed = None
- state.job_count = len(xs) * len(ys)
+ state.job_count = len(xs) * len(ys) * p.n_iter
for iy, y in enumerate(ys):
for ix, x in enumerate(xs):
@@ -99,7 +99,8 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
res.append(processed.images[0])
grid = images.image_grid(res, rows=len(ys))
- grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
+ if draw_legend:
+ grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
first_pocessed.images = [grid]
@@ -109,6 +110,9 @@ def draw_xy_grid(xs, ys, x_label, y_label, cell):
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*")
+re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*")
+re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*")
+
class Script(scripts.Script):
def title(self):
return "X/Y plot"
@@ -123,13 +127,14 @@ class Script(scripts.Script):
with gr.Row():
y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[4].label, visible=False, type="index", elem_id="y_type")
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
+
+ draw_legend = gr.Checkbox(label='Draw legend', value=True)
+
+ return [x_type, x_values, y_type, y_values, draw_legend]
- return [x_type, x_values, y_type, y_values]
-
- def run(self, p, x_type, x_values, y_type, y_values):
+ def run(self, p, x_type, x_values, y_type, y_values, draw_legend):
modules.processing.fix_seed(p)
p.batch_size = 1
- p.batch_count = 1
def process_axis(opt, vals):
valslist = [x.strip() for x in vals.split(",")]
@@ -139,6 +144,7 @@ class Script(scripts.Script):
for val in valslist:
m = re_range.fullmatch(val)
+ mc = re_range_count.fullmatch(val)
if m is not None:
start = int(m.group(1))
@@ -146,6 +152,12 @@ class Script(scripts.Script):
step = int(m.group(3)) if m.group(3) is not None else 1
valslist_ext += list(range(start, end, step))
+ elif mc is not None:
+ start = int(mc.group(1))
+ end = int(mc.group(2))
+ num = int(mc.group(3)) if mc.group(3) is not None else 1
+
+ valslist_ext += [int(x) for x in np.linspace(start = start, stop = end, num = num).tolist()]
else:
valslist_ext.append(val)
@@ -155,12 +167,19 @@ class Script(scripts.Script):
for val in valslist:
m = re_range_float.fullmatch(val)
+ mc = re_range_count_float.fullmatch(val)
if m is not None:
start = float(m.group(1))
end = float(m.group(2))
step = float(m.group(3)) if m.group(3) is not None else 1
valslist_ext += np.arange(start, end + step, step).tolist()
+ elif mc is not None:
+ start = float(mc.group(1))
+ end = float(mc.group(2))
+ num = int(mc.group(3)) if mc.group(3) is not None else 1
+
+ valslist_ext += np.linspace(start = start, stop = end, num = num).tolist()
else:
valslist_ext.append(val)
@@ -184,14 +203,16 @@ class Script(scripts.Script):
return process_images(pc)
processed = draw_xy_grid(
+ p,
xs=xs,
ys=ys,
x_label=lambda x: x_opt.format_value(p, x_opt, x),
y_label=lambda y: y_opt.format_value(p, y_opt, y),
- cell=cell
+ cell=cell,
+ draw_legend=draw_legend
)
if opts.grid_save:
- images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, p=p)
+ images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
return processed