diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2024-01-09 16:33:00 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2024-01-09 16:33:00 +0000 |
commit | 905b14237fbc42f56ed20f0333ae5e1f06ce9bae (patch) | |
tree | 0e6cf7f344faed75b16150a4c4dfb66a9a9ec620 /modules/devices.py | |
parent | 6869d95890849c9b209bb66774539bfdf870df2c (diff) | |
parent | ca671e5d7b9d03227f01e6bcb350032b6d14e722 (diff) | |
download | stable-diffusion-webui-gfx803-905b14237fbc42f56ed20f0333ae5e1f06ce9bae.tar.gz stable-diffusion-webui-gfx803-905b14237fbc42f56ed20f0333ae5e1f06ce9bae.tar.bz2 stable-diffusion-webui-gfx803-905b14237fbc42f56ed20f0333ae5e1f06ce9bae.zip |
Merge pull request #14597 from AUTOMATIC1111/improved-manual-cast
Improve the implementation of Manual Cast and IPEX support
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 56 |
1 files changed, 40 insertions, 16 deletions
diff --git a/modules/devices.py b/modules/devices.py index ff279ac5..0321d12c 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -110,6 +110,7 @@ device_codeformer: torch.device = None dtype: torch.dtype = torch.float16 dtype_vae: torch.dtype = torch.float16 dtype_unet: torch.dtype = torch.float16 +dtype_inference: torch.dtype = torch.float16 unet_needs_upcast = False @@ -131,21 +132,44 @@ patch_module_list = [ ] -def manual_cast_forward(self, *args, **kwargs): - org_dtype = torch_utils.get_param(self).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result +def manual_cast_forward(target_dtype): + def forward_wrapper(self, *args, **kwargs): + if any( + isinstance(arg, torch.Tensor) and arg.dtype != target_dtype + for arg in args + ): + args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + + org_dtype = torch_utils.get_param(self).dtype + if org_dtype != target_dtype: + self.to(target_dtype) + result = self.org_forward(*args, **kwargs) + if org_dtype != target_dtype: + self.to(org_dtype) + + if target_dtype != dtype_inference: + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + return forward_wrapper @contextlib.contextmanager -def manual_cast(): +def manual_cast(target_dtype): for module_type in patch_module_list: org_forward = module_type.forward - module_type.forward = manual_cast_forward + if module_type == torch.nn.MultiheadAttention and has_xpu(): + module_type.forward = manual_cast_forward(torch.float32) + else: + module_type.forward = manual_cast_forward(target_dtype) module_type.org_forward = org_forward try: yield None @@ -161,15 +185,15 @@ def autocast(disable=False): if fp8 and device==cpu: return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) - if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): - return manual_cast() + if fp8 and dtype_inference == torch.float32: + return manual_cast(dtype) - if has_mps() and shared.cmd_opts.precision != "full": - return manual_cast() - - if dtype == torch.float32 or shared.cmd_opts.precision == "full": + if dtype == torch.float32 or dtype_inference == torch.float32: return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype) + return torch.autocast("cuda") |