diff options
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r-- | modules/textual_inversion/autocrop.py | 6 | ||||
-rw-r--r-- | modules/textual_inversion/dataset.py | 16 | ||||
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 21 |
3 files changed, 23 insertions, 20 deletions
diff --git a/modules/textual_inversion/autocrop.py b/modules/textual_inversion/autocrop.py index 9859974a..68e1103c 100644 --- a/modules/textual_inversion/autocrop.py +++ b/modules/textual_inversion/autocrop.py @@ -276,8 +276,8 @@ def poi_average(pois, settings): weight += poi.weight
x += poi.x * poi.weight
y += poi.y * poi.weight
- avg_x = round(x / weight)
- avg_y = round(y / weight)
+ avg_x = round(weight and x / weight)
+ avg_y = round(weight and y / weight)
return PointOfInterest(avg_x, avg_y)
@@ -338,4 +338,4 @@ class Settings: self.face_points_weight = face_points_weight
self.annotate_image = annotate_image
self.destop_view_image = False
- self.dnn_model_path = dnn_model_path
\ No newline at end of file + self.dnn_model_path = dnn_model_path
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py index f470324a..88d68c76 100644 --- a/modules/textual_inversion/dataset.py +++ b/modules/textual_inversion/dataset.py @@ -28,9 +28,9 @@ class DatasetEntry: class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
-
+
self.placeholder_token = placeholder_token
self.width = width
@@ -50,14 +50,14 @@ class PersonalizedBase(Dataset): self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
-
+
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
- raise Exception("inturrupted")
+ raise Exception("interrupted")
try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception:
@@ -82,7 +82,7 @@ class PersonalizedBase(Dataset): torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
latent_sample = None
- with torch.autocast("cuda"):
+ with devices.autocast():
latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))
if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
@@ -101,7 +101,7 @@ class PersonalizedBase(Dataset): entry.cond_text = self.create_text(filename_text)
if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
- with torch.autocast("cuda"):
+ with devices.autocast():
entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
self.dataset.append(entry)
@@ -117,13 +117,13 @@ class PersonalizedBase(Dataset): def create_text(self, filename_text):
text = random.choice(self.lines)
- text = text.replace("[name]", self.placeholder_token)
tags = filename_text.split(',')
if self.tag_drop_out != 0:
tags = [t for t in tags if random.random() > self.tag_drop_out]
if self.shuffle_tags:
random.shuffle(tags)
text = text.replace("[filewords]", ','.join(tags))
+ text = text.replace("[name]", self.placeholder_token)
return text
def __len__(self):
@@ -144,7 +144,7 @@ class PersonalizedDataLoader(DataLoader): self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
-
+
class BatchLoader:
def __init__(self, data):
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4eb75cb5..daf3997b 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -133,7 +133,7 @@ class EmbeddingDatabase: process_file(fullfn, fn)
except Exception:
- print(f"Error loading emedding {fn}:", file=sys.stderr)
+ print(f"Error loading embedding {fn}:", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
continue
@@ -194,7 +194,7 @@ def write_loss(log_directory, filename, step, epoch_len, values): csv_writer.writeheader()
epoch = (step - 1) // epoch_len
- epoch_step = (step - 1) % epoch_len
+ epoch_step = (step - 1) % epoch_len
csv_writer.writerow({
"step": step,
@@ -269,9 +269,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
-
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
pin_memory = shared.opts.pin_memory
-
+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
latent_sampling_method = ds.latent_sampling_method
@@ -279,6 +280,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
@@ -293,12 +295,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ loss_step = 0
_loss_step = 0 #internal
-
+
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
-
+
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
@@ -316,7 +318,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ if shared.state.interrupted:
break
- with torch.autocast("cuda"):
+ with devices.autocast():
# c = stack_conds(batch.cond).to(devices.device)
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask)
@@ -325,10 +327,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ c = shared.sd_model.cond_stage_model(batch.cond_text)
loss = shared.sd_model(x, c)[0] / gradient_step
del x
-
+
_loss_step += loss.item()
scaler.scale(loss).backward()
-
+
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
@@ -450,6 +452,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return embedding, filename
|