From ac0ecf3b4b9d147743c04f0ff4ddc4cf4595e11d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 1 Jan 2024 16:28:58 +0300 Subject: option to convert VAE to bfloat16 (implementation of #9295) --- modules/processing.py | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 846e4796..f0656882 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -628,20 +628,33 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): sample = decode_first_stage(model, batch[i:i + 1])[0] if check_for_nans: + try: devices.test_for_nans(sample, "vae") except devices.NansException as e: - if devices.dtype_vae == torch.float32 or not shared.opts.auto_vae_precision: + if shared.opts.auto_vae_precision_bfloat16: + autofix_dtype = torch.bfloat16 + autofix_dtype_text = "bfloat16" + autofix_dtype_setting = "Automatically convert VAE to bfloat16" + autofix_dtype_comment = "" + elif shared.opts.auto_vae_precision: + autofix_dtype = torch.float32 + autofix_dtype_text = "32-bit float" + autofix_dtype_setting = "Automatically revert VAE to 32-bit floats" + autofix_dtype_comment = "\nTo always start with 32-bit VAE, use --no-half-vae commandline flag." + else: + raise e + + if devices.dtype_vae == autofix_dtype: raise e errors.print_error_explanation( "A tensor with all NaNs was produced in VAE.\n" - "Web UI will now convert VAE into 32-bit float and retry.\n" - "To disable this behavior, disable the 'Automatically revert VAE to 32-bit floats' setting.\n" - "To always start with 32-bit VAE, use --no-half-vae commandline flag." + f"Web UI will now convert VAE into {autofix_dtype_text} and retry.\n" + f"To disable this behavior, disable the '{autofix_dtype_setting}' setting.{autofix_dtype_comment}" ) - devices.dtype_vae = torch.float32 + devices.dtype_vae = autofix_dtype model.first_stage_model.to(devices.dtype_vae) batch = batch.to(devices.dtype_vae) -- cgit v1.2.3