From 5f9ddfa46f28ca2aa9e0bd832f6bbd67069be63e Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:57:22 +0800 Subject: Add sdxl only arg --- modules/sd_models.py | 3 +++ 1 file changed, 3 insertions(+) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 3b8ff820..08af128f 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -394,6 +394,9 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer if shared.cmd_opts.opt_unet_fp8_storage: model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) timer.record("apply fp8 unet") + elif model.is_sdxl and shared.cmd_opts.opt_unet_fp8_storage_xl: + model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn) + timer.record("apply fp8 unet for sdxl") devices.unet_needs_upcast = shared.cmd_opts.upcast_sampling and devices.dtype == torch.float16 and devices.dtype_unet == torch.float16 -- cgit v1.2.3