From a176d89487d92f5a5b152401e5c424b34ff43b96 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 13 Jan 2023 14:32:15 +0300 Subject: print bucket sizes for training without resizing images #6620 fix an error when generating a picture with embedding in it --- modules/textual_inversion/dataset.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) (limited to 'modules/textual_inversion/dataset.py') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b47414f3..d31963d4 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -118,6 +118,12 @@ class PersonalizedBase(Dataset): self.gradient_step = min(gradient_step, self.length // self.batch_size) self.latent_sampling_method = latent_sampling_method + if len(groups) > 1: + print("Buckets:") + for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]): + print(f" {w}x{h}: {len(ids)}") + print() + def create_text(self, filename_text): text = random.choice(self.lines) tags = filename_text.split(',') @@ -140,8 +146,11 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry + class GroupedBatchSampler(Sampler): def __init__(self, data_source: PersonalizedBase, batch_size: int): + super().__init__(data_source) + n = len(data_source) self.groups = data_source.groups self.len = n_batch = n // batch_size @@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler): self.n_rand_batches = nrb = n_batch - sum(self.base) self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] self.batch_size = batch_size + def __len__(self): return self.len + def __iter__(self): b = self.batch_size + for g in self.groups: shuffle(g) + batches = [] for g in self.groups: batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) for _ in range(self.n_rand_batches): rand_group = choices(self.groups, self.probs)[0] batches.append(choices(rand_group, k=b)) + shuffle(batches) + yield from batches + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) -- cgit v1.2.3