aboutsummaryrefslogtreecommitdiffstats
path: root/modules/processing.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2024-01-01 13:28:58 +0000
committerAUTOMATIC1111 <16777216c@gmail.com>2024-01-01 13:28:58 +0000
commitac0ecf3b4b9d147743c04f0ff4ddc4cf4595e11d (patch)
treee02a744d39bccc8a112431fc9e9cb0b1b3c8b7ff /modules/processing.py
parent0743ee9b3eda8dd4ceea625d710031577201f4ad (diff)
downloadstable-diffusion-webui-gfx803-ac0ecf3b4b9d147743c04f0ff4ddc4cf4595e11d.tar.gz
stable-diffusion-webui-gfx803-ac0ecf3b4b9d147743c04f0ff4ddc4cf4595e11d.tar.bz2
stable-diffusion-webui-gfx803-ac0ecf3b4b9d147743c04f0ff4ddc4cf4595e11d.zip
option to convert VAE to bfloat16 (implementation of #9295)
Diffstat (limited to 'modules/processing.py')
-rw-r--r--modules/processing.py23
1 files changed, 18 insertions, 5 deletions
diff --git a/modules/processing.py b/modules/processing.py
index 846e4796..f0656882 100644
--- a/modules/processing.py
+++ b/modules/processing.py
@@ -628,20 +628,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
sample = decode_first_stage(model, batch[i:i + 1])[0]
if check_for_nans:
+
try:
devices.test_for_nans(sample, "vae")
except devices.NansException as e:
- if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision:
+ if shared.opts.auto_vae_precision_bfloat16:
+ autofix_dtype = torch.bfloat16
+ autofix_dtype_text = "bfloat16"
+ autofix_dtype_setting = "Automatically convert VAE to bfloat16"
+ autofix_dtype_comment = ""
+ elif shared.opts.auto_vae_precision:
+ autofix_dtype = torch.float32
+ autofix_dtype_text = "32-bit float"
+ autofix_dtype_setting = "Automatically revert VAE to 32-bit floats"
+ autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag."
+ else:
+ raise e
+
+ if devices.dtype_vae == autofix_dtype:
raise e
errors.print_error_explanation(
"A tensor with all NaNs was produced in VAE.\n"
- "Web UI will now convert VAE into 32-bit float and retry.\n"
- "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n"
- "To always start with 32-bit VAE, use --no-half-vae commandline flag."
+ f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n"
+ f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}"
)
- devices.dtype_vae = torch.float32
+ devices.dtype_vae = autofix_dtype
model.first_stage_model.to(devices.dtype_vae)
batch = batch.to(devices.dtype_vae)