From c4ee6d9b73300d906b8df4602157d646df2415ef Mon Sep 17 00:00:00 2001 From: Robert Barron Date: Sun, 30 Jul 2023 00:41:10 -0700 Subject: xyz_grid: allow varying the seed along an axis along with the axis's other changes --- scripts/xyz_grid.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) (limited to 'scripts/xyz_grid.py') diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 1010845e..4eb1b197 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -416,6 +416,10 @@ class Script(scripts.Script): with gr.Column(): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) + with gr.Column(): + vary_seeds_x = gr.Checkbox(label='Vary seed on X axis', value=False, elem_id=self.elem_id("vary_seeds_x")) + vary_seeds_y = gr.Checkbox(label='Vary seed on Y axis', value=False, elem_id=self.elem_id("vary_seeds_y")) + vary_seeds_z = gr.Checkbox(label='Vary seed on Z axis', value=False, elem_id=self.elem_id("vary_seeds_z")) with gr.Column(): margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) @@ -475,9 +479,9 @@ class Script(scripts.Script): (z_values_dropdown, lambda params:get_dropdown_update_from_params("Z",params)), ) - return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size] + return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size] - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size): + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size): if not no_fixed_seeds: modules.processing.fix_seed(p) @@ -648,6 +652,16 @@ class Script(scripts.Script): y_opt.apply(pc, y, ys) z_opt.apply(pc, z, zs) + xdim = len(xs) if vary_seeds_x else 1 + ydim = len(ys) if vary_seeds_y else 1 + + if vary_seeds_x: + pc.seed += ix + if vary_seeds_y: + pc.seed += iy * xdim + if vary_seeds_z: + pc.seed += iz * xdim * ydim + res = process_images(pc) # Sets subgrid infotexts -- cgit v1.2.3 From 8aa13d5dce2789a7d0bd802e6d62453b3c380496 Mon Sep 17 00:00:00 2001 From: Anthony Fu Date: Mon, 16 Oct 2023 14:12:18 +0800 Subject: Interrupt after current generation --- modules/call_queue.py | 1 + modules/img2img.py | 2 +- modules/processing.py | 2 +- modules/shared_options.py | 1 + modules/shared_state.py | 11 +++++++++-- scripts/loopback.py | 6 +++--- scripts/xyz_grid.py | 2 +- 7 files changed, 17 insertions(+), 8 deletions(-) (limited to 'scripts/xyz_grid.py') diff --git a/modules/call_queue.py b/modules/call_queue.py index ddf0d573..01c6d17f 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -78,6 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): shared.state.skipped = False shared.state.interrupted = False + shared.state.interrupted_next = False shared.state.job_count = 0 if not add_stats: diff --git a/modules/img2img.py b/modules/img2img.py index 52cb577a..31f8c2aa 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -49,7 +49,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if state.skipped: state.skipped = False - if state.interrupted: + if state.interrupted or state.interrupted_next: break try: diff --git a/modules/processing.py b/modules/processing.py index 40598f5c..e7eecd66 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -819,7 +819,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.skipped: state.skipped = False - if state.interrupted: + if state.interrupted or state.interrupted_next: break sd_models.reload_model_weights() # model can be changed for example by refiner diff --git a/modules/shared_options.py b/modules/shared_options.py index 32bf7353..4638ef06 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -113,6 +113,7 @@ options_templates.update(options_section(('system', "System"), { "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), + "interrupt_after_current": OptionInfo(False, "Interrupt generation after current image is finished on batch processing"), })) options_templates.update(options_section(('API', "API"), { diff --git a/modules/shared_state.py b/modules/shared_state.py index a68789cc..c72c3f63 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -12,6 +12,7 @@ log = logging.getLogger(__name__) class State: skipped = False interrupted = False + interrupted_next = False job = "" job_no = 0 job_count = 0 @@ -76,8 +77,12 @@ class State: log.info("Received skip request") def interrupt(self): - self.interrupted = True - log.info("Received interrupt request") + if shared.opts.interrupt_after_current and self.job_count > 1: + self.interrupted_next = True + log.info("Received interrupt request, interrupt after current job") + else: + self.interrupted = True + log.info("Received interrupt request") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: @@ -91,6 +96,7 @@ class State: obj = { "skipped": self.skipped, "interrupted": self.interrupted, + "interrupted_next": self.interrupted_next, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, @@ -114,6 +120,7 @@ class State: self.id_live_preview = 0 self.skipped = False self.interrupted = False + self.interrupted_next = False self.textinfo = None self.job = job devices.torch_gc() diff --git a/scripts/loopback.py b/scripts/loopback.py index 2d5feaf9..ad921269 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -95,7 +95,7 @@ class Script(scripts.Script): processed = processing.process_images(p) # Generation cancelled. - if state.interrupted: + if state.interrupted or state.interrupted_next: break if initial_seed is None: @@ -122,8 +122,8 @@ class Script(scripts.Script): p.inpainting_fill = original_inpainting_fill - if state.interrupted: - break + if state.interrupted or state.interrupted_next: + break if len(history) > 1: grid = images.image_grid(history, rows=1) diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc..495008ad 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -688,7 +688,7 @@ class Script(scripts.Script): grid_infotext = [None] * (1 + len(zs)) def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted: + if shared.state.interrupted or state.interrupted_next: return Processed(p, [], p.seed, "") pc = copy(p) -- cgit v1.2.3 From 598da5cd4928618b166886d3485ce30ce3a43490 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:50:06 +0800 Subject: Use options instead of cmd_args --- modules/cmd_args.py | 2 -- modules/devices.py | 25 ++++++++++--------- modules/initialize_util.py | 1 + modules/sd_models.py | 61 ++++++++++++++++++++++++---------------------- modules/shared_options.py | 1 + scripts/xyz_grid.py | 1 + 6 files changed, 49 insertions(+), 42 deletions(-) (limited to 'scripts/xyz_grid.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 088d5dea..a9fb9bfa 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -118,5 +118,3 @@ parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set time parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) parser.add_argument("--disable-extra-extensions", action='store_true', help="prevent all extensions except built-in from running regardless of any other settings", default=False) parser.add_argument("--skip-load-model-at-start", action='store_true', help="if load a model at web start, only take effect when --nowebui", ) -parser.add_argument("--opt-unet-fp8-storage", action='store_true', help="use fp8 for SD UNet to save vram", default=False) -parser.add_argument("--opt-unet-fp8-storage-xl", action='store_true', help="use fp8 for SD UNet to save vram", default=False) diff --git a/modules/devices.py b/modules/devices.py index d7c905c2..03e7bdb7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool: if device_id is None: device_id = get_cuda_device_id() return ( - torch.cuda.get_device_capability(device_id) == (7, 5) + torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") ) def get_cuda_device_id(): return ( - int(shared.cmd_opts.device_id) - if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0 ) or torch.cuda.current_device() @@ -116,16 +116,19 @@ patch_module_list = [ torch.nn.LayerNorm, ] + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + @contextlib.contextmanager def manual_autocast(): - def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward diff --git a/modules/initialize_util.py b/modules/initialize_util.py index 2e9b6d89..1b11ead6 100644 --- a/modules/initialize_util.py +++ b/modules/initialize_util.py @@ -177,6 +177,7 @@ def configure_opts_onchange(): shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + shared.opts.onchange("fp8_storage", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) startup_timer.record("opts onchange") diff --git a/modules/sd_models.py b/modules/sd_models.py index a6c8b2fa..eb491434 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -339,10 +339,28 @@ class SkipWritingToConfig: SkipWritingToConfig.skip = self.previous +def check_fp8(model): + if model is None: + return None + if devices.get_optimal_device_name() == "mps": + enable_fp8 = False + elif shared.opts.fp8_storage == "Enable": + enable_fp8 = True + elif getattr(model, "is_sdxl", False) and shared.opts.fp8_storage == "Enable for SDXL": + enable_fp8 = True + else: + enable_fp8 = False + return enable_fp8 + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") + if not check_fp8(model) and devices.fp8: + # prevent model to load state dict in fp8 + model.half() + if not SkipWritingToConfig.skip: shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title @@ -395,34 +413,16 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer devices.dtype_unet = torch.float16 timer.record("apply half()") - if devices.get_optimal_device_name() == "mps": - enable_fp8 = False - elif shared.cmd_opts.opt_unet_fp8_storage: - enable_fp8 = True - elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: - enable_fp8 = True - else: - enable_fp8 = False - - if enable_fp8: + if check_fp8(model): devices.fp8 = True - if model.is_sdxl: - cond_stage = model.conditioner - else: - cond_stage = model.cond_stage_model - - for module in cond_stage.modules(): - if isinstance(module, torch.nn.Linear): + first_stage = model.first_stage_model + model.first_stage_model = None + for module in model.modules(): + if isinstance(module, torch.nn.Conv2d): module.to(torch.float8_e4m3fn) - - if devices.device == devices.cpu: - for module in model.model.diffusion_model.modules(): - if isinstance(module, torch.nn.Conv2d): - module.to(torch.float8_e4m3fn) - elif isinstance(module, torch.nn.Linear): - module.to(torch.float8_e4m3fn) - else: - model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + elif isinstance(module, torch.nn.Linear): + module.to(torch.float8_e4m3fn) + model.first_stage_model = first_stage timer.record("apply fp8") else: devices.fp8 = False @@ -769,7 +769,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer): return None -def reload_model_weights(sd_model=None, info=None): +def reload_model_weights(sd_model=None, info=None, forced_reload=False): checkpoint_info = info or select_checkpoint() timer = Timer() @@ -781,11 +781,14 @@ def reload_model_weights(sd_model=None, info=None): current_checkpoint_info = None else: current_checkpoint_info = sd_model.sd_checkpoint_info - if sd_model.sd_model_checkpoint == checkpoint_info.filename: + if check_fp8(sd_model) != devices.fp8: + # load from state dict again to prevent extra numerical errors + forced_reload = True + elif sd_model.sd_model_checkpoint == checkpoint_info.filename: return sd_model sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer) - if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: return sd_model if sd_model is not None: diff --git a/modules/shared_options.py b/modules/shared_options.py index f1003f21..d27f35e9 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -200,6 +200,7 @@ options_templates.update(options_section(('optimizations', "Optimizations"), { "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"), "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"), "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"), + "fp8_storage": OptionInfo("Disable", "FP8 weight", gr.Dropdown, {"choices": ["Disable", "Enable for SDXL", "Enable"]}).info("Use FP8 to store Linear/Conv layers' weight. Require pytorch>=2.1.0."), })) options_templates.update(options_section(('compatibility', "Compatibility"), { diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 0dc255bc..b2250c04 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -270,6 +270,7 @@ axis_options = [ AxisOption("Refiner checkpoint", str, apply_field('refiner_checkpoint'), format_value=format_remove_path, confirm=confirm_checkpoints_or_none, cost=1.0, choices=lambda: ['None'] + sorted(sd_models.checkpoints_list, key=str.casefold)), AxisOption("Refiner switch at", float, apply_field('refiner_switch_at')), AxisOption("RNG source", str, apply_override("randn_source"), choices=lambda: ["GPU", "CPU", "NV"]), + AxisOption("FP8 mode", str, apply_override("fp8_storage"), cost=0.9, choices=lambda: ["Disable", "Enable for SDXL", "Enable"]), ] -- cgit v1.2.3 From de1809bd14450cfc41623b6021c7087fb385ab6f Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Fri, 22 Dec 2023 00:08:35 +0900 Subject: handle axis_type is None --- scripts/xyz_grid.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'scripts/xyz_grid.py') diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index b2250c04..e5083874 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -476,6 +476,8 @@ class Script(scripts.Script): fill_z_button.click(fn=fill, inputs=[z_type, csv_mode], outputs=[z_values, z_values_dropdown]) def select_axis(axis_type, axis_values, axis_values_dropdown, csv_mode): + axis_type = axis_type or 0 # if axle type is None set to 0 + choices = self.current_axis_options[axis_type].choices has_choices = choices is not None @@ -526,6 +528,8 @@ class Script(scripts.Script): return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode] def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode): + x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0 + if not no_fixed_seeds: modules.processing.fix_seed(p) -- cgit v1.2.3 From 0743ee9b3eda8dd4ceea625d710031577201f4ad Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 15:50:47 +0300 Subject: re-layout checkboxes for XYZ grid a bit --- scripts/xyz_grid.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) (limited to 'scripts/xyz_grid.py') diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 34267c2c..2d550994 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -438,17 +438,16 @@ class Script(scripts.Script): with gr.Column(): draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend")) no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds")) + with gr.Row(): + vary_seeds_x = gr.Checkbox(label='Vary seeds for X', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_x"), tooltip="Use different seeds for images along X axis.") + vary_seeds_y = gr.Checkbox(label='Vary seeds for Y', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_y"), tooltip="Use different seeds for images along Y axis.") + vary_seeds_z = gr.Checkbox(label='Vary seeds for Z', value=False, min_width=80, elem_id=self.elem_id("vary_seeds_z"), tooltip="Use different seeds for images along Z axis.") with gr.Column(): include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images")) include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids")) - with gr.Column(): - vary_seeds_x = gr.Checkbox(label='Vary seed on X axis', value=False, elem_id=self.elem_id("vary_seeds_x")) - vary_seeds_y = gr.Checkbox(label='Vary seed on Y axis', value=False, elem_id=self.elem_id("vary_seeds_y")) - vary_seeds_z = gr.Checkbox(label='Vary seed on Z axis', value=False, elem_id=self.elem_id("vary_seeds_z")) + csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Column(): margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size")) - with gr.Column(): - csv_mode = gr.Checkbox(label='Use text inputs instead of dropdowns', value=False, elem_id=self.elem_id("csv_mode")) with gr.Row(variant="compact", elem_id="swap_axes"): swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button") @@ -531,7 +530,7 @@ class Script(scripts.Script): return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode] - def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size, csv_mode): + def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, vary_seeds_x, vary_seeds_y, vary_seeds_z, margin_size, csv_mode): x_type, y_type, z_type = x_type or 0, y_type or 0, z_type or 0 # if axle type is None set to 0 if not no_fixed_seeds: -- cgit v1.2.3 From 0aa7c53c0b9469849377aff83f43c9f75c19b3fa Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 16:50:59 +0300 Subject: fix borked merge, rename fields to better match what they do, change setting default to true for #13653 --- modules/call_queue.py | 2 +- modules/img2img.py | 2 +- modules/processing.py | 2 +- modules/shared_options.py | 2 +- modules/shared_state.py | 12 ++++++------ modules/ui_toprow.py | 8 +++++++- scripts/loopback.py | 4 ++-- scripts/xyz_grid.py | 2 +- 8 files changed, 20 insertions(+), 14 deletions(-) (limited to 'scripts/xyz_grid.py') diff --git a/modules/call_queue.py b/modules/call_queue.py index 01c6d17f..bcd7c546 100644 --- a/modules/call_queue.py +++ b/modules/call_queue.py @@ -78,7 +78,7 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False): shared.state.skipped = False shared.state.interrupted = False - shared.state.interrupted_next = False + shared.state.stopping_generation = False shared.state.job_count = 0 if not add_stats: diff --git a/modules/img2img.py b/modules/img2img.py index 829faa81..e7e8e251 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -51,7 +51,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal if state.skipped: state.skipped = False - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break try: diff --git a/modules/processing.py b/modules/processing.py index 00de2ed2..f55b85ed 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -865,7 +865,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.skipped: state.skipped = False - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break sd_models.reload_model_weights() # model can be changed for example by refiner diff --git a/modules/shared_options.py b/modules/shared_options.py index 7852e0ea..7581e276 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -120,7 +120,6 @@ options_templates.update(options_section(('system', "System", "system"), { "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"), "hide_ldm_prints": OptionInfo(True, "Prevent Stability-AI's ldm/sgm modules from printing noise to console."), "dump_stacks_on_signal": OptionInfo(False, "Print stack traces before exiting the program with ctrl+c."), - "interrupt_after_current": OptionInfo(False, "Interrupt generation after current image is finished on batch processing"), })) options_templates.update(options_section(('API', "API", "system"), { @@ -286,6 +285,7 @@ options_templates.update(options_section(('ui_alternatives', "UI alternatives", "hires_fix_show_prompts": OptionInfo(False, "Hires fix: show hires prompt and negative prompt").needs_reload_ui(), "txt2img_settings_accordion": OptionInfo(False, "Settings in txt2img hidden under Accordion").needs_reload_ui(), "img2img_settings_accordion": OptionInfo(False, "Settings in img2img hidden under Accordion").needs_reload_ui(), + "interrupt_after_current": OptionInfo(True, "Don't Interrupt in the middle").info("when using Interrupt button, if generating more than one image, stop after the generation of an image has finished, instead of immediately"), })) options_templates.update(options_section(('ui', "User interface", "ui"), { diff --git a/modules/shared_state.py b/modules/shared_state.py index 532fdcd8..33996691 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -12,7 +12,7 @@ log = logging.getLogger(__name__) class State: skipped = False interrupted = False - interrupted_next = False + stopping_generation = False job = "" job_no = 0 job_count = 0 @@ -80,9 +80,9 @@ class State: self.interrupted = True log.info("Received interrupt request") - def interrupt_next(self): - self.interrupted_next = True - log.info("Received interrupt request, interrupt after current job") + def stop_generating(self): + self.stopping_generation = True + log.info("Received stop generating request") def nextjob(self): if shared.opts.live_previews_enable and shared.opts.show_progress_every_n_steps == -1: @@ -96,7 +96,7 @@ class State: obj = { "skipped": self.skipped, "interrupted": self.interrupted, - "interrupted_next": self.interrupted_next, + "stopping_generation": self.stopping_generation, "job": self.job, "job_count": self.job_count, "job_timestamp": self.job_timestamp, @@ -120,7 +120,7 @@ class State: self.id_live_preview = 0 self.skipped = False self.interrupted = False - self.interrupted_next = False + self.stopping_generation = False self.textinfo = None self.job = job devices.torch_gc() diff --git a/modules/ui_toprow.py b/modules/ui_toprow.py index 9caf8faa..1abc9117 100644 --- a/modules/ui_toprow.py +++ b/modules/ui_toprow.py @@ -106,8 +106,14 @@ class Toprow: outputs=[], ) + def interrupt_function(): + if shared.state.job_count > 1 and shared.opts.interrupt_after_current: + shared.state.stop_generating() + else: + shared.state.interrupt() + self.interrupt.click( - fn=lambda: shared.state.interrupt(), + fn=interrupt_function, inputs=[], outputs=[], ) diff --git a/scripts/loopback.py b/scripts/loopback.py index ad921269..800ee882 100644 --- a/scripts/loopback.py +++ b/scripts/loopback.py @@ -95,7 +95,7 @@ class Script(scripts.Script): processed = processing.process_images(p) # Generation cancelled. - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break if initial_seed is None: @@ -122,7 +122,7 @@ class Script(scripts.Script): p.inpainting_fill = original_inpainting_fill - if state.interrupted or state.interrupted_next: + if state.interrupted or state.stopping_generation: break if len(history) > 1: diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 2deff365..2f385ebf 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -696,7 +696,7 @@ class Script(scripts.Script): grid_infotext = [None] * (1 + len(zs)) def cell(x, y, z, ix, iy, iz): - if shared.state.interrupted or state.interrupted_next: + if shared.state.interrupted or state.stopping_generation: return Processed(p, [], p.seed, "") pc = copy(p) -- cgit v1.2.3 From 0466ee2a83680090e28de74881ccf306cb89667c Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Thu, 25 Jan 2024 05:45:45 +0900 Subject: xyz filter blank for number axes --- scripts/xyz_grid.py | 4 ++++ 1 file changed, 4 insertions(+) (limited to 'scripts/xyz_grid.py') diff --git a/scripts/xyz_grid.py b/scripts/xyz_grid.py index 2f385ebf..6d3e42c0 100644 --- a/scripts/xyz_grid.py +++ b/scripts/xyz_grid.py @@ -554,6 +554,8 @@ class Script(scripts.Script): valslist_ext = [] for val in valslist: + if val.strip() == '': + continue m = re_range.fullmatch(val) mc = re_range_count.fullmatch(val) if m is not None: @@ -576,6 +578,8 @@ class Script(scripts.Script): valslist_ext = [] for val in valslist: + if val.strip() == '': + continue m = re_range_float.fullmatch(val) mc = re_range_count_float.fullmatch(val) if m is not None: -- cgit v1.2.3