diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-14 19:14:50 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-14 19:14:50 +0000 |
commit | 326fe7d44ba7c813cd40166d15fdaa8e8eaf8be8 (patch) | |
tree | 6824dd89c15ea7d470d266f413707537a5e22090 /modules/hypernetworks | |
parent | 989a552de3d1fcd1f178fe873713b884e192dd61 (diff) | |
parent | 8636b50aea83f9c743f005722d9f3f8ee9303e00 (diff) | |
download | stable-diffusion-webui-gfx803-326fe7d44ba7c813cd40166d15fdaa8e8eaf8be8.tar.gz stable-diffusion-webui-gfx803-326fe7d44ba7c813cd40166d15fdaa8e8eaf8be8.tar.bz2 stable-diffusion-webui-gfx803-326fe7d44ba7c813cd40166d15fdaa8e8eaf8be8.zip |
Merge remote-tracking branch 'Melanpan/master'
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index e5cb1817..edb8cba1 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -5,6 +5,7 @@ import os import sys
import traceback
import tqdm
+import csv
import torch
@@ -262,6 +263,20 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
+ if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
+ write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
+
+ with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
+
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ csv_writer.writerow({"step": hypernetwork.step,
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate})
+
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|