diff options
author | flamelaw <flamelaw.com3d2@gmail.com> | 2022-11-21 01:15:46 +0000 |
---|---|---|
committer | flamelaw <flamelaw.com3d2@gmail.com> | 2022-11-21 01:15:46 +0000 |
commit | 5b57f61ba47f8b11d19a5b46e7fb5a52458abae5 (patch) | |
tree | 474be8cd4ca69e75ba95516b100874a88d235fde /modules/textual_inversion/dataset.py | |
parent | 2d22d72cdaaf2b78b2986b841d478c11ac855dd2 (diff) | |
download | stable-diffusion-webui-gfx803-5b57f61ba47f8b11d19a5b46e7fb5a52458abae5.tar.gz stable-diffusion-webui-gfx803-5b57f61ba47f8b11d19a5b46e7fb5a52458abae5.tar.bz2 stable-diffusion-webui-gfx803-5b57f61ba47f8b11d19a5b46e7fb5a52458abae5.zip |
fix pin_memory with different latent sampling method
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r-- | modules/textual_inversion/dataset.py | 23 |
1 files changed, 19 insertions, 4 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 110c0e09..f470324a 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -138,9 +138,12 @@ class PersonalizedBase(Dataset): return entry
class PersonalizedDataLoader(DataLoader):
- def __init__(self, *args, **kwargs):
- super(PersonalizedDataLoader, self).__init__(shuffle=True, drop_last=True, *args, **kwargs)
- self.collate_fn = collate_wrapper
+ def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False):
+ super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory)
+ if latent_sampling_method == "random":
+ self.collate_fn = collate_wrapper_random
+ else:
+ self.collate_fn = collate_wrapper
class BatchLoader:
@@ -148,10 +151,22 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data]
self.cond = [entry.cond for entry in data]
self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)
+ #self.emb_index = [entry.emb_index for entry in data]
+ #print(self.latent_sample.device)
def pin_memory(self):
self.latent_sample = self.latent_sample.pin_memory()
return self
def collate_wrapper(batch):
- return BatchLoader(batch)
\ No newline at end of file + return BatchLoader(batch)
+
+class BatchLoaderRandom(BatchLoader):
+ def __init__(self, data):
+ super().__init__(data)
+
+ def pin_memory(self):
+ return self
+
+def collate_wrapper_random(batch):
+ return BatchLoaderRandom(batch)
\ No newline at end of file |