diff options
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 44 |
1 files changed, 4 insertions, 40 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 47a27faf..7717837d 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -10,6 +10,7 @@ import datetime from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
+from modules.textual_inversion.learn_schedule import LearnSchedule
class Embedding:
@@ -198,11 +199,8 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if ititial_step > steps:
return embedding, filename
- tr_img_len = len([os.path.join(data_root, file_path) for file_path in os.listdir(data_root)])
- epoch_len = (tr_img_len * num_repeats) + tr_img_len
-
- scheduleIter = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(scheduleIter)
+ schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
+ (learn_rate, end_step) = next(schedules)
print(f'Training at rate of {learn_rate} until step {end_step}')
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
@@ -213,7 +211,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini if embedding.step > end_step:
try:
- (learn_rate, end_step) = next(scheduleIter)
+ (learn_rate, end_step) = next(schedules)
except:
break
tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
@@ -288,37 +286,3 @@ Last saved image: {html.escape(last_saved_image)}<br/> embedding.save(filename)
return embedding, filename
-
-class LearnSchedule:
- def __init__(self, learn_rate, max_steps, cur_step=0):
- pairs = learn_rate.split(',')
- self.rates = []
- self.it = 0
- self.maxit = 0
- for i, pair in enumerate(pairs):
- tmp = pair.split(':')
- if len(tmp) == 2:
- step = int(tmp[1])
- if step > cur_step:
- self.rates.append((float(tmp[0]), min(step, max_steps)))
- self.maxit += 1
- if step > max_steps:
- return
- elif step == -1:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
- else:
- self.rates.append((float(tmp[0]), max_steps))
- self.maxit += 1
- return
-
- def __iter__(self):
- return self
-
- def __next__(self):
- if self.it < self.maxit:
- self.it += 1
- return self.rates[self.it - 1]
- else:
- raise StopIteration
|