diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-10-15 07:47:26 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-10-15 07:47:26 +0000 |
commit | f42e0aae6de6b9a7f8da4eaf13594a13502b4fa9 (patch) | |
tree | 472025101577ff5cbd45a3bcb524e6e4accb75ec /modules/textual_inversion/learn_schedule.py | |
parent | 0e77ee24b0b651d6a564245243850e4fb9831e31 (diff) | |
parent | d13ce89e203d76ab2b54a3406a93a5e4304f529e (diff) | |
download | stable-diffusion-webui-gfx803-f42e0aae6de6b9a7f8da4eaf13594a13502b4fa9.tar.gz stable-diffusion-webui-gfx803-f42e0aae6de6b9a7f8da4eaf13594a13502b4fa9.tar.bz2 stable-diffusion-webui-gfx803-f42e0aae6de6b9a7f8da4eaf13594a13502b4fa9.zip |
Merge branch 'master' into master
Diffstat (limited to 'modules/textual_inversion/learn_schedule.py')
-rw-r--r-- | modules/textual_inversion/learn_schedule.py | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py new file mode 100644 index 00000000..2062726a --- /dev/null +++ b/modules/textual_inversion/learn_schedule.py @@ -0,0 +1,69 @@ +import tqdm
+
+
+class LearnScheduleIterator:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ """
+
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ if self.verbose:
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def apply(self, optimizer, step_number):
+ if step_number <= self.end_step:
+ return
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except Exception:
+ self.finished = True
+ return
+
+ if self.verbose:
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
+
|