diff options
-rw-r--r-- | modules/sd_models.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/modules/sd_models.py b/modules/sd_models.py index 44d4038b..69395294 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -407,6 +407,13 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer module.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet for cpu")
else:
+ if model.is_sdxl:
+ cond_stage = model.conditioner
+ else:
+ cond_stage = model.cond_stage_model
+ for module in cond_stage.modules():
+ if isinstance(module, torch.nn.Linear):
+ module.to(torch.float8_e4m3fn)
model.model.diffusion_model = model.model.diffusion_model.to(torch.float8_e4m3fn)
timer.record("apply fp8 unet")
|