From 448b9cedab66e05b5b2800513ca334a769b42aa7 Mon Sep 17 00:00:00 2001 From: dan Date: Sat, 7 Jan 2023 21:07:27 +0800 Subject: Allow variable img size --- modules/textual_inversion/dataset.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) (limited to 'modules/textual_inversion/dataset.py') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 88d68c76..375178ed 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None): self.filename = filename self.filename_text = filename_text self.latent_dist = latent_dist @@ -25,6 +25,7 @@ class DatasetEntry: self.cond = cond self.cond_text = cond_text self.pixel_values = pixel_values + self.img_shape = img_shape class PersonalizedBase(Dataset): @@ -33,8 +34,6 @@ class PersonalizedBase(Dataset): self.placeholder_token = placeholder_token - self.width = width - self.height = height self.flip = transforms.RandomHorizontalFlip(p=flip_p) self.dataset = [] @@ -59,7 +58,11 @@ class PersonalizedBase(Dataset): if shared.state.interrupted: raise Exception("interrupted") try: - image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC) + image = Image.open(path).convert('RGB') + if width < 2000: + image = image.resize((width, height), PIL.Image.BICUBIC) + else: + assert batch_size == 1, 'variable img size must have batch size 1' except Exception: continue @@ -88,14 +91,14 @@ class PersonalizedBase(Dataset): if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) elif latent_sampling_method == "deterministic": # Works only for DiagonalGaussianDistribution latent_dist.std = 0 latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -151,6 +154,7 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) + self.img_shape = [entry.img_shape for entry in data] #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) -- cgit v1.2.3 From 669fb18d5222f53ae48abe0f30393d846c50ad91 Mon Sep 17 00:00:00 2001 From: dan Date: Sun, 8 Jan 2023 01:34:52 +0800 Subject: Add checkbox for variable training dims --- modules/hypernetworks/hypernetwork.py | 2 +- modules/textual_inversion/dataset.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/ui.py | 3 +++ 4 files changed, 8 insertions(+), 5 deletions(-) (limited to 'modules/textual_inversion/dataset.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index b0cfbe71..dba52841 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -403,7 +403,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, shared.reload_hypernetworks() -def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): # images allows training previews to have infotext. Importing it at the top causes a circular import problem. from modules import images diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 375178ed..7f8a314f 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -29,7 +29,7 @@ class DatasetEntry: class PersonalizedBase(Dataset): - def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'): + def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once', varsize=False): re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None self.placeholder_token = placeholder_token @@ -59,7 +59,7 @@ class PersonalizedBase(Dataset): raise Exception("interrupted") try: image = Image.open(path).convert('RGB') - if width < 2000: + if not varsize: image = image.resize((width, height), PIL.Image.BICUBIC) else: assert batch_size == 1, 'variable img size must have batch size 1' diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 9f96d0fd..110efd19 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -255,7 +255,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat if save_model_every or create_image_every: assert log_directory, "Log directory is empty" -def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): +def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding") @@ -309,7 +309,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ pin_memory = shared.opts.pin_memory - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()}) diff --git a/modules/ui.py b/modules/ui.py index 99483130..4e709a71 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1343,6 +1343,7 @@ def create_ui(): template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") + varsize = gr.Checkbox(label="Ignore dimension settings and do not resize images", value=False, elem_id="train_varsize") steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") with FormRow(): @@ -1449,6 +1450,7 @@ def create_ui(): log_directory, training_width, training_height, + varsize, steps, clip_grad_mode, clip_grad_value, @@ -1480,6 +1482,7 @@ def create_ui(): log_directory, training_width, training_height, + varsize, steps, clip_grad_mode, clip_grad_value, -- cgit v1.2.3 From 72497895b9b1948f86d9309fe897cbb70c20ba7e Mon Sep 17 00:00:00 2001 From: dan Date: Sun, 8 Jan 2023 01:36:00 +0800 Subject: Move batchsize check --- modules/hypernetworks/hypernetwork.py | 2 +- modules/textual_inversion/dataset.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/textual_inversion/dataset.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index dba52841..32c67ccc 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -456,7 +456,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, pin_memory = shared.opts.pin_memory - ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method) + ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize) if shared.opts.save_training_settings_to_txt: saved_params = dict( diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index 7f8a314f..bcad6848 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -46,6 +46,8 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" + if varsize: + assert batch_size == 1, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] @@ -61,8 +63,6 @@ class PersonalizedBase(Dataset): image = Image.open(path).convert('RGB') if not varsize: image = image.resize((width, height), PIL.Image.BICUBIC) - else: - assert batch_size == 1, 'variable img size must have batch size 1' except Exception: continue -- cgit v1.2.3 From 43bb5190fc9e7ae479a5dc6640be202c9a71e464 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 9 Jan 2023 22:52:23 +0300 Subject: remove/simplify some changes from #6481 --- modules/textual_inversion/dataset.py | 14 +++++--------- modules/textual_inversion/textual_inversion.py | 4 ++-- modules/ui.py | 2 +- 3 files changed, 8 insertions(+), 12 deletions(-) (limited to 'modules/textual_inversion/dataset.py') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index bcad6848..fa48708e 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -17,7 +17,7 @@ re_numbers_at_start = re.compile(r"^[-\d]+\s*") class DatasetEntry: - def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None, img_shape=None): + def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None): self.filename = filename self.filename_text = filename_text self.latent_dist = latent_dist @@ -25,7 +25,6 @@ class DatasetEntry: self.cond = cond self.cond_text = cond_text self.pixel_values = pixel_values - self.img_shape = img_shape class PersonalizedBase(Dataset): @@ -46,12 +45,10 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - if varsize: - assert batch_size == 1, 'variable img size must have batch size 1' + assert batch_size == 1 or not varsize, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] - self.shuffle_tags = shuffle_tags self.tag_drop_out = tag_drop_out @@ -91,14 +88,14 @@ class PersonalizedBase(Dataset): if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)): latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) latent_sampling_method = "once" - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) elif latent_sampling_method == "deterministic": # Works only for DiagonalGaussianDistribution latent_dist.std = 0 latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu) - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample) elif latent_sampling_method == "random": - entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist, img_shape=image.size) + entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist) if not (self.tag_drop_out != 0 or self.shuffle_tags): entry.cond_text = self.create_text(filename_text) @@ -154,7 +151,6 @@ class BatchLoader: self.cond_text = [entry.cond_text for entry in data] self.cond = [entry.cond for entry in data] self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1) - self.img_shape = [entry.img_shape for entry in data] #self.emb_index = [entry.emb_index for entry in data] #print(self.latent_sample.device) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index ad76297e..14be2c96 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -492,8 +492,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ else: p.prompt = batch.cond_text[0] p.steps = 20 - p.width = batch.img_shape[0][0] - p.height = batch.img_shape[0][1] + p.width = training_width + p.height = training_height preview_text = p.prompt diff --git a/modules/ui.py b/modules/ui.py index 9d6b141e..ddfe1b1a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1348,7 +1348,7 @@ def create_ui(): template_file = gr.Textbox(label='Prompt template file', value=os.path.join(script_path, "textual_inversion_templates", "style_filewords.txt"), elem_id="train_template_file") training_width = gr.Slider(minimum=64, maximum=2048, step=8, label="Width", value=512, elem_id="train_training_width") training_height = gr.Slider(minimum=64, maximum=2048, step=8, label="Height", value=512, elem_id="train_training_height") - varsize = gr.Checkbox(label="Ignore dimension settings and do not resize images", value=False, elem_id="train_varsize") + varsize = gr.Checkbox(label="Do not resize images", value=False, elem_id="train_varsize") steps = gr.Number(label='Max steps', value=100000, precision=0, elem_id="train_steps") with FormRow(): -- cgit v1.2.3 From 6be644fa04ce1542f3a01804310cbbc0a4a91620 Mon Sep 17 00:00:00 2001 From: dan Date: Wed, 11 Jan 2023 05:31:58 +0800 Subject: Enable batch_size>1 for mixed-sized training --- modules/textual_inversion/dataset.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) (limited to 'modules/textual_inversion/dataset.py') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index fa48708e..b47414f3 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -3,8 +3,10 @@ import numpy as np import PIL import torch from PIL import Image -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset, DataLoader, Sampler from torchvision import transforms +from collections import defaultdict +from random import shuffle, choices import random import tqdm @@ -45,12 +47,12 @@ class PersonalizedBase(Dataset): assert data_root, 'dataset directory not specified' assert os.path.isdir(data_root), "Dataset directory doesn't exist" assert os.listdir(data_root), "Dataset directory is empty" - assert batch_size == 1 or not varsize, 'variable img size must have batch size 1' self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)] self.shuffle_tags = shuffle_tags self.tag_drop_out = tag_drop_out + groups = defaultdict(list) print("Preparing dataset...") for path in tqdm.tqdm(self.image_paths): @@ -103,13 +105,14 @@ class PersonalizedBase(Dataset): if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags): with devices.autocast(): entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0) - + groups[image.size].append(len(self.dataset)) self.dataset.append(entry) del torchdata del latent_dist del latent_sample self.length = len(self.dataset) + self.groups = list(groups.values()) assert self.length > 0, "No images have been found in the dataset." self.batch_size = min(batch_size, self.length) self.gradient_step = min(gradient_step, self.length // self.batch_size) @@ -137,9 +140,34 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry +class GroupedBatchSampler(Sampler): + def __init__(self, data_source: PersonalizedBase, batch_size: int): + n = len(data_source) + self.groups = data_source.groups + self.len = n_batch = n // batch_size + expected = [len(g) / n * n_batch * batch_size for g in data_source.groups] + self.base = [int(e) // batch_size for e in expected] + self.n_rand_batches = nrb = n_batch - sum(self.base) + self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] + self.batch_size = batch_size + def __len__(self): + return self.len + def __iter__(self): + b = self.batch_size + for g in self.groups: + shuffle(g) + batches = [] + for g in self.groups: + batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) + for _ in range(self.n_rand_batches): + rand_group = choices(self.groups, self.probs)[0] + batches.append(choices(rand_group, k=b)) + shuffle(batches) + yield from batches + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): - super(PersonalizedDataLoader, self).__init__(dataset, shuffle=True, drop_last=True, batch_size=batch_size, pin_memory=pin_memory) + super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) if latent_sampling_method == "random": self.collate_fn = collate_wrapper_random else: -- cgit v1.2.3 From a176d89487d92f5a5b152401e5c424b34ff43b96 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 13 Jan 2023 14:32:15 +0300 Subject: print bucket sizes for training without resizing images #6620 fix an error when generating a picture with embedding in it --- modules/textual_inversion/dataset.py | 16 ++++++++++++++++ modules/textual_inversion/image_embedding.py | 4 ++-- modules/textual_inversion/textual_inversion.py | 2 +- 3 files changed, 19 insertions(+), 3 deletions(-) (limited to 'modules/textual_inversion/dataset.py') diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index b47414f3..d31963d4 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -118,6 +118,12 @@ class PersonalizedBase(Dataset): self.gradient_step = min(gradient_step, self.length // self.batch_size) self.latent_sampling_method = latent_sampling_method + if len(groups) > 1: + print("Buckets:") + for (w, h), ids in sorted(groups.items(), key=lambda x: x[0]): + print(f" {w}x{h}: {len(ids)}") + print() + def create_text(self, filename_text): text = random.choice(self.lines) tags = filename_text.split(',') @@ -140,8 +146,11 @@ class PersonalizedBase(Dataset): entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist).to(devices.cpu) return entry + class GroupedBatchSampler(Sampler): def __init__(self, data_source: PersonalizedBase, batch_size: int): + super().__init__(data_source) + n = len(data_source) self.groups = data_source.groups self.len = n_batch = n // batch_size @@ -150,21 +159,28 @@ class GroupedBatchSampler(Sampler): self.n_rand_batches = nrb = n_batch - sum(self.base) self.probs = [e%batch_size/nrb/batch_size if nrb>0 else 0 for e in expected] self.batch_size = batch_size + def __len__(self): return self.len + def __iter__(self): b = self.batch_size + for g in self.groups: shuffle(g) + batches = [] for g in self.groups: batches.extend(g[i*b:(i+1)*b] for i in range(len(g) // b)) for _ in range(self.n_rand_batches): rand_group = choices(self.groups, self.probs)[0] batches.append(choices(rand_group, k=b)) + shuffle(batches) + yield from batches + class PersonalizedDataLoader(DataLoader): def __init__(self, dataset, latent_sampling_method="once", batch_size=1, pin_memory=False): super(PersonalizedDataLoader, self).__init__(dataset, batch_sampler=GroupedBatchSampler(dataset, batch_size), pin_memory=pin_memory) diff --git a/modules/textual_inversion/image_embedding.py b/modules/textual_inversion/image_embedding.py index ea653806..5593f88c 100644 --- a/modules/textual_inversion/image_embedding.py +++ b/modules/textual_inversion/image_embedding.py @@ -76,10 +76,10 @@ def insert_image_data_embed(image, data): next_size = data_np_low.shape[0] + (h-(data_np_low.shape[0] % h)) next_size = next_size + ((h*d)-(next_size % (h*d))) - data_np_low.resize(next_size) + data_np_low = np.resize(data_np_low, next_size) data_np_low = data_np_low.reshape((h, -1, d)) - data_np_high.resize(next_size) + data_np_high = np.resize(data_np_high, next_size) data_np_high = data_np_high.reshape((h, -1, d)) edge_style = list(data['string_to_param'].values())[0].cpu().detach().numpy().tolist()[0][:1024] diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 853246a6..e23906ca 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -479,7 +479,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ epoch_num = embedding.step // steps_per_epoch epoch_step = embedding.step % steps_per_epoch - description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}" + description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}" pbar.set_description(description) shared.state.textinfo = description if embedding_dir is not None and steps_done % save_embedding_every == 0: -- cgit v1.2.3