aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--modules/textual_inversion/dataset.py7
-rw-r--r--modules/textual_inversion/preprocess.py5
-rw-r--r--modules/textual_inversion/textual_inversion.py9
3 files changed, 11 insertions, 10 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index bcf772d2..4d006366 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -35,9 +35,10 @@ class PersonalizedBase(Dataset):
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
- image = Image.open(path)
- image = image.convert('RGB')
- image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
+ try:
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ except Exception:
+ continue
filename = os.path.basename(path)
filename_tokens = os.path.splitext(filename)[0]
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index d7efdef2..1a672725 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -46,7 +46,10 @@ def preprocess(process_src, process_dst, process_width, process_height, process_
for index, imagefile in enumerate(tqdm.tqdm(files)):
subindex = [0]
filename = os.path.join(src, imagefile)
- img = Image.open(filename).convert("RGB")
+ try:
+ img = Image.open(filename).convert("RGB")
+ except Exception:
+ continue
if shared.state.interrupted:
break
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index d6977950..bb05cdc6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -200,9 +200,6 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
- tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
- epoch_len = (tr_img_len * num_repeats) + tr_img_len
-
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
for i, (x, text) in pbar:
embedding.step = i + ititial_step
@@ -226,10 +223,10 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
loss.backward()
optimizer.step()
- epoch_num = embedding.step // epoch_len
- epoch_step = embedding.step - (epoch_num * epoch_len) + 1
+ epoch_num = embedding.step // len(ds)
+ epoch_step = embedding.step - (epoch_num * len(ds)) + 1
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
+ pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')