diff options
author | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-11-05 04:48:38 +0000 |
---|---|---|
committer | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-11-05 04:48:38 +0000 |
commit | bb832d7725187f8a8ab44faa6ee1b38cb5f600aa (patch) | |
tree | 7b577e55bc4fa5044d63e6471a562c816f19fba6 /modules/textual_inversion/textual_inversion.py | |
parent | 3277f90e933485d2590a55998480d02f9499be5c (diff) | |
download | stable-diffusion-webui-gfx803-bb832d7725187f8a8ab44faa6ee1b38cb5f600aa.tar.gz stable-diffusion-webui-gfx803-bb832d7725187f8a8ab44faa6ee1b38cb5f600aa.tar.bz2 stable-diffusion-webui-gfx803-bb832d7725187f8a8ab44faa6ee1b38cb5f600aa.zip |
Simplify grad clip
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 16 |
1 files changed, 7 insertions, 9 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index c567ec3f..687d97bb 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -269,10 +269,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- clip_grad_mode_value = clip_grad_mode == "value"
- clip_grad_mode_norm = clip_grad_mode == "norm"
- clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
- if clip_grad_enabled:
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -302,7 +302,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if shared.state.interrupted:
break
- if clip_grad_enabled:
+ if clip_grad:
clip_grad_sched.step(embedding.step)
with torch.autocast("cuda"):
@@ -316,10 +316,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc optimizer.zero_grad()
loss.backward()
- if clip_grad_mode_value:
- torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate)
- elif clip_grad_mode_norm:
- torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate)
+ if clip_grad:
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
optimizer.step()
|