diff options
author | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-10-29 17:49:29 +0000 |
---|---|---|
committer | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-10-29 17:49:29 +0000 |
commit | a07f054c86f33360ff620d6a3fffdee366ab2d99 (patch) | |
tree | f162e4a3621b30d641848cb8643475eaf7e6e146 /modules/hypernetworks | |
parent | ab05a74ead9fabb45dd099990e34061c7eb02ca3 (diff) | |
download | stable-diffusion-webui-gfx803-a07f054c86f33360ff620d6a3fffdee366ab2d99.tar.gz stable-diffusion-webui-gfx803-a07f054c86f33360ff620d6a3fffdee366ab2d99.tar.bz2 stable-diffusion-webui-gfx803-a07f054c86f33360ff620d6a3fffdee366ab2d99.zip |
Add missing info on hypernetwork/embedding model log
Mentioned here: https://github.com/AUTOMATIC1111/stable-diffusion-webui/discussions/1528#discussioncomment-3991513
Also group the saving into one
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 31 |
1 files changed, 21 insertions, 10 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 38f35c58..86daf825 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -361,6 +361,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log images_dir = None
hypernetwork = shared.loaded_hypernetwork
+ checkpoint = sd_models.select_checkpoint()
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
@@ -449,9 +450,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
- last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
- hypernetwork.save(last_saved_file)
+ hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
+ last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
"loss": f"{previous_mean_loss:.7f}",
@@ -512,13 +513,23 @@ Last saved image: {html.escape(last_saved_image)}<br/> """
report_statistics(loss_dict)
- checkpoint = sd_models.select_checkpoint()
- hypernetwork.sd_checkpoint = checkpoint.hash
- hypernetwork.sd_checkpoint_name = checkpoint.model_name
- # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
- hypernetwork.name = hypernetwork_name
- filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
- hypernetwork.save(filename)
+ filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
+ save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
return hypernetwork, filename
+
+def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
+ old_hypernetwork_name = hypernetwork.name
+ old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
+ old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
+ try:
+ hypernetwork.sd_checkpoint = checkpoint.hash
+ hypernetwork.sd_checkpoint_name = checkpoint.model_name
+ hypernetwork.name = hypernetwork_name
+ hypernetwork.save(filename)
+ except:
+ hypernetwork.sd_checkpoint = old_sd_checkpoint
+ hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
+ hypernetwork.name = old_hypernetwork_name
+ raise
|