diff options
author | TinkTheBoush <TinkTheBoush@github.com> | 2022-11-01 14:29:12 +0000 |
---|---|---|
committer | TinkTheBoush <TinkTheBoush@github.com> | 2022-11-01 14:29:12 +0000 |
commit | 467cae167a3066ffa2b2a5e6f16dd42642219aba (patch) | |
tree | 91441ba33d4d03b546d942a335d6a3d08d4e90a0 /modules/textual_inversion/dataset.py | |
parent | c28de154b0ffb143019387f9fc169953347a60f4 (diff) | |
download | stable-diffusion-webui-gfx803-467cae167a3066ffa2b2a5e6f16dd42642219aba.tar.gz stable-diffusion-webui-gfx803-467cae167a3066ffa2b2a5e6f16dd42642219aba.tar.bz2 stable-diffusion-webui-gfx803-467cae167a3066ffa2b2a5e6f16dd42642219aba.zip |
append_tag_shuffle
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r-- | modules/textual_inversion/dataset.py | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index ad726577..e9d97cc1 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -24,7 +24,7 @@ class DatasetEntry: class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
@@ -33,6 +33,7 @@ class PersonalizedBase(Dataset): self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
+ self.shuffle_tags = shuffle_tags
self.dataset = []
@@ -98,7 +99,12 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- text = text.replace("[filewords]", filename_text)
+ if self.tag_shuffle:
+ tags = filename_text.split(',')
+ random.shuffle(tags)
+ text = text.replace("[filewords]", ','.join(tags))
+ else:
+ text = text.replace("[filewords]", filename_text)
return text
def __len__(self):
|