From 70931652a4289e28d83869b6d10cf11e80a70345 Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Fri, 30 Sep 2022 18:02:46 -0700 Subject: [xy_grid] made -1 seed fixing apply to Var. seed too --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 146663b0..9c078888 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -218,7 +218,7 @@ class Script(scripts.Script): ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label == 'Seed': + if axis_opt.label == 'Seed' or 'Var. seed': return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] else: return axis_list -- cgit v1.2.3 From cf141157e7b49b0b3a6e57dc7aa0d1345158b4c8 Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Fri, 30 Sep 2022 22:02:29 -0700 Subject: Added X/Y plot parameters to extra_generation_params --- scripts/xy_grid.py | 8 ++++++++ 1 file changed, 8 insertions(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 9c078888..d9f8d55b 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -244,6 +244,14 @@ class Script(scripts.Script): return process_images(pc) + if not x_opt.label == 'Nothing': + p.extra_generation_params["X/Y Plot X Type"] = x_opt.label + p.extra_generation_params["X Values"] = '{' + ", ".join([f'{x}' for x in xs]) + '}' + + if not y_opt.label == 'Nothing': + p.extra_generation_params["X/Y Plot Y Type"] = y_opt.label + p.extra_generation_params["Y Values"] = '{' + ", ".join([f'{y}' for y in ys]) + '}' + processed = draw_xy_grid( p, xs=xs, -- cgit v1.2.3 From eba0c29dbc3bad8c4e32f1fa3a03dc6f9caf1f5a Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sat, 1 Oct 2022 13:56:29 -0700 Subject: Updated xy_grid infotext formatting, parser regex --- modules/generation_parameters_copypaste.py | 2 +- scripts/xy_grid.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index ac1ba7f4..39d67d94 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,7 +1,7 @@ import re import gradio as gr -re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" +re_param_code = r"\s*([\w ]+):\s*((?:{[^}]+})|(?:[^,]+))(?:,|$)" re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index d9f8d55b..f87c6c1f 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -245,12 +245,16 @@ class Script(scripts.Script): return process_images(pc) if not x_opt.label == 'Nothing': - p.extra_generation_params["X/Y Plot X Type"] = x_opt.label - p.extra_generation_params["X Values"] = '{' + ", ".join([f'{x}' for x in xs]) + '}' + p.extra_generation_params["XY Plot X Type"] = x_opt.label + p.extra_generation_params["X Values"] = '{' + x_values + '}' + if x_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: + p.extra_generation_params["Fixed X Values"] = '{' + ", ".join([str(x) for x in xs])+ '}' if not y_opt.label == 'Nothing': - p.extra_generation_params["X/Y Plot Y Type"] = y_opt.label - p.extra_generation_params["Y Values"] = '{' + ", ".join([f'{y}' for y in ys]) + '}' + p.extra_generation_params["XY Plot Y Type"] = y_opt.label + p.extra_generation_params["Y Values"] = '{' + y_values + '}' + if y_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: + p.extra_generation_params["Fixed Y Values"] = '{' + ", ".join([str(y) for y in ys])+ '}' processed = draw_xy_grid( p, -- cgit v1.2.3 From b99a4f769f11ed74df0344a23069d3858613fbef Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sat, 1 Oct 2022 14:26:12 -0700 Subject: fixed expression error in condition --- scripts/xy_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index f87c6c1f..f1f54d9c 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -218,7 +218,7 @@ class Script(scripts.Script): ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label == 'Seed' or 'Var. seed': + if axis_opt.label in ["Seed","Var. seed"]: return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] else: return axis_list -- cgit v1.2.3 From fe6e2362e8fa5d739de6997ab155a26686d20a49 Mon Sep 17 00:00:00 2001 From: RnDMonkey Date: Sun, 2 Oct 2022 22:04:28 -0700 Subject: Update xy_grid.py Changed XY Plot infotext value keys to not be so generic. --- scripts/xy_grid.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index f1f54d9c..ae011a17 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -246,15 +246,15 @@ class Script(scripts.Script): if not x_opt.label == 'Nothing': p.extra_generation_params["XY Plot X Type"] = x_opt.label - p.extra_generation_params["X Values"] = '{' + x_values + '}' + p.extra_generation_params["XY Plot X Values"] = '{' + x_values + '}' if x_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: - p.extra_generation_params["Fixed X Values"] = '{' + ", ".join([str(x) for x in xs])+ '}' + p.extra_generation_params["XY Plot Fixed X Values"] = '{' + ", ".join([str(x) for x in xs])+ '}' if not y_opt.label == 'Nothing': p.extra_generation_params["XY Plot Y Type"] = y_opt.label - p.extra_generation_params["Y Values"] = '{' + y_values + '}' + p.extra_generation_params["XY Plot Y Values"] = '{' + y_values + '}' if y_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: - p.extra_generation_params["Fixed Y Values"] = '{' + ", ".join([str(y) for y in ys])+ '}' + p.extra_generation_params["XY Plot Fixed Y Values"] = '{' + ", ".join([str(y) for y in ys])+ '}' processed = draw_xy_grid( p, -- cgit v1.2.3 From ef27a18b6b7cb1a8eebdc9b2e88d25baf2c2414d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 19:42:10 +0300 Subject: Hires fix rework --- modules/generation_parameters_copypaste.py | 32 ++++++++++++++ modules/images.py | 24 +++++++++-- modules/processing.py | 68 ++++++++++++------------------ modules/shared.py | 7 ++- modules/txt2img.py | 6 +-- modules/ui.py | 15 +++---- scripts/xy_grid.py | 4 +- 7 files changed, 96 insertions(+), 60 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 8e7f0df0..d6fa822b 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,5 +1,6 @@ import base64 import io +import math import os import re from pathlib import Path @@ -164,6 +165,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None): return None +def restore_old_hires_fix_params(res): + """for infotexts that specify old First pass size parameter, convert it into + width, height, and hr scale""" + + firstpass_width = res.get('First pass size-1', None) + firstpass_height = res.get('First pass size-2', None) + + if firstpass_width is None or firstpass_height is None: + return + + firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height) + width = int(res.get("Size-1", 512)) + height = int(res.get("Size-2", 512)) + + if firstpass_width == 0 or firstpass_height == 0: + # old algorithm for auto-calculating first pass size + desired_pixel_count = 512 * 512 + actual_pixel_count = width * height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + firstpass_width = math.ceil(scale * width / 64) * 64 + firstpass_height = math.ceil(scale * height / 64) * 64 + + hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height + + res['Size-1'] = firstpass_width + res['Size-2'] = firstpass_height + res['Hires upscale'] = hr_scale + + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -221,6 +251,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hypernet_hash = res.get("Hypernet hash", None) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + restore_old_hires_fix_params(res) + return res diff --git a/modules/images.py b/modules/images.py index f84fd485..c3a5fc8b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -230,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts): return draw_grid_annotations(im, width, height, hor_texts, ver_texts) -def resize_image(resize_mode, im, width, height): +def resize_image(resize_mode, im, width, height, upscaler_name=None): + """ + Resizes an image with the specified resize_mode, width, and height. + + Args: + resize_mode: The mode to use when resizing the image. + 0: Resize the image to the specified width and height. + 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. + 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. + im: The image to resize. + width: The width to resize the image to. + height: The height to resize the image to. + upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img. + """ + + upscaler_name = upscaler_name or opts.upscaler_for_img2img + def resize(im, w, h): - if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': + if upscaler_name is None or upscaler_name == "None" or im.mode == 'L': return im.resize((w, h), resample=LANCZOS) scale = max(w / im.width, h / im.height) if scale > 1.0: - upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img] - assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}" + upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name] + assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}" upscaler = upscalers[0] im = upscaler.scaler.upscale(im, scale, upscaler.data_path) diff --git a/modules/processing.py b/modules/processing.py index 42dc19ea..4654570c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -658,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength - self.firstphase_width = firstphase_width - self.firstphase_height = firstphase_height - self.truncate_x = 0 - self.truncate_y = 0 + self.hr_scale = hr_scale + self.hr_upscaler = hr_upscaler + + if firstphase_width != 0 or firstphase_height != 0: + print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) + self.hr_scale = self.width / firstphase_width + self.width = firstphase_width + self.height = firstphase_height def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -674,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" - - if self.firstphase_width == 0 or self.firstphase_height == 0: - desired_pixel_count = 512 * 512 - actual_pixel_count = self.width * self.height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) - self.firstphase_width = math.ceil(scale * self.width / 64) * 64 - self.firstphase_height = math.ceil(scale * self.height / 64) * 64 - firstphase_width_truncated = int(scale * self.width) - firstphase_height_truncated = int(scale * self.height) - - else: - - width_ratio = self.width / self.firstphase_width - height_ratio = self.height / self.firstphase_height - - if width_ratio > height_ratio: - firstphase_width_truncated = self.firstphase_width - firstphase_height_truncated = self.firstphase_width * self.height / self.width - else: - firstphase_width_truncated = self.firstphase_height * self.width / self.height - firstphase_height_truncated = self.firstphase_height - - self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f - self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f + self.extra_generation_params["Hires upscale"] = self.hr_scale + if self.hr_upscaler is not None: + self.extra_generation_params["Hires upscaler"] = self.hr_upscaler def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode + if self.enable_hr and latent_scale_mode is None: + assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" + + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + if not self.enable_hr: - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) return samples - x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) - - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] + target_width = int(self.width * self.hr_scale) + target_height = int(self.height * self.hr_scale) - """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" def save_intermediate(image, index): + """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" + if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: return @@ -723,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") - if opts.use_scale_latent_for_hires_fix: + if latent_scale_mode is not None: for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -747,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): save_intermediate(image, i) - image = images.resize_image(0, image, self.width, self.height) + image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) batch_images.append(image) @@ -764,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory x = None diff --git a/modules/shared.py b/modules/shared.py index 7f430b93..b65559ee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -327,7 +327,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), - "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { @@ -545,6 +544,12 @@ opts = Options() if os.path.exists(config_filename): opts.load(config_filename) +latent_upscale_default_mode = "Latent" +latent_upscale_modes = { + "Latent": "bilinear", + "Latent (nearest)": "nearest", +} + sd_upscalers = [] sd_model = None diff --git a/modules/txt2img.py b/modules/txt2img.py index 7f61e19a..e189a899 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: tiling=tiling, enable_hr=enable_hr, denoising_strength=denoising_strength if enable_hr else None, - firstphase_width=firstphase_width if enable_hr else None, - firstphase_height=firstphase_height if enable_hr else None, + hr_scale=hr_scale, + hr_upscaler=hr_upscaler, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 7070ea15..27cd9ddd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -684,11 +684,11 @@ def create_ui(): with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width") - firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height") + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") with gr.Row(equal_height=True): @@ -729,8 +729,8 @@ def create_ui(): width, enable_hr, denoising_strength, - firstphase_width, - firstphase_height, + hr_scale, + hr_upscaler, ] + custom_inputs, outputs=[ @@ -762,7 +762,6 @@ def create_ui(): outputs=[hr_options], ) - txt2img_paste_fields = [ (txt2img_prompt, "Prompt"), (txt2img_negative_prompt, "Negative prompt"), @@ -781,8 +780,8 @@ def create_ui(): (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (firstphase_width, "First pass size-1"), - (firstphase_height, "First pass size-2"), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), *modules.scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 3e0b2805..f92f9776 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -202,7 +202,7 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None), + AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None), AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), AxisOption("VAE", str, apply_vae, format_value_add_label, None), AxisOption("Styles", str, apply_styles, format_value_add_label, None), @@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.hypernetwork = opts.sd_hypernetwork self.model = shared.sd_model - self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): @@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object): hypernetwork.apply_strength() opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers - opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -- cgit v1.2.3 From 097a90b88bb92878cf435c513b4757b5b82ae299 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 19:19:11 +0300 Subject: add XY plot parameters to grid image and do not add them to individual images --- modules/processing.py | 2 +- scripts/xy_grid.py | 38 ++++++++++++++++++++++++-------------- 2 files changed, 25 insertions(+), 15 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/processing.py b/modules/processing.py index c7264aff..47712159 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -422,7 +422,7 @@ def fix_seed(p): p.subseed = get_fixed_seed(p.subseed) -def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration=0, position_in_batch=0): +def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iteration=0, position_in_batch=0): index = position_in_batch + iteration * p.batch_size clip_skip = getattr(p, 'clip_skip', opts.CLIP_stop_at_last_layers) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 59907f0b..78ff12c5 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images, paths, sd_samplers +from modules import images, paths, sd_samplers, processing from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state @@ -285,6 +285,7 @@ re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") + class Script(scripts.Script): def title(self): return "X/Y plot" @@ -381,7 +382,7 @@ class Script(scripts.Script): ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label in ['Seed','Var. seed']: + if axis_opt.label in ['Seed', 'Var. seed']: return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] else: return axis_list @@ -403,24 +404,33 @@ class Script(scripts.Script): print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") shared.total_tqdm.updateTotal(total_steps * p.n_iter) + grid_infotext = [None] + def cell(x, y): pc = copy(p) x_opt.apply(pc, x, xs) y_opt.apply(pc, y, ys) - return process_images(pc) + res = process_images(pc) + + if grid_infotext[0] is None: + pc.extra_generation_params = copy(pc.extra_generation_params) + + if x_opt.label != 'Nothing': + pc.extra_generation_params["X Type"] = x_opt.label + pc.extra_generation_params["X Values"] = x_values + if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) + + if y_opt.label != 'Nothing': + pc.extra_generation_params["Y Type"] = y_opt.label + pc.extra_generation_params["Y Values"] = y_values + if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) - if not x_opt.label == 'Nothing': - p.extra_generation_params["XY Plot X Type"] = x_opt.label - p.extra_generation_params["XY Plot X Values"] = '{' + x_values + '}' - if x_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: - p.extra_generation_params["XY Plot Fixed X Values"] = '{' + ", ".join([str(x) for x in xs])+ '}' + grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) - if not y_opt.label == 'Nothing': - p.extra_generation_params["XY Plot Y Type"] = y_opt.label - p.extra_generation_params["XY Plot Y Values"] = '{' + y_values + '}' - if y_opt.label in ["Seed","Var. seed"] and not no_fixed_seeds: - p.extra_generation_params["XY Plot Fixed Y Values"] = '{' + ", ".join([str(y) for y in ys])+ '}' + return res with SharedSettingsStackHelper(): processed = draw_xy_grid( @@ -435,6 +445,6 @@ class Script(scripts.Script): ) if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "xy_grid", extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) return processed -- cgit v1.2.3 From 5851bc839b6f639cda59e84eb1ee8c706986633d Mon Sep 17 00:00:00 2001 From: me <25877290+Kryptortio@users.noreply.github.com> Date: Wed, 4 Jan 2023 22:03:32 +0100 Subject: Add element ids for script components and a few more in ui.py --- modules/ui.py | 16 ++++++++-------- scripts/custom_code.py | 4 +++- scripts/img2imgalt.py | 22 ++++++++++++---------- scripts/loopback.py | 6 ++++-- scripts/outpainting_mk_2.py | 12 +++++++----- scripts/poor_mans_outpainting.py | 10 ++++++---- scripts/prompt_matrix.py | 6 ++++-- scripts/prompts_from_file.py | 10 ++++++---- scripts/sd_upscale.py | 8 +++++--- scripts/xy_grid.py | 15 ++++++++------- 10 files changed, 63 insertions(+), 46 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/ui.py b/modules/ui.py index 04091e67..bb64fe20 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -560,7 +560,7 @@ Requested path was: {f} generation_info = None with gr.Column(): with gr.Row(elem_id=f"image_buttons_{tabname}"): - open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else 'open_folder') + open_folder_button = gr.Button(folder_symbol, elem_id="hidden_element" if shared.cmd_opts.hide_ui_dir_config else f'open_folder_{tabname}') if tabname != "extras": save = gr.Button('Save', elem_id=f'save_{tabname}') @@ -576,13 +576,13 @@ Requested path was: {f} if tabname != "extras": with gr.Row(): - download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False) + download_files = gr.File(None, file_count="multiple", interactive=False, show_label=False, visible=False, elem_id=f'download_files_{tabname}') with gr.Group(): - html_info = gr.HTML() - html_log = gr.HTML() + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') - generation_info = gr.Textbox(visible=False) + generation_info = gr.Textbox(visible=False, elem_id=f'generation_info_{tabname}') if tabname == 'txt2img' or tabname == 'img2img': generation_info_button = gr.Button(visible=False, elem_id=f"{tabname}_generation_info_button") generation_info_button.click( @@ -624,9 +624,9 @@ Requested path was: {f} ) else: - html_info_x = gr.HTML() - html_info = gr.HTML() - html_log = gr.HTML() + html_info_x = gr.HTML(elem_id=f'html_info_x_{tabname}') + html_info = gr.HTML(elem_id=f'html_info_{tabname}') + html_log = gr.HTML(elem_id=f'html_log_{tabname}') parameters_copypaste.bind_buttons(buttons, result_gallery, "txt2img" if tabname == "txt2img" else None) return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log diff --git a/scripts/custom_code.py b/scripts/custom_code.py index 22e7b77a..841fed97 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -14,7 +14,9 @@ class Script(scripts.Script): return cmd_opts.allow_code def ui(self, is_img2img): - code = gr.Textbox(label="Python code", lines=1) + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_custom_code_' + + code = gr.Textbox(label="Python code", lines=1, elem_id=elem_prefix + "code") return [code] diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 1229f61b..cddd46e7 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -126,24 +126,26 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_i2i_alternative_test_' + info = gr.Markdown(''' * `CFG Scale` should be 2 or lower. ''') - override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True) + override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=elem_prefix + "override_sampler") - override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True) - original_prompt = gr.Textbox(label="Original prompt", lines=1) - original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1) + override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=elem_prefix + "override_prompt") + original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=elem_prefix + "original_prompt") + original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=elem_prefix + "original_negative_prompt") - override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True) - st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50) + override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=elem_prefix + "override_steps") + st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=elem_prefix + "st") - override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True) + override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=elem_prefix + "override_strength") - cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0) - randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0) - sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False) + cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=elem_prefix + "cfg") + randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=elem_prefix + "randomness") + sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=elem_prefix + "sigma_adjustment") return [ info, diff --git a/scripts/loopback.py b/scripts/loopback.py index d8c68af8..5c1265a0 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -17,8 +17,10 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): - loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4) - denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1) + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_loopback_' + + loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=elem_prefix + "loops") + denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=elem_prefix + "denoising_strength_change_factor") return [loops, denoising_strength_change_factor] diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index cf71cb92..760cce64 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -129,13 +129,15 @@ class Script(scripts.Script): if not is_img2img: return None + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_outpainting_mk_2_' + info = gr.HTML("

Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8

") - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8) - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) - noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0) - color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05) + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=elem_prefix + "pixels") + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=elem_prefix + "mask_blur") + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=elem_prefix + "direction") + noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=elem_prefix + "noise_q") + color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=elem_prefix + "color_variation") return [info, pixels, mask_blur, direction, noise_q, color_variation] diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index ea45beb0..6bcdcc02 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -21,10 +21,12 @@ class Script(scripts.Script): if not is_img2img: return None - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128) - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4) - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index") - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down']) + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_poor_mans_outpainting_' + + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=elem_prefix + "pixels") + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=elem_prefix + "mask_blur") + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=elem_prefix + "inpainting_fill") + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=elem_prefix + "direction") return [pixels, mask_blur, inpainting_fill, direction] diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 4c79eaef..59172315 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -45,8 +45,10 @@ class Script(scripts.Script): return "Prompt matrix" def ui(self, is_img2img): - put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False) - different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False) + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_prompt_matrix_' + + put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=elem_prefix + "put_at_start") + different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=elem_prefix + "different_seeds") return [put_at_start, different_seeds] diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index e8386ed2..fc8ddd8a 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -112,11 +112,13 @@ class Script(scripts.Script): return "Prompts from file or textbox" def ui(self, is_img2img): - checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False) - checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False) + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_prompt_from_file_' + + checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=elem_prefix + "checkbox_iterate") + checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=elem_prefix + "checkbox_iterate_batch") - prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1) - file = gr.File(label="Upload prompt inputs", type='bytes') + prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=elem_prefix + "prompt_txt") + file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=elem_prefix + "file") file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 9739545c..9f483a67 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -17,10 +17,12 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_sd_upscale_' + info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") - overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64) - scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0) - upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index") + overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=elem_prefix + "overlap") + scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=elem_prefix + "scale_factor") + upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=elem_prefix + "upscaler_index") return [info, overlap, upscaler_index, scale_factor] diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 78ff12c5..90226ccd 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -292,18 +292,19 @@ class Script(scripts.Script): def ui(self, is_img2img): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] + elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_xy_grid_' with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id="x_type") - x_values = gr.Textbox(label="X values", lines=1) + x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=elem_prefix + "x_type") + x_values = gr.Textbox(label="X values", lines=1, elem_id=elem_prefix + "x_values") with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id="y_type") - y_values = gr.Textbox(label="Y values", lines=1) + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=elem_prefix + "y_type") + y_values = gr.Textbox(label="Y values", lines=1, elem_id=elem_prefix + "y_values") - draw_legend = gr.Checkbox(label='Draw legend', value=True) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=False) - no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False) + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=elem_prefix + "draw_legend") + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=elem_prefix + "include_lone_images") + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=elem_prefix + "no_fixed_seeds") return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] -- cgit v1.2.3 From c3109fa18a5a105eea5e343875b540939884f304 Mon Sep 17 00:00:00 2001 From: me <25877290+Kryptortio@users.noreply.github.com> Date: Thu, 5 Jan 2023 08:27:09 +0100 Subject: Adjusted prefix from i2i/t2i to txt2img and img2img and removed those prefixes from img exclusive scripts --- scripts/custom_code.py | 2 +- scripts/img2imgalt.py | 2 +- scripts/loopback.py | 2 +- scripts/outpainting_mk_2.py | 2 +- scripts/poor_mans_outpainting.py | 2 +- scripts/prompt_matrix.py | 2 +- scripts/prompts_from_file.py | 2 +- scripts/sd_upscale.py | 2 +- scripts/xy_grid.py | 2 +- 9 files changed, 9 insertions(+), 9 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/custom_code.py b/scripts/custom_code.py index 841fed97..b3bbee03 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -14,7 +14,7 @@ class Script(scripts.Script): return cmd_opts.allow_code def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_custom_code_' + elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_custom_code_' code = gr.Textbox(label="Python code", lines=1, elem_id=elem_prefix + "code") diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index cddd46e7..c062dd24 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -126,7 +126,7 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_i2i_alternative_test_' + elem_prefix = 'script_i2i_alternative_test_' info = gr.Markdown(''' * `CFG Scale` should be 2 or lower. diff --git a/scripts/loopback.py b/scripts/loopback.py index 5c1265a0..93eda1eb 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -17,7 +17,7 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_loopback_' + elem_prefix = 'script_loopback_' loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=elem_prefix + "loops") denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=elem_prefix + "denoising_strength_change_factor") diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index 760cce64..c37bc238 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -129,7 +129,7 @@ class Script(scripts.Script): if not is_img2img: return None - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_outpainting_mk_2_' + elem_prefix = 'script_outpainting_mk_2_' info = gr.HTML("

Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8

") diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 6bcdcc02..784ee422 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -21,7 +21,7 @@ class Script(scripts.Script): if not is_img2img: return None - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_poor_mans_outpainting_' + elem_prefix = 'script_poor_mans_outpainting_' pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=elem_prefix + "pixels") mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=elem_prefix + "mask_blur") diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 59172315..f610c334 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -45,7 +45,7 @@ class Script(scripts.Script): return "Prompt matrix" def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_prompt_matrix_' + elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_prompt_matrix_' put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=elem_prefix + "put_at_start") different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=elem_prefix + "different_seeds") diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index fc8ddd8a..c6a0b709 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -112,7 +112,7 @@ class Script(scripts.Script): return "Prompts from file or textbox" def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_prompt_from_file_' + elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_prompt_from_file_' checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=elem_prefix + "checkbox_iterate") checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=elem_prefix + "checkbox_iterate_batch") diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 9f483a67..2aeeb106 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -17,7 +17,7 @@ class Script(scripts.Script): return is_img2img def ui(self, is_img2img): - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_sd_upscale_' + elem_prefix = 'script_sd_upscale_' info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=elem_prefix + "overlap") diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 90226ccd..8c9cfb9b 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -292,7 +292,7 @@ class Script(scripts.Script): def ui(self, is_img2img): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] - elem_prefix = ('i2i' if is_img2img else 't2i') + '_script_xy_grid_' + elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_xy_grid_' with gr.Row(): x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=elem_prefix + "x_type") -- cgit v1.2.3 From f185baeb28f348e4ec97cd7070ed219b5f74a48e Mon Sep 17 00:00:00 2001 From: me <25877290+Kryptortio@users.noreply.github.com> Date: Thu, 5 Jan 2023 09:29:07 +0100 Subject: Refactor elem_prefix as function elem_id --- scripts/custom_code.py | 9 ++++++--- scripts/img2imgalt.py | 30 +++++++++++++++++------------- scripts/loopback.py | 15 ++++++++++----- scripts/outpainting_mk_2.py | 18 +++++++++++------- scripts/poor_mans_outpainting.py | 17 ++++++++++------- scripts/prompt_matrix.py | 14 +++++++++----- scripts/prompts_from_file.py | 18 +++++++++++------- scripts/sd_upscale.py | 16 ++++++++++------ scripts/xy_grid.py | 20 ++++++++++++-------- 9 files changed, 96 insertions(+), 61 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/custom_code.py b/scripts/custom_code.py index b3bbee03..9ce1f650 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -3,20 +3,23 @@ import gradio as gr from modules.processing import Processed from modules.shared import opts, cmd_opts, state +import re class Script(scripts.Script): def title(self): return "Custom code" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id def show(self, is_img2img): return cmd_opts.allow_code def ui(self, is_img2img): - elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_custom_code_' - - code = gr.Textbox(label="Python code", lines=1, elem_id=elem_prefix + "code") + code = gr.Textbox(label="Python code", lines=1, elem_id=self.elem_id("code")) return [code] diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index c062dd24..7555e874 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -16,6 +16,7 @@ import k_diffusion as K from PIL import Image from torch import autocast from einops import rearrange, repeat +import re def find_noise_for_image(p, cond, uncond, cfg_scale, steps): @@ -122,30 +123,33 @@ class Script(scripts.Script): def title(self): return "img2img alternative test" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def show(self, is_img2img): return is_img2img - def ui(self, is_img2img): - elem_prefix = 'script_i2i_alternative_test_' - + def ui(self, is_img2img): info = gr.Markdown(''' * `CFG Scale` should be 2 or lower. ''') - override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=elem_prefix + "override_sampler") + override_sampler = gr.Checkbox(label="Override `Sampling method` to Euler?(this method is built for it)", value=True, elem_id=self.elem_id("override_sampler")) - override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=elem_prefix + "override_prompt") - original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=elem_prefix + "original_prompt") - original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=elem_prefix + "original_negative_prompt") + override_prompt = gr.Checkbox(label="Override `prompt` to the same value as `original prompt`?(and `negative prompt`)", value=True, elem_id=self.elem_id("override_prompt")) + original_prompt = gr.Textbox(label="Original prompt", lines=1, elem_id=self.elem_id("original_prompt")) + original_negative_prompt = gr.Textbox(label="Original negative prompt", lines=1, elem_id=self.elem_id("original_negative_prompt")) - override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=elem_prefix + "override_steps") - st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=elem_prefix + "st") + override_steps = gr.Checkbox(label="Override `Sampling Steps` to the same value as `Decode steps`?", value=True, elem_id=self.elem_id("override_steps")) + st = gr.Slider(label="Decode steps", minimum=1, maximum=150, step=1, value=50, elem_id=self.elem_id("st")) - override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=elem_prefix + "override_strength") + override_strength = gr.Checkbox(label="Override `Denoising strength` to 1?", value=True, elem_id=self.elem_id("override_strength")) - cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=elem_prefix + "cfg") - randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=elem_prefix + "randomness") - sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=elem_prefix + "sigma_adjustment") + cfg = gr.Slider(label="Decode CFG scale", minimum=0.0, maximum=15.0, step=0.1, value=1.0, elem_id=self.elem_id("cfg")) + randomness = gr.Slider(label="Randomness", minimum=0.0, maximum=1.0, step=0.01, value=0.0, elem_id=self.elem_id("randomness")) + sigma_adjustment = gr.Checkbox(label="Sigma adjustment for finding noise for image", value=False, elem_id=self.elem_id("sigma_adjustment")) return [ info, diff --git a/scripts/loopback.py b/scripts/loopback.py index 93eda1eb..4df7b73f 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -8,19 +8,24 @@ from modules import processing, shared, sd_samplers, images from modules.processing import Processed from modules.sd_samplers import samplers from modules.shared import opts, cmd_opts, state +import re + class Script(scripts.Script): def title(self): return "Loopback" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def show(self, is_img2img): return is_img2img - def ui(self, is_img2img): - elem_prefix = 'script_loopback_' - - loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=elem_prefix + "loops") - denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=elem_prefix + "denoising_strength_change_factor") + def ui(self, is_img2img): + loops = gr.Slider(minimum=1, maximum=32, step=1, label='Loops', value=4, elem_id=self.elem_id("loops")) + denoising_strength_change_factor = gr.Slider(minimum=0.9, maximum=1.1, step=0.01, label='Denoising strength change factor', value=1, elem_id=self.elem_id("denoising_strength_change_factor")) return [loops, denoising_strength_change_factor] diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index c37bc238..b4a0dc73 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,6 +10,7 @@ from PIL import Image, ImageDraw from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state +import re # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -122,6 +123,11 @@ class Script(scripts.Script): def title(self): return "Outpainting mk2" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def show(self, is_img2img): return is_img2img @@ -129,15 +135,13 @@ class Script(scripts.Script): if not is_img2img: return None - elem_prefix = 'script_outpainting_mk_2_' - info = gr.HTML("

Recommended settings: Sampling Steps: 80-100, Sampler: Euler a, Denoising strength: 0.8

") - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=elem_prefix + "pixels") - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=elem_prefix + "mask_blur") - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=elem_prefix + "direction") - noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=elem_prefix + "noise_q") - color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=elem_prefix + "color_variation") + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=8, elem_id=self.elem_id("mask_blur")) + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) + noise_q = gr.Slider(label="Fall-off exponent (lower=higher detail)", minimum=0.0, maximum=4.0, step=0.01, value=1.0, elem_id=self.elem_id("noise_q")) + color_variation = gr.Slider(label="Color variation", minimum=0.0, maximum=1.0, step=0.01, value=0.05, elem_id=self.elem_id("color_variation")) return [info, pixels, mask_blur, direction, noise_q, color_variation] diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 784ee422..1c7dc467 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,26 +7,29 @@ from PIL import Image, ImageDraw from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state - +import re class Script(scripts.Script): def title(self): return "Poor man's outpainting" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def show(self, is_img2img): return is_img2img def ui(self, is_img2img): if not is_img2img: return None - - elem_prefix = 'script_poor_mans_outpainting_' - pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=elem_prefix + "pixels") - mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=elem_prefix + "mask_blur") - inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=elem_prefix + "inpainting_fill") - direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=elem_prefix + "direction") + pixels = gr.Slider(label="Pixels to expand", minimum=8, maximum=256, step=8, value=128, elem_id=self.elem_id("pixels")) + mask_blur = gr.Slider(label='Mask blur', minimum=0, maximum=64, step=1, value=4, elem_id=self.elem_id("mask_blur")) + inpainting_fill = gr.Radio(label='Masked content', choices=['fill', 'original', 'latent noise', 'latent nothing'], value='fill', type="index", elem_id=self.elem_id("inpainting_fill")) + direction = gr.CheckboxGroup(label="Outpainting direction", choices=['left', 'right', 'up', 'down'], value=['left', 'right', 'up', 'down'], elem_id=self.elem_id("direction")) return [pixels, mask_blur, inpainting_fill, direction] diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index f610c334..278d2e68 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -10,6 +10,7 @@ from modules import images from modules.processing import process_images, Processed from modules.shared import opts, cmd_opts, state import modules.sd_samplers +import re def draw_xy_grid(xs, ys, x_label, y_label, cell): @@ -44,11 +45,14 @@ class Script(scripts.Script): def title(self): return "Prompt matrix" - def ui(self, is_img2img): - elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_prompt_matrix_' - - put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=elem_prefix + "put_at_start") - different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=elem_prefix + "different_seeds") + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + + def ui(self, is_img2img): + put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) + different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) return [put_at_start, different_seeds] diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index c6a0b709..5c84c3e9 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -13,6 +13,7 @@ from modules import sd_samplers from modules.processing import Processed, process_images from PIL import Image from modules.shared import opts, cmd_opts, state +import re def process_string_tag(tag): @@ -111,14 +112,17 @@ class Script(scripts.Script): def title(self): return "Prompts from file or textbox" - def ui(self, is_img2img): - elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_prompt_from_file_' - - checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=elem_prefix + "checkbox_iterate") - checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=elem_prefix + "checkbox_iterate_batch") + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id - prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=elem_prefix + "prompt_txt") - file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=elem_prefix + "file") + def ui(self, is_img2img): + checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) + checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) + + prompt_txt = gr.Textbox(label="List of prompt inputs", lines=1, elem_id=self.elem_id("prompt_txt")) + file = gr.File(label="Upload prompt inputs", type='bytes', elem_id=self.elem_id("file")) file.change(fn=load_prompt_file, inputs=[file], outputs=[file, prompt_txt, prompt_txt]) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 2aeeb106..247e755b 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -7,22 +7,26 @@ from PIL import Image from modules import processing, shared, sd_samplers, images, devices from modules.processing import Processed from modules.shared import opts, cmd_opts, state +import re class Script(scripts.Script): def title(self): return "SD upscale" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def show(self, is_img2img): return is_img2img - def ui(self, is_img2img): - elem_prefix = 'script_sd_upscale_' - + def ui(self, is_img2img): info = gr.HTML("

Will upscale the image by the selected scale factor; use width and height sliders to set tile size

") - overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=elem_prefix + "overlap") - scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=elem_prefix + "scale_factor") - upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=elem_prefix + "upscaler_index") + overlap = gr.Slider(minimum=0, maximum=256, step=16, label='Tile overlap', value=64, elem_id=self.elem_id("overlap")) + scale_factor = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label='Scale Factor', value=2.0, elem_id=self.elem_id("scale_factor")) + upscaler_index = gr.Radio(label='Upscaler', choices=[x.name for x in shared.sd_upscalers], value=shared.sd_upscalers[0].name, type="index", elem_id=self.elem_id("upscaler_index")) return [info, overlap, upscaler_index, scale_factor] diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8c9cfb9b..b277a439 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -290,21 +290,25 @@ class Script(scripts.Script): def title(self): return "X/Y plot" + def elem_id(self, item_id): + gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id + gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) + return gen_elem_id + def ui(self, is_img2img): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] - elem_prefix = ('img2img' if is_img2img else 'txt2txt') + '_script_xy_grid_' with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=elem_prefix + "x_type") - x_values = gr.Textbox(label="X values", lines=1, elem_id=elem_prefix + "x_values") + x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=elem_prefix + "y_type") - y_values = gr.Textbox(label="Y values", lines=1, elem_id=elem_prefix + "y_values") + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=elem_prefix + "draw_legend") - include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=elem_prefix + "include_lone_images") - no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=elem_prefix + "no_fixed_seeds") + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] -- cgit v1.2.3 From f8d0cf6a6ec4911559cfecb9a9d1d46b547b38e8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 5 Jan 2023 12:08:11 +0300 Subject: rework #6329 to remove duplicate code and add prevent tab names for showing in ids for scripts that only exist on one tab --- modules/scripts.py | 10 ++++++++++ scripts/custom_code.py | 6 ------ scripts/img2imgalt.py | 6 ------ scripts/loopback.py | 6 ------ scripts/outpainting_mk_2.py | 6 ------ scripts/poor_mans_outpainting.py | 6 ------ scripts/prompt_matrix.py | 6 ------ scripts/prompts_from_file.py | 6 ------ scripts/sd_upscale.py | 6 ------ scripts/xy_grid.py | 5 ----- 10 files changed, 10 insertions(+), 53 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/scripts.py b/modules/scripts.py index 722f8685..0c44f191 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,4 +1,5 @@ import os +import re import sys import traceback from collections import namedtuple @@ -128,6 +129,15 @@ class Script: """unused""" return "" + def elem_id(self, item_id): + """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id""" + + need_tabname = self.show(True) == self.show(False) + tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else "" + title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower())) + + return f'script_{tabname}{title}_{item_id}' + current_basedir = paths.script_path diff --git a/scripts/custom_code.py b/scripts/custom_code.py index 9ce1f650..d29113e6 100644 --- a/scripts/custom_code.py +++ b/scripts/custom_code.py @@ -3,18 +3,12 @@ import gradio as gr from modules.processing import Processed from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "Custom code" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return cmd_opts.allow_code diff --git a/scripts/img2imgalt.py b/scripts/img2imgalt.py index 7555e874..cbdfc6b3 100644 --- a/scripts/img2imgalt.py +++ b/scripts/img2imgalt.py @@ -16,7 +16,6 @@ import k_diffusion as K from PIL import Image from torch import autocast from einops import rearrange, repeat -import re def find_noise_for_image(p, cond, uncond, cfg_scale, steps): @@ -123,11 +122,6 @@ class Script(scripts.Script): def title(self): return "img2img alternative test" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/loopback.py b/scripts/loopback.py index 4df7b73f..1dab9476 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -8,18 +8,12 @@ from modules import processing, shared, sd_samplers, images from modules.processing import Processed from modules.sd_samplers import samplers from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "Loopback" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/outpainting_mk_2.py b/scripts/outpainting_mk_2.py index b4a0dc73..0906da6a 100644 --- a/scripts/outpainting_mk_2.py +++ b/scripts/outpainting_mk_2.py @@ -10,7 +10,6 @@ from PIL import Image, ImageDraw from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state -import re # this function is taken from https://github.com/parlance-zz/g-diffuser-bot @@ -123,11 +122,6 @@ class Script(scripts.Script): def title(self): return "Outpainting mk2" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/poor_mans_outpainting.py b/scripts/poor_mans_outpainting.py index 1c7dc467..d8feda00 100644 --- a/scripts/poor_mans_outpainting.py +++ b/scripts/poor_mans_outpainting.py @@ -7,18 +7,12 @@ from PIL import Image, ImageDraw from modules import images, processing, devices from modules.processing import Processed, process_images from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "Poor man's outpainting" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/prompt_matrix.py b/scripts/prompt_matrix.py index 278d2e68..dd95e588 100644 --- a/scripts/prompt_matrix.py +++ b/scripts/prompt_matrix.py @@ -10,7 +10,6 @@ from modules import images from modules.processing import process_images, Processed from modules.shared import opts, cmd_opts, state import modules.sd_samplers -import re def draw_xy_grid(xs, ys, x_label, y_label, cell): @@ -45,11 +44,6 @@ class Script(scripts.Script): def title(self): return "Prompt matrix" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def ui(self, is_img2img): put_at_start = gr.Checkbox(label='Put variable parts at start of prompt', value=False, elem_id=self.elem_id("put_at_start")) different_seeds = gr.Checkbox(label='Use different seed for each picture', value=False, elem_id=self.elem_id("different_seeds")) diff --git a/scripts/prompts_from_file.py b/scripts/prompts_from_file.py index 5c84c3e9..2751f98a 100644 --- a/scripts/prompts_from_file.py +++ b/scripts/prompts_from_file.py @@ -13,7 +13,6 @@ from modules import sd_samplers from modules.processing import Processed, process_images from PIL import Image from modules.shared import opts, cmd_opts, state -import re def process_string_tag(tag): @@ -112,11 +111,6 @@ class Script(scripts.Script): def title(self): return "Prompts from file or textbox" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def ui(self, is_img2img): checkbox_iterate = gr.Checkbox(label="Iterate seed every line", value=False, elem_id=self.elem_id("checkbox_iterate")) checkbox_iterate_batch = gr.Checkbox(label="Use same random seed for all lines", value=False, elem_id=self.elem_id("checkbox_iterate_batch")) diff --git a/scripts/sd_upscale.py b/scripts/sd_upscale.py index 247e755b..9b8ffd85 100644 --- a/scripts/sd_upscale.py +++ b/scripts/sd_upscale.py @@ -7,18 +7,12 @@ from PIL import Image from modules import processing, shared, sd_samplers, images, devices from modules.processing import Processed from modules.shared import opts, cmd_opts, state -import re class Script(scripts.Script): def title(self): return "SD upscale" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def show(self, is_img2img): return is_img2img diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index b277a439..f04d9b7e 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -290,11 +290,6 @@ class Script(scripts.Script): def title(self): return "X/Y plot" - def elem_id(self, item_id): - gen_elem_id = ('img2img' if self.is_img2img else 'txt2txt') + '_script_' + re.sub(r'\s', '_', self.title().lower()) + '_' + item_id - gen_elem_id = re.sub(r'[^a-z_0-9]', '', gen_elem_id) - return gen_elem_id - def ui(self, is_img2img): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] -- cgit v1.2.3 From a5bbcd215304e0c83ab2b9fe7f172f88536d7629 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 19:56:09 +0300 Subject: fix bug with "Ignore selected VAE for..." option completely disabling VAE election rework VAE resolving code to be more simple --- modules/sd_models.py | 6 +- modules/sd_vae.py | 194 ++++++++++++++++++++------------------------------- modules/shared.py | 4 +- scripts/xy_grid.py | 27 ++++--- 4 files changed, 95 insertions(+), 136 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index e5a0bc63..6a681cef 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -224,7 +224,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None return sd -def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): +def load_model_weights(model, checkpoint_info: CheckpointInfo): sd_model_hash = checkpoint_info.calculate_shorthash() cache_enabled = shared.opts.sd_checkpoint_cache > 0 @@ -277,8 +277,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"): sd_vae.delete_base_vae() sd_vae.clear_loaded_vae() - vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file) - sd_vae.load_vae(model, vae_file) + vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename) + sd_vae.load_vae(model, vae_file, vae_source) def enable_midas_autodownload(): diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 0a49daa1..6ea92711 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -9,23 +9,9 @@ import glob from copy import deepcopy -model_dir = "Stable-diffusion" -model_path = os.path.abspath(os.path.join(models_path, model_dir)) -vae_dir = "VAE" -vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) - - +vae_path = os.path.abspath(os.path.join(models_path, "VAE")) vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - - -default_vae_dict = {"auto": "auto", "None": None, None: None} -default_vae_list = ["auto", "None"] - - -default_vae_values = [default_vae_dict[x] for x in default_vae_list] -vae_dict = dict(default_vae_dict) -vae_list = list(default_vae_list) -first_load = True +vae_dict = {} base_vae = None @@ -64,100 +50,69 @@ def restore_base_vae(model): def get_filename(filepath): - return os.path.splitext(os.path.basename(filepath))[0] - - -def refresh_vae_list(vae_path=vae_path, model_path=model_path): - global vae_dict, vae_list - res = {} - candidates = [ - *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True), - *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True), - *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True), + return os.path.basename(filepath) + + +def refresh_vae_list(): + vae_dict.clear() + + paths = [ + os.path.join(sd_models.model_path, '**/*.vae.ckpt'), + os.path.join(sd_models.model_path, '**/*.vae.pt'), + os.path.join(sd_models.model_path, '**/*.vae.safetensors'), + os.path.join(vae_path, '**/*.ckpt'), + os.path.join(vae_path, '**/*.pt'), + os.path.join(vae_path, '**/*.safetensors'), ] - if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path): - candidates.append(shared.cmd_opts.vae_path) + + if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir): + paths += [ + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'), + os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'), + ] + + candidates = [] + for path in paths: + candidates += glob.iglob(path, recursive=True) + for filepath in candidates: name = get_filename(filepath) - res[name] = filepath - vae_list.clear() - vae_list.extend(default_vae_list) - vae_list.extend(list(res.keys())) - vae_dict.clear() - vae_dict.update(res) - vae_dict.update(default_vae_dict) - return vae_list - - -def get_vae_from_settings(vae_file="auto"): - # else, we load from settings, if not set to be default - if vae_file == "auto" and shared.opts.sd_vae is not None: - # if saved VAE settings isn't recognized, fallback to auto - vae_file = vae_dict.get(shared.opts.sd_vae, "auto") - # if VAE selected but not found, fallback to auto - if vae_file not in default_vae_values and not os.path.isfile(vae_file): - vae_file = "auto" - print(f"Selected VAE doesn't exist: {vae_file}") - return vae_file - - -def resolve_vae(checkpoint_file=None, vae_file="auto"): - global first_load, vae_dict, vae_list - - # if vae_file argument is provided, it takes priority, but not saved - if vae_file and vae_file not in default_vae_list: - if not os.path.isfile(vae_file): - print(f"VAE provided as function argument doesn't exist: {vae_file}") - vae_file = "auto" - # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported - if first_load and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - shared.opts.data['sd_vae'] = get_filename(vae_file) - else: - print(f"VAE provided as command line argument doesn't exist: {vae_file}") - # fallback to selector in settings, if vae selector not set to act as default fallback - if not shared.opts.sd_vae_as_default: - vae_file = get_vae_from_settings(vae_file) - # vae-path cmd arg takes priority for auto - if vae_file == "auto" and shared.cmd_opts.vae_path is not None: - if os.path.isfile(shared.cmd_opts.vae_path): - vae_file = shared.cmd_opts.vae_path - print(f"Using VAE provided as command line argument: {vae_file}") - # if still not found, try look for ".vae.pt" beside model - model_path = os.path.splitext(checkpoint_file)[0] - if vae_file == "auto": - vae_file_try = model_path + ".vae.pt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.ckpt" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.ckpt" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # if still not found, try look for ".vae.safetensors" beside model - if vae_file == "auto": - vae_file_try = model_path + ".vae.safetensors" - if os.path.isfile(vae_file_try): - vae_file = vae_file_try - print(f"Using VAE found similar to selected model: {vae_file}") - # No more fallbacks for auto - if vae_file == "auto": - vae_file = None - # Last check, just because - if vae_file and not os.path.exists(vae_file): - vae_file = None - - return vae_file - - -def load_vae(model, vae_file=None): - global first_load, vae_dict, vae_list, loaded_vae_file + vae_dict[name] = filepath + + +def find_vae_near_checkpoint(checkpoint_file): + checkpoint_path = os.path.splitext(checkpoint_file)[0] + for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]: + if os.path.isfile(vae_location): + return vae_location + + return None + + +def resolve_vae(checkpoint_file): + if shared.cmd_opts.vae_path is not None: + return shared.cmd_opts.vae_path, 'from commandline argument' + + vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) + if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"): + return vae_near_checkpoint, 'found near the checkpoint' + + if shared.opts.sd_vae == "None": + return None, None + + vae_from_options = vae_dict.get(shared.opts.sd_vae, None) + if vae_from_options is not None: + return vae_from_options, 'specified in settings' + + if shared.opts.sd_vae != "Automatic": + print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead") + + return None, None + + +def load_vae(model, vae_file=None, vae_source="from unknown source"): + global vae_dict, loaded_vae_file # save_settings = False cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0 @@ -165,12 +120,12 @@ def load_vae(model, vae_file=None): if vae_file: if cache_enabled and vae_file in checkpoints_loaded: # use vae checkpoint cache - print(f"Loading VAE weights [{get_filename(vae_file)}] from cache") + print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}") store_base_vae(model) _load_vae_dict(model, checkpoints_loaded[vae_file]) else: - assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}" - print(f"Loading VAE weights from: {vae_file}") + assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}" + print(f"Loading VAE weights {vae_source}: {vae_file}") store_base_vae(model) vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location) @@ -191,14 +146,12 @@ def load_vae(model, vae_file=None): vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file - vae_list.append(vae_opt) + elif loaded_vae_file: restore_base_vae(model) loaded_vae_file = vae_file - first_load = False - # don't call this from outside def _load_vae_dict(model, vae_dict_1): @@ -211,7 +164,10 @@ def clear_loaded_vae(): loaded_vae_file = None -def reload_vae_weights(sd_model=None, vae_file="auto"): +unspecified = object() + + +def reload_vae_weights(sd_model=None, vae_file=unspecified): from modules import lowvram, devices, sd_hijack if not sd_model: @@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): checkpoint_info = sd_model.sd_checkpoint_info checkpoint_file = checkpoint_info.filename - vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if vae_file == unspecified: + vae_file, vae_source = resolve_vae(checkpoint_file) + else: + vae_source = "from function argument" if loaded_vae_file == vae_file: return @@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): sd_hijack.model_hijack.undo_hijack(sd_model) - load_vae(sd_model, vae_file) + load_vae(sd_model, vae_file, vae_source) sd_hijack.model_hijack.hijack(sd_model) script_callbacks.model_loaded_callback(sd_model) @@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"): if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: sd_model.to(devices.device) - print("VAE Weights loaded.") + print("VAE weights loaded.") return sd_model diff --git a/modules/shared.py b/modules/shared.py index e0ec3136..9756adea 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -83,7 +83,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False) parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False) parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False) -parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None) +parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None) parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False) parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)") parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None) @@ -383,7 +383,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints), "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), - "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list), + "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index f04d9b7e..bd3087d4 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -125,24 +125,21 @@ def apply_upscale_latent_space(p, x, xs): def find_vae(name: str): - if name.lower() in ['auto', 'none']: - return name + if name.lower() in ['auto', 'automatic']: + return modules.sd_vae.unspecified + if name.lower() == 'none': + return None else: - vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE')) - found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True) - if found: - return found[0] + choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] + if len(choices) == 0: + print(f"No VAE found for {name}; using automatic") + return modules.sd_vae.unspecified else: - return 'auto' + return modules.sd_vae.vae_dict[choices[0]] def apply_vae(p, x, xs): - if x.lower().strip() == 'none': - modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None') - else: - found = find_vae(x) - if found: - v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found) + modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): @@ -271,7 +268,9 @@ class SharedSettingsStackHelper(object): def __exit__(self, exc_type, exc_value, tb): modules.sd_models.reload_model_weights(self.model) - modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae)) + + opts.data["sd_vae"] = self.vae + modules.sd_vae.reload_vae_weights(self.model) hypernetwork.load_hypernetwork(self.hypernetwork) hypernetwork.apply_strength() -- cgit v1.2.3 From f202ff1901c27d1f82d5e2684dba9e1ed24ffdf2 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sun, 15 Jan 2023 19:43:34 -0800 Subject: Make XY grid cancellation much faster --- scripts/xy_grid.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index bd3087d4..13a3a046 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -406,6 +406,9 @@ class Script(scripts.Script): grid_infotext = [None] def cell(x, y): + if shared.state.interrupted: + return Processed(p, [], p.seed, "") + pc = copy(p) x_opt.apply(pc, x, xs) y_opt.apply(pc, y, ys) -- cgit v1.2.3 From 029260b4ca7267d7a75319dbc11bca2a8c52774e Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sun, 15 Jan 2023 21:40:57 -0800 Subject: Optimize XY grid to run slower axes fewer times --- scripts/xy_grid.py | 123 ++++++++++++++++++++++++++++++----------------------- 1 file changed, 70 insertions(+), 53 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 13a3a046..074ee919 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -175,76 +175,87 @@ def str_permutations(x): """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" return x -AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"]) -AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"]) +AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"]) +AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"]) axis_options = [ - AxisOption("Nothing", str, do_nothing, format_nothing, None), - AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None), - AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None), - AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None), - AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None), - AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None), - AxisOption("Prompt S/R", str, apply_prompt, format_value, None), - AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None), - AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers), - AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks), - AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None), - AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None), - AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None), - AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None), - AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None), - AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), - AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), - AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None), - AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), - AxisOption("VAE", str, apply_vae, format_value_add_label, None), - AxisOption("Styles", str, apply_styles, format_value_add_label, None), + AxisOption("Nothing", str, do_nothing, format_nothing, None, 0), + AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0), + AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0), + AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0), + AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0), + AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0), + AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0), + AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0), + AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2), + AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0), + AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0), + AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0), + AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0), + AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0), + AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0), + AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0), + AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0), + AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0), + AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0), + AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7), + AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0), ] -def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images): +def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): ver_texts = [[images.GridAnnotation(y)] for y in y_labels] hor_texts = [[images.GridAnnotation(x)] for x in x_labels] # Temporary list of all the images that are generated to be populated into the grid. # Will be filled with empty images for any individual step that fails to process properly - image_cache = [] + image_cache = [None] * (len(xs) * len(ys)) processed_result = None cell_mode = "P" - cell_size = (1,1) + cell_size = (1, 1) state.job_count = len(xs) * len(ys) * p.n_iter - for iy, y in enumerate(ys): + def process_cell(x, y, ix, iy): + nonlocal image_cache, processed_result, cell_mode, cell_size + + state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" + + processed: Processed = cell(x, y) + + try: + # this dereference will throw an exception if the image was not processed + # (this happens in cases such as if the user stops the process from the UI) + processed_image = processed.images[0] + + if processed_result is None: + # Use our first valid processed result as a template container to hold our full results + processed_result = copy(processed) + cell_mode = processed_image.mode + cell_size = processed_image.size + processed_result.images = [Image.new(cell_mode, cell_size)] + + image_cache[ix + iy * len(xs)] = processed_image + if include_lone_images: + processed_result.images.append(processed_image) + processed_result.all_prompts.append(processed.prompt) + processed_result.all_seeds.append(processed.seed) + processed_result.infotexts.append(processed.infotexts[0]) + except: + image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size) + + if swap_axes_processing_order: for ix, x in enumerate(xs): - state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" - - processed:Processed = cell(x, y) - try: - # this dereference will throw an exception if the image was not processed - # (this happens in cases such as if the user stops the process from the UI) - processed_image = processed.images[0] - - if processed_result is None: - # Use our first valid processed result as a template container to hold our full results - processed_result = copy(processed) - cell_mode = processed_image.mode - cell_size = processed_image.size - processed_result.images = [Image.new(cell_mode, cell_size)] - - image_cache.append(processed_image) - if include_lone_images: - processed_result.images.append(processed_image) - processed_result.all_prompts.append(processed.prompt) - processed_result.all_seeds.append(processed.seed) - processed_result.infotexts.append(processed.infotexts[0]) - except: - image_cache.append(Image.new(cell_mode, cell_size)) + for iy, y in enumerate(ys): + process_cell(x, y, ix, iy) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, ix, iy) if not processed_result: print("Unexpected error: draw_xy_grid failed to return even a single processed image") @@ -405,6 +416,11 @@ class Script(scripts.Script): grid_infotext = [None] + # If one of the axes is very slow to change between (like SD model + # checkpoint), then make sure it is in the outer iteration of the nested + # `for` loop. + swap_axes_processing_order = x_opt.cost > y_opt.cost + def cell(x, y): if shared.state.interrupted: return Processed(p, [], p.seed, "") @@ -443,7 +459,8 @@ class Script(scripts.Script): y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], cell=cell, draw_legend=draw_legend, - include_lone_images=include_lone_images + include_lone_images=include_lone_images, + swap_axes_processing_order=swap_axes_processing_order ) if opts.grid_save: -- cgit v1.2.3 From 2144c2eb7f5842caed1227d4ec7e659c79a84ce9 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sun, 15 Jan 2023 21:41:58 -0800 Subject: Add swap axes button for XY Grid --- scripts/xy_grid.py | 26 ++++++++++++++++++++------ style.css | 10 ++++++++++ 2 files changed, 30 insertions(+), 6 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 13a3a046..99a660c1 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -23,6 +23,9 @@ import os import re +up_down_arrow_symbol = "\u2195\ufe0f" + + def apply_field(field): def fun(p, x, xs): setattr(p, field, x) @@ -293,17 +296,28 @@ class Script(scripts.Script): current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) - x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) - - with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) - y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) + with gr.Column(scale=1, elem_id="xy_grid_button_column"): + swap_axes_button = gr.Button(value=up_down_arrow_symbol, elem_id="xy_grid_swap_axes") + with gr.Column(scale=19): + with gr.Row(): + x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) + + with gr.Row(): + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + def swap_axes(x_type, x_values, y_type, y_values): + nonlocal current_axis_options + return current_axis_options[y_type].label, y_values, current_axis_options[x_type].label, x_values + + swap_args = [x_type, x_values, y_type, y_values] + swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) + return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): diff --git a/style.css b/style.css index 78fa9838..1fddfcc2 100644 --- a/style.css +++ b/style.css @@ -717,6 +717,16 @@ footer { line-height: 2.4em; } +#xy_grid_button_column { + min-width: 38px !important; +} + +#xy_grid_button_column button { + height: 100%; + margin-bottom: 0.7em; + margin-left: 1em; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From 972f5785073b8ba5957add72debd74fc56ee9329 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 09:27:52 +0300 Subject: fix problems related to checkpoint/VAE switching in XY plot --- scripts/xy_grid.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 13a3a046..0cdfa952 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -263,14 +263,12 @@ class SharedSettingsStackHelper(object): def __enter__(self): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.hypernetwork = opts.sd_hypernetwork - self.model = shared.sd_model self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): - modules.sd_models.reload_model_weights(self.model) - opts.data["sd_vae"] = self.vae - modules.sd_vae.reload_vae_weights(self.model) + modules.sd_models.reload_model_weights() + modules.sd_vae.reload_vae_weights() hypernetwork.load_hypernetwork(self.hypernetwork) hypernetwork.apply_strength() -- cgit v1.2.3 From 55947857f035040d00249f02b17e39370033a99b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 17:36:56 +0300 Subject: add a button for XY Plot to fill in available values for axes that support this --- javascript/hints.js | 1 + scripts/xy_grid.py | 101 ++++++++++++++++++++++++++++++++++------------------ style.css | 12 +------ 3 files changed, 68 insertions(+), 46 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/javascript/hints.js b/javascript/hints.js index 244bfde2..fa5e5ae8 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -20,6 +20,7 @@ titles = { "\u{1f4be}": "Save style", "\U0001F5D1": "Clear prompt", "\u{1f4cb}": "Apply selected styles to current prompt", + "\u{1f4d2}": "Paste available values into the field", "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index e06c11cb..bf4ba92f 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images, paths, sd_samplers, processing +from modules import images, paths, sd_samplers, processing, sd_models, sd_vae from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state @@ -22,8 +22,9 @@ import glob import os import re +from modules.ui_components import ToolButton -up_down_arrow_symbol = "\u2195\ufe0f" +fill_values_symbol = "\U0001f4d2" # 📒 def apply_field(field): @@ -178,34 +179,49 @@ def str_permutations(x): """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" return x -AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"]) -AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"]) + +class AxisOption: + def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): + self.label = label + self.type = type + self.apply = apply + self.format_value = format_value + self.confirm = confirm + self.cost = cost + self.choices = choices + self.is_img2img = False + + +class AxisOptionImg2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = False axis_options = [ - AxisOption("Nothing", str, do_nothing, format_nothing, None, 0), - AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0), - AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0), - AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0), - AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0), - AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0), - AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0), - AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0), - AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0), - AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2), - AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0), - AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0), - AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0), - AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0), - AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0), - AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0), - AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0), - AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0), - AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0), - AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0), - AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7), - AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0), + AxisOption("Nothing", str, do_nothing, format_value=format_nothing), + AxisOption("Seed", int, apply_field("seed")), + AxisOption("Var. seed", int, apply_field("subseed")), + AxisOption("Var. strength", float, apply_field("subseed_strength")), + AxisOption("Steps", int, apply_field("steps")), + AxisOption("CFG Scale", float, apply_field("cfg_scale")), + AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), + AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), + AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), + AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)), + AxisOption("Hypernet str.", float, apply_hypernetwork_strength), + AxisOption("Sigma Churn", float, apply_field("s_churn")), + AxisOption("Sigma min", float, apply_field("s_tmin")), + AxisOption("Sigma max", float, apply_field("s_tmax")), + AxisOption("Sigma noise", float, apply_field("s_noise")), + AxisOption("Eta", float, apply_field("eta")), + AxisOption("Clip skip", int, apply_clip_skip), + AxisOption("Denoising", float, apply_field("denoising_strength")), + AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]), + AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), + AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), + AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), ] @@ -262,7 +278,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ if not processed_result: print("Unexpected error: draw_xy_grid failed to return even a single processed image") - return Processed() + return Processed(p, []) grid = images.image_grid(image_cache, rows=len(ys)) if draw_legend: @@ -302,23 +318,25 @@ class Script(scripts.Script): return "X/Y plot" def ui(self, is_img2img): - current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img] + current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img] with gr.Row(): - with gr.Column(scale=1, elem_id="xy_grid_button_column"): - swap_axes_button = gr.Button(value=up_down_arrow_symbol, elem_id="xy_grid_swap_axes") with gr.Column(scale=19): with gr.Row(): x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) + fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) with gr.Row(): y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - - draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) - no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) + + with gr.Row(variant="compact"): + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") def swap_axes(x_type, x_values, y_type, y_values): nonlocal current_axis_options @@ -327,6 +345,19 @@ class Script(scripts.Script): swap_args = [x_type, x_values, y_type, y_values] swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) + def fill(x_type): + axis = axis_options[x_type] + return ", ".join(axis.choices()) if axis.choices else gr.update() + + fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) + fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) + + def select_axis(x_type): + return gr.Button.update(visible=axis_options[x_type].choices is not None) + + x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) + y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) + return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): diff --git a/style.css b/style.css index 1fddfcc2..97f9402a 100644 --- a/style.css +++ b/style.css @@ -644,7 +644,7 @@ canvas[key="mask"] { max-width: 2.5em; min-width: 2.5em !important; height: 2.4em; - margin: 0.55em 0; + margin: 0.55em 0.7em 0.55em 0; } #quicksettings .gr-button-tool{ @@ -717,16 +717,6 @@ footer { line-height: 2.4em; } -#xy_grid_button_column { - min-width: 38px !important; -} - -#xy_grid_button_column button { - height: 100%; - margin-bottom: 0.7em; - margin-left: 1em; -} - /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running -- cgit v1.2.3 From e0e80050091ea7f58ae17c69f31d1b5de5e0ae20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 16 Jan 2023 23:09:08 +0300 Subject: make StableDiffusionProcessing class not hold a reference to shared.sd_model object --- modules/processing.py | 9 +++++---- scripts/xy_grid.py | 1 - 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/modules/processing.py b/modules/processing.py index ab7b3b7d..9c3673de 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -94,7 +94,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): return image_conditioning -class StableDiffusionProcessing(): +class StableDiffusionProcessing: """ The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing """ @@ -102,7 +102,6 @@ class StableDiffusionProcessing(): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) - self.sd_model = sd_model self.outpath_samples: str = outpath_samples self.outpath_grids: str = outpath_grids self.prompt: str = prompt @@ -156,6 +155,10 @@ class StableDiffusionProcessing(): self.all_subseeds = None self.iteration = 0 + @property + def sd_model(self): + return shared.sd_model + def txt2img_image_conditioning(self, x, width=None, height=None): self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} @@ -236,7 +239,6 @@ class StableDiffusionProcessing(): raise NotImplementedError() def close(self): - self.sd_model = None self.sampler = None @@ -471,7 +473,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if k == 'sd_model_checkpoint': sd_models.reload_model_weights() # make onchange call for changing SD model - p.sd_model = shared.sd_model if k == 'sd_vae': sd_vae.reload_vae_weights() # make onchange call for changing VAE diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index bf4ba92f..6629f5d5 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -86,7 +86,6 @@ def apply_checkpoint(p, x, xs): if info is None: raise RuntimeError(f"Unknown checkpoint: {x}") modules.sd_models.reload_model_weights(shared.sd_model, info) - p.sd_model = shared.sd_model def confirm_checkpoints(p, xs): -- cgit v1.2.3 From 40ff6db5325fc34ad4fa35e80cb1e7768d9f7e75 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 08:36:07 +0300 Subject: extra networks UI rework of hypernets: rather than via settings, hypernets are added directly to prompt as --- html/card-no-preview.png | Bin 0 -> 84440 bytes html/extra-networks-card.html | 11 ++ html/extra-networks-no-cards.html | 8 ++ javascript/extraNetworks.js | 60 ++++++++ javascript/hints.js | 2 + javascript/ui.js | 9 +- modules/api/api.py | 7 +- modules/extra_networks.py | 147 +++++++++++++++++++ modules/extra_networks_hypernet.py | 21 +++ modules/generation_parameters_copypaste.py | 12 +- modules/hypernetworks/hypernetwork.py | 107 +++++++++----- modules/hypernetworks/ui.py | 5 +- modules/processing.py | 24 ++-- modules/sd_hijack_optimizations.py | 10 +- modules/shared.py | 21 ++- modules/textual_inversion/textual_inversion.py | 2 + modules/ui.py | 50 ++++--- modules/ui_components.py | 10 ++ modules/ui_extra_networks.py | 149 +++++++++++++++++++ modules/ui_extra_networks_hypernets.py | 34 +++++ modules/ui_extra_networks_textual_inversion.py | 32 +++++ script.js | 13 +- scripts/xy_grid.py | 29 ---- style.css | 190 +++++++++++++------------ webui.py | 26 +++- 25 files changed, 765 insertions(+), 214 deletions(-) create mode 100644 html/card-no-preview.png create mode 100644 html/extra-networks-card.html create mode 100644 html/extra-networks-no-cards.html create mode 100644 javascript/extraNetworks.js create mode 100644 modules/extra_networks.py create mode 100644 modules/extra_networks_hypernet.py create mode 100644 modules/ui_extra_networks.py create mode 100644 modules/ui_extra_networks_hypernets.py create mode 100644 modules/ui_extra_networks_textual_inversion.py (limited to 'scripts/xy_grid.py') diff --git a/html/card-no-preview.png b/html/card-no-preview.png new file mode 100644 index 00000000..e2beb269 Binary files /dev/null and b/html/card-no-preview.png differ diff --git a/html/extra-networks-card.html b/html/extra-networks-card.html new file mode 100644 index 00000000..7314b063 --- /dev/null +++ b/html/extra-networks-card.html @@ -0,0 +1,11 @@ +
+
+
+ +
+ {name} +
+
+ diff --git a/html/extra-networks-no-cards.html b/html/extra-networks-no-cards.html new file mode 100644 index 00000000..389358d6 --- /dev/null +++ b/html/extra-networks-no-cards.html @@ -0,0 +1,8 @@ +
+

Nothing here. Add some content to the following directories:

+ +
    +{dirs} +
+
+ diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js new file mode 100644 index 00000000..71e522d1 --- /dev/null +++ b/javascript/extraNetworks.js @@ -0,0 +1,60 @@ + +function setupExtraNetworksForTab(tabname){ + gradioApp().querySelector('#'+tabname+'_extra_tabs').classList.add('extra-networks') + + gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_refresh')) + gradioApp().querySelector('#'+tabname+'_extra_tabs > div').appendChild(gradioApp().getElementById(tabname+'_extra_close')) +} + +var activePromptTextarea = null; +var activePositivePromptTextarea = null; + +function setupExtraNetworks(){ + setupExtraNetworksForTab('txt2img') + setupExtraNetworksForTab('img2img') + + function registerPrompt(id, isNegative){ + var textarea = gradioApp().querySelector("#" + id + " > label > textarea"); + + if (activePromptTextarea == null){ + activePromptTextarea = textarea + } + if (activePositivePromptTextarea == null && ! isNegative){ + activePositivePromptTextarea = textarea + } + + textarea.addEventListener("focus", function(){ + activePromptTextarea = textarea; + if(! isNegative) activePositivePromptTextarea = textarea; + }); + } + + registerPrompt('txt2img_prompt') + registerPrompt('txt2img_neg_prompt', true) + registerPrompt('img2img_prompt') + registerPrompt('img2img_neg_prompt', true) +} + +onUiLoaded(setupExtraNetworks) + +function cardClicked(textToAdd, allowNegativePrompt){ + textarea = allowNegativePrompt ? activePromptTextarea : activePositivePromptTextarea + + textarea.value = textarea.value + " " + textToAdd + updateInput(textarea) + + return false +} + +function saveCardPreview(event, tabname, filename){ + textarea = gradioApp().querySelector("#" + tabname + '_preview_filename > label > textarea') + button = gradioApp().getElementById(tabname + '_save_preview') + + textarea.value = filename + updateInput(textarea) + + button.click() + + event.stopPropagation() + event.preventDefault() +} diff --git a/javascript/hints.js b/javascript/hints.js index e746e20d..f4079f96 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -21,6 +21,8 @@ titles = { "\U0001F5D1": "Clear prompt", "\u{1f4cb}": "Apply selected styles to current prompt", "\u{1f4d2}": "Paste available values into the field", + "\u{1f3b4}": "Show extra networks", + "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt", "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back", diff --git a/javascript/ui.js b/javascript/ui.js index 3ba90ca8..a7e75439 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -196,8 +196,6 @@ function confirm_clear_prompt(prompt, negative_prompt) { return [prompt, negative_prompt] } - - opts = {} onUiUpdate(function(){ if(Object.keys(opts).length != 0) return; @@ -239,11 +237,14 @@ onUiUpdate(function(){ return } + prompt.parentElement.insertBefore(counter, prompt) counter.classList.add("token-counter") prompt.parentElement.style.position = "relative" - textarea.addEventListener("input", () => update_token_counter(id_button)); + textarea.addEventListener("input", function(){ + update_token_counter(id_button); + }); } registerTextarea('txt2img_prompt', 'txt2img_token_counter', 'txt2img_token_button') @@ -261,10 +262,8 @@ onUiUpdate(function(){ }) } } - }) - onOptionsChanged(function(){ elem = gradioApp().getElementById('sd_checkpoint_hash') sd_checkpoint_hash = opts.sd_checkpoint_hash || "" diff --git a/modules/api/api.py b/modules/api/api.py index 9814bbc2..2c371e6e 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -480,7 +480,7 @@ class Api: def train_hypernetwork(self, args: dict): try: shared.state.begin() - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] apply_optimizations = shared.opts.training_xattention_optimizations error = None filename = '' @@ -491,16 +491,15 @@ class Api: except Exception as e: error = e finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) if not apply_optimizations: sd_hijack.apply_optimizations() shared.state.end() - return TrainResponse(info = "train embedding complete: filename: {filename} error: {error}".format(filename = filename, error = error)) + return TrainResponse(info="train embedding complete: filename: {filename} error: {error}".format(filename=filename, error=error)) except AssertionError as msg: shared.state.end() - return TrainResponse(info = "train embedding error: {error}".format(error = error)) + return TrainResponse(info="train embedding error: {error}".format(error=error)) def get_memory(self): try: diff --git a/modules/extra_networks.py b/modules/extra_networks.py new file mode 100644 index 00000000..1978673d --- /dev/null +++ b/modules/extra_networks.py @@ -0,0 +1,147 @@ +import re +from collections import defaultdict + +from modules import errors + +extra_network_registry = {} + + +def initialize(): + extra_network_registry.clear() + + +def register_extra_network(extra_network): + extra_network_registry[extra_network.name] = extra_network + + +class ExtraNetworkParams: + def __init__(self, items=None): + self.items = items or [] + + +class ExtraNetwork: + def __init__(self, name): + self.name = name + + def activate(self, p, params_list): + """ + Called by processing on every run. Whatever the extra network is meant to do should be activated here. + Passes arguments related to this extra network in params_list. + User passes arguments by specifying this in his prompt: + + + + Where name matches the name of this ExtraNetwork object, and arg1:arg2:arg3 are any natural number of text arguments + separated by colon. + + Even if the user does not mention this ExtraNetwork in his prompt, the call will stil be made, with empty params_list - + in this case, all effects of this extra networks should be disabled. + + Can be called multiple times before deactivate() - each new call should override the previous call completely. + + For example, if this ExtraNetwork's name is 'hypernet' and user's prompt is: + + > "1girl, " + + params_list will be: + + [ + ExtraNetworkParams(items=["agm", "1.1"]), + ExtraNetworkParams(items=["ray"]) + ] + + """ + raise NotImplementedError + + def deactivate(self, p): + """ + Called at the end of processing for housekeeping. No need to do anything here. + """ + + raise NotImplementedError + + +def activate(p, extra_network_data): + """call activate for extra networks in extra_network_data in specified order, then call + activate for all remaining registered networks with an empty argument list""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + print(f"Skipping unknown extra network: {extra_network_name}") + continue + + try: + extra_network.activate(p, extra_network_args) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name} with arguments {extra_network_args}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.activate(p, []) + except Exception as e: + errors.display(e, f"activating extra network {extra_network_name}") + + +def deactivate(p, extra_network_data): + """call deactivate for extra networks in extra_network_data in specified order, then call + deactivate for all remaining registered networks""" + + for extra_network_name, extra_network_args in extra_network_data.items(): + extra_network = extra_network_registry.get(extra_network_name, None) + if extra_network is None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating extra network {extra_network_name}") + + for extra_network_name, extra_network in extra_network_registry.items(): + args = extra_network_data.get(extra_network_name, None) + if args is not None: + continue + + try: + extra_network.deactivate(p) + except Exception as e: + errors.display(e, f"deactivating unmentioned extra network {extra_network_name}") + + +re_extra_net = re.compile(r"<(\w+):([^>]+)>") + + +def parse_prompt(prompt): + res = defaultdict(list) + + def found(m): + name = m.group(1) + args = m.group(2) + + res[name].append(ExtraNetworkParams(items=args.split(":"))) + + return "" + + prompt = re.sub(re_extra_net, found, prompt) + + return prompt, res + + +def parse_prompts(prompts): + res = [] + extra_data = None + + for prompt in prompts: + updated_prompt, parsed_extra_data = parse_prompt(prompt) + + if extra_data is None: + extra_data = parsed_extra_data + + res.append(updated_prompt) + + return res, extra_data + diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py new file mode 100644 index 00000000..6a0c4ba8 --- /dev/null +++ b/modules/extra_networks_hypernet.py @@ -0,0 +1,21 @@ +from modules import extra_networks +from modules.hypernetworks import hypernetwork + + +class ExtraNetworkHypernet(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('hypernet') + + def activate(self, p, params_list): + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + hypernetwork.load_hypernetworks(names, multipliers) + + def deactivate(p, self): + pass diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index a381ff59..46e12dc6 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -79,8 +79,6 @@ def integrate_settings_paste_fields(component_dict): from modules import ui settings_map = { - 'sd_hypernetwork': 'Hypernet', - 'sd_hypernetwork_strength': 'Hypernet strength', 'CLIP_stop_at_last_layers': 'Clip skip', 'inpainting_mask_weight': 'Conditional mask weight', 'sd_model_checkpoint': 'Model hash', @@ -275,13 +273,9 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model if "Clip skip" not in res: res["Clip skip"] = "1" - if "Hypernet strength" not in res: - res["Hypernet strength"] = "1" - - if "Hypernet" in res: - hypernet_name = res["Hypernet"] - hypernet_hash = res.get("Hypernet hash", None) - res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + hypernet = res.get("Hypernet", None) + if hypernet is not None: + res["Prompt"] += f"""""" if "Hires resize-1" not in res: res["Hires resize-1"] = 0 diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 74e78582..80a47c79 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -25,7 +25,6 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"} class HypernetworkModule(torch.nn.Module): - multiplier = 1.0 activation_dict = { "linear": torch.nn.Identity, "relu": torch.nn.ReLU, @@ -41,6 +40,8 @@ class HypernetworkModule(torch.nn.Module): add_layer_norm=False, activate_output=False, dropout_structure=None): super().__init__() + self.multiplier = 1.0 + assert layer_structure is not None, "layer_structure must not be None" assert layer_structure[0] == 1, "Multiplier Sequence should start with size 1!" assert layer_structure[-1] == 1, "Multiplier Sequence should end with size 1!" @@ -115,7 +116,7 @@ class HypernetworkModule(torch.nn.Module): state_dict[to] = x def forward(self, x): - return x + self.linear(x) * (HypernetworkModule.multiplier if not self.training else 1) + return x + self.linear(x) * (self.multiplier if not self.training else 1) def trainables(self): layer_structure = [] @@ -125,9 +126,6 @@ class HypernetworkModule(torch.nn.Module): return layer_structure -def apply_strength(value=None): - HypernetworkModule.multiplier = value if value is not None else shared.opts.sd_hypernetwork_strength - #param layer_structure : sequence used for length, use_dropout : controlling boolean, last_layer_dropout : for compatibility check. def parse_dropout_structure(layer_structure, use_dropout, last_layer_dropout): if layer_structure is None: @@ -192,6 +190,20 @@ class Hypernetwork: for param in layer.parameters(): param.requires_grad = mode + def to(self, device): + for k, layers in self.layers.items(): + for layer in layers: + layer.to(device) + + return self + + def set_multiplier(self, multiplier): + for k, layers in self.layers.items(): + for layer in layers: + layer.multiplier = multiplier + + return self + def eval(self): for k, layers in self.layers.items(): for layer in layers: @@ -269,11 +281,13 @@ class Hypernetwork: self.optimizer_state_dict = None if self.optimizer_state_dict: self.optimizer_name = optimizer_saved_dict.get('optimizer_name', 'AdamW') - print("Loaded existing optimizer from checkpoint") - print(f"Optimizer name is {self.optimizer_name}") + if shared.opts.print_hypernet_extra: + print("Loaded existing optimizer from checkpoint") + print(f"Optimizer name is {self.optimizer_name}") else: self.optimizer_name = "AdamW" - print("No saved optimizer exists in checkpoint") + if shared.opts.print_hypernet_extra: + print("No saved optimizer exists in checkpoint") for size, sd in state_dict.items(): if type(size) == int: @@ -306,23 +320,43 @@ def list_hypernetworks(path): return res -def load_hypernetwork(filename): - path = shared.hypernetworks.get(filename, None) - # Prevent any file named "None.pt" from being loaded. - if path is not None and filename != "None": - print(f"Loading hypernetwork {filename}") - try: - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) +def load_hypernetwork(name): + path = shared.hypernetworks.get(name, None) - except Exception: - print(f"Error loading hypernetwork {path}", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - else: - if shared.loaded_hypernetwork is not None: - print("Unloading hypernetwork") + if path is None: + return None + + hypernetwork = Hypernetwork() + + try: + hypernetwork.load(path) + except Exception: + print(f"Error loading hypernetwork {path}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return None + + return hypernetwork + + +def load_hypernetworks(names, multipliers=None): + already_loaded = {} + + for hypernetwork in shared.loaded_hypernetworks: + if hypernetwork.name in names: + already_loaded[hypernetwork.name] = hypernetwork - shared.loaded_hypernetwork = None + shared.loaded_hypernetworks.clear() + + for i, name in enumerate(names): + hypernetwork = already_loaded.get(name, None) + if hypernetwork is None: + hypernetwork = load_hypernetwork(name) + + if hypernetwork is None: + continue + + hypernetwork.set_multiplier(multipliers[i] if multipliers else 1.0) + shared.loaded_hypernetworks.append(hypernetwork) def find_closest_hypernetwork_name(search: str): @@ -336,18 +370,27 @@ def find_closest_hypernetwork_name(search: str): return applicable[0] -def apply_hypernetwork(hypernetwork, context, layer=None): - hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None) +def apply_single_hypernetwork(hypernetwork, context_k, context_v, layer=None): + hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context_k.shape[2], None) if hypernetwork_layers is None: - return context, context + return context_k, context_v if layer is not None: layer.hyper_k = hypernetwork_layers[0] layer.hyper_v = hypernetwork_layers[1] - context_k = hypernetwork_layers[0](context) - context_v = hypernetwork_layers[1](context) + context_k = hypernetwork_layers[0](context_k) + context_v = hypernetwork_layers[1](context_v) + return context_k, context_v + + +def apply_hypernetworks(hypernetworks, context, layer=None): + context_k = context + context_v = context + for hypernetwork in hypernetworks: + context_k, context_v = apply_single_hypernetwork(hypernetwork, context_k, context_v, layer) + return context_k, context_v @@ -357,7 +400,7 @@ def attention_CrossAttention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = apply_hypernetwork(shared.loaded_hypernetwork, context, self) + context_k, context_v = apply_hypernetworks(shared.loaded_hypernetworks, context, self) k = self.to_k(context_k) v = self.to_v(context_v) @@ -464,8 +507,9 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi template_file = template_file.path path = shared.hypernetworks.get(hypernetwork_name, None) - shared.loaded_hypernetwork = Hypernetwork() - shared.loaded_hypernetwork.load(path) + hypernetwork = Hypernetwork() + hypernetwork.load(path) + shared.loaded_hypernetworks = [hypernetwork] shared.state.job = "train-hypernetwork" shared.state.textinfo = "Initializing hypernetwork training..." @@ -489,7 +533,6 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi else: images_dir = None - hypernetwork = shared.loaded_hypernetwork checkpoint = sd_models.select_checkpoint() initial_step = hypernetwork.step or 0 diff --git a/modules/hypernetworks/ui.py b/modules/hypernetworks/ui.py index 81e3f519..76599f5a 100644 --- a/modules/hypernetworks/ui.py +++ b/modules/hypernetworks/ui.py @@ -9,6 +9,7 @@ from modules import devices, sd_hijack, shared not_available = ["hardswish", "multiheadattention"] keys = list(x for x in modules.hypernetworks.hypernetwork.HypernetworkModule.activation_dict.keys() if x not in not_available) + def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, activation_func=None, weight_init=None, add_layer_norm=False, use_dropout=False, dropout_structure=None): filename = modules.hypernetworks.hypernetwork.create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure, activation_func, weight_init, add_layer_norm, use_dropout, dropout_structure) @@ -16,8 +17,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(*args): - - initial_hypernetwork = shared.loaded_hypernetwork + shared.loaded_hypernetworks = [] assert not shared.cmd_opts.lowvram, 'Training models with lowvram is not possible' @@ -34,7 +34,6 @@ Hypernetwork saved to {html.escape(filename)} except Exception: raise finally: - shared.loaded_hypernetwork = initial_hypernetwork shared.sd_model.cond_stage_model.to(devices.device) shared.sd_model.first_stage_model.to(devices.device) sd_hijack.apply_optimizations() diff --git a/modules/processing.py b/modules/processing.py index a3e9f709..b5deeacf 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -438,9 +438,6 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info or not shared.sd_model.sd_checkpoint_info.model_name else shared.sd_model.sd_checkpoint_info.model_name.replace(',', '').replace(':', '')), - "Hypernet": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.name), - "Hypernet hash": (None if shared.loaded_hypernetwork is None else shared.loaded_hypernetwork.shorthash()), - "Hypernet strength": (None if shared.loaded_hypernetwork is None or shared.opts.sd_hypernetwork_strength >= 1 else shared.opts.sd_hypernetwork_strength), "Batch size": (None if p.batch_size < 2 else p.batch_size), "Batch pos": (None if p.batch_size < 2 else position_in_batch), "Variation seed": (None if p.subseed_strength == 0 else all_subseeds[index]), @@ -468,14 +465,12 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try: for k, v in p.override_settings.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': - shared.reload_hypernetworks() # make onchange call for changing hypernet if k == 'sd_model_checkpoint': - sd_models.reload_model_weights() # make onchange call for changing SD model + sd_models.reload_model_weights() if k == 'sd_vae': - sd_vae.reload_vae_weights() # make onchange call for changing VAE + sd_vae.reload_vae_weights() res = process_images_inner(p) @@ -484,9 +479,11 @@ def process_images(p: StableDiffusionProcessing) -> Processed: if p.override_settings_restore_afterwards: for k, v in stored_opts.items(): setattr(opts, k, v) - if k == 'sd_hypernetwork': shared.reload_hypernetworks() - if k == 'sd_model_checkpoint': sd_models.reload_model_weights() - if k == 'sd_vae': sd_vae.reload_vae_weights() + if k == 'sd_model_checkpoint': + sd_models.reload_model_weights() + + if k == 'sd_vae': + sd_vae.reload_vae_weights() return res @@ -564,10 +561,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: cache[0] = (required_prompts, steps) return cache[1] + p.all_prompts, extra_network_data = extra_networks.parse_prompts(p.all_prompts) + with torch.no_grad(), p.sd_model.ema_scope(): with devices.autocast(): p.init(p.all_prompts, p.all_seeds, p.all_subseeds) + extra_networks.activate(p, extra_network_data) + with open(os.path.join(shared.script_path, "params.txt"), "w", encoding="utf8") as file: processed = Processed(p, [], p.seed, "") file.write(processed.infotext(p, 0)) @@ -681,6 +682,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if opts.grid_save: images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], opts.grid_format, info=infotext(), short_filename=not opts.grid_extended_filename, p=p, grid=True) + extra_networks.deactivate(p, extra_network_data) devices.torch_gc() res = Processed(p, output_images, p.all_seeds[0], infotext(), comments="".join(["\n\n" + x for x in comments]), subseed=p.all_subseeds[0], index_of_first_image=index_of_first_image, infotexts=infotexts) diff --git a/modules/sd_hijack_optimizations.py b/modules/sd_hijack_optimizations.py index cdc63ed7..4fa54329 100644 --- a/modules/sd_hijack_optimizations.py +++ b/modules/sd_hijack_optimizations.py @@ -44,7 +44,7 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) del context, context_k, context_v, x @@ -78,7 +78,7 @@ def split_cross_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) @@ -203,7 +203,7 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) * self.scale v = self.to_v(context_v) del context, context_k, context_v, x @@ -225,7 +225,7 @@ def sub_quad_attention_forward(self, x, context=None, mask=None): q = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k = self.to_k(context_k) v = self.to_v(context_v) del context, context_k, context_v, x @@ -284,7 +284,7 @@ def xformers_attention_forward(self, x, context=None, mask=None): q_in = self.to_q(x) context = default(context, x) - context_k, context_v = hypernetwork.apply_hypernetwork(shared.loaded_hypernetwork, context) + context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context) k_in = self.to_k(context_k) v_in = self.to_v(context_v) diff --git a/modules/shared.py b/modules/shared.py index 2f366454..c0e11f18 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -23,6 +23,7 @@ demo = None sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml") sd_model_file = os.path.join(script_path, 'model.ckpt') default_sd_model_file = sd_model_file + parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) @@ -145,7 +146,7 @@ config_filename = cmd_opts.ui_settings_file os.makedirs(cmd_opts.hypernetwork_dir, exist_ok=True) hypernetworks = {} -loaded_hypernetwork = None +loaded_hypernetworks = [] def reload_hypernetworks(): @@ -153,8 +154,6 @@ def reload_hypernetworks(): global hypernetworks hypernetworks = hypernetwork.list_hypernetworks(cmd_opts.hypernetwork_dir) - hypernetwork.load_hypernetwork(opts.sd_hypernetwork) - class State: @@ -399,8 +398,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}), "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list), "sd_vae_as_default": OptionInfo(True, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"), - "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks), - "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}), "inpainting_mask_weight": OptionInfo(1.0, "Inpainting conditioning mask strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}), "initial_noise_multiplier": OptionInfo(1.0, "Noise multiplier for img2img", gr.Slider, {"minimum": 0.5, "maximum": 1.5, "step": 0.01 }), "img2img_color_correction": OptionInfo(False, "Apply color correction to img2img results to match original colors."), @@ -661,3 +658,17 @@ mem_mon.start() def listfiles(dirname): filenames = [os.path.join(dirname, x) for x in sorted(os.listdir(dirname)) if not x.startswith(".")] return [file for file in filenames if os.path.isfile(file)] + + +def html_path(filename): + return os.path.join(script_path, "html", filename) + + +def html(filename): + path = html_path(filename) + + if os.path.exists(path): + with open(path, encoding="utf8") as file: + return file.read() + + return "" diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 5a7be422..4e90f690 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -50,6 +50,7 @@ class Embedding: self.sd_checkpoint = None self.sd_checkpoint_name = None self.optimizer_state_dict = None + self.filename = None def save(self, filename): embedding_data = { @@ -182,6 +183,7 @@ class EmbeddingDatabase: embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None) embedding.vectors = vec.shape[0] embedding.shape = vec.shape[-1] + embedding.filename = path if self.expected_shape == -1 or self.expected_shape == embedding.shape: self.register_embedding(embedding, shared.sd_model) diff --git a/modules/ui.py b/modules/ui.py index 06c11848..d23b2b8e 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -20,7 +20,7 @@ import numpy as np from PIL import Image, PngImagePlugin from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path @@ -90,6 +90,7 @@ refresh_symbol = '\U0001f504' # 🔄 save_style_symbol = '\U0001f4be' # 💾 apply_style_symbol = '\U0001f4cb' # 📋 clear_prompt_symbol = '\U0001F5D1' # 🗑️ +extra_networks_symbol = '\U0001F3B4' # 🎴 def plaintext_to_html(text): @@ -324,6 +325,8 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info: def update_token_counter(text, steps): try: + text, _ = extra_networks.parse_prompt(text) + _, prompt_flat_list, _ = prompt_parser.get_multicond_prompt_list([text]) prompt_schedules = prompt_parser.get_learned_conditioning_prompt_schedules(prompt_flat_list, steps) @@ -354,10 +357,10 @@ def create_toprow(is_img2img): negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=2, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)") with gr.Column(scale=1, elem_id="roll_col"): - paste = gr.Button(value=paste_symbol, elem_id="paste") - save_style = gr.Button(value=save_style_symbol, elem_id="style_create") - prompt_style_apply = gr.Button(value=apply_style_symbol, elem_id="style_apply") - clear_prompt_button = gr.Button(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + paste = ToolButton(value=paste_symbol, elem_id="paste") + clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt") + extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks") + token_counter = gr.HTML(value="", elem_id=f"{id_part}_token_counter") token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button") negative_token_counter = gr.HTML(value="", elem_id=f"{id_part}_negative_token_counter") @@ -395,11 +398,14 @@ def create_toprow(is_img2img): outputs=[], ) - with gr.Row(): + with gr.Row(elem_id=f"{id_part}_styles_row"): prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button, negative_token_counter, negative_token_button + prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id="style_apply") + save_style = ToolButton(value=save_style_symbol, elem_id="style_create") + + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button def setup_progressbar(*args, **kwargs): @@ -616,11 +622,15 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) + with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img') + with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): for category in ordered_ui_categories(): @@ -794,14 +804,20 @@ def create_ui(): token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) + modules.scripts.scripts_current = modules.scripts.scripts_img2img modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button = create_toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) + with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks: + from modules import ui_extra_networks + extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img') + with FormRow().style(equal_height=False): with gr.Column(variant='compact', elem_id="img2img_settings"): copy_image_buttons = [] @@ -1064,6 +1080,8 @@ def create_ui(): token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter]) + ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery) + img2img_paste_fields = [ (img2img_prompt, "Prompt"), (img2img_negative_prompt, "Negative prompt"), @@ -1666,10 +1684,8 @@ def create_ui(): download_localization = gr.Button(value='Download localization template', elem_id="download_localization") reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies") - if os.path.exists("html/licenses.html"): - with open("html/licenses.html", encoding="utf8") as file: - with gr.TabItem("Licenses"): - gr.HTML(file.read(), elem_id="licenses") + with gr.TabItem("Licenses"): + gr.HTML(shared.html("licenses.html"), elem_id="licenses") gr.Button(value="Show all pages", elem_id="settings_show_all_pages") @@ -1756,11 +1772,9 @@ def create_ui(): if os.path.exists(os.path.join(script_path, "notification.mp3")): audio_notification = gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False) - if os.path.exists("html/footer.html"): - with open("html/footer.html", encoding="utf8") as file: - footer = file.read() - footer = footer.format(versions=versions_html()) - gr.HTML(footer, elem_id="footer") + footer = shared.html("footer.html") + footer = footer.format(versions=versions_html()) + gr.HTML(footer, elem_id="footer") text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False) settings_submit.click( diff --git a/modules/ui_components.py b/modules/ui_components.py index 97acff06..46324425 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -11,6 +11,16 @@ class ToolButton(gr.Button, gr.components.FormComponent): return "button" +class ToolButtonTop(gr.Button, gr.components.FormComponent): + """Small button with single emoji as text, with extra margin at top, fits inside gradio forms""" + + def __init__(self, **kwargs): + super().__init__(variant="tool-top", **kwargs) + + def get_block_name(self): + return "button" + + class FormRow(gr.Row, gr.components.FormComponent): """Same as gr.Row but fits inside gradio forms""" diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py new file mode 100644 index 00000000..253e90f7 --- /dev/null +++ b/modules/ui_extra_networks.py @@ -0,0 +1,149 @@ +import os.path + +from modules import shared +import gradio as gr +import json + +from modules.generation_parameters_copypaste import image_from_url_text + +extra_pages = [] + + +def register_page(page): + """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions""" + + extra_pages.append(page) + + +class ExtraNetworksPage: + def __init__(self, title): + self.title = title + self.card_page = shared.html("extra-networks-card.html") + self.allow_negative_prompt = False + + def refresh(self): + pass + + def create_html(self, tabname): + items_html = '' + + for item in self.list_items(): + items_html += self.create_html_for_item(item, tabname) + + if items_html == '': + dirs = "".join([f"
  • {x}
  • " for x in self.allowed_directories_for_previews()]) + items_html = shared.html("extra-networks-no-cards.html").format(dirs=dirs) + + res = "
    " + items_html + "
    " + + return res + + def list_items(self): + raise NotImplementedError() + + def allowed_directories_for_previews(self): + return [] + + def create_html_for_item(self, item, tabname): + preview = item.get("preview", None) + + args = { + "preview_html": "style='background-image: url(" + json.dumps(preview) + ")'" if preview else '', + "prompt": json.dumps(item["prompt"]), + "tabname": json.dumps(tabname), + "local_preview": json.dumps(item["local_preview"]), + "name": item["name"], + "allow_negative_prompt": "true" if self.allow_negative_prompt else "false", + } + + return self.card_page.format(**args) + + +def intialize(): + extra_pages.clear() + + +class ExtraNetworksUi: + def __init__(self): + self.pages = None + self.stored_extra_pages = None + + self.button_save_preview = None + self.preview_target_filename = None + + self.tabname = None + + +def create_ui(container, button, tabname): + ui = ExtraNetworksUi() + ui.pages = [] + ui.stored_extra_pages = extra_pages.copy() + ui.tabname = tabname + + with gr.Tabs(elem_id=tabname+"_extra_tabs") as tabs: + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + button_close = gr.Button('Close', elem_id=tabname+"_extra_close") + + for page in ui.stored_extra_pages: + with gr.Tab(page.title): + page_elem = gr.HTML(page.create_html(ui.tabname)) + ui.pages.append(page_elem) + + ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) + + button.click(fn=lambda: gr.update(visible=True), inputs=[], outputs=[container]) + button_close.click(fn=lambda: gr.update(visible=False), inputs=[], outputs=[container]) + + def refresh(): + res = [] + + for pg in ui.stored_extra_pages: + pg.refresh() + res.append(pg.create_html(ui.tabname)) + + return res + + button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) + + return ui + + +def path_is_parent(parent_path, child_path): + parent_path = os.path.abspath(parent_path) + child_path = os.path.abspath(child_path) + + return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path]) + + +def setup_ui(ui, gallery): + def save_preview(index, images, filename): + if len(images) == 0: + print("There is no image in gallery to save as a preview.") + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + index = int(index) + index = 0 if index < 0 else index + index = len(images) - 1 if index >= len(images) else index + + img_info = images[index if index >= 0 else 0] + image = image_from_url_text(img_info) + + is_allowed = False + for extra_page in ui.stored_extra_pages: + if any([path_is_parent(x, filename) for x in extra_page.allowed_directories_for_previews()]): + is_allowed = True + break + + assert is_allowed, f'writing to {filename} is not allowed' + + image.save(filename) + + return [page.create_html(ui.tabname) for page in ui.stored_extra_pages] + + ui.button_save_preview.click( + fn=save_preview, + _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}", + inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename], + outputs=[*ui.pages] + ) diff --git a/modules/ui_extra_networks_hypernets.py b/modules/ui_extra_networks_hypernets.py new file mode 100644 index 00000000..312dbaf0 --- /dev/null +++ b/modules/ui_extra_networks_hypernets.py @@ -0,0 +1,34 @@ +import os + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageHypernetworks(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Hypernetworks') + + def refresh(self): + shared.reload_hypernetworks() + + def list_items(self): + for name, path in shared.hypernetworks.items(): + path, ext = os.path.splitext(path) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "prompt": f"", + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [shared.cmd_opts.hypernetwork_dir] + diff --git a/modules/ui_extra_networks_textual_inversion.py b/modules/ui_extra_networks_textual_inversion.py new file mode 100644 index 00000000..e4a6e3bf --- /dev/null +++ b/modules/ui_extra_networks_textual_inversion.py @@ -0,0 +1,32 @@ +import os + +from modules import ui_extra_networks, sd_hijack + + +class ExtraNetworksPageTextualInversion(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Textual Inversion') + self.allow_negative_prompt = True + + def refresh(self): + sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True) + + def list_items(self): + for embedding in sd_hijack.model_hijack.embedding_db.word_embeddings.values(): + path, ext = os.path.splitext(embedding.filename) + preview_file = path + ".preview.png" + + preview = None + if os.path.isfile(preview_file): + preview = "./file=" + preview_file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(preview_file)) + + yield { + "name": embedding.name, + "filename": embedding.filename, + "preview": preview, + "prompt": embedding.name, + "local_preview": path + ".preview.png", + } + + def allowed_directories_for_previews(self): + return list(sd_hijack.model_hijack.embedding_db.embedding_dirs) diff --git a/script.js b/script.js index 3345e32b..97e0bfcf 100644 --- a/script.js +++ b/script.js @@ -13,6 +13,7 @@ function get_uiCurrentTabContent() { } uiUpdateCallbacks = [] +uiLoadedCallbacks = [] uiTabChangeCallbacks = [] optionsChangedCallbacks = [] let uiCurrentTab = null @@ -20,6 +21,9 @@ let uiCurrentTab = null function onUiUpdate(callback){ uiUpdateCallbacks.push(callback) } +function onUiLoaded(callback){ + uiLoadedCallbacks.push(callback) +} function onUiTabChange(callback){ uiTabChangeCallbacks.push(callback) } @@ -38,8 +42,15 @@ function executeCallbacks(queue, m) { queue.forEach(function(x){runCallback(x, m)}) } +var executedOnLoaded = false; + document.addEventListener("DOMContentLoaded", function() { var mutationObserver = new MutationObserver(function(m){ + if(!executedOnLoaded && gradioApp().querySelector('#txt2img_prompt')){ + executedOnLoaded = true; + executeCallbacks(uiLoadedCallbacks); + } + executeCallbacks(uiUpdateCallbacks, m); const newTab = get_uiCurrentTab(); if ( newTab && ( newTab !== uiCurrentTab ) ) { @@ -53,7 +64,7 @@ document.addEventListener("DOMContentLoaded", function() { /** * Add a ctrl+enter as a shortcut to start a generation */ - document.addEventListener('keydown', function(e) { +document.addEventListener('keydown', function(e) { var handled = false; if (e.key !== undefined) { if((e.key == "Enter" && (e.metaKey || e.ctrlKey || e.altKey))) handled = true; diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 6629f5d5..b1badec9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -11,7 +11,6 @@ import modules.scripts as scripts import gradio as gr from modules import images, paths, sd_samplers, processing, sd_models, sd_vae -from modules.hypernetworks import hypernetwork from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -94,28 +93,6 @@ def confirm_checkpoints(p, xs): raise RuntimeError(f"Unknown checkpoint: {x}") -def apply_hypernetwork(p, x, xs): - if x.lower() in ["", "none"]: - name = None - else: - name = hypernetwork.find_closest_hypernetwork_name(x) - if not name: - raise RuntimeError(f"Unknown hypernetwork: {x}") - hypernetwork.load_hypernetwork(name) - - -def apply_hypernetwork_strength(p, x, xs): - hypernetwork.apply_strength(x) - - -def confirm_hypernetworks(p, xs): - for x in xs: - if x.lower() in ["", "none"]: - continue - if not hypernetwork.find_closest_hypernetwork_name(x): - raise RuntimeError(f"Unknown hypernetwork: {x}") - - def apply_clip_skip(p, x, xs): opts.data["CLIP_stop_at_last_layers"] = x @@ -208,8 +185,6 @@ axis_options = [ AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), - AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)), - AxisOption("Hypernet str.", float, apply_hypernetwork_strength), AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma min", float, apply_field("s_tmin")), AxisOption("Sigma max", float, apply_field("s_tmax")), @@ -291,7 +266,6 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_ class SharedSettingsStackHelper(object): def __enter__(self): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers - self.hypernetwork = opts.sd_hypernetwork self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): @@ -299,9 +273,6 @@ class SharedSettingsStackHelper(object): modules.sd_models.reload_model_weights() modules.sd_vae.reload_vae_weights() - hypernetwork.load_hypernetwork(self.hypernetwork) - hypernetwork.apply_strength() - opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers diff --git a/style.css b/style.css index 3a515ebd..5e8bc2ca 100644 --- a/style.css +++ b/style.css @@ -132,13 +132,6 @@ } #roll_col > button { - min-width: 2em; - min-height: 2em; - max-width: 2em; - max-height: 2em; - flex-grow: 0; - padding-left: 0.25em; - padding-right: 0.25em; margin: 0.1em 0; } @@ -146,9 +139,10 @@ min-width: 0 !important; max-width: 8em !important; margin-right: 1em; + gap: 0; } #interrogate, #deepbooru{ - margin: 0em 0.25em 0.9em 0.25em; + margin: 0em 0.25em 0.5em 0.25em; min-width: 8em; max-width: 8em; } @@ -157,8 +151,17 @@ min-width: 8em !important; } +#txt2img_styles_row, #img2img_styles_row{ + gap: 0.25em; + margin-top: 0.5em; +} + +#txt2img_styles_row > button, #img2img_styles_row > button{ + margin: 0; +} + #txt2img_styles, #img2img_styles{ - margin-top: 1em; + padding: 0; } #txt2img_styles ul, #img2img_styles ul{ @@ -635,17 +638,21 @@ canvas[key="mask"] { background-color: rgb(31 41 55 / var(--tw-bg-opacity)); } -.gr-button-tool{ +.gr-button-tool, .gr-button-tool-top{ max-width: 2.5em; min-width: 2.5em !important; height: 2.4em; - margin: 1.6em 0.7em 0.55em 0; } -#tab_modelmerger .gr-button-tool{ +.gr-button-tool{ margin: 0.6em 0em 0.55em 0; } +.gr-button-tool-top, #settings .gr-button-tool{ + margin: 1.6em 0.7em 0.55em 0; +} + + #modelmerger_results_container{ margin-top: 1em; overflow: visible; @@ -763,81 +770,88 @@ footer { line-height: 2.4em; } -/* The following handles localization for right-to-left (RTL) languages like Arabic. -The rtl media type will only be activated by the logic in javascript/localization.js. -If you change anything above, you need to make sure it is RTL compliant by just running -your changes through converters like https://cssjanus.github.io/ or https://rtlcss.com/. -Then, you will need to add the RTL counterpart only if needed in the rtl section below.*/ -@media rtl { - /* this part was added manually */ - :host { - direction: rtl; - } - select, .file-preview, .gr-text-input, .output-html:has(.performance), #ti_progress { - direction: ltr; - } - #script_list > label > select, - #x_type > label > select, - #y_type > label > select { - direction: rtl; - } - .gr-radio, .gr-checkbox{ - margin-left: 0.25em; - } +#txt2img_extra_networks, #img2img_extra_networks{ + margin-top: -1em; +} - /* automatically generated with few manual modifications */ - .performance .time { - margin-right: unset; - margin-left: 0; - } - .justify-center.overflow-x-scroll { - justify-content: right; - } - .justify-center.overflow-x-scroll button:first-of-type { - margin-left: unset; - margin-right: auto; - } - .justify-center.overflow-x-scroll button:last-of-type { - margin-right: unset; - margin-left: auto; - } - #settings fieldset span.text-gray-500, #settings .gr-block.gr-box span.text-gray-500, #settings label.block span{ - margin-right: unset; - margin-left: 8em; - } - #txt2img_progressbar, #img2img_progressbar, #ti_progressbar{ - right: unset; - left: 0; - } - .progressDiv .progress{ - padding: 0 0 0 8px; - text-align: left; - } - #lightboxModal{ - left: unset; - right: 0; - } - .modalPrev, .modalNext{ - border-radius: 3px 0 0 3px; - } - .modalNext { - right: unset; - left: 0; - border-radius: 0 3px 3px 0; - } - #imageARPreview{ - left:unset; - right:0px; - } - #txt2img_skip, #img2img_skip{ - right: unset; - left: 0px; - } - #context-menu{ - box-shadow:-1px 1px 2px #CE6400; - } - .gr-box > div > div > input.gr-text-input{ - right: unset; - left: 0.5em; - } +.extra-networks > div > [id *= '_extra_']{ + margin: 0.3em; } + +.extra-network-cards .nocards{ + margin: 1.25em 0.5em 0.5em 0.5em; +} + +.extra-network-cards .nocards h1{ + font-size: 1.5em; + margin-bottom: 1em; +} + +.extra-network-cards .nocards li{ + margin-left: 0.5em; +} + +.extra-network-cards .card{ + display: inline-block; + margin: 0.5em; + width: 16em; + height: 24em; + box-shadow: 0 0 5px rgba(128, 128, 128, 0.5); + border-radius: 0.2em; + position: relative; + + background-size: auto 100%; + background-position: center; + overflow: hidden; + cursor: pointer; + + background-image: url('./file=html/card-no-preview.png') +} + +.extra-network-cards .card:hover{ + box-shadow: 0 0 2px 0.3em rgba(0, 128, 255, 0.35); +} + +.extra-network-cards .card .actions .additional{ + display: none; +} + +.extra-network-cards .card .actions{ + position: absolute; + bottom: 0; + left: 0; + right: 0; + padding: 0.5em; + color: white; + background: rgba(0,0,0,0.5); + box-shadow: 0 0 0.25em 0.25em rgba(0,0,0,0.5); + text-shadow: 0 0 0.2em black; +} + +.extra-network-cards .card .actions:hover{ + box-shadow: 0 0 0.75em 0.75em rgba(0,0,0,0.5) !important; +} + +.extra-network-cards .card .actions .name{ + font-size: 1.7em; + font-weight: bold; + line-break: anywhere; +} + +.extra-network-cards .card .actions:hover .additional{ + display: block; +} + +.extra-network-cards .card ul{ + margin: 0.25em 0 0.75em 0.25em; + cursor: unset; +} + +.extra-network-cards .card ul a{ + cursor: pointer; +} + +.extra-network-cards .card ul a:hover{ + color: red; +} + diff --git a/webui.py b/webui.py index 865a7300..e8dd822a 100644 --- a/webui.py +++ b/webui.py @@ -9,16 +9,18 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from modules import import_hook, errors +from modules import import_hook, errors, extra_networks +from modules import extra_networks_hypernet, ui_extra_networks_hypernets, ui_extra_networks_textual_inversion from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call from modules.paths import script_path import torch + # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) -from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -84,10 +86,17 @@ def initialize(): shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks())) - shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: try: @@ -209,6 +218,15 @@ def webui(): modules.sd_models.list_models() + shared.reload_hypernetworks() + + ui_extra_networks.intialize() + ui_extra_networks.register_page(ui_extra_networks_textual_inversion.ExtraNetworksPageTextualInversion()) + ui_extra_networks.register_page(ui_extra_networks_hypernets.ExtraNetworksPageHypernetworks()) + + extra_networks.initialize() + extra_networks.register_extra_network(extra_networks_hypernet.ExtraNetworkHypernet()) + if __name__ == "__main__": if cmd_opts.nowebui: -- cgit v1.2.3 From ac2eb97db90fe35cdea00d3fdd4680289259bd42 Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sat, 21 Jan 2023 22:43:37 +0300 Subject: fix auto fill and repair separate axisoptions --- scripts/xy_grid.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index b1badec9..8ff315a7 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -165,10 +165,14 @@ class AxisOption: self.confirm = confirm self.cost = cost self.choices = choices - self.is_img2img = False class AxisOptionImg2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = True + +class AxisOptionTxt2Img(AxisOption): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.is_img2img = False @@ -183,7 +187,8 @@ axis_options = [ AxisOption("CFG Scale", float, apply_field("cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), - AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), AxisOption("Sigma Churn", float, apply_field("s_churn")), AxisOption("Sigma min", float, apply_field("s_tmin")), @@ -192,8 +197,8 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta")), AxisOption("Clip skip", int, apply_clip_skip), AxisOption("Denoising", float, apply_field("denoising_strength")), - AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]), - AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), + AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), + AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), ] @@ -288,7 +293,7 @@ class Script(scripts.Script): return "X/Y plot" def ui(self, is_img2img): - current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img] + current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] with gr.Row(): with gr.Column(scale=19): @@ -316,14 +321,14 @@ class Script(scripts.Script): swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) def fill(x_type): - axis = axis_options[x_type] + axis = current_axis_options[x_type] return ", ".join(axis.choices()) if axis.choices else gr.update() fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) def select_axis(x_type): - return gr.Button.update(visible=axis_options[x_type].choices is not None) + return gr.Button.update(visible=current_axis_options[x_type].choices is not None) x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) -- cgit v1.2.3 From e5520232e853656e10e4a06f38db24f199474aba Mon Sep 17 00:00:00 2001 From: Vladimir Repin <32306715+mezotaken@users.noreply.github.com> Date: Sat, 21 Jan 2023 23:58:59 +0300 Subject: make current_axis_options class variable --- scripts/xy_grid.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8ff315a7..98254c64 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -293,17 +293,17 @@ class Script(scripts.Script): return "X/Y plot" def ui(self, is_img2img): - current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] + self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] with gr.Row(): with gr.Column(scale=19): with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) @@ -314,21 +314,20 @@ class Script(scripts.Script): swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") def swap_axes(x_type, x_values, y_type, y_values): - nonlocal current_axis_options - return current_axis_options[y_type].label, y_values, current_axis_options[x_type].label, x_values + return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values swap_args = [x_type, x_values, y_type, y_values] swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) def fill(x_type): - axis = current_axis_options[x_type] + axis = self.current_axis_options[x_type] return ", ".join(axis.choices()) if axis.choices else gr.update() fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) def select_axis(x_type): - return gr.Button.update(visible=current_axis_options[x_type].choices is not None) + return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) @@ -403,10 +402,10 @@ class Script(scripts.Script): return valslist - x_opt = axis_options[x_type] + x_opt = self.current_axis_options[x_type] xs = process_axis(x_opt, x_values) - y_opt = axis_options[y_type] + y_opt = self.current_axis_options[y_type] ys = process_axis(y_opt, y_values) def fix_axis_seeds(axis_opt, axis_list): -- cgit v1.2.3 From 5560150fdaf5d974a122f0b226d6abe24dea12c0 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sat, 21 Jan 2023 16:58:45 -0500 Subject: aligns the axis buttons in x/y plot --- scripts/xy_grid.py | 2 +- style.css | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8ff315a7..0caece09 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -307,7 +307,7 @@ class Script(scripts.Script): y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) - with gr.Row(variant="compact"): + with gr.Row(variant="compact", elem_id="axis_options"): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) diff --git a/style.css b/style.css index b215405d..bf8260d7 100644 --- a/style.css +++ b/style.css @@ -722,6 +722,10 @@ footer { margin-left: 0em; } +#axis_options { + margin-left: 0em; +} + .inactive{ opacity: 0.5; } -- cgit v1.2.3 From 8a3f85c4cc1910f59e04b5c8355a30c4c42431e5 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Sun, 22 Jan 2023 17:08:08 -0500 Subject: adds hires steps to x/y plot and fixes total_steps calculation --- scripts/xy_grid.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 8ff315a7..5990b78d 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -184,6 +184,7 @@ axis_options = [ AxisOption("Var. seed", int, apply_field("subseed")), AxisOption("Var. strength", float, apply_field("subseed_strength")), AxisOption("Steps", int, apply_field("steps")), + AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), AxisOption("CFG Scale", float, apply_field("cfg_scale")), AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), @@ -427,10 +428,21 @@ class Script(scripts.Script): total_steps = p.steps * len(xs) * len(ys) if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: - total_steps *= 2 + if x_opt.label == "Hires steps": + total_steps += sum(xs) * len(ys) + elif y_opt.label == "Hires steps": + total_steps += sum(ys) * len(xs) + elif p.hr_second_pass_steps: + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) + else: + total_steps *= 2 + + total_steps *= p.n_iter - print(f"X/Y plot will create {len(xs) * len(ys) * p.n_iter} images on a {len(xs)}x{len(ys)} grid. (Total steps to process: {total_steps * p.n_iter})") - shared.total_tqdm.updateTotal(total_steps * p.n_iter) + image_cell_count = p.n_iter * p.batch_size + cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" + print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})") + shared.total_tqdm.updateTotal(total_steps) grid_infotext = [None] -- cgit v1.2.3 From d30ac02f28bf5fa1ca5d4ba444180ba9e50b4860 Mon Sep 17 00:00:00 2001 From: EllangoK Date: Tue, 24 Jan 2023 02:21:32 -0500 Subject: renamed xy to xyz grid this is mostly just so git can detect it properly --- scripts/xy_grid.py | 498 ---------------------------------------------------- scripts/xyz_grid.py | 498 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 498 insertions(+), 498 deletions(-) delete mode 100644 scripts/xy_grid.py create mode 100644 scripts/xyz_grid.py (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py deleted file mode 100644 index 1a452355..00000000 --- a/scripts/xy_grid.py +++ /dev/null @@ -1,498 +0,0 @@ -from collections import namedtuple -from copy import copy -from itertools import permutations, chain -import random -import csv -from io import StringIO -from PIL import Image -import numpy as np - -import modules.scripts as scripts -import gradio as gr - -from modules import images, paths, sd_samplers, processing, sd_models, sd_vae -from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img -from modules.shared import opts, cmd_opts, state -import modules.shared as shared -import modules.sd_samplers -import modules.sd_models -import modules.sd_vae -import glob -import os -import re - -from modules.ui_components import ToolButton - -fill_values_symbol = "\U0001f4d2" # 📒 - - -def apply_field(field): - def fun(p, x, xs): - setattr(p, field, x) - - return fun - - -def apply_prompt(p, x, xs): - if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: - raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.") - - p.prompt = p.prompt.replace(xs[0], x) - p.negative_prompt = p.negative_prompt.replace(xs[0], x) - - -def apply_order(p, x, xs): - token_order = [] - - # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen - for token in x: - token_order.append((p.prompt.find(token), token)) - - token_order.sort(key=lambda t: t[0]) - - prompt_parts = [] - - # Split the prompt up, taking out the tokens - for _, token in token_order: - n = p.prompt.find(token) - prompt_parts.append(p.prompt[0:n]) - p.prompt = p.prompt[n + len(token):] - - # Rebuild the prompt with the tokens in the order we want - prompt_tmp = "" - for idx, part in enumerate(prompt_parts): - prompt_tmp += part - prompt_tmp += x[idx] - p.prompt = prompt_tmp + p.prompt - - -def apply_sampler(p, x, xs): - sampler_name = sd_samplers.samplers_map.get(x.lower(), None) - if sampler_name is None: - raise RuntimeError(f"Unknown sampler: {x}") - - p.sampler_name = sampler_name - - -def confirm_samplers(p, xs): - for x in xs: - if x.lower() not in sd_samplers.samplers_map: - raise RuntimeError(f"Unknown sampler: {x}") - - -def apply_checkpoint(p, x, xs): - info = modules.sd_models.get_closet_checkpoint_match(x) - if info is None: - raise RuntimeError(f"Unknown checkpoint: {x}") - modules.sd_models.reload_model_weights(shared.sd_model, info) - - -def confirm_checkpoints(p, xs): - for x in xs: - if modules.sd_models.get_closet_checkpoint_match(x) is None: - raise RuntimeError(f"Unknown checkpoint: {x}") - - -def apply_clip_skip(p, x, xs): - opts.data["CLIP_stop_at_last_layers"] = x - - -def apply_upscale_latent_space(p, x, xs): - if x.lower().strip() != '0': - opts.data["use_scale_latent_for_hires_fix"] = True - else: - opts.data["use_scale_latent_for_hires_fix"] = False - - -def find_vae(name: str): - if name.lower() in ['auto', 'automatic']: - return modules.sd_vae.unspecified - if name.lower() == 'none': - return None - else: - choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] - if len(choices) == 0: - print(f"No VAE found for {name}; using automatic") - return modules.sd_vae.unspecified - else: - return modules.sd_vae.vae_dict[choices[0]] - - -def apply_vae(p, x, xs): - modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) - - -def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): - p.styles = x.split(',') - - -def format_value_add_label(p, opt, x): - if type(x) == float: - x = round(x, 8) - - return f"{opt.label}: {x}" - - -def format_value(p, opt, x): - if type(x) == float: - x = round(x, 8) - return x - - -def format_value_join_list(p, opt, x): - return ", ".join(x) - - -def do_nothing(p, x, xs): - pass - - -def format_nothing(p, opt, x): - return "" - - -def str_permutations(x): - """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" - return x - - -class AxisOption: - def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): - self.label = label - self.type = type - self.apply = apply - self.format_value = format_value - self.confirm = confirm - self.cost = cost - self.choices = choices - - -class AxisOptionImg2Img(AxisOption): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_img2img = True - -class AxisOptionTxt2Img(AxisOption): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_img2img = False - - -axis_options = [ - AxisOption("Nothing", str, do_nothing, format_value=format_nothing), - AxisOption("Seed", int, apply_field("seed")), - AxisOption("Var. seed", int, apply_field("subseed")), - AxisOption("Var. strength", float, apply_field("subseed_strength")), - AxisOption("Steps", int, apply_field("steps")), - AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), - AxisOption("CFG Scale", float, apply_field("cfg_scale")), - AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), - AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), - AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), - AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), - AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), - AxisOption("Sigma Churn", float, apply_field("s_churn")), - AxisOption("Sigma min", float, apply_field("s_tmin")), - AxisOption("Sigma max", float, apply_field("s_tmax")), - AxisOption("Sigma noise", float, apply_field("s_noise")), - AxisOption("Eta", float, apply_field("eta")), - AxisOption("Clip skip", int, apply_clip_skip), - AxisOption("Denoising", float, apply_field("denoising_strength")), - AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), - AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), - AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), - AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), -] - - -def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): - ver_texts = [[images.GridAnnotation(y)] for y in y_labels] - hor_texts = [[images.GridAnnotation(x)] for x in x_labels] - - # Temporary list of all the images that are generated to be populated into the grid. - # Will be filled with empty images for any individual step that fails to process properly - image_cache = [None] * (len(xs) * len(ys)) - - processed_result = None - cell_mode = "P" - cell_size = (1, 1) - - state.job_count = len(xs) * len(ys) * p.n_iter - - def process_cell(x, y, ix, iy): - nonlocal image_cache, processed_result, cell_mode, cell_size - - state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" - - processed: Processed = cell(x, y) - - try: - # this dereference will throw an exception if the image was not processed - # (this happens in cases such as if the user stops the process from the UI) - processed_image = processed.images[0] - - if processed_result is None: - # Use our first valid processed result as a template container to hold our full results - processed_result = copy(processed) - cell_mode = processed_image.mode - cell_size = processed_image.size - processed_result.images = [Image.new(cell_mode, cell_size)] - - image_cache[ix + iy * len(xs)] = processed_image - if include_lone_images: - processed_result.images.append(processed_image) - processed_result.all_prompts.append(processed.prompt) - processed_result.all_seeds.append(processed.seed) - processed_result.infotexts.append(processed.infotexts[0]) - except: - image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size) - - if swap_axes_processing_order: - for ix, x in enumerate(xs): - for iy, y in enumerate(ys): - process_cell(x, y, ix, iy) - else: - for iy, y in enumerate(ys): - for ix, x in enumerate(xs): - process_cell(x, y, ix, iy) - - if not processed_result: - print("Unexpected error: draw_xy_grid failed to return even a single processed image") - return Processed(p, []) - - grid = images.image_grid(image_cache, rows=len(ys)) - if draw_legend: - grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) - - processed_result.images[0] = grid - - return processed_result - - -class SharedSettingsStackHelper(object): - def __enter__(self): - self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers - self.vae = opts.sd_vae - - def __exit__(self, exc_type, exc_value, tb): - opts.data["sd_vae"] = self.vae - modules.sd_models.reload_model_weights() - modules.sd_vae.reload_vae_weights() - - opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers - - -re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") - -re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") -re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") - - -class Script(scripts.Script): - def title(self): - return "X/Y plot" - - def ui(self, is_img2img): - self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] - - with gr.Row(): - with gr.Column(scale=19): - with gr.Row(): - x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) - x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) - fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) - - with gr.Row(): - y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) - y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) - fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) - - with gr.Row(variant="compact", elem_id="axis_options"): - draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) - include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) - no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) - swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") - - def swap_axes(x_type, x_values, y_type, y_values): - return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values - - swap_args = [x_type, x_values, y_type, y_values] - swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) - - def fill(x_type): - axis = self.current_axis_options[x_type] - return ", ".join(axis.choices()) if axis.choices else gr.update() - - fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) - fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) - - def select_axis(x_type): - return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) - - x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) - y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) - - return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] - - def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): - if not no_fixed_seeds: - modules.processing.fix_seed(p) - - if not opts.return_grid: - p.batch_size = 1 - - def process_axis(opt, vals): - if opt.label == 'Nothing': - return [0] - - valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))] - - if opt.type == int: - valslist_ext = [] - - for val in valslist: - m = re_range.fullmatch(val) - mc = re_range_count.fullmatch(val) - if m is not None: - start = int(m.group(1)) - end = int(m.group(2))+1 - step = int(m.group(3)) if m.group(3) is not None else 1 - - valslist_ext += list(range(start, end, step)) - elif mc is not None: - start = int(mc.group(1)) - end = int(mc.group(2)) - num = int(mc.group(3)) if mc.group(3) is not None else 1 - - valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] - else: - valslist_ext.append(val) - - valslist = valslist_ext - elif opt.type == float: - valslist_ext = [] - - for val in valslist: - m = re_range_float.fullmatch(val) - mc = re_range_count_float.fullmatch(val) - if m is not None: - start = float(m.group(1)) - end = float(m.group(2)) - step = float(m.group(3)) if m.group(3) is not None else 1 - - valslist_ext += np.arange(start, end + step, step).tolist() - elif mc is not None: - start = float(mc.group(1)) - end = float(mc.group(2)) - num = int(mc.group(3)) if mc.group(3) is not None else 1 - - valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() - else: - valslist_ext.append(val) - - valslist = valslist_ext - elif opt.type == str_permutations: - valslist = list(permutations(valslist)) - - valslist = [opt.type(x) for x in valslist] - - # Confirm options are valid before starting - if opt.confirm: - opt.confirm(p, valslist) - - return valslist - - x_opt = self.current_axis_options[x_type] - xs = process_axis(x_opt, x_values) - - y_opt = self.current_axis_options[y_type] - ys = process_axis(y_opt, y_values) - - def fix_axis_seeds(axis_opt, axis_list): - if axis_opt.label in ['Seed', 'Var. seed']: - return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] - else: - return axis_list - - if not no_fixed_seeds: - xs = fix_axis_seeds(x_opt, xs) - ys = fix_axis_seeds(y_opt, ys) - - if x_opt.label == 'Steps': - total_steps = sum(xs) * len(ys) - elif y_opt.label == 'Steps': - total_steps = sum(ys) * len(xs) - else: - total_steps = p.steps * len(xs) * len(ys) - - if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: - if x_opt.label == "Hires steps": - total_steps += sum(xs) * len(ys) - elif y_opt.label == "Hires steps": - total_steps += sum(ys) * len(xs) - elif p.hr_second_pass_steps: - total_steps += p.hr_second_pass_steps * len(xs) * len(ys) - else: - total_steps *= 2 - - total_steps *= p.n_iter - - image_cell_count = p.n_iter * p.batch_size - cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" - print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})") - shared.total_tqdm.updateTotal(total_steps) - - grid_infotext = [None] - - # If one of the axes is very slow to change between (like SD model - # checkpoint), then make sure it is in the outer iteration of the nested - # `for` loop. - swap_axes_processing_order = x_opt.cost > y_opt.cost - - def cell(x, y): - if shared.state.interrupted: - return Processed(p, [], p.seed, "") - - pc = copy(p) - x_opt.apply(pc, x, xs) - y_opt.apply(pc, y, ys) - - res = process_images(pc) - - if grid_infotext[0] is None: - pc.extra_generation_params = copy(pc.extra_generation_params) - - if x_opt.label != 'Nothing': - pc.extra_generation_params["X Type"] = x_opt.label - pc.extra_generation_params["X Values"] = x_values - if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: - pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) - - if y_opt.label != 'Nothing': - pc.extra_generation_params["Y Type"] = y_opt.label - pc.extra_generation_params["Y Values"] = y_values - if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: - pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) - - grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) - - return res - - with SharedSettingsStackHelper(): - processed = draw_xy_grid( - p, - xs=xs, - ys=ys, - x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], - y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], - cell=cell, - draw_legend=draw_legend, - include_lone_images=include_lone_images, - swap_axes_processing_order=swap_axes_processing_order - ) - - if opts.grid_save: - images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) - - return processed diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py new file mode 100644 index 00000000..1a452355 --- /dev/null +++ b/scripts/xyz_grid.py @@ -0,0 +1,498 @@ +from collections import namedtuple +from copy import copy +from itertools import permutations, chain +import random +import csv +from io import StringIO +from PIL import Image +import numpy as np + +import modules.scripts as scripts +import gradio as gr + +from modules import images, paths, sd_samplers, processing, sd_models, sd_vae +from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img +from modules.shared import opts, cmd_opts, state +import modules.shared as shared +import modules.sd_samplers +import modules.sd_models +import modules.sd_vae +import glob +import os +import re + +from modules.ui_components import ToolButton + +fill_values_symbol = "\U0001f4d2" # 📒 + + +def apply_field(field): + def fun(p, x, xs): + setattr(p, field, x) + + return fun + + +def apply_prompt(p, x, xs): + if xs[0] not in p.prompt and xs[0] not in p.negative_prompt: + raise RuntimeError(f"Prompt S/R did not find {xs[0]} in prompt or negative prompt.") + + p.prompt = p.prompt.replace(xs[0], x) + p.negative_prompt = p.negative_prompt.replace(xs[0], x) + + +def apply_order(p, x, xs): + token_order = [] + + # Initally grab the tokens from the prompt, so they can be replaced in order of earliest seen + for token in x: + token_order.append((p.prompt.find(token), token)) + + token_order.sort(key=lambda t: t[0]) + + prompt_parts = [] + + # Split the prompt up, taking out the tokens + for _, token in token_order: + n = p.prompt.find(token) + prompt_parts.append(p.prompt[0:n]) + p.prompt = p.prompt[n + len(token):] + + # Rebuild the prompt with the tokens in the order we want + prompt_tmp = "" + for idx, part in enumerate(prompt_parts): + prompt_tmp += part + prompt_tmp += x[idx] + p.prompt = prompt_tmp + p.prompt + + +def apply_sampler(p, x, xs): + sampler_name = sd_samplers.samplers_map.get(x.lower(), None) + if sampler_name is None: + raise RuntimeError(f"Unknown sampler: {x}") + + p.sampler_name = sampler_name + + +def confirm_samplers(p, xs): + for x in xs: + if x.lower() not in sd_samplers.samplers_map: + raise RuntimeError(f"Unknown sampler: {x}") + + +def apply_checkpoint(p, x, xs): + info = modules.sd_models.get_closet_checkpoint_match(x) + if info is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + modules.sd_models.reload_model_weights(shared.sd_model, info) + + +def confirm_checkpoints(p, xs): + for x in xs: + if modules.sd_models.get_closet_checkpoint_match(x) is None: + raise RuntimeError(f"Unknown checkpoint: {x}") + + +def apply_clip_skip(p, x, xs): + opts.data["CLIP_stop_at_last_layers"] = x + + +def apply_upscale_latent_space(p, x, xs): + if x.lower().strip() != '0': + opts.data["use_scale_latent_for_hires_fix"] = True + else: + opts.data["use_scale_latent_for_hires_fix"] = False + + +def find_vae(name: str): + if name.lower() in ['auto', 'automatic']: + return modules.sd_vae.unspecified + if name.lower() == 'none': + return None + else: + choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()] + if len(choices) == 0: + print(f"No VAE found for {name}; using automatic") + return modules.sd_vae.unspecified + else: + return modules.sd_vae.vae_dict[choices[0]] + + +def apply_vae(p, x, xs): + modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x)) + + +def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _): + p.styles = x.split(',') + + +def format_value_add_label(p, opt, x): + if type(x) == float: + x = round(x, 8) + + return f"{opt.label}: {x}" + + +def format_value(p, opt, x): + if type(x) == float: + x = round(x, 8) + return x + + +def format_value_join_list(p, opt, x): + return ", ".join(x) + + +def do_nothing(p, x, xs): + pass + + +def format_nothing(p, opt, x): + return "" + + +def str_permutations(x): + """dummy function for specifying it in AxisOption's type when you want to get a list of permutations""" + return x + + +class AxisOption: + def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None): + self.label = label + self.type = type + self.apply = apply + self.format_value = format_value + self.confirm = confirm + self.cost = cost + self.choices = choices + + +class AxisOptionImg2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = True + +class AxisOptionTxt2Img(AxisOption): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_img2img = False + + +axis_options = [ + AxisOption("Nothing", str, do_nothing, format_value=format_nothing), + AxisOption("Seed", int, apply_field("seed")), + AxisOption("Var. seed", int, apply_field("subseed")), + AxisOption("Var. strength", float, apply_field("subseed_strength")), + AxisOption("Steps", int, apply_field("steps")), + AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")), + AxisOption("CFG Scale", float, apply_field("cfg_scale")), + AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value), + AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list), + AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]), + AxisOptionImg2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers_for_img2img]), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)), + AxisOption("Sigma Churn", float, apply_field("s_churn")), + AxisOption("Sigma min", float, apply_field("s_tmin")), + AxisOption("Sigma max", float, apply_field("s_tmax")), + AxisOption("Sigma noise", float, apply_field("s_noise")), + AxisOption("Eta", float, apply_field("eta")), + AxisOption("Clip skip", int, apply_clip_skip), + AxisOption("Denoising", float, apply_field("denoising_strength")), + AxisOptionTxt2Img("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]]), + AxisOptionImg2Img("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")), + AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)), + AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)), +] + + +def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order): + ver_texts = [[images.GridAnnotation(y)] for y in y_labels] + hor_texts = [[images.GridAnnotation(x)] for x in x_labels] + + # Temporary list of all the images that are generated to be populated into the grid. + # Will be filled with empty images for any individual step that fails to process properly + image_cache = [None] * (len(xs) * len(ys)) + + processed_result = None + cell_mode = "P" + cell_size = (1, 1) + + state.job_count = len(xs) * len(ys) * p.n_iter + + def process_cell(x, y, ix, iy): + nonlocal image_cache, processed_result, cell_mode, cell_size + + state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}" + + processed: Processed = cell(x, y) + + try: + # this dereference will throw an exception if the image was not processed + # (this happens in cases such as if the user stops the process from the UI) + processed_image = processed.images[0] + + if processed_result is None: + # Use our first valid processed result as a template container to hold our full results + processed_result = copy(processed) + cell_mode = processed_image.mode + cell_size = processed_image.size + processed_result.images = [Image.new(cell_mode, cell_size)] + + image_cache[ix + iy * len(xs)] = processed_image + if include_lone_images: + processed_result.images.append(processed_image) + processed_result.all_prompts.append(processed.prompt) + processed_result.all_seeds.append(processed.seed) + processed_result.infotexts.append(processed.infotexts[0]) + except: + image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size) + + if swap_axes_processing_order: + for ix, x in enumerate(xs): + for iy, y in enumerate(ys): + process_cell(x, y, ix, iy) + else: + for iy, y in enumerate(ys): + for ix, x in enumerate(xs): + process_cell(x, y, ix, iy) + + if not processed_result: + print("Unexpected error: draw_xy_grid failed to return even a single processed image") + return Processed(p, []) + + grid = images.image_grid(image_cache, rows=len(ys)) + if draw_legend: + grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts) + + processed_result.images[0] = grid + + return processed_result + + +class SharedSettingsStackHelper(object): + def __enter__(self): + self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers + self.vae = opts.sd_vae + + def __exit__(self, exc_type, exc_value, tb): + opts.data["sd_vae"] = self.vae + modules.sd_models.reload_model_weights() + modules.sd_vae.reload_vae_weights() + + opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers + + +re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") +re_range_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\(([+-]\d+(?:.\d*)?)\s*\))?\s*") + +re_range_count = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\[(\d+)\s*\])?\s*") +re_range_count_float = re.compile(r"\s*([+-]?\s*\d+(?:.\d*)?)\s*-\s*([+-]?\s*\d+(?:.\d*)?)(?:\s*\[(\d+(?:.\d*)?)\s*\])?\s*") + + +class Script(scripts.Script): + def title(self): + return "X/Y plot" + + def ui(self, is_img2img): + self.current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img == is_img2img] + + with gr.Row(): + with gr.Column(scale=19): + with gr.Row(): + x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type")) + x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values")) + fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False) + + with gr.Row(): + y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type")) + y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values")) + fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False) + + with gr.Row(variant="compact", elem_id="axis_options"): + draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) + include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images")) + no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button") + + def swap_axes(x_type, x_values, y_type, y_values): + return self.current_axis_options[y_type].label, y_values, self.current_axis_options[x_type].label, x_values + + swap_args = [x_type, x_values, y_type, y_values] + swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args) + + def fill(x_type): + axis = self.current_axis_options[x_type] + return ", ".join(axis.choices()) if axis.choices else gr.update() + + fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values]) + fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values]) + + def select_axis(x_type): + return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None) + + x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button]) + y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button]) + + return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds] + + def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds): + if not no_fixed_seeds: + modules.processing.fix_seed(p) + + if not opts.return_grid: + p.batch_size = 1 + + def process_axis(opt, vals): + if opt.label == 'Nothing': + return [0] + + valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals)))] + + if opt.type == int: + valslist_ext = [] + + for val in valslist: + m = re_range.fullmatch(val) + mc = re_range_count.fullmatch(val) + if m is not None: + start = int(m.group(1)) + end = int(m.group(2))+1 + step = int(m.group(3)) if m.group(3) is not None else 1 + + valslist_ext += list(range(start, end, step)) + elif mc is not None: + start = int(mc.group(1)) + end = int(mc.group(2)) + num = int(mc.group(3)) if mc.group(3) is not None else 1 + + valslist_ext += [int(x) for x in np.linspace(start=start, stop=end, num=num).tolist()] + else: + valslist_ext.append(val) + + valslist = valslist_ext + elif opt.type == float: + valslist_ext = [] + + for val in valslist: + m = re_range_float.fullmatch(val) + mc = re_range_count_float.fullmatch(val) + if m is not None: + start = float(m.group(1)) + end = float(m.group(2)) + step = float(m.group(3)) if m.group(3) is not None else 1 + + valslist_ext += np.arange(start, end + step, step).tolist() + elif mc is not None: + start = float(mc.group(1)) + end = float(mc.group(2)) + num = int(mc.group(3)) if mc.group(3) is not None else 1 + + valslist_ext += np.linspace(start=start, stop=end, num=num).tolist() + else: + valslist_ext.append(val) + + valslist = valslist_ext + elif opt.type == str_permutations: + valslist = list(permutations(valslist)) + + valslist = [opt.type(x) for x in valslist] + + # Confirm options are valid before starting + if opt.confirm: + opt.confirm(p, valslist) + + return valslist + + x_opt = self.current_axis_options[x_type] + xs = process_axis(x_opt, x_values) + + y_opt = self.current_axis_options[y_type] + ys = process_axis(y_opt, y_values) + + def fix_axis_seeds(axis_opt, axis_list): + if axis_opt.label in ['Seed', 'Var. seed']: + return [int(random.randrange(4294967294)) if val is None or val == '' or val == -1 else val for val in axis_list] + else: + return axis_list + + if not no_fixed_seeds: + xs = fix_axis_seeds(x_opt, xs) + ys = fix_axis_seeds(y_opt, ys) + + if x_opt.label == 'Steps': + total_steps = sum(xs) * len(ys) + elif y_opt.label == 'Steps': + total_steps = sum(ys) * len(xs) + else: + total_steps = p.steps * len(xs) * len(ys) + + if isinstance(p, StableDiffusionProcessingTxt2Img) and p.enable_hr: + if x_opt.label == "Hires steps": + total_steps += sum(xs) * len(ys) + elif y_opt.label == "Hires steps": + total_steps += sum(ys) * len(xs) + elif p.hr_second_pass_steps: + total_steps += p.hr_second_pass_steps * len(xs) * len(ys) + else: + total_steps *= 2 + + total_steps *= p.n_iter + + image_cell_count = p.n_iter * p.batch_size + cell_console_text = f"; {image_cell_count} images per cell" if image_cell_count > 1 else "" + print(f"X/Y plot will create {len(xs) * len(ys) * image_cell_count} images on a {len(xs)}x{len(ys)} grid{cell_console_text}. (Total steps to process: {total_steps})") + shared.total_tqdm.updateTotal(total_steps) + + grid_infotext = [None] + + # If one of the axes is very slow to change between (like SD model + # checkpoint), then make sure it is in the outer iteration of the nested + # `for` loop. + swap_axes_processing_order = x_opt.cost > y_opt.cost + + def cell(x, y): + if shared.state.interrupted: + return Processed(p, [], p.seed, "") + + pc = copy(p) + x_opt.apply(pc, x, xs) + y_opt.apply(pc, y, ys) + + res = process_images(pc) + + if grid_infotext[0] is None: + pc.extra_generation_params = copy(pc.extra_generation_params) + + if x_opt.label != 'Nothing': + pc.extra_generation_params["X Type"] = x_opt.label + pc.extra_generation_params["X Values"] = x_values + if x_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed X Values"] = ", ".join([str(x) for x in xs]) + + if y_opt.label != 'Nothing': + pc.extra_generation_params["Y Type"] = y_opt.label + pc.extra_generation_params["Y Values"] = y_values + if y_opt.label in ["Seed", "Var. seed"] and not no_fixed_seeds: + pc.extra_generation_params["Fixed Y Values"] = ", ".join([str(y) for y in ys]) + + grid_infotext[0] = processing.create_infotext(pc, pc.all_prompts, pc.all_seeds, pc.all_subseeds) + + return res + + with SharedSettingsStackHelper(): + processed = draw_xy_grid( + p, + xs=xs, + ys=ys, + x_labels=[x_opt.format_value(p, x_opt, x) for x in xs], + y_labels=[y_opt.format_value(p, y_opt, y) for y in ys], + cell=cell, + draw_legend=draw_legend, + include_lone_images=include_lone_images, + swap_axes_processing_order=swap_axes_processing_order + ) + + if opts.grid_save: + images.save_image(processed.images[0], p.outpath_grids, "xy_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p) + + return processed -- cgit v1.2.3