aboutsummaryrefslogtreecommitdiffstats
path: root/modules/hypernetworks/hypernetwork.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2022-12-03 07:20:17 +0000
committerGitHub <noreply@github.com>2022-12-03 07:20:17 +0000
commit5267414319ef89c18061127fab971ffc1b5b24ad (patch)
treeeabdca1b7665e0ee00f130e9f8544ffd23e474a2 /modules/hypernetworks/hypernetwork.py
parent5cd5a672f7889dcc018c3873ec557d645ebe35d0 (diff)
parentc9a2cfdf2a53d37c2de1908423e4f548088667ef (diff)
downloadstable-diffusion-webui-gfx803-5267414319ef89c18061127fab971ffc1b5b24ad.tar.gz
stable-diffusion-webui-gfx803-5267414319ef89c18061127fab971ffc1b5b24ad.tar.bz2
stable-diffusion-webui-gfx803-5267414319ef89c18061127fab971ffc1b5b24ad.zip
Merge pull request #4271 from MarkovInequality/racecond_fix
Fixes #4137 caused by race condition in training when VAE is unloaded
Diffstat (limited to 'modules/hypernetworks/hypernetwork.py')
-rw-r--r--modules/hypernetworks/hypernetwork.py5
1 files changed, 5 insertions, 0 deletions
diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py
index eb5ae372..c406ffb3 100644
--- a/modules/hypernetworks/hypernetwork.py
+++ b/modules/hypernetworks/hypernetwork.py
@@ -433,7 +433,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)
+ old_parallel_processing_allowed = shared.parallel_processing_allowed
+
if unload:
+ shared.parallel_processing_allowed = False
shared.sd_model.cond_stage_model.to(devices.cpu)
shared.sd_model.first_stage_model.to(devices.cpu)
@@ -612,10 +615,12 @@ Last saved image: {html.escape(last_saved_image)}<br/>
if shared.opts.save_optimizer_state:
hypernetwork.optimizer_state_dict = optimizer.state_dict()
save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
+
del optimizer
hypernetwork.optimizer_state_dict = None # dereference it after saving, to save memory.
shared.sd_model.cond_stage_model.to(devices.device)
shared.sd_model.first_stage_model.to(devices.device)
+ shared.parallel_processing_allowed = old_parallel_processing_allowed
return hypernetwork, filename