diff options
author | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-10-23 17:49:05 +0000 |
---|---|---|
committer | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-10-23 17:49:05 +0000 |
commit | eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66 (patch) | |
tree | f8bf60786db8d42a0a0e85deb56c885780bda654 /modules/devices.py | |
parent | 5f9ddfa46f28ca2aa9e0bd832f6bbd67069be63e (diff) | |
download | stable-diffusion-webui-gfx803-eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66.tar.gz stable-diffusion-webui-gfx803-eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66.tar.bz2 stable-diffusion-webui-gfx803-eaa9f5162fbca2ebcb2682eb861bc7e5510a2b66.zip |
Add CPU fp8 support
Since norm layer need fp32, I only convert the linear operation layer(conv2d/linear)
And TE have some pytorch function not support bf16 amp in CPU. I add a condition to indicate if the autocast is for unet.
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/modules/devices.py b/modules/devices.py index 1d4eb563..0cd2b55d 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -71,6 +71,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -93,10 +94,13 @@ def cond_cast_float(input): nv_rng = None -def autocast(disable=False): +def autocast(disable=False, unet=False): if disable: return contextlib.nullcontext() + if unet and fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() |