aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion
diff options
context:
space:
mode:
Diffstat (limited to 'modules/textual_inversion')
-rw-r--r--modules/textual_inversion/dataset.py10
-rw-r--r--modules/textual_inversion/learn_schedule.py11
-rw-r--r--modules/textual_inversion/logging.py24
-rw-r--r--modules/textual_inversion/preprocess.py1
-rw-r--r--modules/textual_inversion/textual_inversion.py149
5 files changed, 145 insertions, 50 deletions
diff --git a/modules/textual_inversion/dataset.py b/modules/textual_inversion/dataset.py
index 2dc64c3c..88d68c76 100644
--- a/modules/textual_inversion/dataset.py
+++ b/modules/textual_inversion/dataset.py
@@ -28,9 +28,9 @@ class DatasetEntry:
class PersonalizedBase(Dataset):
- def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
+ def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):
re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
-
+
self.placeholder_token = placeholder_token
self.width = width
@@ -50,14 +50,14 @@ class PersonalizedBase(Dataset):
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
-
+
self.shuffle_tags = shuffle_tags
self.tag_drop_out = tag_drop_out
print("Preparing dataset...")
for path in tqdm.tqdm(self.image_paths):
if shared.state.interrupted:
- raise Exception("inturrupted")
+ raise Exception("interrupted")
try:
image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
except Exception:
@@ -144,7 +144,7 @@ class PersonalizedDataLoader(DataLoader):
self.collate_fn = collate_wrapper_random
else:
self.collate_fn = collate_wrapper
-
+
class BatchLoader:
def __init__(self, data):
diff --git a/modules/textual_inversion/learn_schedule.py b/modules/textual_inversion/learn_schedule.py
index dd0c0ad1..f63fc72f 100644
--- a/modules/textual_inversion/learn_schedule.py
+++ b/modules/textual_inversion/learn_schedule.py
@@ -58,14 +58,19 @@ class LearnRateScheduler:
self.finished = False
- def apply(self, optimizer, step_number):
+ def step(self, step_number):
if step_number < self.end_step:
- return
+ return False
try:
(self.learn_rate, self.end_step) = next(self.schedules)
- except Exception:
+ except StopIteration:
self.finished = True
+ return False
+ return True
+
+ def apply(self, optimizer, step_number):
+ if not self.step(step_number):
return
if self.verbose:
diff --git a/modules/textual_inversion/logging.py b/modules/textual_inversion/logging.py
new file mode 100644
index 00000000..8b1981d5
--- /dev/null
+++ b/modules/textual_inversion/logging.py
@@ -0,0 +1,24 @@
+import datetime
+import json
+import os
+
+saved_params_shared = {"model_name", "model_hash", "initial_step", "num_of_dataset_images", "learn_rate", "batch_size", "data_root", "log_directory", "training_width", "training_height", "steps", "create_image_every", "template_file"}
+saved_params_ti = {"embedding_name", "num_vectors_per_token", "save_embedding_every", "save_image_with_stored_embedding"}
+saved_params_hypernet = {"hypernetwork_name", "layer_structure", "activation_func", "weight_init", "add_layer_norm", "use_dropout", "save_hypernetwork_every"}
+saved_params_all = saved_params_shared | saved_params_ti | saved_params_hypernet
+saved_params_previews = {"preview_prompt", "preview_negative_prompt", "preview_steps", "preview_sampler_index", "preview_cfg_scale", "preview_seed", "preview_width", "preview_height"}
+
+
+def save_settings_to_file(log_directory, all_params):
+ now = datetime.datetime.now()
+ params = {"datetime": now.strftime("%Y-%m-%d %H:%M:%S")}
+
+ keys = saved_params_all
+ if all_params.get('preview_from_txt2img'):
+ keys = keys | saved_params_previews
+
+ params.update({k: v for k, v in all_params.items() if k in keys})
+
+ filename = f'settings-{now.strftime("%Y-%m-%d-%H-%M-%S")}.json'
+ with open(os.path.join(log_directory, filename), "w") as file:
+ json.dump(params, file, indent=4)
diff --git a/modules/textual_inversion/preprocess.py b/modules/textual_inversion/preprocess.py
index 56b9b2eb..feb876c6 100644
--- a/modules/textual_inversion/preprocess.py
+++ b/modules/textual_inversion/preprocess.py
@@ -124,6 +124,7 @@ def preprocess_work(process_src, process_dst, process_width, process_height, pre
files = listfiles(src)
+ shared.state.job = "preprocess"
shared.state.textinfo = "Preprocessing..."
shared.state.job_count = len(files)
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index e28c357a..45882ed6 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,6 +1,7 @@
import os
import sys
import traceback
+import inspect
import torch
import tqdm
@@ -17,15 +18,20 @@ from modules.textual_inversion.learn_schedule import LearnRateScheduler
from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
insert_image_data_embed, extract_image_data_embed,
caption_image_overlay)
+from modules.textual_inversion.logging import save_settings_to_file
+
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
self.name = name
self.step = step
+ self.shape = None
+ self.vectors = 0
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
@@ -39,6 +45,13 @@ class Embedding:
torch.save(embedding_data, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+ optimizer_saved_dict = {
+ 'hash': self.checksum(),
+ 'optimizer_state_dict': self.optimizer_state_dict,
+ }
+ torch.save(optimizer_saved_dict, filename + '.optim')
+
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
@@ -57,14 +70,15 @@ class EmbeddingDatabase:
def __init__(self, embeddings_dir):
self.ids_lookup = {}
self.word_embeddings = {}
+ self.skipped_embeddings = {}
self.dir_mtime = None
self.embeddings_dir = embeddings_dir
+ self.expected_shape = -1
def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
- # TODO changing between clip and open clip changes tokenization, which will cause embeddings to stop working
ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
@@ -75,21 +89,26 @@ class EmbeddingDatabase:
return embedding
- def load_textual_inversion_embeddings(self):
+ def get_expected_shape(self):
+ vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
+ return vec.shape[1]
+
+ def load_textual_inversion_embeddings(self, force_reload = False):
mt = os.path.getmtime(self.embeddings_dir)
- if self.dir_mtime is not None and mt <= self.dir_mtime:
+ if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
return
self.dir_mtime = mt
self.ids_lookup.clear()
self.word_embeddings.clear()
+ self.skipped_embeddings.clear()
+ self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
- name = os.path.splitext(filename)[0]
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- data = []
-
- if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -97,8 +116,10 @@ class EmbeddingDatabase:
else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
- else:
+ elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ else:
+ return
# textual inversion embeddings
if 'string_to_param' in data:
@@ -122,23 +143,31 @@ class EmbeddingDatabase:
embedding.step = data.get('step', None)
embedding.sd_checkpoint = data.get('sd_checkpoint', None)
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
- self.register_embedding(embedding, shared.sd_model)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
- for fn in os.listdir(self.embeddings_dir):
- try:
- fullfn = os.path.join(self.embeddings_dir, fn)
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
+ self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings[name] = embedding
- if os.stat(fullfn).st_size == 0:
- continue
+ for root, dirs, fns in os.walk(self.embeddings_dir):
+ for fn in fns:
+ try:
+ fullfn = os.path.join(root, fn)
- process_file(fullfn, fn)
- except Exception:
- print(f"Error loading emedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
+ if os.stat(fullfn).st_size == 0:
+ continue
- print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
- print("Embeddings:", ', '.join(self.word_embeddings.keys()))
+ process_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading embedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+ continue
+
+ print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
+ if len(self.skipped_embeddings) > 0:
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
@@ -194,7 +223,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
csv_writer.writeheader()
epoch = (step - 1) // epoch_len
- epoch_step = (step - 1) % epoch_len
+ epoch_step = (step - 1) % epoch_len
csv_writer.writerow({
"step": step,
@@ -203,6 +232,7 @@ def write_loss(log_directory, filename, step, epoch_len, values):
**values,
})
+
def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_model_every, create_image_every, log_directory, name="embedding"):
assert model_name, f"{name} not selected"
assert learn_rate, "Learning rate is empty or 0"
@@ -225,11 +255,12 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
if save_model_every or create_image_every:
assert log_directory, "Log directory is empty"
-def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
save_embedding_every = save_embedding_every or 0
create_image_every = create_image_every or 0
validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
@@ -263,18 +294,26 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
initial_step = embedding.step or 0
if initial_step >= steps:
- shared.state.textinfo = f"Model has already been trained beyond specified max steps"
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
-
- # dataset loading may take a while, so input validations and early returns should be done before this
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
+ # dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
old_parallel_processing_allowed = shared.parallel_processing_allowed
-
+
pin_memory = shared.opts.pin_memory
-
+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
+ if shared.opts.save_training_settings_to_txt:
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
+
latent_sampling_method = ds.latent_sampling_method
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
@@ -285,6 +324,19 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ if shared.opts.save_optimizer_state:
+ optimizer_state_dict = None
+ if os.path.exists(filename + '.optim'):
+ optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+
+ if optimizer_state_dict is not None:
+ optimizer.load_state_dict(optimizer_state_dict)
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
@@ -295,12 +347,14 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
loss_step = 0
_loss_step = 0 #internal
-
last_saved_file = "<none>"
last_saved_image = "<none>"
forced_filename = "<none>"
embedding_yet_to_be_embedded = False
-
+
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+ img_c = None
+
pbar = tqdm.tqdm(total=steps - initial_step)
try:
for i in range((steps-initial_step) * gradient_step):
@@ -318,22 +372,34 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if shared.state.interrupted:
break
+ if clip_grad:
+ clip_grad_sched.step(embedding.step)
+
with devices.autocast():
- # c = stack_conds(batch.cond).to(devices.device)
- # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
- # print(mask)
- # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
c = shared.sd_model.cond_stage_model(batch.cond_text)
- loss = shared.sd_model(x, c)[0] / gradient_step
+
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
+
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
+
+ loss = shared.sd_model(x, cond)[0] / gradient_step
del x
-
+
_loss_step += loss.item()
scaler.scale(loss).backward()
-
+
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
+
+ if clip_grad:
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
+
scaler.step(optimizer)
scaler.update()
embedding.step += 1
@@ -352,9 +418,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
- #if shared.opts.save_optimizer_state:
- #embedding.optimizer_state_dict = optimizer.state_dict()
- save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
@@ -444,7 +508,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
@@ -456,7 +520,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
return embedding, filename
-def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
@@ -467,6 +531,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
if remove_cached_checksum:
embedding.cached_checksum = None
embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename)
except:
embedding.sd_checkpoint = old_sd_checkpoint