diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/img2imgalt.py | 48 | ||||
-rw-r--r-- | scripts/loopback.py | 4 | ||||
-rw-r--r-- | scripts/outpainting_mk_2.py | 139 | ||||
-rw-r--r-- | scripts/prompts_from_file.py | 115 | ||||
-rw-r--r-- | scripts/xy_grid.py | 206 |
5 files changed, 369 insertions, 143 deletions
diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index f9894cb0..88abc093 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -34,6 +34,9 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps): sigma_in = torch.cat([sigmas[i] * s_in] * 2)
cond_in = torch.cat([uncond, cond])
+ image_conditioning = torch.cat([p.image_conditioning] * 2)
+ cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
+
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
t = dnw.sigma_to_t(sigma_in)
@@ -78,6 +81,9 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps): sigma_in = torch.cat([sigmas[i - 1] * s_in] * 2)
cond_in = torch.cat([uncond, cond])
+ image_conditioning = torch.cat([p.image_conditioning] * 2)
+ cond_in = {"c_concat": [image_conditioning], "c_crossattn": [cond_in]}
+
c_out, c_in = [K.utils.append_dims(k, x_in.ndim) for k in dnw.get_scalings(sigma_in)]
if i == 1:
@@ -120,17 +126,45 @@ class Script(scripts.Script): return is_img2img
def ui(self, is_img2img):
+ info = gr.Markdown('''
+ * `CFG Scale` should be 2 or lower.
+ ''')
+
+ override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True)
+
+ override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True)
original_prompt = gr.Textbox(label="Original prompt", lines=1)
original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1)
- cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
+
+ override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True)
st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50)
+
+ override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True)
+
+ cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0)
randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0)
sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False)
- return [original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment]
- def run(self, p, original_prompt, original_negative_prompt, cfg, st, randomness, sigma_adjustment):
- p.batch_size = 1
- p.batch_count = 1
+ return [
+ info,
+ override_sampler,
+ override_prompt, original_prompt, original_negative_prompt,
+ override_steps, st,
+ override_strength,
+ cfg, randomness, sigma_adjustment,
+ ]
+
+ def run(self, p, _, override_sampler, override_prompt, original_prompt, original_negative_prompt, override_steps, st, override_strength, cfg, randomness, sigma_adjustment):
+ # Override
+ if override_sampler:
+ p.sampler_index = [sampler.name for sampler in sd_samplers.samplers].index("Euler")
+ if override_prompt:
+ p.prompt = original_prompt
+ p.negative_prompt = original_negative_prompt
+ if override_steps:
+ p.steps = st
+ if override_strength:
+ p.denoising_strength = 1.0
def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength):
@@ -154,7 +188,7 @@ class Script(scripts.Script): rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)
- rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], [p.seed + x + 1 for x in range(p.init_latent.shape[0])])
+ rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)
combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)
@@ -166,7 +200,7 @@ class Script(scripts.Script): p.seed = p.seed + 1
- return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning)
+ return sampler.sample_img2img(p, p.init_latent, noise_dt, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)
p.sample = sample_extra
diff --git a/scripts/loopback.py b/scripts/loopback.py index e90b58d4..d8c68af8 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -38,6 +38,7 @@ class Script(scripts.Script): grids = []
all_images = []
+ original_init_image = p.init_images
state.job_count = loops * batch_count
initial_color_corrections = [processing.setup_color_correction(p.init_images[0])]
@@ -45,6 +46,9 @@ class Script(scripts.Script): for n in range(batch_count):
history = []
+ # Reset to original init image at the start of each batch
+ p.init_images = original_init_image
+
for i in range(loops):
p.n_iter = 1
p.batch_size = 1
diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index a6468e09..2afd4aa5 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -172,54 +172,54 @@ class Script(scripts.Script): if down > 0:
down = target_h - init_img.height - up
- init_image = p.init_images[0]
-
- state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)
-
- def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
+ def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
is_horiz = is_left or is_right
is_vert = is_top or is_bottom
pixels_horiz = expand_pixels if is_horiz else 0
pixels_vert = expand_pixels if is_vert else 0
- res_w = init.width + pixels_horiz
- res_h = init.height + pixels_vert
- process_res_w = math.ceil(res_w / 64) * 64
- process_res_h = math.ceil(res_h / 64) * 64
-
- img = Image.new("RGB", (process_res_w, process_res_h))
- img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
- mask = Image.new("RGB", (process_res_w, process_res_h), "white")
- draw = ImageDraw.Draw(mask)
- draw.rectangle((
- expand_pixels + mask_blur if is_left else 0,
- expand_pixels + mask_blur if is_top else 0,
- mask.width - expand_pixels - mask_blur if is_right else res_w,
- mask.height - expand_pixels - mask_blur if is_bottom else res_h,
- ), fill="black")
-
- np_image = (np.asarray(img) / 255.0).astype(np.float64)
- np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
- noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
- out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
-
- target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width
- target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height
-
- crop_region = (
- 0 if is_left else out.width - target_width,
- 0 if is_top else out.height - target_height,
- target_width if is_left else out.width,
- target_height if is_top else out.height,
- )
-
- image_to_process = out.crop(crop_region)
- mask = mask.crop(crop_region)
-
- p.width = target_width if is_horiz else img.width
- p.height = target_height if is_vert else img.height
- p.init_images = [image_to_process]
- p.image_mask = mask
+ images_to_process = []
+ output_images = []
+ for n in range(count):
+ res_w = init[n].width + pixels_horiz
+ res_h = init[n].height + pixels_vert
+ process_res_w = math.ceil(res_w / 64) * 64
+ process_res_h = math.ceil(res_h / 64) * 64
+
+ img = Image.new("RGB", (process_res_w, process_res_h))
+ img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
+ mask = Image.new("RGB", (process_res_w, process_res_h), "white")
+ draw = ImageDraw.Draw(mask)
+ draw.rectangle((
+ expand_pixels + mask_blur if is_left else 0,
+ expand_pixels + mask_blur if is_top else 0,
+ mask.width - expand_pixels - mask_blur if is_right else res_w,
+ mask.height - expand_pixels - mask_blur if is_bottom else res_h,
+ ), fill="black")
+
+ np_image = (np.asarray(img) / 255.0).astype(np.float64)
+ np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
+ noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
+ output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))
+
+ target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
+ target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
+ p.width = target_width if is_horiz else img.width
+ p.height = target_height if is_vert else img.height
+
+ crop_region = (
+ 0 if is_left else output_images[n].width - target_width,
+ 0 if is_top else output_images[n].height - target_height,
+ target_width if is_left else output_images[n].width,
+ target_height if is_top else output_images[n].height,
+ )
+ mask = mask.crop(crop_region)
+ p.image_mask = mask
+
+ image_to_process = output_images[n].crop(crop_region)
+ images_to_process.append(image_to_process)
+
+ p.init_images = images_to_process
latent_mask = Image.new("RGB", (p.width, p.height), "white")
draw = ImageDraw.Draw(latent_mask)
@@ -232,31 +232,52 @@ class Script(scripts.Script): p.latent_mask = latent_mask
proc = process_images(p)
- proc_img = proc.images[0]
if initial_seed_and_info[0] is None:
initial_seed_and_info[0] = proc.seed
initial_seed_and_info[1] = proc.info
- out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height))
- out = out.crop((0, 0, res_w, res_h))
- return out
+ for n in range(count):
+ output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
+ output_images[n] = output_images[n].crop((0, 0, res_w, res_h))
- img = init_image
+ return output_images
- if left > 0:
- img = expand(img, left, is_left=True)
- if right > 0:
- img = expand(img, right, is_right=True)
- if up > 0:
- img = expand(img, up, is_top=True)
- if down > 0:
- img = expand(img, down, is_bottom=True)
+ batch_count = p.n_iter
+ batch_size = p.batch_size
+ p.n_iter = 1
+ state.job_count = batch_count * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
+ all_processed_images = []
+
+ for i in range(batch_count):
+ imgs = [init_img] * batch_size
+ state.job = f"Batch {i + 1} out of {batch_count}"
+
+ if left > 0:
+ imgs = expand(imgs, batch_size, left, is_left=True)
+ if right > 0:
+ imgs = expand(imgs, batch_size, right, is_right=True)
+ if up > 0:
+ imgs = expand(imgs, batch_size, up, is_top=True)
+ if down > 0:
+ imgs = expand(imgs, batch_size, down, is_bottom=True)
- res = Processed(p, [img], initial_seed_and_info[0], initial_seed_and_info[1])
+ all_processed_images += imgs
+
+ all_images = all_processed_images
+
+ combined_grid_image = images.image_grid(all_processed_images)
+ unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
+ if opts.return_grid and not unwanted_grid_because_of_img_count:
+ all_images = [combined_grid_image] + all_processed_images
+
+ res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])
if opts.samples_save:
- images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
+ for img in all_processed_images:
+ images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)
- return res
+ if opts.grid_save and not unwanted_grid_because_of_img_count:
+ images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)
+ return res
diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index b24f1a80..1266be6f 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -1,7 +1,9 @@ +import copy
import math
import os
import sys
import traceback
+import shlex
import modules.scripts as scripts
import gradio as gr
@@ -10,6 +12,75 @@ from modules.processing import Processed, process_images from PIL import Image
from modules.shared import opts, cmd_opts, state
+
+def process_string_tag(tag):
+ return tag
+
+
+def process_int_tag(tag):
+ return int(tag)
+
+
+def process_float_tag(tag):
+ return float(tag)
+
+
+def process_boolean_tag(tag):
+ return True if (tag == "true") else False
+
+
+prompt_tags = {
+ "sd_model": None,
+ "outpath_samples": process_string_tag,
+ "outpath_grids": process_string_tag,
+ "prompt_for_display": process_string_tag,
+ "prompt": process_string_tag,
+ "negative_prompt": process_string_tag,
+ "styles": process_string_tag,
+ "seed": process_int_tag,
+ "subseed_strength": process_float_tag,
+ "subseed": process_int_tag,
+ "seed_resize_from_h": process_int_tag,
+ "seed_resize_from_w": process_int_tag,
+ "sampler_index": process_int_tag,
+ "batch_size": process_int_tag,
+ "n_iter": process_int_tag,
+ "steps": process_int_tag,
+ "cfg_scale": process_float_tag,
+ "width": process_int_tag,
+ "height": process_int_tag,
+ "restore_faces": process_boolean_tag,
+ "tiling": process_boolean_tag,
+ "do_not_save_samples": process_boolean_tag,
+ "do_not_save_grid": process_boolean_tag
+}
+
+
+def cmdargs(line):
+ args = shlex.split(line)
+ pos = 0
+ res = {}
+
+ while pos < len(args):
+ arg = args[pos]
+
+ assert arg.startswith("--"), f'must start with "--": {arg}'
+ tag = arg[2:]
+
+ func = prompt_tags.get(tag, None)
+ assert func, f'unknown commandline option: {arg}'
+
+ assert pos+1 < len(args), f'missing argument for command line option {arg}'
+
+ val = args[pos+1]
+
+ res[tag] = func(val)
+
+ pos += 2
+
+ return res
+
+
class Script(scripts.Script):
def title(self):
return "Prompts from file or textbox"
@@ -32,26 +103,48 @@ class Script(scripts.Script): return [ gr.Checkbox.update(visible = True), gr.File.update(visible = not checkbox_txt), gr.TextArea.update(visible = checkbox_txt) ]
def run(self, p, checkbox_txt, data: bytes, prompt_txt: str):
- if (checkbox_txt):
+ if checkbox_txt:
lines = [x.strip() for x in prompt_txt.splitlines()]
else:
lines = [x.strip() for x in data.decode('utf8', errors='ignore').split("\n")]
lines = [x for x in lines if len(x) > 0]
- img_count = len(lines) * p.n_iter
- batch_count = math.ceil(img_count / p.batch_size)
- loop_count = math.ceil(batch_count / p.n_iter)
- print(f"Will process {img_count} images in {batch_count} batches.")
-
p.do_not_save_grid = True
- state.job_count = batch_count
+ job_count = 0
+ jobs = []
+
+ for line in lines:
+ if "--" in line:
+ try:
+ args = cmdargs(line)
+ except Exception:
+ print(f"Error parsing line [line] as commandline:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ args = {"prompt": line}
+ else:
+ args = {"prompt": line}
+
+ n_iter = args.get("n_iter", 1)
+ if n_iter != 1:
+ job_count += n_iter
+ else:
+ job_count += 1
+
+ jobs.append(args)
+
+ print(f"Will process {len(lines)} lines in {job_count} jobs.")
+ state.job_count = job_count
images = []
- for loop_no in range(loop_count):
- state.job = f"{loop_no + 1} out of {loop_count}"
- p.prompt = lines[loop_no*p.batch_size:(loop_no+1)*p.batch_size] * p.n_iter
- proc = process_images(p)
+ for n, args in enumerate(jobs):
+ state.job = f"{state.job_no + 1} out of {state.job_count}"
+
+ copy_p = copy.copy(p)
+ for k, v in args.items():
+ setattr(copy_p, k, v)
+
+ proc = process_images(copy_p)
images += proc.images
return Processed(p, images, p.seed, "")
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index c89ca1a9..eff0c942 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,8 +10,9 @@ import numpy as np import modules.scripts as scripts
import gradio as gr
-from modules import images, hypernetwork
-from modules.processing import process_images, Processed, get_correct_sampler
+from modules import images
+from modules.hypernetworks import hypernetwork
+from modules.processing import process_images, Processed, get_correct_sampler, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.sd_samplers
@@ -27,6 +28,9 @@ def apply_field(field): def apply_prompt(p, x, xs):
+ if xs[0] not in p.prompt and xs[0] not in p.negative_prompt:
+ raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.")
+
p.prompt = p.prompt.replace(xs[0], x)
p.negative_prompt = p.negative_prompt.replace(xs[0], x)
@@ -73,14 +77,51 @@ def apply_sampler(p, x, xs): p.sampler_index = sampler_index
+def confirm_samplers(p, xs):
+ samplers_dict = build_samplers_dict(p)
+ for x in xs:
+ if x.lower() not in samplers_dict.keys():
+ raise RuntimeError(f"Unknown sampler: {x}")
+
+
def apply_checkpoint(p, x, xs):
info = modules.sd_models.get_closet_checkpoint_match(x)
- assert info is not None, f'Checkpoint for {x} not found'
+ if info is None:
+ raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info)
+ p.sd_model = shared.sd_model
+
+
+def confirm_checkpoints(p, xs):
+ for x in xs:
+ if modules.sd_models.get_closet_checkpoint_match(x) is None:
+ raise RuntimeError(f"Unknown checkpoint: {x}")
def apply_hypernetwork(p, x, xs):
- hypernetwork.load_hypernetwork(x)
+ if x.lower() in ["", "none"]:
+ name = None
+ else:
+ name = hypernetwork.find_closest_hypernetwork_name(x)
+ if not name:
+ raise RuntimeError(f"Unknown hypernetwork: {x}")
+ hypernetwork.load_hypernetwork(name)
+
+
+def apply_hypernetwork_strength(p, x, xs):
+ hypernetwork.apply_strength(x)
+
+
+def confirm_hypernetworks(p, xs):
+ for x in xs:
+ if x.lower() in ["", "none"]:
+ continue
+ if not hypernetwork.find_closest_hypernetwork_name(x):
+ raise RuntimeError(f"Unknown hypernetwork: {x}")
+
+
+def apply_clip_skip(p, x, xs):
+ opts.data["CLIP_stop_at_last_layers"] = x
def format_value_add_label(p, opt, x):
@@ -113,38 +154,44 @@ def str_permutations(x): return x
-AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value"])
-AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value"])
+AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
+AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
axis_options = [
- AxisOption("Nothing", str, do_nothing, format_nothing),
- AxisOption("Seed", int, apply_field("seed"), format_value_add_label),
- AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label),
- AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label),
- AxisOption("Steps", int, apply_field("steps"), format_value_add_label),
- AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
- AxisOption("Prompt S/R", str, apply_prompt, format_value),
- AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list),
- AxisOption("Sampler", str, apply_sampler, format_value),
- AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
- AxisOption("Hypernetwork", str, apply_hypernetwork, format_value),
- AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label),
- AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
- AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
- AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
- AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
- AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
+ AxisOption("Nothing", str, do_nothing, format_nothing, None),
+ AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
+ AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
+ AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
+ AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
+ AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
+ AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
+ AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
+ AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
+ AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
+ AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
+ AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
+ AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
+ AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
+ AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
+ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
+ AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
+ AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
]
-def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend):
- res = []
-
+def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images):
ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
- first_processed = None
+ # Temporary list of all the images that are generated to be populated into the grid.
+ # Will be filled with empty images for any individual step that fails to process properly
+ image_cache = []
+
+ processed_result = None
+ cell_mode = "P"
+ cell_size = (1,1)
state.job_count = len(xs) * len(ys) * p.n_iter
@@ -152,22 +199,54 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend): for ix, x in enumerate(xs):
state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
- processed = cell(x, y)
- if first_processed is None:
- first_processed = processed
-
+ processed:Processed = cell(x, y)
try:
- res.append(processed.images[0])
+ # this dereference will throw an exception if the image was not processed
+ # (this happens in cases such as if the user stops the process from the UI)
+ processed_image = processed.images[0]
+
+ if processed_result is None:
+ # Use our first valid processed result as a template container to hold our full results
+ processed_result = copy(processed)
+ cell_mode = processed_image.mode
+ cell_size = processed_image.size
+ processed_result.images = [Image.new(cell_mode, cell_size)]
+
+ image_cache.append(processed_image)
+ if include_lone_images:
+ processed_result.images.append(processed_image)
+ processed_result.all_prompts.append(processed.prompt)
+ processed_result.all_seeds.append(processed.seed)
+ processed_result.infotexts.append(processed.infotexts[0])
except:
- res.append(Image.new(res[0].mode, res[0].size))
+ image_cache.append(Image.new(cell_mode, cell_size))
+
+ if not processed_result:
+ print("Unexpected error: draw_xy_grid failed to return even a single processed image")
+ return Processed()
- grid = images.image_grid(res, rows=len(ys))
+ grid = images.image_grid(image_cache, rows=len(ys))
if draw_legend:
- grid = images.draw_grid_annotations(grid, res[0].width, res[0].height, hor_texts, ver_texts)
+ grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts)
+
+ processed_result.images[0] = grid
+
+ return processed_result
+
- first_processed.images = [grid]
+class SharedSettingsStackHelper(object):
+ def __enter__(self):
+ self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers
+ self.hypernetwork = opts.sd_hypernetwork
+ self.model = shared.sd_model
+
+ def __exit__(self, exc_type, exc_value, tb):
+ modules.sd_models.reload_model_weights(self.model)
- return first_processed
+ hypernetwork.load_hypernetwork(self.hypernetwork)
+ hypernetwork.apply_strength()
+
+ opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers
re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*")
@@ -188,19 +267,21 @@ class Script(scripts.Script): x_values = gr.Textbox(label="X values", visible=False, lines=1)
with gr.Row():
- y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[4].label, visible=False, type="index", elem_id="y_type")
+ y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, visible=False, type="index", elem_id="y_type")
y_values = gr.Textbox(label="Y values", visible=False, lines=1)
draw_legend = gr.Checkbox(label='Draw legend', value=True)
+ include_lone_images = gr.Checkbox(label='Include Separate Images', value=False)
no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False)
- return [x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds]
+ return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
- def run(self, p, x_type, x_values, y_type, y_values, draw_legend, no_fixed_seeds):
+ def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
if not no_fixed_seeds:
modules.processing.fix_seed(p)
- p.batch_size = 1
+ if not opts.return_grid:
+ p.batch_size = 1
def process_axis(opt, vals):
if opt.label == 'Nothing':
@@ -256,17 +337,10 @@ class Script(scripts.Script): valslist = list(permutations(valslist))
valslist = [opt.type(x) for x in valslist]
-
+
# Confirm options are valid before starting
- if opt.label == "Sampler":
- samplers_dict = build_samplers_dict(p)
- for sampler_val in valslist:
- if sampler_val.lower() not in samplers_dict.keys():
- raise RuntimeError(f"Unknown sampler: {sampler_val}")
- elif opt.label == "Checkpoint name":
- for ckpt_val in valslist:
- if modules.sd_models.get_closet_checkpoint_match(ckpt_val) is None:
- raise RuntimeError(f"Checkpoint for {ckpt_val} not found")
+ if opt.confirm:
+ opt.confirm(p, valslist)
return valslist
@@ -277,7 +351,7 @@ class Script(scripts.Script): ys = process_axis(y_opt, y_values)
def fix_axis_seeds(axis_opt, axis_list):
- if axis_opt.label == 'Seed':
+ if axis_opt.label in ['Seed','Var. seed']:
return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list]
else:
return axis_list
@@ -293,6 +367,9 @@ class Script(scripts.Script): else:
total_steps = p.steps * len(xs) * len(ys)
+ if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr:
+ total_steps *= 2
+
print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})")
shared.total_tqdm.updateTotal(total_steps * p.n_iter)
@@ -303,22 +380,19 @@ class Script(scripts.Script): return process_images(pc)
- processed = draw_xy_grid(
- p,
- xs=xs,
- ys=ys,
- x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
- y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
- cell=cell,
- draw_legend=draw_legend
- )
+ with SharedSettingsStackHelper():
+ processed = draw_xy_grid(
+ p,
+ xs=xs,
+ ys=ys,
+ x_labels=[x_opt.format_value(p, x_opt, x) for x in xs],
+ y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
+ cell=cell,
+ draw_legend=draw_legend,
+ include_lone_images=include_lone_images
+ )
if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
- # restore checkpoint in case it was changed by axes
- modules.sd_models.reload_model_weights(shared.sd_model)
-
- hypernetwork.load_hypernetwork(opts.sd_hypernetwork)
-
return processed
|