diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-13 11:32:15 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-13 11:32:15 +0000 |
commit | a176d89487d92f5a5b152401e5c424b34ff43b96 (patch) | |
tree | 914910f9e24f3f489d1a273ec5abf4531dcd502c /modules/textual_inversion/dataset.py | |
parent | 486bda9b331d054870e2b3551f94ece7aa39574d (diff) | |
download | stable-diffusion-webui-gfx803-a176d89487d92f5a5b152401e5c424b34ff43b96.tar.gz stable-diffusion-webui-gfx803-a176d89487d92f5a5b152401e5c424b34ff43b96.tar.bz2 stable-diffusion-webui-gfx803-a176d89487d92f5a5b152401e5c424b34ff43b96.zip |
print bucket sizes for training without resizing images #6620
fix an error when generating a picture with embedding in it
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r-- | modules/textual_inversion/dataset.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b47414f3..d31963d4 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -118,6 +118,12 @@ class PersonalizedBase(Dataset): self.gradient_step = min(gradient_step, self.length // self.batch_size)
self.latent_sampling_method = latent_sampling_method
+ if len(groups) > 1:
+ print("Buckets:")
+ for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]):
+ print(f" {w}x{h}: {len(ids)}")
+ print()
+
def create_text(self, filename_text):
text = random.choice(self.lines)
tags = filename_text.split(',')
@@ -140,8 +146,11 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu)
return entry
+
class GroupedBatchSampler(Sampler):
def __init__(self, data_source: PersonalizedBase, batch_size: int):
+ super().__init__(data_source)
+
n = len(data_source)
self.groups = data_source.groups
self.len = n_batch = n // batch_size
@@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler): self.n_rand_batches = nrb = n_batch - sum(self.base)
self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected]
self.batch_size = batch_size
+
def __len__(self):
return self.len
+
def __iter__(self):
b = self.batch_size
+
for g in self.groups:
shuffle(g)
+
batches = []
for g in self.groups:
batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b))
for _ in range(self.n_rand_batches):
rand_group = choices(self.groups, self.probs)[0]
batches.append(choices(rand_group, k=b))
+
shuffle(batches)
+
yield from batches
+
class PersonalizedDataLoader(DataLoader):
def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory)
|