diff options
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index d6977950..bb05cdc6 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -200,9 +200,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if ititial_step > steps:
return embedding, filename
- tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
- epoch_len = (tr_img_len * num_repeats) + tr_img_len
-
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
@@ -226,10 +223,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini loss.backward()
optimizer.step()
- epoch_num = embedding.step // epoch_len
- epoch_step = embedding.step - (epoch_num * epoch_len) + 1
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step - (epoch_num * len(ds)) + 1
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|