diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-10-14 19:43:55 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-10-14 19:43:55 +0000 |
commit | 03d62538aebeff51713619fe808c953bdb70193d (patch) | |
tree | 1c01da8d5ec04779838ccecce04f217e3fcbef92 /modules/hypernetworks | |
parent | 326fe7d44ba7c813cd40166d15fdaa8e8eaf8be8 (diff) | |
download | stable-diffusion-webui-gfx803-03d62538aebeff51713619fe808c953bdb70193d.tar.gz stable-diffusion-webui-gfx803-03d62538aebeff51713619fe808c953bdb70193d.tar.bz2 stable-diffusion-webui-gfx803-03d62538aebeff51713619fe808c953bdb70193d.zip |
remove duplicate code for log loss, add step, make it read from options rather than gradio input
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 20 |
1 files changed, 6 insertions, 14 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index edb8cba1..59c7ac6e 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -15,6 +15,7 @@ import torch from torch import einsum
from einops import rearrange, repeat
import modules.textual_inversion.dataset
+from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
@@ -210,7 +211,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=1, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=512, height=512, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True)
if unload:
shared.sd_model.cond_stage_model.to(devices.cpu)
@@ -263,19 +264,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, data_root, log_directory, last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
- if write_csv_every > 0 and hypernetwork_dir is not None and hypernetwork.step % write_csv_every == 0:
- write_csv_header = False if os.path.exists(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv")) else True
-
- with open(os.path.join(hypernetwork_dir, "hypernetwork_loss.csv"), "a+") as fout:
-
- csv_writer = csv.DictWriter(fout, fieldnames=["step", "loss", "learn_rate"])
-
- if write_csv_header:
- csv_writer.writeheader()
-
- csv_writer.writerow({"step": hypernetwork.step,
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate})
+ textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
+ "loss": f"{losses.mean():.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{hypernetwork_name}-{hypernetwork.step}.png')
|