diff options
Diffstat (limited to 'modules/textual_inversion/learn_schedule.py')
-rw-r--r-- | modules/textual_inversion/learn_schedule.py | 76 |
1 files changed, 76 insertions, 0 deletions
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py new file mode 100644 index 00000000..dd0c0ad1 --- /dev/null +++ b/modules/textual_inversion/learn_schedule.py @@ -0,0 +1,76 @@ +import tqdm
+
+
+class LearnScheduleIterator:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000
+ """
+
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ try:
+ for i, pair in enumerate(pairs):
+ if not pair.strip():
+ continue
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ assert self.rates
+ except (ValueError, AssertionError):
+ raise Exception('Invalid learning rate schedule. It should be a number or, for example, like "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, and 1e-5 until 10000.')
+
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ if self.verbose:
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def apply(self, optimizer, step_number):
+ if step_number < self.end_step:
+ return
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except Exception:
+ self.finished = True
+ return
+
+ if self.verbose:
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
+
|