aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 17dfb223..f272e536 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -205,7 +205,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
})
-def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -254,6 +254,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
ititial_step = embedding.step or 0
if ititial_step > steps:
return embedding, filename
+
+ clip_grad_mode_value = clip_grad_mode == "value"
+ clip_grad_mode_norm = clip_grad_mode == "norm"
+ clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
+ if clip_grad_enabled:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
@@ -269,6 +275,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if shared.state.interrupted:
break
+ if clip_grad_enabled:
+ clip_grad_sched.step(embedding.step)
+
with torch.autocast("cuda"):
c = cond_model([entry.cond_text for entry in entries])
x = torch.stack([entry.latent for entry in entries]).to(devices.device)
@@ -279,6 +288,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
optimizer.zero_grad()
loss.backward()
+
+ if clip_grad_mode_value:
+ torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate)
+ elif clip_grad_mode_norm:
+ torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate)
+
optimizer.step()
steps_done = embedding.step + 1