aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/dataset.py
diff options
context:
space:
mode:
authorTinkTheBoush <TinkTheBoush@github.com>2022-11-04 10:39:03 +0000
committerTinkTheBoush <TinkTheBoush@github.com>2022-11-04 10:39:03 +0000
commit821e2b883dbb42a187bc37379175cd55b7cd7e81 (patch)
treedae5c1757dbdc7130ebe16012c6d1bcf36f37223 /modules/textual_inversion/dataset.py
parentaf6fba247553e670ef5e2dcc1866279f9f065d6d (diff)
downloadstable-diffusion-webui-gfx803-821e2b883dbb42a187bc37379175cd55b7cd7e81.tar.gz
stable-diffusion-webui-gfx803-821e2b883dbb42a187bc37379175cd55b7cd7e81.tar.bz2
stable-diffusion-webui-gfx803-821e2b883dbb42a187bc37379175cd55b7cd7e81.zip
change option position to Training setting
Diffstat (limited to 'modules/textual_inversion/dataset.py')
-rw-r--r--modules/textual_inversion/dataset.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index e9d97cc1..df278dc2 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -24,7 +24,7 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
self.placeholder_token = placeholder_token
@@ -33,7 +33,6 @@ class PersonalizedBase(Dataset):
self.width = width
self.height = height
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
- self.shuffle_tags = shuffle_tags
self.dataset = []
@@ -99,7 +98,7 @@ class PersonalizedBase(Dataset):
def create_text(self, filename_text):
text = random.choice(self.lines)
text = text.replace("[name]", self.placeholder_token)
- if self.tag_shuffle:
+ if shared.opts.shuffle_tags:
tags = filename_text.split(',')
random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags))