From 542a3d3a4a00c1383fbdaf938ceefef87cf834bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 14:33:22 +0300 Subject: fix btoken hypernetworks in XY plot --- scripts/xy_grid.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) (limited to 'scripts') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index a8f53bef..fe949067 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -10,7 +10,7 @@ import numpy as np import modules.scripts as scripts import gradio as gr -from modules import images +from modules import images, hypernetwork from modules.processing import process_images, Processed, get_correct_sampler from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -80,8 +80,7 @@ def apply_checkpoint(p, x, xs): def apply_hypernetwork(p, x, xs): - hn = shared.hypernetworks.get(x, None) - opts.data["sd_hypernetwork"] = hn.name if hn is not None else 'None' + hypernetwork.load_hypernetwork(x) def format_value_add_label(p, opt, x): @@ -203,8 +202,6 @@ class Script(scripts.Script): p.batch_size = 1 - initial_hn = opts.sd_hypernetwork - def process_axis(opt, vals): if opt.label == 'Nothing': return [0] @@ -321,6 +318,6 @@ class Script(scripts.Script): # restore checkpoint in case it was changed by axes modules.sd_models.reload_model_weights(shared.sd_model) - opts.data["sd_hypernetwork"] = initial_hn + hypernetwork.load_hypernetwork(opts.sd_hypernetwork) return processed -- cgit v1.2.3 From 2c52f4da7ff80a3ec277105f4db6146c6379898a Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 9 Oct 2022 15:01:42 +0300 Subject: fix broken samplers in XY plot --- scripts/xy_grid.py | 1 + 1 file changed, 1 insertion(+) (limited to 'scripts') diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index fe949067..c89ca1a9 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -259,6 +259,7 @@ class Script(scripts.Script): # Confirm options are valid before starting if opt.label == "Sampler": + samplers_dict = build_samplers_dict(p) for sampler_val in valslist: if sampler_val.lower() not in samplers_dict.keys(): raise RuntimeError(f"Unknown sampler: {sampler_val}") -- cgit v1.2.3