aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorMalumaDev <piano.lu92@gmail.com>2022-10-15 22:06:36 +0000
committerGitHub <noreply@github.com>2022-10-15 22:06:36 +0000
commit97ceaa23d00f6a17ca752dda757e6016f99230cb (patch)
tree1c44afefae393779f57ede3578422d08330cceac /modules/textual_inversion/textual_inversion.py
parent3d21684ee30ca5734126b8d08c05b3a0f513fe75 (diff)
parentbe1596ce30b1ead6998da0c62003003dcce5eb2c (diff)
downloadstable-diffusion-webui-gfx803-97ceaa23d00f6a17ca752dda757e6016f99230cb.tar.gz
stable-diffusion-webui-gfx803-97ceaa23d00f6a17ca752dda757e6016f99230cb.tar.bz2
stable-diffusion-webui-gfx803-97ceaa23d00f6a17ca752dda757e6016f99230cb.zip
Merge branch 'master' into test_resolve_conflicts
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py17
1 files changed, 13 insertions, 4 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index f59b47a9..d2a389c9 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -88,9 +88,9 @@ class EmbeddingDatabase:
data = []
- if filename.upper().endswith('.PNG'):
+ if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
- if 'sd-ti-embedding' in embed_image.text:
+ if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
name = data.get('name', name)
else:
@@ -242,6 +242,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
last_saved_file = "<none>"
last_saved_image = "<none>"
+ embedding_yet_to_be_embedded = False
ititial_step = embedding.step or 0
if ititial_step > steps:
@@ -283,6 +284,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
embedding.save(last_saved_file)
+ embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
"loss": f"{losses.mean():.7f}",
@@ -320,7 +322,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
shared.state.current_image = image
- if save_image_with_stored_embedding and os.path.exists(last_saved_file):
+ if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{embedding.step}.png')
@@ -329,15 +331,22 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
info.add_text("sd-ti-embedding", embedding_to_b64(data))
title = "<{}>".format(data.get('name', '???'))
+
+ try:
+ vectorSize = list(data['string_to_param'].values())[0].shape[0]
+ except Exception as e:
+ vectorSize = '?'
+
checkpoint = sd_models.select_checkpoint()
footer_left = checkpoint.model_name
footer_mid = '[{}]'.format(checkpoint.hash)
- footer_right = '{}'.format(embedding.step)
+ footer_right = '{}v {}s'.format(vectorSize, embedding.step)
captioned_image = caption_image_overlay(image, title, footer_left, footer_mid, footer_right)
captioned_image = insert_image_data_embed(captioned_image, data)
captioned_image.save(last_saved_image_chunks, "PNG", pnginfo=info)
+ embedding_yet_to_be_embedded = False
image.save(last_saved_image)