diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-04 16:57:02 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-04 16:57:02 +0000 |
commit | 9092e1ca7756642692197316d3dec47c23322381 (patch) | |
tree | 22f5d5e7417f24599a415fd64c9f1652495ce5a3 /modules/textual_inversion | |
parent | b7deea47eeb033052062621b0005d4321b53bff7 (diff) | |
parent | eeb1de4388773ba92b9920a4f64eb91add2e02ca (diff) | |
download | stable-diffusion-webui-gfx803-9092e1ca7756642692197316d3dec47c23322381.tar.gz stable-diffusion-webui-gfx803-9092e1ca7756642692197316d3dec47c23322381.tar.bz2 stable-diffusion-webui-gfx803-9092e1ca7756642692197316d3dec47c23322381.zip |
Merge pull request #3842 from R-N/gradient-clipping
Gradient clipping in train tab
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r-- | modules/textual_inversion/learn_schedule.py | 11 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 15 |
2 files changed, 21 insertions, 5 deletions
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py index dd0c0ad1..f63fc72f 100644 --- a/modules/textual_inversion/learn_schedule.py +++ b/modules/textual_inversion/learn_schedule.py @@ -58,14 +58,19 @@ class LearnRateScheduler: self.finished = False
- def apply(self, optimizer, step_number):
+ def step(self, step_number):
if step_number < self.end_step:
- return
+ return False
try:
(self.learn_rate, self.end_step) = next(self.schedules)
- except Exception:
+ except StopIteration:
self.finished = True
+ return False
+ return True
+
+ def apply(self, optimizer, step_number):
+ if not self.step(step_number):
return
if self.verbose:
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 2250e41b..71e07bcc 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -251,8 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
@@ -295,6 +294,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ return embedding, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
@@ -361,6 +365,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ if shared.state.interrupted:
break
+ if clip_grad:
+ clip_grad_sched.step(embedding.step)
+
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
@@ -382,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ # go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
+
+ if clip_grad:
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
+
scaler.step(optimizer)
scaler.update()
embedding.step += 1
|