aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/dataset.py
diff options
context:
space:
mode:
authorunknown <mcgpapu@gmail.com>2022-12-25 08:03:55 +0000
committerunknown <mcgpapu@gmail.com>2022-12-25 08:03:55 +0000
commit876da1259965130603f2a7fea505cfa0fce09e2e (patch)
treeccb8b89d64480a4bd224b311702ffeb13b8fe754 /modules/textual_inversion/dataset.py
parentd6fdfde9d70f1b86b696240fb0a0c8f2a4d024f6 (diff)
parentc6f347b81f584b6c0d44af7a209983284dbb52d2 (diff)
downloadstable-diffusion-webui-gfx803-876da1259965130603f2a7fea505cfa0fce09e2e.tar.gz
stable-diffusion-webui-gfx803-876da1259965130603f2a7fea505cfa0fce09e2e.tar.bz2
stable-diffusion-webui-gfx803-876da1259965130603f2a7fea505cfa0fce09e2e.zip
Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r--modules/textual_inversion/dataset.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 2dc64c3c..88d68c76 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -28,9 +28,9 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
-
+
self.placeholder_token = placeholder_token
self.width = width
@@ -50,14 +50,14 @@ class PersonalizedBase(Dataset):
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
-
+
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
- raise Exception("inturrupted")
+ raise Exception("interrupted")
try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception:
@@ -144,7 +144,7 @@ class PersonalizedDataLoader(DataLoader):
self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
-
+
class BatchLoader:
def __init__(self, data):