aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/learn_schedule.py
diff options
context:
space:
mode:
author不会画画的中医不是好程序员 <yfszzx@gmail.com>2022-10-13 04:35:39 +0000
committerGitHub <noreply@github.com>2022-10-13 04:35:39 +0000
commit0186db178e12b94eae559827594898c0611f1c0c (patch)
tree03390c1a8f8908002f39770d54626ef7b3fa565d /modules/textual_inversion/learn_schedule.py
parent716a9e034f1aff434083363b218bd6043a774fc2 (diff)
parent698d303b04e293635bfb49c525409f3bcf671dce (diff)
downloadstable-diffusion-webui-gfx803-0186db178e12b94eae559827594898c0611f1c0c.tar.gz
stable-diffusion-webui-gfx803-0186db178e12b94eae559827594898c0611f1c0c.tar.bz2
stable-diffusion-webui-gfx803-0186db178e12b94eae559827594898c0611f1c0c.zip
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/textual_inversion/learn_schedule.py')
-rw-r--r--modules/textual_inversion/learn_schedule.py37
1 files changed, 36 insertions, 1 deletions
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index db720271..2062726a 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -1,6 +1,12 @@
+import tqdm
-class LearnSchedule:
+
+class LearnScheduleIterator:
def __init__(self, learn_rate, max_steps, cur_step=0):
+ """
+ specify learn_rate as "0.001:100, 0.00001:1000, 1e-5:10000" to have lr of 0.001 until step 100, 0.00001 until 1000, 1e-5:10000 until 10000
+ """
+
pairs = learn_rate.split(',')
self.rates = []
self.it = 0
@@ -32,3 +38,32 @@ class LearnSchedule:
return self.rates[self.it - 1]
else:
raise StopIteration
+
+
+class LearnRateScheduler:
+ def __init__(self, learn_rate, max_steps, cur_step=0, verbose=True):
+ self.schedules = LearnScheduleIterator(learn_rate, max_steps, cur_step)
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ self.verbose = verbose
+
+ if self.verbose:
+ print(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ self.finished = False
+
+ def apply(self, optimizer, step_number):
+ if step_number <= self.end_step:
+ return
+
+ try:
+ (self.learn_rate, self.end_step) = next(self.schedules)
+ except Exception:
+ self.finished = True
+ return
+
+ if self.verbose:
+ tqdm.tqdm.write(f'Training at rate of {self.learn_rate} until step {self.end_step}')
+
+ for pg in optimizer.param_groups:
+ pg['lr'] = self.learn_rate
+