diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-13 11:58:03 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-13 11:58:03 +0000 |
commit | 1849f6eb806f637f783b3beee3b48772da1cfab1 (patch) | |
tree | 345be78dd1991b77fcf4519bc44097e975e0b0c4 /modules/hypernetworks/hypernetwork.py | |
parent | 544e7a233e994f379dd67df08f5f519290b10293 (diff) | |
parent | 9cd7716753c5be47f76b8e5555cc3e7c0f17d34d (diff) | |
download | stable-diffusion-webui-gfx803-1849f6eb806f637f783b3beee3b48772da1cfab1.tar.gz stable-diffusion-webui-gfx803-1849f6eb806f637f783b3beee3b48772da1cfab1.tar.bz2 stable-diffusion-webui-gfx803-1849f6eb806f637f783b3beee3b48772da1cfab1.zip |
Merge pull request #3264 from Melanpan/tensorboard
Add support for Tensorboard (training)
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 194679e8..83cbb4f0 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -24,7 +24,6 @@ from statistics import stdev, mean optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}
-
class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
@@ -498,6 +497,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
+ if shared.opts.training_enable_tensorboard:
+ tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)
+
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -632,6 +634,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
+
+
+ if shared.opts.training_enable_tensorboard:
+ epoch_num = hypernetwork.step // len(ds)
+ epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
+
+ textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
+
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
"loss": f"{loss_step:.7f}",
"learn_rate": scheduler.learn_rate
@@ -673,6 +683,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, processed = processing.process_images(p)
image = processed.images[0] if len(processed.images) > 0 else None
+
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
|