diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-05-11 18:25:15 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-11 18:25:15 +0000 |
commit | abe32cefa39dee36d7f661d4e63c28ea8dd60c4f (patch) | |
tree | 1f1d817b59b49c6d3944c959151ce4c67d9041da /modules/hypernetworks/hypernetwork.py | |
parent | b4aaa339d529c81859858f0bedcc72b44fccd3d0 (diff) | |
parent | 49a55b410b66b7dd9be9335d8a2e3a71e4f8b15c (diff) | |
download | stable-diffusion-webui-gfx803-abe32cefa39dee36d7f661d4e63c28ea8dd60c4f.tar.gz stable-diffusion-webui-gfx803-abe32cefa39dee36d7f661d4e63c28ea8dd60c4f.tar.bz2 stable-diffusion-webui-gfx803-abe32cefa39dee36d7f661d4e63c28ea8dd60c4f.zip |
Merge pull request #10285 from akx/ruff-spacing
Indentation + ruff whitespace fixes
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 12 |
1 files changed, 6 insertions, 6 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 38ef074f..570b5603 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -540,7 +540,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi return hypernetwork, filename
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
-
+
clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
if clip_grad:
clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
@@ -593,7 +593,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi print(e)
scaler = torch.cuda.amp.GradScaler()
-
+
batch_size = ds.batch_size
gradient_step = ds.gradient_step
# n steps = batch_size * gradient_step * n image processed
@@ -636,7 +636,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi if clip_grad:
clip_grad_sched.step(hypernetwork.step)
-
+
with devices.autocast():
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
if use_weight:
@@ -657,14 +657,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi _loss_step += loss.item()
scaler.scale(loss).backward()
-
+
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
loss_logging.append(_loss_step)
if clip_grad:
clip_grad(weights, clip_grad_sched.learn_rate)
-
+
scaler.step(optimizer)
scaler.update()
hypernetwork.step += 1
@@ -674,7 +674,7 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi _loss_step = 0
steps_done = hypernetwork.step + 1
-
+
epoch_num = hypernetwork.step // steps_per_epoch
epoch_step = hypernetwork.step % steps_per_epoch
|