diff options
author | Dynamic <bradje@naver.com> | 2022-10-29 13:33:06 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-29 13:33:06 +0000 |
commit | 3d36d62d61425a2538270c489dc5fe827c125ad1 (patch) | |
tree | d3ebb7d5f5fa23f355a02288f0c645c102b8e6a0 /modules/hypernetworks/hypernetwork.py | |
parent | a668444110743cd163474ec563b0e69025dea3d2 (diff) | |
parent | 35c45df28b303a05d56a13cb56d4046f08cf8c25 (diff) | |
download | stable-diffusion-webui-gfx803-3d36d62d61425a2538270c489dc5fe827c125ad1.tar.gz stable-diffusion-webui-gfx803-3d36d62d61425a2538270c489dc5fe827c125ad1.tar.bz2 stable-diffusion-webui-gfx803-3d36d62d61425a2538270c489dc5fe827c125ad1.zip |
Merge branch 'AUTOMATIC1111:master' into kr-localization
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 13 |
1 files changed, 8 insertions, 5 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 8113b35b..2e84583b 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -25,6 +25,7 @@ from statistics import stdev, mean class HypernetworkModule(torch.nn.Module):
multiplier = 1.0
activation_dict = {
+ "linear": torch.nn.Identity,
"relu": torch.nn.ReLU,
"leakyrelu": torch.nn.LeakyReLU,
"elu": torch.nn.ELU,
@@ -428,7 +429,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log optimizer.step()
- if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
+ steps_done = hypernetwork.step + 1
+
+ if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
raise RuntimeError("Loss diverged.")
if len(previous_mean_losses) > 1:
@@ -438,9 +441,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
pbar.set_description(dataset_loss_info)
- if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
+ if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
# Before saving, change name to match current checkpoint.
- hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
+ hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
hypernetwork.save(last_saved_file)
@@ -449,8 +452,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log "learn_rate": scheduler.learn_rate
})
- if hypernetwork.step > 0 and images_dir is not None and hypernetwork.step % create_image_every == 0:
- forced_filename = f'{hypernetwork_name}-{hypernetwork.step}'
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{hypernetwork_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
optimizer.zero_grad()
|