aboutsummaryrefslogtreecommitdiffstats
path: root/modules/textual_inversion/textual_inversion.py
diff options
context:
space:
mode:
authorZac Liu <liuguang@baai.ac.cn>2022-12-06 01:16:15 +0000
committerGitHub <noreply@github.com>2022-12-06 01:16:15 +0000
commit3ebf977a6e4f478ab918e44506974beee32da276 (patch)
treef68456207e5cd78718ec1e9c588ecdc22d568d81 /modules/textual_inversion/textual_inversion.py
parent231fb72872191ffa8c446af1577c9003b3d19d4f (diff)
parent44c46f0ed395967cd3830dd481a2db759fda5b3b (diff)
downloadstable-diffusion-webui-gfx803-3ebf977a6e4f478ab918e44506974beee32da276.tar.gz
stable-diffusion-webui-gfx803-3ebf977a6e4f478ab918e44506974beee32da276.tar.bz2
stable-diffusion-webui-gfx803-3ebf977a6e4f478ab918e44506974beee32da276.zip
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/textual_inversion/textual_inversion.py')
-rw-r--r--modules/textual_inversion/textual_inversion.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py
index 4eb75cb5..e28c357a 100644
--- a/modules/textual_inversion/textual_inversion.py
+++ b/modules/textual_inversion/textual_inversion.py
@@ -269,6 +269,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
# dataset loading may take a while, so input validations and early returns should be done before this
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
pin_memory = shared.opts.pin_memory
@@ -279,6 +280,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.first_stage_model.to(devices.cpu)
embedding.vec.requires_grad = True
@@ -316,7 +318,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
if shared.state.interrupted:
break
- with torch.autocast("cuda"):
+ with devices.autocast():
# c = stack_conds(batch.cond).to(devices.device)
# mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
# print(mask)
@@ -450,6 +452,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
pbar.leave = False
pbar.close()
shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return embedding, filename