From f4535f6e4f001314bd155bc6e1b6908e02792b9a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 31 Dec 2022 23:40:55 +0300 Subject: make it so that memory/embeddings info is displayed in a separate UI element from generation parameters, and is preserved when you change the displayed infotext by clicking on gallery images --- modules/txt2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/txt2img.py') diff --git a/modules/txt2img.py b/modules/txt2img.py index c8f81176..7f61e19a 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -59,4 +59,4 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: if opts.do_not_show_images: processed.images = [] - return processed.images, generation_info_js, plaintext_to_html(processed.info) + return processed.images, generation_info_js, plaintext_to_html(processed.info), plaintext_to_html(processed.comments) -- cgit v1.2.3 From ef27a18b6b7cb1a8eebdc9b2e88d25baf2c2414d Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 2 Jan 2023 19:42:10 +0300 Subject: Hires fix rework --- modules/generation_parameters_copypaste.py | 32 ++++++++++++++ modules/images.py | 24 +++++++++-- modules/processing.py | 68 ++++++++++++------------------ modules/shared.py | 7 ++- modules/txt2img.py | 6 +-- modules/ui.py | 15 +++---- scripts/xy_grid.py | 4 +- 7 files changed, 96 insertions(+), 60 deletions(-) (limited to 'modules/txt2img.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 8e7f0df0..d6fa822b 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -1,5 +1,6 @@ import base64 import io +import math import os import re from pathlib import Path @@ -164,6 +165,35 @@ def find_hypernetwork_key(hypernet_name, hypernet_hash=None): return None +def restore_old_hires_fix_params(res): + """for infotexts that specify old First pass size parameter, convert it into + width, height, and hr scale""" + + firstpass_width = res.get('First pass size-1', None) + firstpass_height = res.get('First pass size-2', None) + + if firstpass_width is None or firstpass_height is None: + return + + firstpass_width, firstpass_height = int(firstpass_width), int(firstpass_height) + width = int(res.get("Size-1", 512)) + height = int(res.get("Size-2", 512)) + + if firstpass_width == 0 or firstpass_height == 0: + # old algorithm for auto-calculating first pass size + desired_pixel_count = 512 * 512 + actual_pixel_count = width * height + scale = math.sqrt(desired_pixel_count / actual_pixel_count) + firstpass_width = math.ceil(scale * width / 64) * 64 + firstpass_height = math.ceil(scale * height / 64) * 64 + + hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height + + res['Size-1'] = firstpass_width + res['Size-2'] = firstpass_height + res['Hires upscale'] = hr_scale + + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -221,6 +251,8 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hypernet_hash = res.get("Hypernet hash", None) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + restore_old_hires_fix_params(res) + return res diff --git a/modules/images.py b/modules/images.py index f84fd485..c3a5fc8b 100644 --- a/modules/images.py +++ b/modules/images.py @@ -230,16 +230,32 @@ def draw_prompt_matrix(im, width, height, all_prompts): return draw_grid_annotations(im, width, height, hor_texts, ver_texts) -def resize_image(resize_mode, im, width, height): +def resize_image(resize_mode, im, width, height, upscaler_name=None): + """ + Resizes an image with the specified resize_mode, width, and height. + + Args: + resize_mode: The mode to use when resizing the image. + 0: Resize the image to the specified width and height. + 1: Resize the image to fill the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, cropping the excess. + 2: Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center the image within the dimensions, filling empty with data from image. + im: The image to resize. + width: The width to resize the image to. + height: The height to resize the image to. + upscaler_name: The name of the upscaler to use. If not provided, defaults to opts.upscaler_for_img2img. + """ + + upscaler_name = upscaler_name or opts.upscaler_for_img2img + def resize(im, w, h): - if opts.upscaler_for_img2img is None or opts.upscaler_for_img2img == "None" or im.mode == 'L': + if upscaler_name is None or upscaler_name == "None" or im.mode == 'L': return im.resize((w, h), resample=LANCZOS) scale = max(w / im.width, h / im.height) if scale > 1.0: - upscalers = [x for x in shared.sd_upscalers if x.name == opts.upscaler_for_img2img] - assert len(upscalers) > 0, f"could not find upscaler named {opts.upscaler_for_img2img}" + upscalers = [x for x in shared.sd_upscalers if x.name == upscaler_name] + assert len(upscalers) > 0, f"could not find upscaler named {upscaler_name}" upscaler = upscalers[0] im = upscaler.scaler.upscale(im, scale, upscaler.data_path) diff --git a/modules/processing.py b/modules/processing.py index 42dc19ea..4654570c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -658,14 +658,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr: bool=False, denoising_strength: float=0.75, firstphase_width: int=0, firstphase_height: int=0, **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength - self.firstphase_width = firstphase_width - self.firstphase_height = firstphase_height - self.truncate_x = 0 - self.truncate_y = 0 + self.hr_scale = hr_scale + self.hr_upscaler = hr_upscaler + + if firstphase_width != 0 or firstphase_height != 0: + print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) + self.hr_scale = self.width / firstphase_width + self.width = firstphase_width + self.height = firstphase_height def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: @@ -674,47 +678,29 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - self.extra_generation_params["First pass size"] = f"{self.firstphase_width}x{self.firstphase_height}" - - if self.firstphase_width == 0 or self.firstphase_height == 0: - desired_pixel_count = 512 * 512 - actual_pixel_count = self.width * self.height - scale = math.sqrt(desired_pixel_count / actual_pixel_count) - self.firstphase_width = math.ceil(scale * self.width / 64) * 64 - self.firstphase_height = math.ceil(scale * self.height / 64) * 64 - firstphase_width_truncated = int(scale * self.width) - firstphase_height_truncated = int(scale * self.height) - - else: - - width_ratio = self.width / self.firstphase_width - height_ratio = self.height / self.firstphase_height - - if width_ratio > height_ratio: - firstphase_width_truncated = self.firstphase_width - firstphase_height_truncated = self.firstphase_width * self.height / self.width - else: - firstphase_width_truncated = self.firstphase_height * self.width / self.height - firstphase_height_truncated = self.firstphase_height - - self.truncate_x = int(self.firstphase_width - firstphase_width_truncated) // opt_f - self.truncate_y = int(self.firstphase_height - firstphase_height_truncated) // opt_f + self.extra_generation_params["Hires upscale"] = self.hr_scale + if self.hr_upscaler is not None: + self.extra_generation_params["Hires upscaler"] = self.hr_upscaler def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_default_mode + if self.enable_hr and latent_scale_mode is None: + assert len([x for x in shared.sd_upscalers if x.name == self.hr_upscaler]) > 0, f"could not find upscaler named {self.hr_upscaler}" + + x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) + if not self.enable_hr: - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) return samples - x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x, self.firstphase_width, self.firstphase_height)) - - samples = samples[:, :, self.truncate_y//2:samples.shape[2]-self.truncate_y//2, self.truncate_x//2:samples.shape[3]-self.truncate_x//2] + target_width = int(self.width * self.hr_scale) + target_height = int(self.height * self.hr_scale) - """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" def save_intermediate(image, index): + """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" + if not opts.save or self.do_not_save_samples or not opts.save_images_before_highres_fix: return @@ -723,11 +709,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix") - if opts.use_scale_latent_for_hires_fix: + if latent_scale_mode is not None: for i in range(samples.shape[0]): save_intermediate(samples, i) - samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear") + samples = torch.nn.functional.interpolate(samples, size=(target_height // opt_f, target_width // opt_f), mode=latent_scale_mode) # Avoid making the inpainting conditioning unless necessary as # this does need some extra compute to decode / encode the image again. @@ -747,7 +733,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): save_intermediate(image, i) - image = images.resize_image(0, image, self.width, self.height) + image = images.resize_image(0, image, target_width, target_height, upscaler_name=self.hr_upscaler) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) batch_images.append(image) @@ -764,7 +750,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory x = None diff --git a/modules/shared.py b/modules/shared.py index 7f430b93..b65559ee 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -327,7 +327,6 @@ options_templates.update(options_section(('upscaling', "Upscaling"), { "ESRGAN_tile_overlap": OptionInfo(8, "Tile overlap, in pixels for ESRGAN upscalers. Low values = visible seam.", gr.Slider, {"minimum": 0, "maximum": 48, "step": 1}), "realesrgan_enabled_models": OptionInfo(["R-ESRGAN 4x+", "R-ESRGAN 4x+ Anime6B"], "Select which Real-ESRGAN models to show in the web UI. (Requires restart)", gr.CheckboxGroup, lambda: {"choices": realesrgan_models_names()}), "upscaler_for_img2img": OptionInfo(None, "Upscaler for img2img", gr.Dropdown, lambda: {"choices": [x.name for x in sd_upscalers]}), - "use_scale_latent_for_hires_fix": OptionInfo(False, "Upscale latent space image when doing hires. fix"), })) options_templates.update(options_section(('face-restoration', "Face restoration"), { @@ -545,6 +544,12 @@ opts = Options() if os.path.exists(config_filename): opts.load(config_filename) +latent_upscale_default_mode = "Latent" +latent_upscale_modes = { + "Latent": "bilinear", + "Latent (nearest)": "nearest", +} + sd_upscalers = [] sd_model = None diff --git a/modules/txt2img.py b/modules/txt2img.py index 7f61e19a..e189a899 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, firstphase_width: int, firstphase_height: int, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -33,8 +33,8 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: tiling=tiling, enable_hr=enable_hr, denoising_strength=denoising_strength if enable_hr else None, - firstphase_width=firstphase_width if enable_hr else None, - firstphase_height=firstphase_height if enable_hr else None, + hr_scale=hr_scale, + hr_upscaler=hr_upscaler, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 7070ea15..27cd9ddd 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -684,11 +684,11 @@ def create_ui(): with gr.Row(): restore_faces = gr.Checkbox(label='Restore faces', value=False, visible=len(shared.face_restorers) > 1, elem_id="txt2img_restore_faces") tiling = gr.Checkbox(label='Tiling', value=False, elem_id="txt2img_tiling") - enable_hr = gr.Checkbox(label='Highres. fix', value=False, elem_id="txt2img_enable_hr") + enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") with gr.Row(visible=False) as hr_options: - firstphase_width = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass width", value=0, elem_id="txt2img_firstphase_width") - firstphase_height = gr.Slider(minimum=0, maximum=1024, step=8, label="Firstpass height", value=0, elem_id="txt2img_firstphase_height") + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") with gr.Row(equal_height=True): @@ -729,8 +729,8 @@ def create_ui(): width, enable_hr, denoising_strength, - firstphase_width, - firstphase_height, + hr_scale, + hr_upscaler, ] + custom_inputs, outputs=[ @@ -762,7 +762,6 @@ def create_ui(): outputs=[hr_options], ) - txt2img_paste_fields = [ (txt2img_prompt, "Prompt"), (txt2img_negative_prompt, "Negative prompt"), @@ -781,8 +780,8 @@ def create_ui(): (denoising_strength, "Denoising strength"), (enable_hr, lambda d: "Denoising strength" in d), (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), - (firstphase_width, "First pass size-1"), - (firstphase_height, "First pass size-2"), + (hr_scale, "Hires upscale"), + (hr_upscaler, "Hires upscaler"), *modules.scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 3e0b2805..f92f9776 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -202,7 +202,7 @@ axis_options = [ AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None), AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None), AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None), - AxisOption("Upscale latent space for hires.", str, apply_upscale_latent_space, format_value_add_label, None), + AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None), AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None), AxisOption("VAE", str, apply_vae, format_value_add_label, None), AxisOption("Styles", str, apply_styles, format_value_add_label, None), @@ -267,7 +267,6 @@ class SharedSettingsStackHelper(object): self.CLIP_stop_at_last_layers = opts.CLIP_stop_at_last_layers self.hypernetwork = opts.sd_hypernetwork self.model = shared.sd_model - self.use_scale_latent_for_hires_fix = opts.use_scale_latent_for_hires_fix self.vae = opts.sd_vae def __exit__(self, exc_type, exc_value, tb): @@ -278,7 +277,6 @@ class SharedSettingsStackHelper(object): hypernetwork.apply_strength() opts.data["CLIP_stop_at_last_layers"] = self.CLIP_stop_at_last_layers - opts.data["use_scale_latent_for_hires_fix"] = self.use_scale_latent_for_hires_fix re_range = re.compile(r"\s*([+-]?\s*\d+)\s*-\s*([+-]?\s*\d+)(?:\s*\(([+-]\d+)\s*\))?\s*") -- cgit v1.2.3 From 81490780949fffed77493b4bd741e96ec737fe27 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Wed, 4 Jan 2023 22:04:40 +0300 Subject: added the option to specify target resolution with possibility of truncating for hires fix; also sampling steps --- javascript/hints.js | 11 ++++--- modules/generation_parameters_copypaste.py | 9 ++++-- modules/processing.py | 51 +++++++++++++++++++++++++++--- modules/txt2img.py | 5 ++- modules/ui.py | 24 ++++++++++---- 5 files changed, 81 insertions(+), 19 deletions(-) (limited to 'modules/txt2img.py') diff --git a/javascript/hints.js b/javascript/hints.js index 63e17e05..dda66e09 100644 --- a/javascript/hints.js +++ b/javascript/hints.js @@ -81,9 +81,6 @@ titles = { "vram": "Torch active: Peak amount of VRAM used by Torch during generation, excluding cached data.\nTorch reserved: Peak amount of VRAM allocated by Torch, including all active and cached data.\nSys VRAM: Peak amount of VRAM allocation across all applications / total GPU VRAM (peak utilization%).", - "Highres. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", - "Scale latent": "Uscale the image in latent space. Alternative is to produce the full image from latent representation, upscale that, and then move it back to latent space.", - "Eta noise seed delta": "If this values is non-zero, it will be added to seed and used to initialize RNG for noises when using samplers with Eta. You can use this to produce even more variation of images, or you can use this to match images of other software if you know what you are doing.", "Do not add watermark to images": "If this option is enabled, watermark will not be added to created images. Warning: if you do not add watermark, you may be behaving in an unethical manner.", @@ -100,7 +97,13 @@ titles = { "Clip skip": "Early stopping parameter for CLIP model; 1 is stop at last layer as usual, 2 is stop at penultimate layer, etc.", "Approx NN": "Cheap neural network approximation. Very fast compared to VAE, but produces pictures with 4 times smaller horizontal/vertical resoluton and lower quality.", - "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality." + "Approx cheap": "Very cheap approximation. Very fast compared to VAE, but produces pictures with 8 times smaller horizontal/vertical resoluton and extremely low quality.", + + "Hires. fix": "Use a two step process to partially create an image at smaller resolution, upscale, and then improve details in it without changing composition", + "Hires steps": "Number of sampling steps for upscaled picture. If 0, uses same as for original.", + "Upscale by": "Adjusts the size of the image by multiplying the original width and height by the selected value. Ignored if either Resize width to or Resize height to are non-zero.", + "Resize width to": "Resizes image to this width. If 0, width is inferred from either of two nearby sliders.", + "Resize height to": "Resizes image to this height. If 0, height is inferred from either of two nearby sliders." } diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 4baf4d9a..12a9de3d 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -212,11 +212,10 @@ def restore_old_hires_fix_params(res): firstpass_width = math.ceil(scale * width / 64) * 64 firstpass_height = math.ceil(scale * height / 64) * 64 - hr_scale = width / firstpass_width if firstpass_width > 0 else height / firstpass_height - res['Size-1'] = firstpass_width res['Size-2'] = firstpass_height - res['Hires upscale'] = hr_scale + res['Hires resize-1'] = width + res['Hires resize-2'] = height def parse_generation_parameters(x: str): @@ -276,6 +275,10 @@ Steps: 20, Sampler: Euler a, CFG scale: 7, Seed: 965400086, Size: 512x512, Model hypernet_hash = res.get("Hypernet hash", None) res["Hypernet"] = find_hypernetwork_key(hypernet_name, hypernet_hash) + if "Hires resize-1" not in res: + res["Hires resize-1"] = 0 + res["Hires resize-2"] = 0 + restore_old_hires_fix_params(res) return res diff --git a/modules/processing.py b/modules/processing.py index 47712159..9cad05f2 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -662,12 +662,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): sampler = None - def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, **kwargs): + def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs): super().__init__(**kwargs) self.enable_hr = enable_hr self.denoising_strength = denoising_strength self.hr_scale = hr_scale self.hr_upscaler = hr_upscaler + self.hr_second_pass_steps = hr_second_pass_steps + self.hr_resize_x = hr_resize_x + self.hr_resize_y = hr_resize_y + self.hr_upscale_to_x = hr_resize_x + self.hr_upscale_to_y = hr_resize_y if firstphase_width != 0 or firstphase_height != 0: print("firstphase_width/firstphase_height no longer supported; use hr_scale", file=sys.stderr) @@ -675,6 +680,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.width = firstphase_width self.height = firstphase_height + self.truncate_x = 0 + self.truncate_y = 0 + def init(self, all_prompts, all_seeds, all_subseeds): if self.enable_hr: if state.job_count == -1: @@ -682,7 +690,38 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else: state.job_count = state.job_count * 2 - self.extra_generation_params["Hires upscale"] = self.hr_scale + if self.hr_resize_x == 0 and self.hr_resize_y == 0: + self.extra_generation_params["Hires upscale"] = self.hr_scale + self.hr_upscale_to_x = int(self.width * self.hr_scale) + self.hr_upscale_to_y = int(self.height * self.hr_scale) + else: + self.extra_generation_params["Hires resize"] = f"{self.hr_resize_x}x{self.hr_resize_y}" + + if self.hr_resize_y == 0: + self.hr_upscale_to_x = self.hr_resize_x + self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width + elif self.hr_resize_x == 0: + self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height + self.hr_upscale_to_y = self.hr_resize_y + else: + target_w = self.hr_resize_x + target_h = self.hr_resize_y + src_ratio = self.width / self.height + dst_ratio = self.hr_resize_x / self.hr_resize_y + + if src_ratio < dst_ratio: + self.hr_upscale_to_x = self.hr_resize_x + self.hr_upscale_to_y = self.hr_resize_x * self.height // self.width + else: + self.hr_upscale_to_x = self.hr_resize_y * self.width // self.height + self.hr_upscale_to_y = self.hr_resize_y + + self.truncate_x = (self.hr_upscale_to_x - target_w) // opt_f + self.truncate_y = (self.hr_upscale_to_y - target_h) // opt_f + + if self.hr_second_pass_steps: + self.extra_generation_params["Hires steps"] = self.hr_second_pass_steps + if self.hr_upscaler is not None: self.extra_generation_params["Hires upscaler"] = self.hr_upscaler @@ -699,8 +738,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not self.enable_hr: return samples - target_width = int(self.width * self.hr_scale) - target_height = int(self.height * self.hr_scale) + target_width = self.hr_upscale_to_x + target_height = self.hr_upscale_to_y def save_intermediate(image, index): """saves image before applying hires fix, if enabled in options; takes as an argument either an image or batch with latent space images""" @@ -755,13 +794,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] + noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) # GC now before running the next img2img to prevent running out of memory x = None devices.torch_gc() - samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.steps, image_conditioning=image_conditioning) + samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning) return samples diff --git a/modules/txt2img.py b/modules/txt2img.py index e189a899..38b5f591 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, @@ -35,6 +35,9 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: denoising_strength=denoising_strength if enable_hr else None, hr_scale=hr_scale, hr_upscaler=hr_upscaler, + hr_second_pass_steps=hr_second_pass_steps, + hr_resize_x=hr_resize_x, + hr_resize_y=hr_resize_y, ) p.scripts = modules.scripts.scripts_txt2img diff --git a/modules/ui.py b/modules/ui.py index 44f4f3a4..04091e67 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -637,10 +637,10 @@ def create_sampler_and_steps_selection(choices, tabname): with FormRow(elem_id=f"sampler_selection_{tabname}"): sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") sampler_index.save_to_config = True - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) else: with FormGroup(elem_id=f"sampler_selection_{tabname}"): - steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling Steps", value=20) + steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") return steps, sampler_index @@ -709,10 +709,16 @@ def create_ui(): enable_hr = gr.Checkbox(label='Hires. fix', value=False, elem_id="txt2img_enable_hr") elif category == "hires_fix": - with FormRow(visible=False, elem_id="txt2img_hires_fix") as hr_options: - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + with FormGroup(visible=False, elem_id="txt2img_hires_fix") as hr_options: + with FormRow(elem_id="txt2img_hires_fix_row1"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + + with FormRow(elem_id="txt2img_hires_fix_row2"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") elif category == "batch": if not opts.dimensions_and_batch_together: @@ -753,6 +759,9 @@ def create_ui(): denoising_strength, hr_scale, hr_upscaler, + hr_second_pass_steps, + hr_resize_x, + hr_resize_y, ] + custom_inputs, outputs=[ @@ -804,6 +813,9 @@ def create_ui(): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (hr_scale, "Hires upscale"), (hr_upscaler, "Hires upscaler"), + (hr_second_pass_steps, "Hires steps"), + (hr_resize_x, "Hires resize-1"), + (hr_resize_y, "Hires resize-2"), *modules.scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields) -- cgit v1.2.3 From 865228a83736bea9ede33e98041f2a7d0ca5daaa Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 14 Jan 2023 14:56:39 +0300 Subject: change style dropdowns to multiselect --- javascript/localization.js | 6 ++---- modules/img2img.py | 4 ++-- modules/styles.py | 12 ++++++++--- modules/txt2img.py | 4 ++-- modules/ui.py | 53 ++++++++++++++++++++++++++-------------------- style.css | 15 ++++++++++--- 6 files changed, 57 insertions(+), 37 deletions(-) (limited to 'modules/txt2img.py') diff --git a/javascript/localization.js b/javascript/localization.js index f92d2d24..bf9e1506 100644 --- a/javascript/localization.js +++ b/javascript/localization.js @@ -10,10 +10,8 @@ ignore_ids_for_localization={ modelmerger_tertiary_model_name: 'OPTION', train_embedding: 'OPTION', train_hypernetwork: 'OPTION', - txt2img_style_index: 'OPTION', - txt2img_style2_index: 'OPTION', - img2img_style_index: 'OPTION', - img2img_style2_index: 'OPTION', + txt2img_styles: 'OPTION', + img2img_styles 'OPTION', setting_random_artist_categories: 'SPAN', setting_face_restoration_model: 'SPAN', setting_realesrgan_enabled_models: 'SPAN', diff --git a/modules/img2img.py b/modules/img2img.py index f62783c6..79382cc1 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_batch = mode == 5 if mode == 0: # img2img @@ -101,7 +101,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro outpath_grids=opts.outdir_grids or opts.outdir_img2img_grids, prompt=prompt, negative_prompt=negative_prompt, - styles=[prompt_style, prompt_style2], + styles=prompt_styles, seed=seed, subseed=subseed, subseed_strength=subseed_strength, diff --git a/modules/styles.py b/modules/styles.py index ce6e71ca..990d5623 100644 --- a/modules/styles.py +++ b/modules/styles.py @@ -40,12 +40,18 @@ def apply_styles_to_prompt(prompt, styles): class StyleDatabase: def __init__(self, path: str): self.no_style = PromptStyle("None", "", "") - self.styles = {"None": self.no_style} + self.styles = {} + self.path = path - if not os.path.exists(path): + self.reload() + + def reload(self): + self.styles.clear() + + if not os.path.exists(self.path): return - with open(path, "r", encoding="utf-8-sig", newline='') as file: + with open(self.path, "r", encoding="utf-8-sig", newline='') as file: reader = csv.DictReader(file) for row in reader: # Support loading old CSV format with "name, text"-columns diff --git a/modules/txt2img.py b/modules/txt2img.py index 38b5f591..5a71793b 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,13 +8,13 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): +def txt2img(prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, outpath_grids=opts.outdir_grids or opts.outdir_txt2img_grids, prompt=prompt, - styles=[prompt_style, prompt_style2], + styles=prompt_styles, negative_prompt=negative_prompt, seed=seed, subseed=subseed, diff --git a/modules/ui.py b/modules/ui.py index 202e84e5..db198a47 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -180,7 +180,7 @@ def add_style(name: str, prompt: str, negative_prompt: str): # reserialize all styles every time we save them shared.prompt_styles.save_styles(shared.styles_filename) - return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(4)] + return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)] def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y): @@ -197,11 +197,11 @@ def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resiz return f"resize: from {p.width}x{p.height} to {p.hr_resize_x or p.hr_upscale_to_x}x{p.hr_resize_y or p.hr_upscale_to_y}" -def apply_styles(prompt, prompt_neg, style1_name, style2_name): - prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, [style1_name, style2_name]) - prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, [style1_name, style2_name]) +def apply_styles(prompt, prompt_neg, styles): + prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles) + prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles) - return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value="None"), gr.Dropdown.update(value="None")] + return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])] def interrogate(image): @@ -374,13 +374,10 @@ def create_toprow(is_img2img): ) with gr.Row(): - with gr.Column(scale=1, elem_id="style_pos_col"): - prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) + prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True) + create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles") - with gr.Column(scale=1, elem_id="style_neg_col"): - prompt_style2 = gr.Dropdown(label="Style 2", elem_id=f"{id_part}_style2_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys()))) - - return prompt, prompt_style, negative_prompt, prompt_style2, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button + return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, token_counter, token_button def setup_progressbar(*args, **kwargs): @@ -590,7 +587,7 @@ def create_ui(): modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - txt2img_prompt, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) + txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _,txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -684,8 +681,7 @@ def create_ui(): inputs=[ txt2img_prompt, txt2img_negative_prompt, - txt2img_prompt_style, - txt2img_prompt_style2, + txt2img_prompt_styles, steps, sampler_index, restore_faces, @@ -780,7 +776,7 @@ def create_ui(): modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) + img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) with gr.Row(elem_id='img2img_progress_row'): img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) @@ -921,8 +917,7 @@ def create_ui(): dummy_component, img2img_prompt, img2img_negative_prompt, - img2img_prompt_style, - img2img_prompt_style2, + img2img_prompt_styles, init_img, sketch, init_img_with_mask, @@ -977,7 +972,7 @@ def create_ui(): ) prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)] - style_dropdowns = [(txt2img_prompt_style, txt2img_prompt_style2), (img2img_prompt_style, img2img_prompt_style2)] + style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles] style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"] for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts): @@ -987,15 +982,15 @@ def create_ui(): # Have to pass empty dummy component here, because the JavaScript and Python function have to accept # the same number of parameters, but we only know the style-name after the JavaScript prompt inputs=[dummy_component, prompt, negative_prompt], - outputs=[txt2img_prompt_style, img2img_prompt_style, txt2img_prompt_style2, img2img_prompt_style2], + outputs=[txt2img_prompt_styles, img2img_prompt_styles], ) - for button, (prompt, negative_prompt), (style1, style2), js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): + for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs): button.click( fn=apply_styles, _js=js_func, - inputs=[prompt, negative_prompt, style1, style2], - outputs=[prompt, negative_prompt, style1, style2], + inputs=[prompt, negative_prompt, styles], + outputs=[prompt, negative_prompt, styles], ) token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) @@ -1530,6 +1525,7 @@ def create_ui(): previous_section = None current_tab = None + current_row = None with gr.Tabs(elem_id="settings"): for i, (k, item) in enumerate(opts.data_labels.items()): section_must_be_skipped = item.section[0] is None @@ -1538,10 +1534,14 @@ def create_ui(): elem_id, text = item.section if current_tab is not None: + current_row.__exit__() current_tab.__exit__() + gr.Group() current_tab = gr.TabItem(elem_id="settings_{}".format(elem_id), label=text) current_tab.__enter__() + current_row = gr.Column(variant='compact') + current_row.__enter__() previous_section = item.section @@ -1556,6 +1556,7 @@ def create_ui(): components.append(component) if current_tab is not None: + current_row.__exit__() current_tab.__exit__() with gr.TabItem("Actions"): @@ -1774,7 +1775,13 @@ def create_ui(): apply_field(x, 'value') if type(x) == gr.Dropdown: - apply_field(x, 'value', lambda val: val in x.choices, getattr(x, 'init_field', None)) + def check_dropdown(val): + if x.multiselect: + return all([value in x.choices for value in val]) + else: + return val in x.choices + + apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None)) visit(txt2img_interface, loadsave, "txt2img") visit(img2img_interface, loadsave, "img2img") diff --git a/style.css b/style.css index 14b15191..c1ebd501 100644 --- a/style.css +++ b/style.css @@ -114,6 +114,7 @@ min-width: unset !important; flex-grow: 0 !important; padding: 0.4em 0; + gap: 0; } #roll_col > button { @@ -141,10 +142,14 @@ min-width: 8em !important; } -#txt2img_style_index, #txt2img_style2_index, #img2img_style_index, #img2img_style2_index{ +#txt2img_styles, #img2img_styles{ margin-top: 1em; } +#txt2img_styles ul, #img2img_styles ul{ + max-height: 35em; +} + .gr-form{ background: transparent; } @@ -154,10 +159,14 @@ margin-bottom: 0; } -#toprow div{ +#toprow div.gr-box, #toprow div.gr-form{ border: none; gap: 0; background: transparent; + box-shadow: none; +} +#toprow div{ + gap: 0; } #resize_mode{ @@ -615,7 +624,7 @@ canvas[key="mask"] { max-width: 2.5em; min-width: 2.5em !important; height: 2.4em; - margin: 0.55em 0; + margin: 1.6em 0 0 0; } #quicksettings .gr-button-tool{ -- cgit v1.2.3 From d8b90ac121cbf0c18b1dc9d56a5e1d14ca51e74e Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 15 Jan 2023 18:50:56 +0300 Subject: big rework of progressbar/preview system to allow multiple users to prompts at the same time and do not get previews of each other --- javascript/progressbar.js | 249 ++++++++++++++++--------- javascript/textualInversion.js | 13 +- javascript/ui.js | 33 +++- modules/call_queue.py | 19 +- modules/hypernetworks/hypernetwork.py | 6 +- modules/img2img.py | 2 +- modules/progress.py | 96 ++++++++++ modules/sd_samplers.py | 2 +- modules/shared.py | 16 +- modules/textual_inversion/preprocess.py | 2 +- modules/textual_inversion/textual_inversion.py | 6 +- modules/txt2img.py | 2 +- modules/ui.py | 41 ++-- modules/ui_progress.py | 101 ---------- style.css | 74 +++++--- webui.py | 3 + 16 files changed, 390 insertions(+), 275 deletions(-) create mode 100644 modules/progress.py delete mode 100644 modules/ui_progress.py (limited to 'modules/txt2img.py') diff --git a/javascript/progressbar.js b/javascript/progressbar.js index d6323ed9..b7524ef7 100644 --- a/javascript/progressbar.js +++ b/javascript/progressbar.js @@ -1,82 +1,25 @@ // code related to showing and updating progressbar shown as the image is being made -global_progressbars = {} -galleries = {} -galleryObservers = {} - -// this tracks launches of window.setTimeout for progressbar to prevent starting a new timeout when the previous is still running -timeoutIds = {} -function check_progressbar(id_part, id_progressbar, id_progressbar_span, id_skip, id_interrupt, id_preview, id_gallery){ - // gradio 3.8's enlightened approach allows them to create two nested div elements inside each other with same id - // every time you use gr.HTML(elem_id='xxx'), so we handle this here - var progressbar = gradioApp().querySelector("#"+id_progressbar+" #"+id_progressbar) - var progressbarParent - if(progressbar){ - progressbarParent = gradioApp().querySelector("#"+id_progressbar) - } else{ - progressbar = gradioApp().getElementById(id_progressbar) - progressbarParent = null - } - var skip = id_skip ? gradioApp().getElementById(id_skip) : null - var interrupt = gradioApp().getElementById(id_interrupt) - - if(opts.show_progress_in_title && progressbar && progressbar.offsetParent){ - if(progressbar.innerText){ - let newtitle = '[' + progressbar.innerText.trim() + '] Stable Diffusion'; - if(document.title != newtitle){ - document.title = newtitle; - } - }else{ - let newtitle = 'Stable Diffusion' - if(document.title != newtitle){ - document.title = newtitle; - } - } - } - - if(progressbar!= null && progressbar != global_progressbars[id_progressbar]){ - global_progressbars[id_progressbar] = progressbar - - var mutationObserver = new MutationObserver(function(m){ - if(timeoutIds[id_part]) return; - - preview = gradioApp().getElementById(id_preview) - gallery = gradioApp().getElementById(id_gallery) +galleries = {} +storedGallerySelections = {} +galleryObservers = {} - if(preview != null && gallery != null){ - preview.style.width = gallery.clientWidth + "px" - preview.style.height = gallery.clientHeight + "px" - if(progressbarParent) progressbar.style.width = progressbarParent.clientWidth + "px" +function rememberGallerySelection(id_gallery){ + storedGallerySelections[id_gallery] = getGallerySelectedIndex(id_gallery) +} - //only watch gallery if there is a generation process going on - check_gallery(id_gallery); +function getGallerySelectedIndex(id_gallery){ + let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') + let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') - var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; - if(progressDiv){ - timeoutIds[id_part] = window.setTimeout(function() { - timeoutIds[id_part] = null - requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt) - }, 500) - } else{ - if (skip) { - skip.style.display = "none" - } - interrupt.style.display = "none" + let currentlySelectedIndex = -1 + galleryButtons.forEach(function(v, i){ if(v==galleryBtnSelected) { currentlySelectedIndex = i } }) - //disconnect observer once generation finished, so user can close selected image if they want - if (galleryObservers[id_gallery]) { - galleryObservers[id_gallery].disconnect(); - galleries[id_gallery] = null; - } - } - } - - }); - mutationObserver.observe( progressbar, { childList:true, subtree:true }) - } + return currentlySelectedIndex } +// this is a workaround for https://github.com/gradio-app/gradio/issues/2984 function check_gallery(id_gallery){ let gallery = gradioApp().getElementById(id_gallery) // if gallery has no change, no need to setting up observer again. @@ -85,10 +28,16 @@ function check_gallery(id_gallery){ if(galleryObservers[id_gallery]){ galleryObservers[id_gallery].disconnect(); } - let prevSelectedIndex = selected_gallery_index(); + + storedGallerySelections[id_gallery] = -1 + galleryObservers[id_gallery] = new MutationObserver(function (){ let galleryButtons = gradioApp().querySelectorAll('#'+id_gallery+' .gallery-item') let galleryBtnSelected = gradioApp().querySelector('#'+id_gallery+' .gallery-item.\\!ring-2') + let currentlySelectedIndex = getGallerySelectedIndex(id_gallery) + prevSelectedIndex = storedGallerySelections[id_gallery] + storedGallerySelections[id_gallery] = -1 + if (prevSelectedIndex !== -1 && galleryButtons.length>prevSelectedIndex && !galleryBtnSelected) { // automatically re-open previously selected index (if exists) activeElement = gradioApp().activeElement; @@ -120,30 +69,150 @@ function check_gallery(id_gallery){ } onUiUpdate(function(){ - check_progressbar('txt2img', 'txt2img_progressbar', 'txt2img_progress_span', 'txt2img_skip', 'txt2img_interrupt', 'txt2img_preview', 'txt2img_gallery') - check_progressbar('img2img', 'img2img_progressbar', 'img2img_progress_span', 'img2img_skip', 'img2img_interrupt', 'img2img_preview', 'img2img_gallery') - check_progressbar('ti', 'ti_progressbar', 'ti_progress_span', '', 'ti_interrupt', 'ti_preview', 'ti_gallery') + check_gallery('txt2img_gallery') + check_gallery('img2img_gallery') }) -function requestMoreProgress(id_part, id_progressbar_span, id_skip, id_interrupt){ - btn = gradioApp().getElementById(id_part+"_check_progress"); - if(btn==null) return; - - btn.click(); - var progressDiv = gradioApp().querySelectorAll('#' + id_progressbar_span).length > 0; - var skip = id_skip ? gradioApp().getElementById(id_skip) : null - var interrupt = gradioApp().getElementById(id_interrupt) - if(progressDiv && interrupt){ - if (skip) { - skip.style.display = "block" +function request(url, data, handler, errorHandler){ + var xhr = new XMLHttpRequest(); + var url = url; + xhr.open("POST", url, true); + xhr.setRequestHeader("Content-Type", "application/json"); + xhr.onreadystatechange = function () { + if (xhr.readyState === 4) { + if (xhr.status === 200) { + var js = JSON.parse(xhr.responseText); + handler(js) + } else{ + errorHandler() + } } - interrupt.style.display = "block" + }; + var js = JSON.stringify(data); + xhr.send(js); +} + +function pad2(x){ + return x<10 ? '0'+x : x +} + +function formatTime(secs){ + if(secs > 3600){ + return pad2(Math.floor(secs/60/60)) + ":" + pad2(Math.floor(secs/60)%60) + ":" + pad2(Math.floor(secs)%60) + } else if(secs > 60){ + return pad2(Math.floor(secs/60)) + ":" + pad2(Math.floor(secs)%60) + } else{ + return Math.floor(secs) + "s" } } -function requestProgress(id_part){ - btn = gradioApp().getElementById(id_part+"_check_progress_initial"); - if(btn==null) return; +function randomId(){ + return "task(" + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7) + Math.random().toString(36).slice(2, 7)+")" +} + +// starts sending progress requests to "/internal/progress" uri, creating progressbar above progressbarContainer element and +// preview inside gallery element. Cleans up all created stuff when the task is over and calls atEnd. +// calls onProgress every time there is a progress update +function requestProgress(id_task, progressbarContainer, gallery, atEnd, onProgress){ + var dateStart = new Date() + var wasEverActive = false + var parentProgressbar = progressbarContainer.parentNode + var parentGallery = gallery.parentNode + + var divProgress = document.createElement('div') + divProgress.className='progressDiv' + var divInner = document.createElement('div') + divInner.className='progress' + + divProgress.appendChild(divInner) + parentProgressbar.insertBefore(divProgress, progressbarContainer) + + var livePreview = document.createElement('div') + livePreview.className='livePreview' + parentGallery.insertBefore(livePreview, gallery) + + var removeProgressBar = function(){ + parentProgressbar.removeChild(divProgress) + parentGallery.removeChild(livePreview) + atEnd() + } + + var fun = function(id_task, id_live_preview){ + request("/internal/progress", {"id_task": id_task, "id_live_preview": id_live_preview}, function(res){ + console.log(res) + + if(res.completed){ + removeProgressBar() + return + } + + var rect = progressbarContainer.getBoundingClientRect() + + if(rect.width){ + divProgress.style.width = rect.width + "px"; + } + + progressText = "" + + divInner.style.width = ((res.progress || 0) * 100.0) + '%' + + if(res.progress > 0){ + progressText = ((res.progress || 0) * 100.0).toFixed(0) + '%' + } + + if(res.eta){ + progressText += " ETA: " + formatTime(res.eta) + } else if(res.textinfo){ + progressText += " " + res.textinfo + } + + divInner.textContent = progressText + + var elapsedFromStart = (new Date() - dateStart) / 1000 + + if(res.active) wasEverActive = true; + + if(! res.active && wasEverActive){ + removeProgressBar() + return + } + + if(elapsedFromStart > 5 && !res.queued && !res.active){ + removeProgressBar() + return + } + + + if(res.live_preview){ + var img = new Image(); + img.onload = function() { + var rect = gallery.getBoundingClientRect() + if(rect.width){ + livePreview.style.width = rect.width + "px" + livePreview.style.height = rect.height + "px" + } + + livePreview.innerHTML = '' + livePreview.appendChild(img) + if(livePreview.childElementCount > 2){ + livePreview.removeChild(livePreview.firstElementChild) + } + } + img.src = res.live_preview; + } + + + if(onProgress){ + onProgress(res) + } + + setTimeout(() => { + fun(id_task, res.id_live_preview); + }, 500) + }, function(){ + removeProgressBar() + }) + } - btn.click(); + fun(id_task, 0) } diff --git a/javascript/textualInversion.js b/javascript/textualInversion.js index 8061be08..0354b860 100644 --- a/javascript/textualInversion.js +++ b/javascript/textualInversion.js @@ -1,8 +1,17 @@ + function start_training_textual_inversion(){ - requestProgress('ti') gradioApp().querySelector('#ti_error').innerHTML='' - return args_to_array(arguments) + var id = randomId() + requestProgress(id, gradioApp().getElementById('ti_output'), gradioApp().getElementById('ti_gallery'), function(){}, function(progress){ + gradioApp().getElementById('ti_progress').innerHTML = progress.textinfo + }) + + var res = args_to_array(arguments) + + res[0] = id + + return res } diff --git a/javascript/ui.js b/javascript/ui.js index f8279124..ecf97cb3 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -126,18 +126,41 @@ function create_submit_args(args){ return res } +function showSubmitButtons(tabname, show){ + gradioApp().getElementById(tabname+'_interrupt').style.display = show ? "none" : "block" + gradioApp().getElementById(tabname+'_skip').style.display = show ? "none" : "block" +} + function submit(){ - requestProgress('txt2img') + rememberGallerySelection('txt2img_gallery') + showSubmitButtons('txt2img', false) + + var id = randomId() + requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function(){ + showSubmitButtons('txt2img', true) + + }) - return create_submit_args(arguments) + var res = create_submit_args(arguments) + + res[0] = id + + return res } function submit_img2img(){ - requestProgress('img2img') + rememberGallerySelection('img2img_gallery') + showSubmitButtons('img2img', false) + + var id = randomId() + requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function(){ + showSubmitButtons('img2img', true) + }) - res = create_submit_args(arguments) + var res = create_submit_args(arguments) - res[0] = get_tab_index('mode_img2img') + res[0] = id + res[1] = get_tab_index('mode_img2img') return res } diff --git a/modules/call_queue.py b/modules/call_queue.py index 4cd49533..92097c15 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -4,7 +4,7 @@ import threading import traceback import time -from modules import shared +from modules import shared, progress queue_lock = threading.Lock() @@ -22,12 +22,23 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): - shared.state.begin() + # if the first argument is a string that says "task(...)", it is treated as a job id + if len(args) > 0 and type(args[0]) == str and args[0][0:5] == "task(" and args[0][-1] == ")": + id_task = args[0] + progress.add_task_to_queue(id_task) + else: + id_task = None with queue_lock: - res = func(*args, **kwargs) + shared.state.begin() + progress.start_task(id_task) + + try: + res = func(*args, **kwargs) + finally: + progress.finish_task(id_task) - shared.state.end() + shared.state.end() return res diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 3aebefa8..ae6af516 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -453,7 +453,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images @@ -629,7 +629,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, description = f"Training hypernetwork [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" pbar.set_description(description) - shared.state.textinfo = description if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' @@ -701,7 +700,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, torch.cuda.set_rng_state_all(cuda_rng_state) hypernetwork.train() if image is not None: - shared.state.current_image = image + shared.state.assign_current_image(image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" diff --git a/modules/img2img.py b/modules/img2img.py index f62783c6..f4a03c57 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -59,7 +59,7 @@ def process_batch(p, input_dir, output_dir, args): processed_image.save(os.path.join(output_dir, filename)) -def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args): is_batch = mode == 5 if mode == 0: # img2img diff --git a/modules/progress.py b/modules/progress.py new file mode 100644 index 00000000..3327b883 --- /dev/null +++ b/modules/progress.py @@ -0,0 +1,96 @@ +import base64 +import io +import time + +import gradio as gr +from pydantic import BaseModel, Field + +from modules.shared import opts + +import modules.shared as shared + + +current_task = None +pending_tasks = {} +finished_tasks = [] + + +def start_task(id_task): + global current_task + + current_task = id_task + pending_tasks.pop(id_task, None) + + +def finish_task(id_task): + global current_task + + if current_task == id_task: + current_task = None + + finished_tasks.append(id_task) + if len(finished_tasks) > 16: + finished_tasks.pop(0) + + +def add_task_to_queue(id_job): + pending_tasks[id_job] = time.time() + + +class ProgressRequest(BaseModel): + id_task: str = Field(default=None, title="Task ID", description="id of the task to get progress for") + id_live_preview: int = Field(default=-1, title="Live preview image ID", description="id of last received last preview image") + + +class ProgressResponse(BaseModel): + active: bool = Field(title="Whether the task is being worked on right now") + queued: bool = Field(title="Whether the task is in queue") + completed: bool = Field(title="Whether the task has already finished") + progress: float = Field(default=None, title="Progress", description="The progress with a range of 0 to 1") + eta: float = Field(default=None, title="ETA in secs") + live_preview: str = Field(default=None, title="Live preview image", description="Current live preview; a data: uri") + id_live_preview: int = Field(default=None, title="Live preview image ID", description="Send this together with next request to prevent receiving same image") + textinfo: str = Field(default=None, title="Info text", description="Info text used by WebUI.") + + +def setup_progress_api(app): + return app.add_api_route("/internal/progress", progressapi, methods=["POST"], response_model=ProgressResponse) + + +def progressapi(req: ProgressRequest): + active = req.id_task == current_task + queued = req.id_task in pending_tasks + completed = req.id_task in finished_tasks + + if not active: + return ProgressResponse(active=active, queued=queued, completed=completed, id_live_preview=-1, textinfo="In queue..." if queued else "Waiting...") + + progress = 0 + + if shared.state.job_count > 0: + progress += shared.state.job_no / shared.state.job_count + if shared.state.sampling_steps > 0: + progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps + + progress = min(progress, 1) + + elapsed_since_start = time.time() - shared.state.time_start + predicted_duration = elapsed_since_start / progress if progress > 0 else None + eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None + + id_live_preview = req.id_live_preview + shared.state.set_current_image() + if opts.live_previews_enable and shared.state.id_live_preview != req.id_live_preview: + image = shared.state.current_image + if image is not None: + buffered = io.BytesIO() + image.save(buffered, format="png") + live_preview = 'data:image/png;base64,' + base64.b64encode(buffered.getvalue()).decode("ascii") + id_live_preview = shared.state.id_live_preview + else: + live_preview = None + else: + live_preview = None + + return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) + diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 7616fded..76e0e0d5 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -140,7 +140,7 @@ def store_latent(decoded): if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if not shared.parallel_processing_allowed: - shared.state.current_image = sample_to_image(decoded) + shared.state.assign_current_image(sample_to_image(decoded)) class InterruptedException(BaseException): diff --git a/modules/shared.py b/modules/shared.py index 51df056c..de99aca9 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -152,6 +152,7 @@ def reload_hypernetworks(): hypernetwork.load_hypernetwork(opts.sd_hypernetwork) + class State: skipped = False interrupted = False @@ -165,6 +166,7 @@ class State: current_latent = None current_image = None current_image_sampling_step = 0 + id_live_preview = 0 textinfo = None time_start = None need_restart = False @@ -207,6 +209,7 @@ class State: self.current_latent = None self.current_image = None self.current_image_sampling_step = 0 + self.id_live_preview = 0 self.skipped = False self.interrupted = False self.textinfo = None @@ -220,8 +223,8 @@ class State: devices.torch_gc() - """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" def set_current_image(self): + """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" if not parallel_processing_allowed: return @@ -234,12 +237,16 @@ class State: import modules.sd_samplers if opts.show_progress_grid: - self.current_image = modules.sd_samplers.samples_to_image_grid(self.current_latent) + self.assign_current_image(modules.sd_samplers.samples_to_image_grid(self.current_latent)) else: - self.current_image = modules.sd_samplers.sample_to_image(self.current_latent) + self.assign_current_image(modules.sd_samplers.sample_to_image(self.current_latent)) self.current_image_sampling_step = self.sampling_step + def assign_current_image(self, image): + self.current_image = image + self.id_live_preview += 1 + state = State() state.server_start = time.time() @@ -424,8 +431,6 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"), })) options_templates.update(options_section(('ui', "User interface"), { - "show_progressbar": OptionInfo(True, "Show progressbar"), - "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "return_grid": OptionInfo(True, "Show grid in results for web"), "do_not_show_images": OptionInfo(False, "Do not show any images in results for web"), "add_model_hash_to_info": OptionInfo(True, "Add model hash to generation information"), @@ -446,6 +451,7 @@ options_templates.update(options_section(('ui', "User interface"), { options_templates.update(options_section(('ui', "Live previews"), { "live_previews_enable": OptionInfo(True, "Show live previews of the created image"), + "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}), "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py index 3c1042ad..64abff4d 100644 --- a/modules/textual_inversion/preprocess.py +++ b/modules/textual_inversion/preprocess.py @@ -12,7 +12,7 @@ from modules.shared import opts, cmd_opts from modules.textual_inversion import autocrop -def preprocess(process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): +def preprocess(id_task, process_src, process_dst, process_width, process_height, preprocess_txt_action, process_flip, process_split, process_caption, process_caption_deepbooru=False, split_threshold=0.5, overlap_ratio=0.2, process_focal_crop=False, process_focal_crop_face_weight=0.9, process_focal_crop_entropy_weight=0.3, process_focal_crop_edges_weight=0.5, process_focal_crop_debug=False): try: if process_caption: shared.interrogator.load() diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 63935878..7e4a6d24 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -345,7 +345,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None) @@ -510,7 +510,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" pbar.set_description(description) - shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: # Before saving, change name to match current checkpoint. embedding_name_every = f'{embedding_name}-{steps_done}' @@ -560,7 +559,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ shared.sd_model.first_stage_model.to(devices.cpu) if image is not None: - shared.state.current_image = image + shared.state.assign_current_image(image) + last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False) last_saved_image += f", prompt: {preview_text}" diff --git a/modules/txt2img.py b/modules/txt2img.py index 38b5f591..ca5d4550 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -8,7 +8,7 @@ import modules.processing as processing from modules.ui import plaintext_to_html -def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, *args): p = StableDiffusionProcessingTxt2Img( sd_model=shared.sd_model, outpath_samples=opts.outdir_samples or opts.outdir_txt2img_samples, diff --git a/modules/ui.py b/modules/ui.py index 2425c66f..ff33236b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -356,7 +356,7 @@ def create_toprow(is_img2img): button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru") with gr.Column(scale=1): - with gr.Row(): + with gr.Row(elem_id=f"{id_part}_generate_box"): skip = gr.Button('Skip', elem_id=f"{id_part}_skip") interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt") submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary') @@ -384,9 +384,7 @@ def create_toprow(is_img2img): def setup_progressbar(*args, **kwargs): - import modules.ui_progress - - modules.ui_progress.setup_progressbar(*args, **kwargs) + pass def apply_setting(key, value): @@ -479,8 +477,8 @@ Requested path was: {f} else: sp.Popen(["xdg-open", path]) - with gr.Column(variant='panel'): - with gr.Group(): + with gr.Column(variant='panel', elem_id=f"{tabname}_results"): + with gr.Group(elem_id=f"{tabname}_gallery_container"): result_gallery = gr.Gallery(label='Output', show_label=False, elem_id=f"{tabname}_gallery").style(grid=4) generation_info = None @@ -595,15 +593,6 @@ def create_ui(): dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="bytes", visible=False) - with gr.Row(elem_id='txt2img_progress_row'): - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="txt2img_progressbar") - txt2img_preview = gr.Image(elem_id='txt2img_preview', visible=False) - setup_progressbar(progressbar, txt2img_preview, 'txt2img') - with gr.Row().style(equal_height=False): with gr.Column(variant='panel', elem_id="txt2img_settings"): for category in ordered_ui_categories(): @@ -682,6 +671,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.txt2img.txt2img, extra_outputs=[None, '', '']), _js="submit", inputs=[ + dummy_component, txt2img_prompt, txt2img_negative_prompt, txt2img_prompt_style, @@ -782,16 +772,7 @@ def create_ui(): with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste,token_counter, token_button = create_toprow(is_img2img=True) - with gr.Row(elem_id='img2img_progress_row'): - img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) - - with gr.Column(scale=1): - pass - - with gr.Column(scale=1): - progressbar = gr.HTML(elem_id="img2img_progressbar") - img2img_preview = gr.Image(elem_id='img2img_preview', visible=False) - setup_progressbar(progressbar, img2img_preview, 'img2img') + img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="bytes", visible=False) with FormRow().style(equal_height=False): with gr.Column(variant='panel', elem_id="img2img_settings"): @@ -958,6 +939,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.img2img.img2img, extra_outputs=[None, '', '']), _js="submit_img2img", inputs=[ + dummy_component, dummy_component, img2img_prompt, img2img_negative_prompt, @@ -1335,15 +1317,11 @@ def create_ui(): script_callbacks.ui_train_tabs_callback(params) - with gr.Column(): - progressbar = gr.HTML(elem_id="ti_progressbar") + with gr.Column(elem_id='ti_gallery_container'): ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) - ti_gallery = gr.Gallery(label='Output', show_label=False, elem_id='ti_gallery').style(grid=4) - ti_preview = gr.Image(elem_id='ti_preview', visible=False) ti_progress = gr.HTML(elem_id="ti_progress", value="") ti_outcome = gr.HTML(elem_id="ti_error", value="") - setup_progressbar(progressbar, ti_preview, 'ti', textinfo=ti_progress) create_embedding.click( fn=modules.textual_inversion.ui.create_embedding, @@ -1384,6 +1362,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, process_src, process_dst, process_width, @@ -1411,6 +1390,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, train_embedding_name, embedding_learn_rate, batch_size, @@ -1443,6 +1423,7 @@ def create_ui(): fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ + dummy_component, train_hypernetwork_name, hypernetwork_learn_rate, batch_size, diff --git a/modules/ui_progress.py b/modules/ui_progress.py deleted file mode 100644 index 7cd312e4..00000000 --- a/modules/ui_progress.py +++ /dev/null @@ -1,101 +0,0 @@ -import time - -import gradio as gr - -from modules.shared import opts - -import modules.shared as shared - - -def calc_time_left(progress, threshold, label, force_display, show_eta): - if progress == 0: - return "" - else: - time_since_start = time.time() - shared.state.time_start - eta = (time_since_start/progress) - eta_relative = eta-time_since_start - if (eta_relative > threshold and show_eta) or force_display: - if eta_relative > 3600: - return label + time.strftime('%H:%M:%S', time.gmtime(eta_relative)) - elif eta_relative > 60: - return label + time.strftime('%M:%S', time.gmtime(eta_relative)) - else: - return label + time.strftime('%Ss', time.gmtime(eta_relative)) - else: - return "" - - -def check_progress_call(id_part): - if shared.state.job_count == 0: - return "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False) - - progress = 0 - - if shared.state.job_count > 0: - progress += shared.state.job_no / shared.state.job_count - if shared.state.sampling_steps > 0: - progress += 1 / shared.state.job_count * shared.state.sampling_step / shared.state.sampling_steps - - # Show progress percentage and time left at the same moment, and base it also on steps done - show_eta = progress >= 0.01 or shared.state.sampling_step >= 10 - - time_left = calc_time_left(progress, 1, " ETA: ", shared.state.time_left_force_display, show_eta) - if time_left != "": - shared.state.time_left_force_display = True - - progress = min(progress, 1) - - progressbar = "" - if opts.show_progressbar: - progressbar = f"""
{" " * 2 + str(int(progress*100))+"%" + time_left if show_eta else ""}
""" - - image = gr.update(visible=False) - preview_visibility = gr.update(visible=False) - - if opts.live_previews_enable: - shared.state.set_current_image() - image = shared.state.current_image - - if image is None: - image = gr.update(value=None) - else: - preview_visibility = gr.update(visible=True) - - if shared.state.textinfo is not None: - textinfo_result = gr.HTML.update(value=shared.state.textinfo, visible=True) - else: - textinfo_result = gr.update(visible=False) - - return f"

{progressbar}

", preview_visibility, image, textinfo_result - - -def check_progress_call_initial(id_part): - shared.state.job_count = -1 - shared.state.current_latent = None - shared.state.current_image = None - shared.state.textinfo = None - shared.state.time_start = time.time() - shared.state.time_left_force_display = False - - return check_progress_call(id_part) - - -def setup_progressbar(progressbar, preview, id_part, textinfo=None): - if textinfo is None: - textinfo = gr.HTML(visible=False) - - check_progress = gr.Button('Check progress', elem_id=f"{id_part}_check_progress", visible=False) - check_progress.click( - fn=lambda: check_progress_call(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) - - check_progress_initial = gr.Button('Check progress (first)', elem_id=f"{id_part}_check_progress_initial", visible=False) - check_progress_initial.click( - fn=lambda: check_progress_call_initial(id_part), - show_progress=False, - inputs=[], - outputs=[progressbar, preview, preview, textinfo], - ) diff --git a/style.css b/style.css index 2d484e06..786b71d1 100644 --- a/style.css +++ b/style.css @@ -305,26 +305,42 @@ input[type="range"]{ } .progressDiv{ - width: 100%; - height: 20px; - background: #b4c0cc; - border-radius: 8px; + position: absolute; + height: 20px; + top: -20px; + background: #b4c0cc; + border-radius: 8px !important; } .dark .progressDiv{ - background: #424c5b; + background: #424c5b; } .progressDiv .progress{ - width: 0%; - height: 20px; - background: #0060df; - color: white; - font-weight: bold; - line-height: 20px; - padding: 0 8px 0 0; - text-align: right; - border-radius: 8px; + width: 0%; + height: 20px; + background: #0060df; + color: white; + font-weight: bold; + line-height: 20px; + padding: 0 8px 0 0; + text-align: right; + border-radius: 8px; + overflow: visible; + white-space: nowrap; +} + +.livePreview{ + position: absolute; + z-index: 300; + background-color: white; + margin: -4px; +} + +.livePreview img{ + object-fit: contain; + width: 100%; + height: 100%; } #lightboxModal{ @@ -450,23 +466,25 @@ input[type="range"]{ display:none } -#txt2img_interrupt, #img2img_interrupt{ - position: absolute; - width: 50%; - height: 72px; - background: #b4c0cc; - border-radius: 0px; - display: none; +#txt2img_generate_box, #img2img_generate_box{ + position: relative; +} + +#txt2img_interrupt, #img2img_interrupt, #txt2img_skip, #img2img_skip{ + position: absolute; + width: 50%; + height: 100%; + background: #b4c0cc; + display: none; } +#txt2img_interrupt, #img2img_interrupt{ + right: 0; + border-radius: 0 0.5rem 0.5rem 0; +} #txt2img_skip, #img2img_skip{ - position: absolute; - width: 50%; - right: 0px; - height: 72px; - background: #b4c0cc; - border-radius: 0px; - display: none; + left: 0; + border-radius: 0.5rem 0 0 0.5rem; } .red { diff --git a/webui.py b/webui.py index 1fff80da..4624fe18 100644 --- a/webui.py +++ b/webui.py @@ -34,6 +34,7 @@ import modules.sd_vae import modules.txt2img import modules.script_callbacks import modules.textual_inversion.textual_inversion +import modules.progress import modules.ui from modules import modelloader @@ -181,6 +182,8 @@ def webui(): app.add_middleware(GZipMiddleware, minimum_size=1000) + modules.progress.setup_progress_api(app) + if launch_api: create_api(app) -- cgit v1.2.3