aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorGreg Fuller <gfuller23@gmail.com>2022-10-12 19:44:41 +0000
committerGreg Fuller <gfuller23@gmail.com>2022-10-12 19:44:41 +0000
commitfb3cefb348600497d964c4fd3f99138d7dde0ec5 (patch)
tree984c02da13dae6ab0d6c21304a551bbfffea02f2 /modules/textual_inversion/textual_inversion.py
parentd717eb079cd6b7fa7a4f97c0a10d400bdec753fb (diff)
parent698d303b04e293635bfb49c525409f3bcf671dce (diff)
downloadstable-diffusion-webui-gfx803-fb3cefb348600497d964c4fd3f99138d7dde0ec5.tar.gz
stable-diffusion-webui-gfx803-fb3cefb348600497d964c4fd3f99138d7dde0ec5.tar.bz2
stable-diffusion-webui-gfx803-fb3cefb348600497d964c4fd3f99138d7dde0ec5.zip
Merge remote-tracking branch 'upstream/master' into interrogate_include_ranks_in_output
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py77
1 files changed, 55 insertions, 22 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 7717837d..fa0e33a2 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -7,11 +7,15 @@ import tqdm
import html
import datetime
+from PIL import Image, PngImagePlugin
from modules import shared, devices, sd_hijack, processing, sd_models
import modules.textual_inversion.dataset
-from modules.textual_inversion.learn_schedule import LearnSchedule
+from modules.textual_inversion.learn_schedule import LearnRateScheduler
+from modules.textual_inversion.image_embedding import (embedding_to_b64, embedding_from_b64,
+ insert_image_data_embed, extract_image_data_embed,
+ caption_image_overlay)
class Embedding:
def __init__(self, vec, name, step=None):
@@ -81,7 +85,18 @@ class EmbeddingDatabase:
def process_file(path, filename):
name = os.path.splitext(filename)[0]
- data = torch.load(path, map_location="cpu")
+ data = []
+
+ if filename.upper().endswith('.PNG'):
+ embed_image = Image.open(path)
+ if 'sd-ti-embedding' in embed_image.text:
+ data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
+ name = data.get('name', name)
+ else:
+ data = extract_image_data_embed(embed_image)
+ name = data.get('name', name)
+ else:
+ data = torch.load(path, map_location="cpu")
# textual inversion embeddings
if 'string_to_param' in data:
@@ -157,7 +172,7 @@ def create_embedding(name, num_vectors_per_token, init_text='*'):
return fn
-def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, num_repeats, create_image_every, save_embedding_every, template_file, preview_image_prompt):
+def train_embedding(embedding_name, learn_rate, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_image_prompt):
assert embedding_name, 'embedding not selected'
shared.state.textinfo = "Initializing textual inversion training..."
@@ -179,11 +194,17 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
else:
images_dir = None
+ if create_image_every > 0 and save_image_with_stored_embedding:
+ images_embeds_dir = os.path.join(log_directory, "image_embeddings")
+ os.makedirs(images_embeds_dir, exist_ok=True)
+ else:
+ images_embeds_dir = None
+
cond_model = shared.sd_model.cond_stage_model
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
- ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=num_repeats, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
hijack = sd_hijack.model_hijack
@@ -199,32 +220,24 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if ititial_step > steps:
return embedding, filename
- schedules = iter(LearnSchedule(learn_rate, steps, ititial_step))
- (learn_rate, end_step) = next(schedules)
- print(f'Training at rate of {learn_rate} until step {end_step}')
-
- optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
+ scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
+ optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
- for i, (x, text, _) in pbar:
+ for i, entry in pbar:
embedding.step = i + ititial_step
- if embedding.step > end_step:
- try:
- (learn_rate, end_step) = next(schedules)
- except:
- break
- tqdm.tqdm.write(f'Training at rate of {learn_rate} until step {end_step}')
- for pg in optimizer.param_groups:
- pg['lr'] = learn_rate
+ scheduler.apply(optimizer, embedding.step)
+ if scheduler.finished:
+ break
if shared.state.interrupted:
break
with torch.autocast("cuda"):
- c = cond_model([text])
+ c = cond_model([entry.cond_text])
- x = x.to(devices.device)
+ x = entry.latent.to(devices.device)
loss = shared.sd_model(x.unsqueeze(0), c)[0]
del x
@@ -246,7 +259,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
- preview_text = text if preview_image_prompt == "" else preview_image_prompt
+ preview_text = entry.cond_text if preview_image_prompt == "" else preview_image_prompt
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
@@ -262,6 +275,26 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
image = processed.images[0]
shared.state.current_image = image
+
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file):
+
+ last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
+
+ info = PngImagePlugin.PngInfo()
+ data = torch.load(last_saved_file)
+ info.add_text("sd-ti-embedding", embedding_to_b64(data))
+
+ title = "<{}>".format(data.get('name', '???'))
+ checkpoint = sd_models.select_checkpoint()
+ footer_left = checkpoint.model_name
+ footer_mid = '[{}]'.format(checkpoint.hash)
+ footer_right = '{}'.format(embedding.step)
+
+ captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
+ captioned_image = insert_image_data_embed(captioned_image, data)
+
+ captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+
image.save(last_saved_image)
last_saved_image += f", prompt: {preview_text}"
@@ -272,7 +305,7 @@ def train_embedding(embedding_name, learn_rate, data_root, log_directory, traini
<p>
Loss: {losses.mean():.7f}<br/>
Step: {embedding.step}<br/>
-Last prompt: {html.escape(text)}<br/>
+Last prompt: {html.escape(entry.cond_text)}<br/>
Last saved embedding: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>