diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-19 07:39:51 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-19 07:39:51 +0000 |
commit | 0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135 (patch) | |
tree | 0e81a16c42f716c704d6aa63458f7c3c1894c56e /modules/extras.py | |
parent | c7e50425f63c07242068f8dcccce70a4ef28a17f (diff) | |
download | stable-diffusion-webui-gfx803-0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135.tar.gz stable-diffusion-webui-gfx803-0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135.tar.bz2 stable-diffusion-webui-gfx803-0f5dbfffd0b7202a48e404d8e74b5cc9a3e5b135.zip |
allow baking in VAE in checkpoint merger tab
do not save config if it's the default for checkpoint merger tab
change file naming scheme for checkpoint merger tab
allow just saving A without any merging for checkpoint merger tab
some stylistic changes for UI in checkpoint merger tab
Diffstat (limited to 'modules/extras.py')
-rw-r--r-- | modules/extras.py | 112 |
1 files changed, 68 insertions, 44 deletions
diff --git a/modules/extras.py b/modules/extras.py index 034f28e4..fe701a0e 100644 --- a/modules/extras.py +++ b/modules/extras.py @@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple from functools import partial
from dataclasses import dataclass
-from modules import processing, shared, images, devices, sd_models, sd_samplers
+from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
@@ -251,7 +251,8 @@ def run_pnginfo(image): def create_config(ckpt_result, config_source, a, b, c):
def config(x):
- return sd_models.find_checkpoint_config(x) if x else None
+ res = sd_models.find_checkpoint_config(x) if x else None
+ return res if res != shared.sd_default_config else None
if config_source == 0:
cfg = config(a) or config(b) or config(c)
@@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c): shutil.copyfile(cfg, checkpoint_filename)
-def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
+chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
+
+
+def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
shared.state.begin()
shared.state.job = 'model-merge'
- shared.state.job_count = 1
def fail(message):
shared.state.textinfo = message
@@ -293,41 +296,68 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ def add_difference(theta0, theta1_2_diff, alpha):
return theta0 + (alpha * theta1_2_diff)
+ def filename_weighed_sum():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ Ma = round(1 - multiplier, 2)
+ Mb = round(multiplier, 2)
+
+ return f"{Ma}({a}) + {Mb}({b})"
+
+ def filename_add_differnece():
+ a = primary_model_info.model_name
+ b = secondary_model_info.model_name
+ c = tertiary_model_info.model_name
+ M = round(multiplier, 2)
+
+ return f"{a} + {M}({b} - {c})"
+
+ def filename_nothing():
+ return primary_model_info.model_name
+
+ theta_funcs = {
+ "Weighted sum": (filename_weighed_sum, None, weighted_sum),
+ "Add difference": (filename_add_differnece, get_difference, add_difference),
+ "No interpolation": (filename_nothing, None, None),
+ }
+ filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
+ shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)
+
if not primary_model_name:
return fail("Failed: Merging requires a primary model.")
primary_model_info = sd_models.checkpoints_list[primary_model_name]
- if not secondary_model_name:
+ if theta_func2 and not secondary_model_name:
return fail("Failed: Merging requires a secondary model.")
-
- secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
- theta_funcs = {
- "Weighted sum": (None, weighted_sum),
- "Add difference": (get_difference, add_difference),
- }
- theta_func1, theta_func2 = theta_funcs[interp_method]
+ secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None
if theta_func1 and not tertiary_model_name:
return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
-
+
tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None
result_is_inpainting_model = False
- shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
- print(f"Loading {secondary_model_info.filename}...")
- theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ if theta_func2:
+ shared.state.textinfo = f"Loading B"
+ print(f"Loading {secondary_model_info.filename}...")
+ theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
+ else:
+ theta_1 = None
if theta_func1:
- shared.state.job_count += 1
-
+ shared.state.textinfo = f"Loading C"
print(f"Loading {tertiary_model_info.filename}...")
theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')
+ shared.state.textinfo = 'Merging B and C'
shared.state.sampling_steps = len(theta_1.keys())
for key in tqdm.tqdm(theta_1.keys()):
+ if key in chckpoint_dict_skip_on_merge:
+ continue
+
if 'model' in key:
if key in theta_2:
t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
@@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
print("Merging...")
-
- chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]
-
+ shared.state.textinfo = 'Merging A and B'
shared.state.sampling_steps = len(theta_0.keys())
for key in tqdm.tqdm(theta_0.keys()):
- if 'model' in key and key in theta_1:
+ if theta_1 and 'model' in key and key in theta_1:
if key in chckpoint_dict_skip_on_merge:
continue
@@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ a = theta_0[key]
b = theta_1[key]
- shared.state.textinfo = f'Merging layer {key}'
# this enables merging an inpainting model (A) with another one (B);
# where normal model would have 4 channels, for latenst space, inpainting model would
# have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ shared.state.sampling_step += 1
- # I believe this part should be discarded, but I'll leave it for now until I am sure
- for key in theta_1.keys():
- if 'model' in key and key not in theta_0:
+ del theta_1
+
+ bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
+ if bake_in_vae_filename is not None:
+ print(f"Baking in VAE from {bake_in_vae_filename}")
+ shared.state.textinfo = 'Baking in VAE'
+ vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')
- if key in chckpoint_dict_skip_on_merge:
- continue
+ for key in vae_dict.keys():
+ theta_0_key = 'first_stage_model.' + key
+ if theta_0_key in theta_0:
+ theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key]
- theta_0[key] = theta_1[key]
- if save_as_half:
- theta_0[key] = theta_0[key].half()
- del theta_1
+ del vae_dict
ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
- filename = \
- primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
- secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
- interp_method.replace(" ", "_") + \
- '-merged.' + \
- ("inpainting." if result_is_inpainting_model else "") + \
- checkpoint_format
-
- filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
+ filename = filename_generator() if custom_name == '' else custom_name
+ filename += ".inpainting" if result_is_inpainting_model else ""
+ filename += "." + checkpoint_format
output_modelname = os.path.join(ckpt_dir, filename)
shared.state.nextjob()
- shared.state.textinfo = f"Saving to {output_modelname}..."
+ shared.state.textinfo = "Saving"
print(f"Saving to {output_modelname}...")
_, extension = os.path.splitext(output_modelname)
@@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_ create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)
- print("Checkpoint saved.")
- shared.state.textinfo = "Checkpoint saved to " + output_modelname
+ print(f"Checkpoint saved to {output_modelname}.")
+ shared.state.textinfo = "Checkpoint saved"
shared.state.end()
return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
|