aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-17 10:49:36 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-09-17 10:49:36 +0000
commit304222ef94d1c3c60fab466a96c448868f391bce (patch)
tree7b5aa24ac29865f3c4c14eb755ef3f5d6d7421c3
parent99585b3514e2d7e987651d5c6a0806f933af012b (diff)
downloadstable-diffusion-webui-gfx803-304222ef94d1c3c60fab466a96c448868f391bce.tar.gz
stable-diffusion-webui-gfx803-304222ef94d1c3c60fab466a96c448868f391bce.tar.bz2
stable-diffusion-webui-gfx803-304222ef94d1c3c60fab466a96c448868f391bce.zip
X/Y plot support for switching checkpoints.
-rw-r--r--modules/sd_models.py4
-rw-r--r--script.js2
-rw-r--r--scripts/xy_grid.py15
3 files changed, 19 insertions, 2 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 036af0e4..4bd70fc5 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -127,9 +127,9 @@ def load_model():
return sd_model
-def reload_model_weights(sd_model):
+def reload_model_weights(sd_model, info=None):
from modules import lowvram, devices
- checkpoint_info = select_checkpoint()
+ checkpoint_info = info or select_checkpoint()
if sd_model.sd_model_checkpint == checkpoint_info.filename:
return
diff --git a/script.js b/script.js
index 4a70e51d..e63e0695 100644
--- a/script.js
+++ b/script.js
@@ -66,6 +66,8 @@ titles = {
"Style 2": "Style to apply; styles have components for both positive and negative prompts and apply to both",
"Apply style": "Insert selected styles into prompt fields",
"Create style": "Save current prompts as a style. If you add the token {prompt} to the text, the style use that as placeholder for your prompt when you use the style in the future.",
+
+ "Checkpoint name": "Loads weights from checkpoint before making images. You can either use hash or a part of filename (as seen in settings) for checkpoint name. Recommended to use with Y axis for less switching.",
}
function gradioApp(){
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py
index eccfda87..680dd702 100644
--- a/scripts/xy_grid.py
+++ b/scripts/xy_grid.py
@@ -10,7 +10,9 @@ import gradio as gr
from modules import images
from modules.processing import process_images, Processed
from modules.shared import opts, cmd_opts, state
+import modules.shared as shared
import modules.sd_samplers
+import modules.sd_models
import re
@@ -41,6 +43,15 @@ def apply_sampler(p, x, xs):
p.sampler_index = sampler_index
+def apply_checkpoint(p, x, xs):
+ applicable = [info for info in modules.sd_models.checkpoints_list.values() if x in info.title]
+ assert len(applicable) > 0, f'Checkpoint {x} for found'
+
+ info = applicable[0]
+
+ modules.sd_models.reload_model_weights(shared.sd_model, info)
+
+
def format_value_add_label(p, opt, x):
if type(x) == float:
x = round(x, 8)
@@ -74,6 +85,7 @@ axis_options = [
AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label),
AxisOption("Prompt S/R", str, apply_prompt, format_value),
AxisOption("Sampler", str, apply_sampler, format_value),
+ AxisOption("Checkpoint name", str, apply_checkpoint, format_value),
AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label), # as it is now all AxisOptionImg2Img items must go after AxisOption ones
]
@@ -215,4 +227,7 @@ class Script(scripts.Script):
if opts.grid_save:
images.save_image(processed.images[0], p.outpath_grids, "xy_grid", prompt=p.prompt, seed=processed.seed, grid=True, p=p)
+ # restore checkpoint in case it was changed by axes
+ modules.sd_models.reload_model_weights(shared.sd_model)
+
return processed