diff options
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r-- | modules/textual_inversion/dataset.py | 2 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 6 |
2 files changed, 5 insertions, 3 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 23bb4b6a..68ceffe3 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -49,7 +49,7 @@ class PersonalizedBase(Dataset): print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
try:
- image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
+ image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.Resampling.BICUBIC)
except Exception:
continue
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 2ed345b1..d2a389c9 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -270,9 +270,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc losses[embedding.step % losses.shape[0]] = loss.item()
- optimizer.zero_grad()
loss.backward()
- optimizer.step()
+ if ((i + 1) % gradient_accumulation == 0) or (i + 1 == steps - ititial_step):
+ optimizer.step()
+ optimizer.zero_grad()
+
epoch_num = embedding.step // len(ds)
epoch_step = embedding.step - (epoch_num * len(ds)) + 1
|