diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2024-03-02 04:03:13 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2024-03-02 04:03:13 +0000 |
commit | bef51aed032c0aaa5cfd80445bc4cf0d85b408b5 (patch) | |
tree | 42957c454a4ac8d98488f19811b60359d05d88ba /modules/devices.py | |
parent | cf2772fab0af5573da775e7437e6acdca424f26e (diff) | |
parent | 13984857890401e8605a3e53bd671e900a18d73f (diff) | |
download | stable-diffusion-webui-gfx803-bef51aed032c0aaa5cfd80445bc4cf0d85b408b5.tar.gz stable-diffusion-webui-gfx803-bef51aed032c0aaa5cfd80445bc4cf0d85b408b5.tar.bz2 stable-diffusion-webui-gfx803-bef51aed032c0aaa5cfd80445bc4cf0d85b408b5.zip |
Merge branch 'release_candidate'
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 114 |
1 files changed, 109 insertions, 5 deletions
diff --git a/modules/devices.py b/modules/devices.py index ea1f712f..28c0c54d 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ import contextlib from functools import lru_cache import torch -from modules import errors, shared +from modules import errors, shared, npu_specific if sys.platform == "darwin": from modules import mac_specific @@ -23,6 +23,23 @@ def has_mps() -> bool: return mac_specific.has_mps +def cuda_no_autocast(device_id=None) -> bool: + if device_id is None: + device_id = get_cuda_device_id() + return ( + torch.cuda.get_device_capability(device_id) == (7, 5) + and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") + ) + + +def get_cuda_device_id(): + return ( + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + else 0 + ) or torch.cuda.current_device() + + def get_cuda_device_string(): if shared.cmd_opts.device_id is not None: return f"cuda:{shared.cmd_opts.device_id}" @@ -40,6 +57,9 @@ def get_optimal_device_name(): if has_xpu(): return xpu_specific.get_xpu_device_string() + if npu_specific.has_npu: + return npu_specific.get_npu_device_string() + return "cpu" @@ -67,14 +87,23 @@ def torch_gc(): if has_xpu(): xpu_specific.torch_xpu_gc() + if npu_specific.has_npu: + torch_npu_set_device() + npu_specific.torch_npu_gc() + + +def torch_npu_set_device(): + # Work around due to bug in torch_npu, revert me after fixed, @see https://gitee.com/ascend/pytorch/issues/I8KECW?from=project-issue + if npu_specific.has_npu: + torch.npu.set_device(0) + def enable_tf32(): if torch.cuda.is_available(): # enabling benchmark option seems to enable a range of cards to do fp16 when they otherwise can't # see https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/4407 - device_id = (int(shared.cmd_opts.device_id) if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0) or torch.cuda.current_device() - if torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16"): + if cuda_no_autocast(): torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True @@ -84,6 +113,7 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") cpu: torch.device = torch.device("cpu") +fp8: bool = False device: torch.device = None device_interrogate: torch.device = None device_gfpgan: torch.device = None @@ -92,6 +122,7 @@ device_codeformer: torch.device = None dtype: torch.dtype = torch.float16 dtype_vae: torch.dtype = torch.float16 dtype_unet: torch.dtype = torch.float16 +dtype_inference: torch.dtype = torch.float16 unet_needs_upcast = False @@ -104,15 +135,89 @@ def cond_cast_float(input): nv_rng = None +patch_module_list = [ + torch.nn.Linear, + torch.nn.Conv2d, + torch.nn.MultiheadAttention, + torch.nn.GroupNorm, + torch.nn.LayerNorm, +] + + +def manual_cast_forward(target_dtype): + def forward_wrapper(self, *args, **kwargs): + if any( + isinstance(arg, torch.Tensor) and arg.dtype != target_dtype + for arg in args + ): + args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + + org_dtype = target_dtype + for param in self.parameters(): + if param.dtype != target_dtype: + org_dtype = param.dtype + break + + if org_dtype != target_dtype: + self.to(target_dtype) + result = self.org_forward(*args, **kwargs) + if org_dtype != target_dtype: + self.to(org_dtype) + + if target_dtype != dtype_inference: + if isinstance(result, tuple): + result = tuple( + i.to(dtype_inference) + if isinstance(i, torch.Tensor) + else i + for i in result + ) + elif isinstance(result, torch.Tensor): + result = result.to(dtype_inference) + return result + return forward_wrapper + + +@contextlib.contextmanager +def manual_cast(target_dtype): + applied = False + for module_type in patch_module_list: + if hasattr(module_type, "org_forward"): + continue + applied = True + org_forward = module_type.forward + if module_type == torch.nn.MultiheadAttention: + module_type.forward = manual_cast_forward(torch.float32) + else: + module_type.forward = manual_cast_forward(target_dtype) + module_type.org_forward = org_forward + try: + yield None + finally: + if applied: + for module_type in patch_module_list: + if hasattr(module_type, "org_forward"): + module_type.forward = module_type.org_forward + delattr(module_type, "org_forward") def autocast(disable=False): if disable: return contextlib.nullcontext() - if dtype == torch.float32 or shared.cmd_opts.precision == "full": + if fp8 and device==cpu: + return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True) + + if fp8 and dtype_inference == torch.float32: + return manual_cast(dtype) + + if dtype == torch.float32 or dtype_inference == torch.float32: return contextlib.nullcontext() + if has_xpu() or has_mps() or cuda_no_autocast(): + return manual_cast(dtype) + return torch.autocast("cuda") @@ -164,4 +269,3 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) - |