diff options
author | Fampai <> | 2022-11-04 08:50:22 +0000 |
---|---|---|
committer | Fampai <> | 2022-11-04 08:50:22 +0000 |
commit | 39541d7725bc42f456a604b07c50aba503a5a09a (patch) | |
tree | 8f6e1866d19adc089e4d93847df4606cc252a033 /modules/hypernetworks | |
parent | f2b69709eaff88fc3a2bd49585556ec0883bf5ea (diff) | |
download | stable-diffusion-webui-gfx803-39541d7725bc42f456a604b07c50aba503a5a09a.tar.gz stable-diffusion-webui-gfx803-39541d7725bc42f456a604b07c50aba503a5a09a.tar.bz2 stable-diffusion-webui-gfx803-39541d7725bc42f456a604b07c50aba503a5a09a.zip |
Fixes race condition in training when VAE is unloaded
set_current_image can attempt to use the VAE when it is unloaded to
the CPU while training
Diffstat (limited to 'modules/hypernetworks')
-rw-r--r-- | modules/hypernetworks/hypernetwork.py | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index 6e1a10cf..fcb96059 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -390,7 +390,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log with torch.autocast("cuda"):
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -531,6 +534,7 @@ Last saved image: {html.escape(last_saved_image)}<br/> filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return hypernetwork, filename
|