diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2022-11-01 15:17:56 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-01 15:17:56 +0000 |
commit | f071a1d25aa8b35bb6406a133df1d03ae5ea8d01 (patch) | |
tree | b18ed4c2113d5c23f5e13ed34f905675620de75b /modules/textual_inversion/textual_inversion.py | |
parent | 0e5d239f06110e875291b906e100de775c5d9658 (diff) | |
parent | 890e68aaf75ae80d5eb2fa95b4bf1adf78b96881 (diff) | |
download | stable-diffusion-webui-gfx803-f071a1d25aa8b35bb6406a133df1d03ae5ea8d01.tar.gz stable-diffusion-webui-gfx803-f071a1d25aa8b35bb6406a133df1d03ae5ea8d01.tar.bz2 stable-diffusion-webui-gfx803-f071a1d25aa8b35bb6406a133df1d03ae5ea8d01.zip |
Merge pull request #4056 from MarkovInequality/TI_optimizations
Allow TI training using 6GB VRAM when xformers is available
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r-- | modules/textual_inversion/textual_inversion.py | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index e0babb46..0aeb0459 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -235,6 +235,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
+ unload = shared.opts.unload_models_when_training
if save_embedding_every > 0:
embedding_dir = os.path.join(log_directory, "embeddings")
@@ -272,6 +273,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
@@ -328,6 +331,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc if images_dir is not None and steps_done % create_image_every == 0:
forced_filename = f'{embedding_name}-{steps_done}'
last_saved_image = os.path.join(images_dir, forced_filename)
+
+ shared.sd_model.first_stage_model.to(devices.device)
+
p = processing.StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
do_not_save_grid=True,
@@ -355,6 +361,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc processed = processing.process_images(p)
image = processed.images[0]
+ if unload:
+ shared.sd_model.first_stage_model.to(devices.cpu)
+
shared.state.current_image = image
if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
@@ -400,6 +409,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
+ shared.sd_model.first_stage_model.to(devices.device)
return embedding, filename
|