diff options
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 769682ea..5965c5a0 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -6,7 +6,6 @@ import torch import tqdm
import html
import datetime
-import math
from modules import shared, devices, sd_hijack, processing, sd_models
@@ -157,7 +156,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'): return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_size, steps, num_repeats, create_image_every, save_embedding_every, template_file):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -183,7 +182,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=training_size, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -227,7 +226,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini loss.backward()
optimizer.step()
- epoch_num = math.floor(embedding.step / epoch_len)
+ epoch_num = embedding.step // epoch_len
epoch_step = embedding.step - (epoch_num * epoch_len) + 1
pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{epoch_len}]loss: {losses.mean():.7f}")
@@ -243,8 +242,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini sd_model=shared.sd_model,
prompt=text,
steps=20,
- height=training_size,
- width=training_size,
+ height=training_height,
+ width=training_width,
do_not_save_grid=True,
do_not_save_samples=True,
)
|