aboutsummaryrefslogtreecommitdiffstats
path: root/modules
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-11-02 11:41:29 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-11-02 11:41:29 +0000
commitf2a5cbe6f55592c4c5527b8e0bf99ea8d658f057 (patch)
treeeba917019e6b95ed19cea4eb4671e1d2b9db8599 /modules
parent675b51ebd3a0afbe097e0dc1384fd9ab8f5f1b38 (diff)
downloadstable-diffusion-webui-gfx803-f2a5cbe6f55592c4c5527b8e0bf99ea8d658f057.tar.gz
stable-diffusion-webui-gfx803-f2a5cbe6f55592c4c5527b8e0bf99ea8d658f057.tar.bz2
stable-diffusion-webui-gfx803-f2a5cbe6f55592c4c5527b8e0bf99ea8d658f057.zip
fix #3986 breaking --no-half-vae
Diffstat (limited to 'modules')
-rw-r--r--modules/sd_models.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py
index 883639d1..5075fadb 100644
--- a/modules/sd_models.py
+++ b/modules/sd_models.py
@@ -183,11 +183,20 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
model.to(memory_format=torch.channels_last)
if not shared.cmd_opts.no_half:
+ vae = model.first_stage_model
+
+ # with --no-half-vae, remove VAE from model when doing half() to prevent its weights from being converted to float16
+ if shared.cmd_opts.no_half_vae:
+ model.first_stage_model = None
+
model.half()
+ model.first_stage_model = vae
devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16
devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16
+ model.first_stage_model.to(devices.dtype_vae)
+
if shared.opts.sd_checkpoint_cache > 0:
# if PR #4035 were to get merged, restore base VAE first before caching
checkpoints_loaded[checkpoint_key] = model.state_dict().copy()