diff options
Diffstat (limited to 'modules/extras.py')
-rw-r--r-- | modules/extras.py | 30 |
1 files changed, 29 insertions, 1 deletions
diff --git a/modules/extras.py b/modules/extras.py index 7407bfe3..a03d558e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -3,6 +3,7 @@ import math import os
import sys
import traceback
+import shutil
import numpy as np
from PIL import Image
@@ -248,7 +249,32 @@ def run_pnginfo(image): return '', geninfo, info
-def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
+def create_config(ckpt_result, config_source, a, b, c):
+ def config(x):
+ return sd_models.find_checkpoint_config(x) if x else None
+
+ if config_source == 0:
+ cfg = config(a) or config(b) or config(c)
+ elif config_source == 1:
+ cfg = config(b)
+ elif config_source == 2:
+ cfg = config(c)
+ else:
+ cfg = None
+
+ if cfg is None:
+ return
+
+ filename, _ = os.path.splitext(ckpt_result)
+ checkpoint_filename = filename + ".yaml"
+
+ print("Copying config:")
+ print(" from:", cfg)
+ print(" to:", checkpoint_filename)
+ shutil.copyfile(cfg, checkpoint_filename)
+
+
+def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
shared.state.begin()
shared.state.job = 'model-merge'
@@ -356,6 +382,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam sd_models.list_models()
+ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
+
print("Checkpoint saved.")
shared.state.textinfo = "Checkpoint saved to " + output_modelname
shared.state.end()
|