diff options
author | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-11-19 07:50:06 +0000 |
---|---|---|
committer | Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> | 2023-11-19 07:50:06 +0000 |
commit | 598da5cd4928618b166886d3485ce30ce3a43490 (patch) | |
tree | 48bdd2fcf47bd88b9283d13fd53ecc39e5f5ff27 /modules/devices.py | |
parent | b60e1088db2497e945d36c7500dcbf03afceedf2 (diff) | |
download | stable-diffusion-webui-gfx803-598da5cd4928618b166886d3485ce30ce3a43490.tar.gz stable-diffusion-webui-gfx803-598da5cd4928618b166886d3485ce30ce3a43490.tar.bz2 stable-diffusion-webui-gfx803-598da5cd4928618b166886d3485ce30ce3a43490.zip |
Use options instead of cmd_args
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/modules/devices.py b/modules/devices.py index d7c905c2..03e7bdb7 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool: if device_id is None: device_id = get_cuda_device_id() return ( - torch.cuda.get_device_capability(device_id) == (7, 5) + torch.cuda.get_device_capability(device_id) == (7, 5) and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16") ) def get_cuda_device_id(): return ( - int(shared.cmd_opts.device_id) - if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() + int(shared.cmd_opts.device_id) + if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit() else 0 ) or torch.cuda.current_device() @@ -116,16 +116,19 @@ patch_module_list = [ torch.nn.LayerNorm, ] + +def manual_cast_forward(self, *args, **kwargs): + org_dtype = next(self.parameters()).dtype + self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + result = self.org_forward(*args, **kwargs) + self.to(org_dtype) + return result + + @contextlib.contextmanager def manual_autocast(): - def manual_cast_forward(self, *args, **kwargs): - org_dtype = next(self.parameters()).dtype - self.to(dtype) - args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] - kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) - return result for module_type in patch_module_list: org_forward = module_type.forward module_type.forward = manual_cast_forward |