diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-01-04 11:09:13 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-01-04 11:09:13 +0000 |
commit | 7bbd984dda22b1392b34ca5d48b7a412216e047b (patch) | |
tree | 44f23ebb82ac6c9f595f00c1beb717049447ee96 /modules/textual_inversion/textual_inversion.py | |
parent | 545ae8cb1c17a038a5bfd126546acf4b9cab334c (diff) | |
parent | bddebe09edeb6a18f2c06986d5658a7be3a563ea (diff) | |
download | stable-diffusion-webui-gfx803-7bbd984dda22b1392b34ca5d48b7a412216e047b.tar.gz stable-diffusion-webui-gfx803-7bbd984dda22b1392b34ca5d48b7a412216e047b.tar.bz2 stable-diffusion-webui-gfx803-7bbd984dda22b1392b34ca5d48b7a412216e047b.zip |
Merge pull request #6253 from Shondoit/ti-optim
Save Optimizer next to TI embedding
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 40 |
1 files changed, 32 insertions, 8 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 2c1251d6..214db01c 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -28,6 +28,7 @@ class Embedding: self.cached_checksum = None
self.sd_checkpoint = None
self.sd_checkpoint_name = None
+ self.optimizer_state_dict = None
def save(self, filename):
embedding_data = {
@@ -41,6 +42,13 @@ class Embedding: torch.save(embedding_data, filename)
+ if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
+ optimizer_saved_dict = {
+ 'hash': self.checksum(),
+ 'optimizer_state_dict': self.optimizer_state_dict,
+ }
+ torch.save(optimizer_saved_dict, filename + '.optim')
+
def checksum(self):
if self.cached_checksum is not None:
return self.cached_checksum
@@ -95,9 +103,10 @@ class EmbeddingDatabase: self.expected_shape = self.get_expected_shape()
def process_file(path, filename):
- name = os.path.splitext(filename)[0]
+ name, ext = os.path.splitext(filename)
+ ext = ext.upper()
- if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
+ if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
embed_image = Image.open(path)
if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -105,8 +114,10 @@ class EmbeddingDatabase: else:
data = extract_image_data_embed(embed_image)
name = data.get('name', name)
- else:
+ elif ext in ['.BIN', '.PT']:
data = torch.load(path, map_location="cpu")
+ else:
+ return
# textual inversion embeddings
if 'string_to_param' in data:
@@ -301,6 +312,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
+ if shared.opts.save_optimizer_state:
+ optimizer_state_dict = None
+ if os.path.exists(filename + '.optim'):
+ optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
+ if embedding.checksum() == optimizer_saved_dict.get('hash', None):
+ optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
+
+ if optimizer_state_dict is not None:
+ optimizer.load_state_dict(optimizer_state_dict)
+ print("Loaded existing optimizer from checkpoint")
+ else:
+ print("No saved optimizer exists in checkpoint")
+
+
scaler = torch.cuda.amp.GradScaler()
batch_size = ds.batch_size
@@ -367,9 +392,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_ # Before saving, change name to match current checkpoint.
embedding_name_every = f'{embedding_name}-{steps_done}'
last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
- #if shared.opts.save_optimizer_state:
- #embedding.optimizer_state_dict = optimizer.state_dict()
- save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
embedding_yet_to_be_embedded = True
write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
@@ -459,7 +482,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> </p>
"""
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
- save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
except Exception:
print(traceback.format_exc(), file=sys.stderr)
pass
@@ -471,7 +494,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> return embedding, filename
-def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
+def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
old_embedding_name = embedding.name
old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
@@ -482,6 +505,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache if remove_cached_checksum:
embedding.cached_checksum = None
embedding.name = embedding_name
+ embedding.optimizer_state_dict = optimizer.state_dict()
embedding.save(filename)
except:
embedding.sd_checkpoint = old_sd_checkpoint
|