From a07f054c86f33360ff620d6a3fffdee366ab2d99 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 00:49:29 +0700 Subject: Add missing info on hypernetwork/embedding model log Mentioned here: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/1528#discussioncomment-3991513 Also group the saving into one --- modules/hypernetworks/hypernetwork.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) (limited to 'modules/hypernetworks') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 38f35c58..86daf825 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -361,6 +361,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log images_dir = None hypernetwork = shared.loaded_hypernetwork + checkpoint = sd_models.select_checkpoint() ititial_step = hypernetwork.step or 0 if ititial_step > steps: @@ -449,9 +450,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0: # Before saving, change name to match current checkpoint. - hypernetwork.name = f'{hypernetwork_name}-{steps_done}' - last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(last_saved_file) + hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}' + last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt') + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file) textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), { "loss": f"{previous_mean_loss:.7f}", @@ -512,13 +513,23 @@ Last saved image: {html.escape(last_saved_image)}
""" report_statistics(loss_dict) - checkpoint = sd_models.select_checkpoint() - hypernetwork.sd_checkpoint = checkpoint.hash - hypernetwork.sd_checkpoint_name = checkpoint.model_name - # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention). - hypernetwork.name = hypernetwork_name - filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt') - hypernetwork.save(filename) + filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt') + save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename) return hypernetwork, filename + +def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename): + old_hypernetwork_name = hypernetwork.name + old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None + old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None + try: + hypernetwork.sd_checkpoint = checkpoint.hash + hypernetwork.sd_checkpoint_name = checkpoint.model_name + hypernetwork.name = hypernetwork_name + hypernetwork.save(filename) + except: + hypernetwork.sd_checkpoint = old_sd_checkpoint + hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name + hypernetwork.name = old_hypernetwork_name + raise -- cgit v1.2.3