From 07be13caa357b14f6afa247566d53339522b8e66 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 1 Aug 2023 08:27:54 +0300 Subject: add metadata to checkpoint merger --- modules/extras.py | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) (limited to 'modules/extras.py') diff --git a/modules/extras.py b/modules/extras.py index e9c0263e..2a310ae3 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ import json import torch import tqdm -from modules import shared, images, sd_models, sd_vae, sd_models_config +from modules import shared, images, sd_models, sd_vae, sd_models_config, errors from modules.ui_common import plaintext_to_html import gradio as gr import safetensors.torch @@ -72,7 +72,20 @@ def to_half(tensor, enable): return tensor -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata): +def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name): + metadata = {} + + for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]: + checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None) + if checkpoint_info is None: + continue + + metadata.update(checkpoint_info.metadata) + + return json.dumps(metadata, indent=4, ensure_ascii=False) + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json): shared.state.begin(job="model-merge") def fail(message): @@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = "Saving" print(f"Saving to {output_modelname}...") - metadata = None + metadata = {} + + if save_metadata and copy_metadata_fields: + if primary_model_info: + metadata.update(primary_model_info.metadata) + if secondary_model_info: + metadata.update(secondary_model_info.metadata) + if tertiary_model_info: + metadata.update(tertiary_model_info.metadata) if save_metadata: - metadata = {"format": "pt"} + try: + metadata.update(json.loads(metadata_json)) + except Exception as e: + errors.display(e, "readin metadata from json") + + metadata["format"] = "pt" + if save_metadata and add_merge_recipe: merge_recipe = { "type": "webui", # indicate this model was merged with webui's built-in merger "primary_model_hash": primary_model_info.sha256, @@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ "is_inpainting": result_is_inpainting_model, "is_instruct_pix2pix": result_is_instruct_pix2pix_model } - metadata["sd_merge_recipe"] = json.dumps(merge_recipe) sd_merge_models = {} @@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if tertiary_model_info: add_model_metadata(tertiary_model_info) + metadata["sd_merge_recipe"] = json.dumps(merge_recipe) metadata["sd_merge_models"] = json.dumps(sd_merge_models) _, extension = os.path.splitext(output_modelname) if extension.lower() == ".safetensors": - safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata) + safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None) else: torch.save(theta_0, output_modelname) -- cgit v1.2.3