From 467cae167a3066ffa2b2a5e6f16dd42642219aba Mon Sep 17 00:00:00 2001 From: TinkTheBoush Date: Tue, 1 Nov 2022 23:29:12 +0900 Subject: append_tag_shuffle --- modules/textual_inversion/dataset.py | 10 ++++++++-- modules/textual_inversion/textual_inversion.py | 4 ++-- 2 files changed, 10 insertions(+), 4 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index ad726577..e9d97cc1 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -24,7 +24,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token @@ -33,6 +33,7 @@ class PersonalizedBase(Dataset): self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) + self.shuffle_tags = shuffle_tags self.dataset = [] @@ -98,7 +99,12 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) - text = text.replace("[filewords]", filename_text) + if self.tag_shuffle: + tags = filename_text.split(',') + random.shuffle(tags) + text = text.replace("[filewords]", ','.join(tags)) + else: + text = text.replace("[filewords]", filename_text) return text def __len__(self): diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e0babb46..64700e23 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -224,7 +224,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") @@ -271,7 +271,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) embedding.vec.requires_grad = True optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate) -- cgit v1.2.3 From 821e2b883dbb42a187bc37379175cd55b7cd7e81 Mon Sep 17 00:00:00 2001 From: TinkTheBoush Date: Fri, 4 Nov 2022 19:39:03 +0900 Subject: change option position to Training setting --- modules/textual_inversion/dataset.py | 5 ++--- modules/textual_inversion/textual_inversion.py | 4 ++-- 2 files changed, 4 insertions(+), 5 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index e9d97cc1..df278dc2 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -24,7 +24,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token @@ -33,7 +33,6 @@ class PersonalizedBase(Dataset): self.width = width self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) - self.shuffle_tags = shuffle_tags self.dataset = [] @@ -99,7 +98,7 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) - if self.tag_shuffle: + if shared.opts.shuffle_tags: tags = filename_text.split(',') random.shuffle(tags) text = text.replace("[filewords]", ','.join(tags)) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 82dde931..0aeb0459 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -224,7 +224,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") @@ -272,7 +272,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc # dataset loading may take a while, so input validations and early returns should be done before this shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..." with torch.autocast("cuda"): - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size) if unload: shared.sd_model.first_stage_model.to(devices.cpu) -- cgit v1.2.3 From 13a2f1dca32980339e1fb4d1995cde428db798c5 Mon Sep 17 00:00:00 2001 From: KyuSeok Jung Date: Fri, 11 Nov 2022 10:29:55 +0900 Subject: adding tag drop out option --- modules/textual_inversion/dataset.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index df278dc2..a95c7835 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -98,12 +98,12 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) + tags = filename_text.split(',') + if shared.opt.tag_drop_out != 0: + tags = [t for t in tags if random.random() > shared.opt.tag_drop_out] if shared.opts.shuffle_tags: - tags = filename_text.split(',') random.shuffle(tags) - text = text.replace("[filewords]", ','.join(tags)) - else: - text = text.replace("[filewords]", filename_text) + text = text.replace("[filewords]", ','.join(tags)) return text def __len__(self): -- cgit v1.2.3 From b19af67d29356f97fea5cccfdfa12583f605243f Mon Sep 17 00:00:00 2001 From: KyuSeok Jung Date: Fri, 11 Nov 2022 10:54:19 +0900 Subject: Update dataset.py --- modules/textual_inversion/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index a95c7835..e2cb8428 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -99,7 +99,7 @@ class PersonalizedBase(Dataset): text = random.choice(self.lines) text = text.replace("[name]", self.placeholder_token) tags = filename_text.split(',') - if shared.opt.tag_drop_out != 0: + if shared.opts.tag_drop_out != 0: tags = [t for t in tags if random.random() > shared.opt.tag_drop_out] if shared.opts.shuffle_tags: random.shuffle(tags) -- cgit v1.2.3 From a1e271207dfc3e89b1286ba41d96b459f210c4b2 Mon Sep 17 00:00:00 2001 From: KyuSeok Jung Date: Fri, 11 Nov 2022 10:56:53 +0900 Subject: Update dataset.py --- modules/textual_inversion/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/textual_inversion') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index e2cb8428..eb75c376 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -100,7 +100,7 @@ class PersonalizedBase(Dataset): text = text.replace("[name]", self.placeholder_token) tags = filename_text.split(',') if shared.opts.tag_drop_out != 0: - tags = [t for t in tags if random.random() > shared.opt.tag_drop_out] + tags = [t for t in tags if random.random() > shared.opts.tag_drop_out] if shared.opts.shuffle_tags: random.shuffle(tags) text = text.replace("[filewords]", ','.join(tags)) -- cgit v1.2.3