From 304222ef94d1c3c60fab466a96c448868f391bce Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 17 Sep 2022 13:49:36 +0300 Subject: X/Y plot support for switching checkpoints. --- scripts/xy_grid.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) (limited to 'scripts/xy_grid.py') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index eccfda87..680dd702 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,9 @@ import gradio as gr from modules import images from modules.processing import process_images, Processed from modules.shared import opts, cmd_opts, state +import modules.shared as shared import modules.sd_samplers +import modules.sd_models import re @@ -41,6 +43,15 @@ def apply_sampler(p, x, xs): p.sampler_index = sampler_index +def apply_checkpoint(p, x, xs): + applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title] + assert len(applicable) > 0, f'Checkpoint {x} for found' + + info = applicable[0] + + modules.sd_models.reload_model_weights(shared.sd_model, info) + + def format_value_add_label(p, opt, x): if type(x) == float: x = round(x, 8) @@ -74,6 +85,7 @@ axis_options = [ AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label), AxisOption("Prompt S/R", str, apply_prompt, format_value), AxisOption("Sampler", str, apply_sampler, format_value), + AxisOption("Checkpoint name", str, apply_checkpoint, format_value), AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones ] @@ -215,4 +227,7 @@ class Script(scripts.Script): if opts.grid_save: images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p) + # restore checkpoint in case it was changed by axes + modules.sd_models.reload_model_weights(shared.sd_model) + return processed -- cgit v1.2.3