aboutsummaryrefslogtreecommitdiffstats
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
authorMalumaDev <piano.lu92@gmail.com>2022-10-15 22:24:05 +0000
committerMalumaDev <piano.lu92@gmail.com>2022-10-15 22:24:05 +0000
commitb694bba39a2f9f9069201e27f0d312f4abe5b41f (patch)
tree08db9b88160496f4326397526200276c7bee1493 /modules/hypernetworks/hypernetwork.py
parent9325c85f780c569d1823e422eaf51b2e497e0d3e (diff)
parent97ceaa23d00f6a17ca752dda757e6016f99230cb (diff)
downloadstable-diffusion-webui-gfx803-b694bba39a2f9f9069201e27f0d312f4abe5b41f.tar.gz
stable-diffusion-webui-gfx803-b694bba39a2f9f9069201e27f0d312f4abe5b41f.tar.bz2
stable-diffusion-webui-gfx803-b694bba39a2f9f9069201e27f0d312f4abe5b41f.zip
Merge remote-tracking branch 'origin/test_resolve_conflicts' into test_resolve_conflicts
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index a2b3bc0a..4905710e 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -272,15 +272,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
optimizer.zero_grad()
loss.backward()
optimizer.step()
-
- pbar.set_description(f"loss: {losses.mean():.7f}")
+ mean_loss = losses.mean()
+ if torch.isnan(mean_loss):
+ raise RuntimeError("Loss diverged.")
+ pbar.set_description(f"loss: {mean_loss:.7f}")
if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
hypernetwork.save(last_saved_file)
textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
- "loss": f"{losses.mean():.7f}",
+ "loss": f"{mean_loss:.7f}",
"learn_rate": scheduler.learn_rate
})
@@ -328,7 +330,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
shared.state.textinfo = f"""
<p>
-Loss: {losses.mean():.7f}<br/>
+Loss: {mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>