aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/learn_schedule.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-11-04 06:02:15 +0000
committerGitHub <noreply@github.com>2022-11-04 06:02:15 +0000
commit4918eb6ce484caa4bc5a9f668bb466a5122a9c87 (patch)
tree76a0e42461d620764ad810c5b8dbd5b28d757519 /modules/textual_inversion/learn_schedule.py
parent80844ac861504e7c67a3d4dec0cbed9f6f4b3e24 (diff)
parent2cf3d2ac15530dbc8fdb486a4dac03b710972445 (diff)
downloadstable-diffusion-webui-gfx803-4918eb6ce484caa4bc5a9f668bb466a5122a9c87.tar.gz
stable-diffusion-webui-gfx803-4918eb6ce484caa4bc5a9f668bb466a5122a9c87.tar.bz2
stable-diffusion-webui-gfx803-4918eb6ce484caa4bc5a9f668bb466a5122a9c87.zip
Merge branch 'master' into hn-activation
Diffstat (limited to 'modules/textual_inversion/learn_schedule.py')
-rw-r--r--modules/textual_inversion/learn_schedule.py37
1 files changed, 22 insertions, 15 deletions
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index 2062726a..dd0c0ad1 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -4,30 +4,37 @@ import tqdm
class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
"""
- specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
"""
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
self.maxit = 0
- for i, pair in enumerate(pairs):
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
+ try:
+ for i, pair in enumerate(pairs):
+ if not pair.strip():
+ continue
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
return
- elif step == -1:
+ else:
self.rates.append((float(tmp[0]), max_steps))
self.maxit += 1
return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
+ assert self.rates
+ except (ValueError, AssertionError):
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
+
def __iter__(self):
return self
@@ -52,7 +59,7 @@ class LearnRateScheduler:
self.finished = False
def apply(self, optimizer, step_number):
- if step_number <= self.end_step:
+ if step_number < self.end_step:
return
try: