aboutsummaryrefslogtreecommitdiffstats
path: root/modules/hypernetworks
diff options
context:
space:
mode:
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r--modules/hypernetworks/hypernetwork.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 2e84583b..f45ce199 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -328,7 +328,7 @@ def report_statistics(loss_info:dict):
-def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
# images allows training previews to have infotext. Importing it at the top causes a circular import problem.
from modules import images
@@ -384,8 +384,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
ititial_step = hypernetwork.step or 0
if ititial_step > steps:
return hypernetwork, filename
+
+ clip_grad_mode_value = clip_grad_mode == "value"
+ clip_grad_mode_norm = clip_grad_mode == "norm"
+ clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
+ if clip_grad_enabled:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+
# if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
@@ -405,6 +412,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
if shared.state.interrupted:
break
+ if clip_grad_enabled:
+ clip_grad_sched.step(hypernetwork.step)
+
with torch.autocast("cuda"):
c = stack_conds([entry.cond for entry in entries]).to(devices.device)
# c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
@@ -427,6 +437,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
steps_without_grad = 0
assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
+ if clip_grad_mode_value:
+ torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate)
+ elif clip_grad_mode_norm:
+ torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate)
+
optimizer.step()
steps_done = hypernetwork.step + 1