aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorflamelaw <flamelaw.com3d2@gmail.com>2022-11-21 01:15:46 +0000
committerflamelaw <flamelaw.com3d2@gmail.com>2022-11-21 01:15:46 +0000
commit5b57f61ba47f8b11d19a5b46e7fb5a52458abae5 (patch)
tree474be8cd4ca69e75ba95516b100874a88d235fde /modules/textual_inversion/textual_inversion.py
parent2d22d72cdaaf2b78b2986b841d478c11ac855dd2 (diff)
downloadstable-diffusion-webui-gfx803-5b57f61ba47f8b11d19a5b46e7fb5a52458abae5.tar.gz
stable-diffusion-webui-gfx803-5b57f61ba47f8b11d19a5b46e7fb5a52458abae5.tar.bz2
stable-diffusion-webui-gfx803-5b57f61ba47f8b11d19a5b46e7fb5a52458abae5.zip
fix pin_memory with different latent sampling method
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py7
1 files changed, 1 insertions, 6 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 1d5e3a32..3036e48a 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -277,7 +277,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
latent_sampling_method = ds.latent_sampling_method
- dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=False)
+ dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload:
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -333,11 +333,6 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# go back until we reach gradient accumulation steps
if (j + 1) % gradient_step != 0:
continue
- #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}")
- #scaler.unscale_(optimizer)
- #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}")
- #torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=1.0)
- #print(f"grad:{embedding.vec.grad.detach().cpu().abs().mean().item():.7f}")
scaler.step(optimizer)
scaler.update()
embedding.step += 1