aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorMelan <alexleander91@gmail.com>2022-10-12 21:36:29 +0000
committerMelan <alexleander91@gmail.com>2022-10-12 21:36:29 +0000
commit1cfc2a18981ee56bdb69a2de7b463a11ad05e329 (patch)
tree73129d944e10f46bc7181a0dfe6e0cbee19170f3 /modules/textual_inversion/textual_inversion.py
parent698d303b04e293635bfb49c525409f3bcf671dce (diff)
downloadstable-diffusion-webui-gfx803-1cfc2a18981ee56bdb69a2de7b463a11ad05e329.tar.gz
stable-diffusion-webui-gfx803-1cfc2a18981ee56bdb69a2de7b463a11ad05e329.tar.bz2
stable-diffusion-webui-gfx803-1cfc2a18981ee56bdb69a2de7b463a11ad05e329.zip
Save a csv containing the loss while training
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index fa0e33a2..25038a89 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -6,6 +6,7 @@ import torch
import tqdm
import html
import datetime
+import csv
from PIL import Image, PngImagePlugin
@@ -172,7 +173,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, write_csv_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -256,6 +257,20 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
+ if write_csv_every > 0 and log_directory is not None and embedding.step % write_csv_every == 0:
+ write_csv_header = False if os.path.exists(os.path.join(log_directory, "textual_inversion_loss.csv")) else True
+
+ with open(os.path.join(log_directory, "textual_inversion_loss.csv"), "a+") as fout:
+
+ csv_writer = csv.DictWriter(fout, fieldnames=["epoch", "epoch_step", "loss"])
+
+ if write_csv_header:
+ csv_writer.writeheader()
+
+ csv_writer.writerow({"epoch": epoch_num + 1,
+ "epoch_step": epoch_step - 1,
+ "loss": f"{losses.mean():.7f}"})
+
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')