diff options
author | random_thoughtss <random_thoughtss@proton.me> | 2022-10-20 23:01:27 +0000 |
---|---|---|
committer | random_thoughtss <random_thoughtss@proton.me> | 2022-10-20 23:01:27 +0000 |
commit | 49533eed9e3aad19e9868ee140708baec4fd44be (patch) | |
tree | ed3bc901e9e63da9e6ff005c0af937d331c03c83 | |
parent | 708c3a7bd8ce68cbe1aa7c268e5a4b1980affc9f (diff) | |
download | stable-diffusion-webui-gfx803-49533eed9e3aad19e9868ee140708baec4fd44be.tar.gz stable-diffusion-webui-gfx803-49533eed9e3aad19e9868ee140708baec4fd44be.tar.bz2 stable-diffusion-webui-gfx803-49533eed9e3aad19e9868ee140708baec4fd44be.zip |
XY grid correctly re-assignes model when config changes
-rw-r--r-- | modules/sd_models.py | 6 | ||||
-rw-r--r-- | scripts/xy_grid.py | 1 |
2 files changed, 4 insertions, 3 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072db08..fea84630 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -204,9 +204,9 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info
-def load_model():
+def load_model(checkpoint_info=None):
from modules import lowvram, sd_hijack
- checkpoint_info = select_checkpoint()
+ checkpoint_info = checkpoint_info or select_checkpoint()
if checkpoint_info.config != shared.cmd_opts.config:
print(f"Loading config from: {checkpoint_info.config}")
@@ -249,7 +249,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
checkpoints_loaded.clear()
- shared.sd_model = load_model()
+ shared.sd_model = load_model(checkpoint_info)
return shared.sd_model
if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 5cca168a..eff0c942 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs): if info is None:
raise RuntimeError(f"Unknown checkpoint: {x}")
modules.sd_models.reload_model_weights(shared.sd_model, info)
+ p.sd_model = shared.sd_model
def confirm_checkpoints(p, xs):
|