diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-14 19:43:55 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-14 19:43:55 +0000 |
commit | 03d62538aebeff51713619fe808c953bdb70193d (patch) | |
tree | 1c01da8d5ec04779838ccecce04f217e3fcbef92 /modules/textual_inversion/textual_inversion.py | |
parent | 326fe7d44ba7c813cd40166d15fdaa8e8eaf8be8 (diff) | |
download | stable-diffusion-webui-gfx803-03d62538aebeff51713619fe808c953bdb70193d.tar.gz stable-diffusion-webui-gfx803-03d62538aebeff51713619fe808c953bdb70193d.tar.bz2 stable-diffusion-webui-gfx803-03d62538aebeff51713619fe808c953bdb70193d.zip |
remove duplicate code for log loss, add step, make it read from options rather than gradio input
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 44 |
1 files changed, 30 insertions, 14 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 1f5ace6f..da0d77a0 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -173,6 +173,32 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn
+def write_loss(log_directory, filename, step, epoch_len, values):
+ if shared.opts.training_write_csv_every == 0:
+ return
+
+ if step % shared.opts.training_write_csv_every != 0:
+ return
+
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
+
+ with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
+ csv_writer = csv.DictWriter(fout, fieldnames=["step", "epoch", "epoch_step", *(values.keys())])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ epoch = step // epoch_len
+ epoch_step = step - epoch * epoch_len
+
+ csv_writer.writerow({
+ "step": step + 1,
+ "epoch": epoch + 1,
+ "epoch_step": epoch_step + 1,
+ **values,
+ })
+
+
def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
@@ -257,20 +283,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
- if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0:
- write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True
-
- with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
-
- csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss", "learn_rate"])
-
- if write_csv_header:
- csv_writer.writeheader()
-
- csv_writer.writerow({"epoch": epoch_num + 1,
- "epoch_step": epoch_step - 1,
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate})
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|