From 0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Thu, 19 Jan 2023 10:39:51 +0300 Subject: allow baking in VAE in checkpoint merger tab do not save config if it's the default for checkpoint merger tab change file naming scheme for checkpoint merger tab allow just saving A without any merging for checkpoint merger tab some stylistic changes for UI in checkpoint merger tab --- modules/extras.py | 112 +++++++++++++++++++++++++++++++++--------------------- 1 file changed, 68 insertions(+), 44 deletions(-) (limited to 'modules/extras.py') diff --git a/modules/extras.py b/modules/extras.py index 034f28e4..fe701a0e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple from functools import partial from dataclasses import dataclass -from modules import processing, shared, images, devices, sd_models, sd_samplers +from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae from modules.shared import opts import modules.gfpgan_model from modules.ui import plaintext_to_html @@ -251,7 +251,8 @@ def run_pnginfo(image): def create_config(ckpt_result, config_source, a, b, c): def config(x): - return sd_models.find_checkpoint_config(x) if x else None + res = sd_models.find_checkpoint_config(x) if x else None + return res if res != shared.sd_default_config else None if config_source == 0: cfg = config(a) or config(b) or config(c) @@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename) -def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source): +chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] + + +def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae): shared.state.begin() shared.state.job = 'model-merge' - shared.state.job_count = 1 def fail(message): shared.state.textinfo = message @@ -293,41 +296,68 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ def add_difference(theta0, theta1_2_diff, alpha): return theta0 + (alpha * theta1_2_diff) + def filename_weighed_sum(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + Ma = round(1 - multiplier, 2) + Mb = round(multiplier, 2) + + return f"{Ma}({a}) + {Mb}({b})" + + def filename_add_differnece(): + a = primary_model_info.model_name + b = secondary_model_info.model_name + c = tertiary_model_info.model_name + M = round(multiplier, 2) + + return f"{a} + {M}({b} - {c})" + + def filename_nothing(): + return primary_model_info.model_name + + theta_funcs = { + "Weighted sum": (filename_weighed_sum, None, weighted_sum), + "Add difference": (filename_add_differnece, get_difference, add_difference), + "No interpolation": (filename_nothing, None, None), + } + filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method] + shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0) + if not primary_model_name: return fail("Failed: Merging requires a primary model.") primary_model_info = sd_models.checkpoints_list[primary_model_name] - if not secondary_model_name: + if theta_func2 and not secondary_model_name: return fail("Failed: Merging requires a secondary model.") - - secondary_model_info = sd_models.checkpoints_list[secondary_model_name] - theta_funcs = { - "Weighted sum": (None, weighted_sum), - "Add difference": (get_difference, add_difference), - } - theta_func1, theta_func2 = theta_funcs[interp_method] + secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None if theta_func1 and not tertiary_model_name: return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.") - + tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None result_is_inpainting_model = False - shared.state.textinfo = f"Loading {secondary_model_info.filename}..." - print(f"Loading {secondary_model_info.filename}...") - theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + if theta_func2: + shared.state.textinfo = f"Loading B" + print(f"Loading {secondary_model_info.filename}...") + theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu') + else: + theta_1 = None if theta_func1: - shared.state.job_count += 1 - + shared.state.textinfo = f"Loading C" print(f"Loading {tertiary_model_info.filename}...") theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu') + shared.state.textinfo = 'Merging B and C' shared.state.sampling_steps = len(theta_1.keys()) for key in tqdm.tqdm(theta_1.keys()): + if key in chckpoint_dict_skip_on_merge: + continue + if 'model' in key: if key in theta_2: t2 = theta_2.get(key, torch.zeros_like(theta_1[key])) @@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu') print("Merging...") - - chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"] - + shared.state.textinfo = 'Merging A and B' shared.state.sampling_steps = len(theta_0.keys()) for key in tqdm.tqdm(theta_0.keys()): - if 'model' in key and key in theta_1: + if theta_1 and 'model' in key and key in theta_1: if key in chckpoint_dict_skip_on_merge: continue @@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ a = theta_0[key] b = theta_1[key] - shared.state.textinfo = f'Merging layer {key}' # this enables merging an inpainting model (A) with another one (B); # where normal model would have 4 channels, for latenst space, inpainting model would # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9 @@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.sampling_step += 1 - # I believe this part should be discarded, but I'll leave it for now until I am sure - for key in theta_1.keys(): - if 'model' in key and key not in theta_0: + del theta_1 + + bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None) + if bake_in_vae_filename is not None: + print(f"Baking in VAE from {bake_in_vae_filename}") + shared.state.textinfo = 'Baking in VAE' + vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu') - if key in chckpoint_dict_skip_on_merge: - continue + for key in vae_dict.keys(): + theta_0_key = 'first_stage_model.' + key + if theta_0_key in theta_0: + theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key] - theta_0[key] = theta_1[key] - if save_as_half: - theta_0[key] = theta_0[key].half() - del theta_1 + del vae_dict ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path - filename = \ - primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \ - secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \ - interp_method.replace(" ", "_") + \ - '-merged.' + \ - ("inpainting." if result_is_inpainting_model else "") + \ - checkpoint_format - - filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format) + filename = filename_generator() if custom_name == '' else custom_name + filename += ".inpainting" if result_is_inpainting_model else "" + filename += "." + checkpoint_format output_modelname = os.path.join(ckpt_dir, filename) shared.state.nextjob() - shared.state.textinfo = f"Saving to {output_modelname}..." + shared.state.textinfo = "Saving" print(f"Saving to {output_modelname}...") _, extension = os.path.splitext(output_modelname) @@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info) - print("Checkpoint saved.") - shared.state.textinfo = "Checkpoint saved to " + output_modelname + print(f"Checkpoint saved to {output_modelname}.") + shared.state.textinfo = "Checkpoint saved" shared.state.end() return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname] -- cgit v1.2.3