aboutsummaryrefslogtreecommitdiffstats
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py16
1 files changed, 7 insertions, 9 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index f4c2668f..02b624e1 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -385,10 +385,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
- clip_grad_mode_value = clip_grad_mode == "value"
- clip_grad_mode_norm = clip_grad_mode == "norm"
- clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
- if clip_grad_enabled:
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
# dataset loading may take a while, so input validations and early returns should be done before this
@@ -433,7 +433,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if shared.state.interrupted:
break
- if clip_grad_enabled:
+ if clip_grad:
clip_grad_sched.step(hypernetwork.step)
with torch.autocast("cuda"):
@@ -458,10 +458,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
steps_without_grad = 0
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
- if clip_grad_mode_value:
- torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate)
- elif clip_grad_mode_norm:
- torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate)
+ if clip_grad:
+ clip_grad(weights, clip_grad_sched.learn_rate)
optimizer.step()