From bf5067f50ca32cd4764638702e3cc38bca8bfd8b Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 25 Oct 2023 12:54:28 +0800 Subject: Fix alphas cumprod --- modules/sd_models.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 23660454..7ed89a9c 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -396,6 +396,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer enable_fp8 = True elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: enable_fp8 = True + else: + enable_fp8 = False if enable_fp8: devices.fp8 = True @@ -416,7 +418,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn) model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") - model.alphas_cumprod = model.alphas_cumprod.to(torch.float32) devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 -- cgit v1.2.3