diff options
author | discus0434 <66945496+discus0434@users.noreply.github.com> | 2022-10-22 13:00:59 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-22 13:00:59 +0000 |
commit | 97749b7c7d9e0b27613aa79197f6094b4f6441d8 (patch) | |
tree | 6f4402badf1253bb77d2e973d31ec1228e9fbbab /modules/hypernetworks/hypernetwork.py | |
parent | 7912acef725832debef58c4c7bf8ec22fb446c0b (diff) | |
parent | 7fd90128eb6d1820045bfe2c2c1269661023a712 (diff) | |
download | stable-diffusion-webui-gfx803-97749b7c7d9e0b27613aa79197f6094b4f6441d8.tar.gz stable-diffusion-webui-gfx803-97749b7c7d9e0b27613aa79197f6094b4f6441d8.tar.bz2 stable-diffusion-webui-gfx803-97749b7c7d9e0b27613aa79197f6094b4f6441d8.zip |
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 7d12e0ff..3372aae2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -325,6 +325,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log # if optimizer == "AdamW": or else Adam / AdamW / SGD, etc...
optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)
+ steps_without_grad = 0
+
pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
for i, entries in pbar:
hypernetwork.step = i + ititial_step
@@ -347,8 +349,17 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log losses[hypernetwork.step % losses.shape[0]] = loss.item()
optimizer.zero_grad()
+ weights[0].grad = None
loss.backward()
+
+ if weights[0].grad is None:
+ steps_without_grad += 1
+ else:
+ steps_without_grad = 0
+ assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'
+
optimizer.step()
+
mean_loss = losses.mean()
if torch.isnan(mean_loss):
raise RuntimeError("Loss diverged.")
|