aboutsummaryrefslogtreecommitdiffstats
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
authorFampai <>2022-11-04 08:50:22 +0000
committerFampai <>2022-11-04 08:50:22 +0000
commit39541d7725bc42f456a604b07c50aba503a5a09a (patch)
tree8f6e1866d19adc089e4d93847df4606cc252a033 /modules/hypernetworks/hypernetwork.py
parentf2b69709eaff88fc3a2bd49585556ec0883bf5ea (diff)
downloadstable-diffusion-webui-gfx803-39541d7725bc42f456a604b07c50aba503a5a09a.tar.gz
stable-diffusion-webui-gfx803-39541d7725bc42f456a604b07c50aba503a5a09a.tar.bz2
stable-diffusion-webui-gfx803-39541d7725bc42f456a604b07c50aba503a5a09a.zip
Fixes race condition in training when VAE is unloaded
set_current_image can attempt to use the VAE when it is unloaded to the CPU while training
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index 6e1a10cf..fcb96059 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -390,7 +390,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -531,6 +534,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return hypernetwork, filename