aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-01-13 11:57:38 +0000
committerGitHub <noreply@github.com>2023-01-13 11:57:38 +0000
commit9cd7716753c5be47f76b8e5555cc3e7c0f17d34d (patch)
tree345be78dd1991b77fcf4519bc44097e975e0b0c4 /modules/textual_inversion/textual_inversion.py
parent18f86e41f6f289042c075bff1498e620ab997b8c (diff)
parent544e7a233e994f379dd67df08f5f519290b10293 (diff)
downloadstable-diffusion-webui-gfx803-9cd7716753c5be47f76b8e5555cc3e7c0f17d34d.tar.gz
stable-diffusion-webui-gfx803-9cd7716753c5be47f76b8e5555cc3e7c0f17d34d.tar.bz2
stable-diffusion-webui-gfx803-9cd7716753c5be47f76b8e5555cc3e7c0f17d34d.zip
Merge branch 'master' into tensorboard
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py639
1 files changed, 434 insertions, 205 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 589314fe..85210b0e 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -1,32 +1,54 @@
import os
import sys
import traceback
+import inspect
+from collections import namedtuple
import torch
import tqdm
import html
import datetime
import csv
-import numpy as np
+import safetensors.torch
from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter
-from modules import shared, devices, sd_hijack, processing, sd_models
+
+from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset
from modules.textual_inversion.learn_schedule import LearnRateScheduler
-from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
- insert_image_data_embed, extract_image_data_embed,
- caption_image_overlay)
+from modules.textual_inversion.image_embedding import embedding_to_b64, embedding_from_b64, insert_image_data_embed, extract_image_data_embed, caption_image_overlay
+from modules.textual_inversion.logging import save_settings_to_file
+
+
+TextualInversionTemplate = namedtuple("TextualInversionTemplate", ["name", "path"])
+textual_inversion_templates = {}
+
+
+def list_textual_inversion_templates():
+ textual_inversion_templates.clear()
+
+ for root, dirs, fns in os.walk(shared.cmd_opts.textual_inversion_templates_dir):
+ for fn in fns:
+ path = os.path.join(root, fn)
+
+ textual_inversion_templates[fn] = TextualInversionTemplate(fn, path)
+
+ return textual_inversion_templates
+
class Embedding:
def __init__(self, vec, name, step=None):
self.vec = vec
self.name = name
self.step = step
+ self.shape = None
+ self.vectors = 0
self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
@@ -40,6 +62,13 @@ class Embedding:
torch.save(embedding_data, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+ optimizer_saved_dict = {
+ 'hash': self.checksum(),
+ 'optimizer_state_dict': self.optimizer_state_dict,
+ }
+ torch.save(optimizer_saved_dict, filename + '.optim')
+
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
@@ -54,18 +83,44 @@ class Embedding:
return self.cached_checksum
+class DirWithTextualInversionEmbeddings:
+ def __init__(self, path):
+ self.path = path
+ self.mtime = None
+
+ def has_changed(self):
+ if not os.path.isdir(self.path):
+ return False
+
+ mt = os.path.getmtime(self.path)
+ if self.mtime is None or mt > self.mtime:
+ return True
+
+ def update(self):
+ if not os.path.isdir(self.path):
+ return
+
+ self.mtime = os.path.getmtime(self.path)
+
+
class EmbeddingDatabase:
- def __init__(self, embeddings_dir):
+ def __init__(self):
self.ids_lookup = {}
self.word_embeddings = {}
- self.dir_mtime = None
- self.embeddings_dir = embeddings_dir
+ self.skipped_embeddings = {}
+ self.expected_shape = -1
+ self.embedding_dirs = {}
- def register_embedding(self, embedding, model):
+ def add_embedding_dir(self, path):
+ self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)
+ def clear_embedding_dirs(self):
+ self.embedding_dirs.clear()
+
+ def register_embedding(self, embedding, model):
self.word_embeddings[embedding.name] = embedding
- ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
+ ids = model.cond_stage_model.tokenize([embedding.name])[0]
first_id = ids[0]
if first_id not in self.ids_lookup:
@@ -75,70 +130,104 @@ class EmbeddingDatabase:
return embedding
- def load_textual_inversion_embeddings(self):
- mt = os.path.getmtime(self.embeddings_dir)
- if self.dir_mtime is not None and mt <= self.dir_mtime:
- return
+ def get_expected_shape(self):
+ vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
+ return vec.shape[1]
- self.dir_mtime = mt
- self.ids_lookup.clear()
- self.word_embeddings.clear()
+ def load_from_file(self, path, filename):
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- def process_file(path, filename):
- name = os.path.splitext(filename)[0]
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ _, second_ext = os.path.splitext(name)
+ if second_ext.upper() == '.PREVIEW':
+ return
- data = []
-
- if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
- embed_image = Image.open(path)
- if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
- data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
- name = data.get('name', name)
- else:
- data = extract_image_data_embed(embed_image)
- name = data.get('name', name)
- else:
- data = torch.load(path, map_location="cpu")
-
- # textual inversion embeddings
- if 'string_to_param' in data:
- param_dict = data['string_to_param']
- if hasattr(param_dict, '_parameters'):
- param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
- assert len(param_dict) == 1, 'embedding file has multiple terms in it'
- emb = next(iter(param_dict.items()))[1]
- # diffuser concepts
- elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
- assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
-
- emb = next(iter(data.values()))
- if len(emb.shape) == 1:
- emb = emb.unsqueeze(0)
+ embed_image = Image.open(path)
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name', name)
else:
- raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+ data = extract_image_data_embed(embed_image)
+ name = data.get('name', name)
+ elif ext in ['.BIN', '.PT']:
+ data = torch.load(path, map_location="cpu")
+ elif ext in ['.SAFETENSORS']:
+ data = safetensors.torch.load_file(path, device="cpu")
+ else:
+ return
- vec = emb.detach().to(devices.device, dtype=torch.float32)
- embedding = Embedding(vec, name)
- embedding.step = data.get('step', None)
- embedding.sd_checkpoint = data.get('hash', None)
- embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ # textual inversion embeddings
+ if 'string_to_param' in data:
+ param_dict = data['string_to_param']
+ if hasattr(param_dict, '_parameters'):
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
+ emb = next(iter(param_dict.items()))[1]
+ # diffuser concepts
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
+
+ emb = next(iter(data.values()))
+ if len(emb.shape) == 1:
+ emb = emb.unsqueeze(0)
+ else:
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
+
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
+ embedding = Embedding(vec, name)
+ embedding.step = data.get('step', None)
+ embedding.sd_checkpoint = data.get('sd_checkpoint', None)
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
+ embedding.vectors = vec.shape[0]
+ embedding.shape = vec.shape[-1]
+
+ if self.expected_shape == -1 or self.expected_shape == embedding.shape:
self.register_embedding(embedding, shared.sd_model)
+ else:
+ self.skipped_embeddings[name] = embedding
+
+ def load_from_dir(self, embdir):
+ if not os.path.isdir(embdir.path):
+ return
+
+ for root, dirs, fns in os.walk(embdir.path):
+ for fn in fns:
+ try:
+ fullfn = os.path.join(root, fn)
- for fn in os.listdir(self.embeddings_dir):
- try:
- fullfn = os.path.join(self.embeddings_dir, fn)
+ if os.stat(fullfn).st_size == 0:
+ continue
- if os.stat(fullfn).st_size == 0:
+ self.load_from_file(fullfn, fn)
+ except Exception:
+ print(f"Error loading embedding {fn}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
continue
- process_file(fullfn, fn)
- except Exception:
- print(f"Error loading emedding {fn}:", file=sys.stderr)
- print(traceback.format_exc(), file=sys.stderr)
- continue
+ def load_textual_inversion_embeddings(self, force_reload=False):
+ if not force_reload:
+ need_reload = False
+ for path, embdir in self.embedding_dirs.items():
+ if embdir.has_changed():
+ need_reload = True
+ break
- print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
- print("Embeddings:", ', '.join(self.word_embeddings.keys()))
+ if not need_reload:
+ return
+
+ self.ids_lookup.clear()
+ self.word_embeddings.clear()
+ self.skipped_embeddings.clear()
+ self.expected_shape = self.get_expected_shape()
+
+ for path, embdir in self.embedding_dirs.items():
+ self.load_from_dir(embdir)
+ embdir.update()
+
+ print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
+ if len(self.skipped_embeddings) > 0:
+ print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
def find_embedding_at_position(self, tokens, offset):
token = tokens[offset]
@@ -154,19 +243,26 @@ class EmbeddingDatabase:
return None, None
-def create_embedding(name, num_vectors_per_token, init_text='*'):
+def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
cond_model = shared.sd_model.cond_stage_model
- embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
- ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
- embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
+ with devices.autocast():
+ cond_model([""]) # will send cond model to GPU if lowvram/medvram is active
+
+ #cond_model expects at least some text, so we provide '*' as backup.
+ embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
- for i in range(num_vectors_per_token):
- vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ #Only copy if we provided an init_text, otherwise keep vectors as zeros
+ if init_text:
+ for i in range(num_vectors_per_token):
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
+ # Remove illegal characters from name.
+ name = "".join( x for x in name if (x.isalnum() or x in "._- "))
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
- assert not os.path.exists(fn), f"file {fn} already exists"
+ if not overwrite_old:
+ assert not os.path.exists(fn), f"file {fn} already exists"
embedding = Embedding(vec, name)
embedding.step = 0
@@ -181,7 +277,6 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if step % shared.opts.training_write_csv_every != 0:
return
-
write_csv_header = False if os.path.exists(os.path.join(log_directory, filename)) else True
with open(os.path.join(log_directory, filename), "a+", newline='') as fout:
@@ -190,13 +285,13 @@ def write_loss(log_directory, filename, step, epoch_len, values):
if write_csv_header:
csv_writer.writeheader()
- epoch = step // epoch_len
- epoch_step = step - epoch * epoch_len
+ epoch = (step - 1) // epoch_len
+ epoch_step = (step - 1) % epoch_len
csv_writer.writerow({
- "step": step + 1,
- "epoch": epoch + 1,
- "epoch_step": epoch_step + 1,
+ "step": step,
+ "epoch": epoch,
+ "epoch_step": epoch_step,
**values,
})
@@ -225,15 +320,45 @@ def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
tensorboard_writer.add_image(tag, img_tensor, global_step=step)
-def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
- assert embedding_name, 'embedding not selected'
-
+def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
+ assert model_name, f"{name} not selected"
+ assert learn_rate, "Learning rate is empty or 0"
+ assert isinstance(batch_size, int), "Batch size must be integer"
+ assert batch_size > 0, "Batch size must be positive"
+ assert isinstance(gradient_step, int), "Gradient accumulation step must be integer"
+ assert gradient_step > 0, "Gradient accumulation step must be positive"
+ assert data_root, "Dataset directory is empty"
+ assert os.path.isdir(data_root), "Dataset directory doesn't exist"
+ assert os.listdir(data_root), "Dataset directory is empty"
+ assert template_filename, "Prompt template file not selected"
+ assert template_file, f"Prompt template file {template_filename} not found"
+ assert os.path.isfile(template_file.path), f"Prompt template file {template_filename} doesn't exist"
+ assert steps, "Max steps is empty or 0"
+ assert isinstance(steps, int), "Max steps must be integer"
+ assert steps > 0, "Max steps must be positive"
+ assert isinstance(save_model_every, int), "Save {name} must be integer"
+ assert save_model_every >= 0, "Save {name} must be positive or 0"
+ assert isinstance(create_image_every, int), "Create image must be integer"
+ assert create_image_every >= 0, "Create image must be positive or 0"
+ if save_model_every or create_image_every:
+ assert log_directory, "Log directory is empty"
+
+
+def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
+ save_embedding_every = save_embedding_every or 0
+ create_image_every = create_image_every or 0
+ template_file = textual_inversion_templates.get(template_filename, None)
+ validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
+ template_file = template_file.path
+
+ shared.state.job = "train-embedding"
shared.state.textinfo = "Initializing textual inversion training..."
shared.state.job_count = steps
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
+ unload = shared.opts.unload_models_when_training
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
@@ -252,160 +377,264 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
os.makedirs(images_embeds_dir, exist_ok=True)
else:
images_embeds_dir = None
-
- cond_model = shared.sd_model.cond_stage_model
-
- shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
- with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
hijack = sd_hijack.model_hijack
embedding = hijack.embedding_db.word_embeddings[embedding_name]
- embedding.vec.requires_grad = True
-
- losses = torch.zeros((32,))
-
- last_saved_file = "<none>"
- last_saved_image = "<none>"
- embedding_yet_to_be_embedded = False
+ checkpoint = sd_models.select_checkpoint()
initial_step = embedding.step or 0
- if initial_step > steps:
+ if initial_step >= steps:
+ shared.state.textinfo = "Model has already been trained beyond specified max steps"
return embedding, filename
-
+
scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
- optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
-
+ clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
+ torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
+ None
+ if clip_grad:
+ clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)
+ # dataset loading may take a while, so input validations and early returns should be done before this
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
if shared.opts.training_enable_tensorboard:
tensorboard_writer = tensorboard_setup(log_directory)
- pbar = tqdm.tqdm(enumerate(ds), total=steps-initial_step)
- for i, entries in pbar:
- embedding.step = i + initial_step
-
- scheduler.apply(optimizer, embedding.step)
- if scheduler.finished:
- break
-
- if shared.state.interrupted:
- break
-
- with torch.autocast("cuda"):
- c = cond_model([entry.cond_text for entry in entries])
- x = torch.stack([entry.latent for entry in entries]).to(devices.device)
- loss = shared.sd_model(x, c)[0]
- del x
-
- losses[embedding.step % losses.shape[0]] = loss.item()
-
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- epoch_num = embedding.step // len(ds)
- epoch_step = embedding.step - (epoch_num * len(ds)) + 1
-
- pbar.set_description(f"[Epoch {epoch_num}: {epoch_step}/{len(ds)}]loss: {losses.mean():.7f}")
-
- if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
- last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
- embedding.save(last_saved_file)
- embedding_yet_to_be_embedded = True
+ pin_memory = shared.opts.pin_memory
- if shared.opts.training_enable_tensorboard:
- tensorboard_add(tensorboard_writer, loss=losses.mean(), global_step=embedding.step,
- step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
- write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
- "loss": f"{losses.mean():.7f}",
- "learn_rate": scheduler.learn_rate
- })
-
- if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
- last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
-
- p = processing.StableDiffusionProcessingTxt2Img(
- sd_model=shared.sd_model,
- do_not_save_grid=True,
- do_not_save_samples=True,
- do_not_reload_embeddings=True,
- )
-
- if preview_from_txt2img:
- p.prompt = preview_prompt
- p.negative_prompt = preview_negative_prompt
- p.steps = preview_steps
- p.sampler_index = preview_sampler_index
- p.cfg_scale = preview_cfg_scale
- p.seed = preview_seed
- p.width = preview_width
- p.height = preview_height
- else:
- p.prompt = entries[0].cond_text
- p.steps = 20
- p.width = training_width
- p.height = training_height
-
- preview_text = p.prompt
-
- processed = processing.process_images(p)
- image = processed.images[0]
+ if shared.opts.save_training_settings_to_txt:
+ save_settings_to_file(log_directory, {**dict(model_name=checkpoint.model_name, model_hash=checkpoint.hash, num_of_dataset_images=len(ds), num_vectors_per_token=len(embedding.vec)), **locals()})
- shared.state.current_image = image
+ latent_sampling_method = ds.latent_sampling_method
- if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
- last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+ if unload:
+ shared.parallel_processing_allowed = False
+ shared.sd_model.first_stage_model.to(devices.cpu)
- info = PngImagePlugin.PngInfo()
- data = torch.load(last_saved_file)
- info.add_text("sd-ti-embedding", embedding_to_b64(data))
-
- title = "<{}>".format(data.get('name', '???'))
-
- try:
- vectorSize = list(data['string_to_param'].values())[0].shape[0]
- except Exception as e:
- vectorSize = '?'
-
- checkpoint = sd_models.select_checkpoint()
- footer_left = checkpoint.model_name
- footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}v {}s'.format(vectorSize, embedding.step)
+ embedding.vec.requires_grad = True
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ if shared.opts.save_optimizer_state:
+ optimizer_state_dict = None
+ if os.path.exists(filename + '.optim'):
+ optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+
+ if optimizer_state_dict is not None:
+ optimizer.load_state_dict(optimizer_state_dict)
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
+ scaler = torch.cuda.amp.GradScaler()
+
+ batch_size = ds.batch_size
+ gradient_step = ds.gradient_step
+ # n steps = batch_size * gradient_step * n image processed
+ steps_per_epoch = len(ds) // batch_size // gradient_step
+ max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
+ loss_step = 0
+ _loss_step = 0 #internal
- captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
- captioned_image = insert_image_data_embed(captioned_image, data)
+ last_saved_file = "<none>"
+ last_saved_image = "<none>"
+ forced_filename = "<none>"
+ embedding_yet_to_be_embedded = False
- captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
- embedding_yet_to_be_embedded = False
+ is_training_inpainting_model = shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}
+ img_c = None
+
+ pbar = tqdm.tqdm(total=steps - initial_step)
+ try:
+ for i in range((steps-initial_step) * gradient_step):
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+ for j, batch in enumerate(dl):
+ # works as a drop_last=True for gradient accumulation
+ if j == max_steps_per_epoch:
+ break
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
+ if shared.state.interrupted:
+ break
+
+ if clip_grad:
+ clip_grad_sched.step(embedding.step)
+
+ with devices.autocast():
+ x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
+ c = shared.sd_model.cond_stage_model(batch.cond_text)
- image.save(last_saved_image)
+ if is_training_inpainting_model:
+ if img_c is None:
+ img_c = processing.txt2img_image_conditioning(shared.sd_model, c, training_width, training_height)
- if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
- tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}",
- image, embedding.step)
+ cond = {"c_concat": [img_c], "c_crossattn": [c]}
+ else:
+ cond = c
- last_saved_image += f", prompt: {preview_text}"
+ loss = shared.sd_model(x, cond)[0] / gradient_step
+ del x
- shared.state.job_no = embedding.step
+ _loss_step += loss.item()
+ scaler.scale(loss).backward()
- shared.state.textinfo = f"""
+ # go back until we reach gradient accumulation steps
+ if (j + 1) % gradient_step != 0:
+ continue
+
+ if clip_grad:
+ clip_grad(embedding.vec, clip_grad_sched.learn_rate)
+
+ scaler.step(optimizer)
+ scaler.update()
+ embedding.step += 1
+ pbar.update()
+ optimizer.zero_grad(set_to_none=True)
+ loss_step = _loss_step
+ _loss_step = 0
+
+ steps_done = embedding.step + 1
+
+ epoch_num = embedding.step // steps_per_epoch
+ epoch_step = embedding.step % steps_per_epoch
+
+ description = f"Training textual inversion [Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}] loss: {loss_step:.7f}"
+ pbar.set_description(description)
+ shared.state.textinfo = description
+ if embedding_dir is not None and steps_done % save_embedding_every == 0:
+ # Before saving, change name to match current checkpoint.
+ embedding_name_every = f'{embedding_name}-{steps_done}'
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ embedding_yet_to_be_embedded = True
+
+ write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
+ "loss": f"{loss_step:.7f}",
+ "learn_rate": scheduler.learn_rate
+ })
+
+ if images_dir is not None and steps_done % create_image_every == 0:
+ forced_filename = f'{embedding_name}-{steps_done}'
+ last_saved_image = os.path.join(images_dir, forced_filename)
+
+ shared.sd_model.first_stage_model.to(devices.device)
+
+ p = processing.StableDiffusionProcessingTxt2Img(
+ sd_model=shared.sd_model,
+ do_not_save_grid=True,
+ do_not_save_samples=True,
+ do_not_reload_embeddings=True,
+ )
+
+ if preview_from_txt2img:
+ p.prompt = preview_prompt
+ p.negative_prompt = preview_negative_prompt
+ p.steps = preview_steps
+ p.sampler_name = sd_samplers.samplers[preview_sampler_index].name
+ p.cfg_scale = preview_cfg_scale
+ p.seed = preview_seed
+ p.width = preview_width
+ p.height = preview_height
+ else:
+ p.prompt = batch.cond_text[0]
+ p.steps = 20
+ p.width = training_width
+ p.height = training_height
+
+ preview_text = p.prompt
+
+ processed = processing.process_images(p)
+ image = processed.images[0] if len(processed.images) > 0 else None
+
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
+ if image is not None:
+ shared.state.current_image = image
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
+ tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')
+
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
+
+ title = "<{}>".format(data.get('name', '???'))
+
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
+
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}v {}s'.format(vectorSize, steps_done)
+
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
+
+ last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
+ last_saved_image += f", prompt: {preview_text}"
+
+ shared.state.job_no = embedding.step
+
+ shared.state.textinfo = f"""
<p>
-Loss: {losses.mean():.7f}<br/>
-Step: {embedding.step}<br/>
-Last prompt: {html.escape(entries[0].cond_text)}<br/>
+Loss: {loss_step:.7f}<br/>
+Step: {steps_done}<br/>
+Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
-
- checkpoint = sd_models.select_checkpoint()
-
- embedding.sd_checkpoint = checkpoint.hash
- embedding.sd_checkpoint_name = checkpoint.model_name
- embedding.cached_checksum = None
- embedding.save(filename)
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ except Exception:
+ print(traceback.format_exc(), file=sys.stderr)
+ pass
+ finally:
+ pbar.leave = False
+ pbar.close()
+ shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return embedding, filename
+
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+ old_embedding_name = embedding.name
+ old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
+ old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
+ old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
+ try:
+ embedding.sd_checkpoint = checkpoint.hash
+ embedding.sd_checkpoint_name = checkpoint.model_name
+ if remove_cached_checksum:
+ embedding.cached_checksum = None
+ embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
+ embedding.save(filename)
+ except:
+ embedding.sd_checkpoint = old_sd_checkpoint
+ embedding.sd_checkpoint_name = old_sd_checkpoint_name
+ embedding.name = old_embedding_name
+ embedding.cached_checksum = old_cached_checksum
+ raise