diff options
author | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-10-25 03:36:43 +0000 |
---|---|---|
committer | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-10-25 03:36:43 +0000 |
commit | 1df6c8bfec4715610d64684b6ad2fa38c76c1df6 (patch) | |
tree | 40ad7dc479ef592433981339f43968bcb48e7d2b /modules | |
parent | 9c1eba2af3a6f9cd6282b3a367656793cbe70c01 (diff) | |
download | stable-diffusion-webui-gfx803-1df6c8bfec4715610d64684b6ad2fa38c76c1df6.tar.gz stable-diffusion-webui-gfx803-1df6c8bfec4715610d64684b6ad2fa38c76c1df6.tar.bz2 stable-diffusion-webui-gfx803-1df6c8bfec4715610d64684b6ad2fa38c76c1df6.zip |
fp8 for TE
Diffstat (limited to 'modules')
-rw-r--r-- | modules/sd_models.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 44d4038b..69395294 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet for cpu")
else:
+ if model.is_sdxl:
+ cond_stage = model.conditioner
+ else:
+ cond_stage = model.cond_stage_model
+ for module in cond_stage.modules():
+ if isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet")
|