From 598da5cd4928618b166886d3485ce30ce3a43490 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sun, 19 Nov 2023 15:50:06 +0800 Subject: Use options instead of cmd_args --- modules/devices.py | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index d7c905c2..03e7bdb7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool: if device_id is None: device_id = get_cuda_device_id() return ( - torch.cuda.get_device_capability(device_id) == (7, 5) + torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") ) def get_cuda_device_id(): return ( - int(shared.cmd_opts.device_id) - if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0 ) or torch.cuda.current_device() @@ -116,16 +116,19 @@ patch_module_list = [ torch.nn.LayerNorm, ] + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + @contextlib.contextmanager def manual_autocast(): - def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward -- cgit v1.2.3