diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-04 05:05:21 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-04 05:05:21 +0000 |
commit | 56c3f94ba30d76bf680db9bc765624b9a143d769 (patch) | |
tree | 6924fc635fe13fe59911a285313308be2d9acdc7 /modules/extras.py | |
parent | 952effa8b10dba2f2f7f2cf4191f987e3666e9f0 (diff) | |
parent | 073c0ebba380acbd73be8262feba41212165ff84 (diff) | |
download | stable-diffusion-webui-gfx803-56c3f94ba30d76bf680db9bc765624b9a143d769.tar.gz stable-diffusion-webui-gfx803-56c3f94ba30d76bf680db9bc765624b9a143d769.tar.bz2 stable-diffusion-webui-gfx803-56c3f94ba30d76bf680db9bc765624b9a143d769.zip |
Merge branch 'dev' into dev
Diffstat (limited to 'modules/extras.py')
-rw-r--r-- | modules/extras.py | 39 |
1 files changed, 33 insertions, 6 deletions
diff --git a/modules/extras.py b/modules/extras.py index e9c0263e..2a310ae3 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -7,7 +7,7 @@ import json import torch
import tqdm
-from modules import shared, images, sd_models, sd_vae, sd_models_config
+from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
from modules.ui_common import plaintext_to_html
import gradio as gr
import safetensors.torch
@@ -72,7 +72,20 @@ def to_half(tensor, enable): return tensor
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
+def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
+ metadata = {}
+
+ for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
+ checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
+ if checkpoint_info is None:
+ continue
+
+ metadata.update(checkpoint_info.metadata)
+
+ return json.dumps(metadata, indent=4, ensure_ascii=False)
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
shared.state.begin(job="model-merge")
def fail(message):
@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
- metadata = None
+ metadata = {}
+
+ if save_metadata and copy_metadata_fields:
+ if primary_model_info:
+ metadata.update(primary_model_info.metadata)
+ if secondary_model_info:
+ metadata.update(secondary_model_info.metadata)
+ if tertiary_model_info:
+ metadata.update(tertiary_model_info.metadata)
if save_metadata:
- metadata = {"format": "pt"}
+ try:
+ metadata.update(json.loads(metadata_json))
+ except Exception as e:
+ errors.display(e, "readin metadata from json")
+
+ metadata["format"] = "pt"
+ if save_metadata and add_merge_recipe:
merge_recipe = {
"type": "webui", # indicate this model was merged with webui's built-in merger
"primary_model_hash": primary_model_info.sha256,
@@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ "is_inpainting": result_is_inpainting_model,
"is_instruct_pix2pix": result_is_instruct_pix2pix_model
}
- metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
sd_merge_models = {}
@@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ if tertiary_model_info:
add_model_metadata(tertiary_model_info)
+ metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
metadata["sd_merge_models"] = json.dumps(sd_merge_models)
_, extension = os.path.splitext(output_modelname)
if extension.lower() == ".safetensors":
- safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
+ safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
else:
torch.save(theta_0, output_modelname)
|