aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/learn_schedule.py
diff options
context:
space:
mode:
authoryfszzx <yfszzx@gmail.com>2022-10-12 13:24:40 +0000
committeryfszzx <yfszzx@gmail.com>2022-10-12 13:24:40 +0000
commitc87c3b9c1169f8a9b632d6d8c8675d98956c387c (patch)
treeeeeb4ff5e05af265686ce3a7916a0df2f30113e4 /modules/textual_inversion/learn_schedule.py
parent511ca57e37483aac0cf260c89838ad2948509101 (diff)
parent429442f4a6aab7301efb89d27bef524fe827e81a (diff)
downloadstable-diffusion-webui-gfx803-c87c3b9c1169f8a9b632d6d8c8675d98956c387c.tar.gz
stable-diffusion-webui-gfx803-c87c3b9c1169f8a9b632d6d8c8675d98956c387c.tar.bz2
stable-diffusion-webui-gfx803-c87c3b9c1169f8a9b632d6d8c8675d98956c387c.zip
test
Diffstat (limited to 'modules/textual_inversion/learn_schedule.py')
-rw-r--r--modules/textual_inversion/learn_schedule.py34
1 files changed, 34 insertions, 0 deletions
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
new file mode 100644
index 00000000..db720271
--- /dev/null
+++ b/modules/textual_inversion/learn_schedule.py
@@ -0,0 +1,34 @@
+
+class LearnSchedule:
+ def __init__(self, learn_rate, max_steps, cur_step=0):
+ pairs = learn_rate.split(',')
+ self.rates = []
+ self.it = 0
+ self.maxit = 0
+ for i, pair in enumerate(pairs):
+ tmp = pair.split(':')
+ if len(tmp) == 2:
+ step = int(tmp[1])
+ if step > cur_step:
+ self.rates.append((float(tmp[0]), min(step, max_steps)))
+ self.maxit += 1
+ if step > max_steps:
+ return
+ elif step == -1:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+ else:
+ self.rates.append((float(tmp[0]), max_steps))
+ self.maxit += 1
+ return
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.it < self.maxit:
+ self.it += 1
+ return self.rates[self.it - 1]
+ else:
+ raise StopIteration