aboutsummaryrefslogtreecommitdiffstats
path: root/modules/devices.py
diff options
context:
space:
mode:
authorKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-11-19 07:50:06 +0000
committerKohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com>2023-11-19 07:50:06 +0000
commit598da5cd4928618b166886d3485ce30ce3a43490 (patch)
tree48bdd2fcf47bd88b9283d13fd53ecc39e5f5ff27 /modules/devices.py
parentb60e1088db2497e945d36c7500dcbf03afceedf2 (diff)
downloadstable-diffusion-webui-gfx803-598da5cd4928618b166886d3485ce30ce3a43490.tar.gz
stable-diffusion-webui-gfx803-598da5cd4928618b166886d3485ce30ce3a43490.tar.bz2
stable-diffusion-webui-gfx803-598da5cd4928618b166886d3485ce30ce3a43490.zip
Use options instead of cmd_args
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py25
1 files changed, 14 insertions, 11 deletions
diff --git a/modules/devices.py b/modules/devices.py
index d7c905c2..03e7bdb7 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -20,15 +20,15 @@ def cuda_no_autocast(device_id=None) -> bool:
if device_id is None:
device_id = get_cuda_device_id()
return (
- torch.cuda.get_device_capability(device_id) == (7, 5)
+ torch.cuda.get_device_capability(device_id) == (7, 5)
and torch.cuda.get_device_name(device_id).startswith("NVIDIA GeForce GTX 16")
)
def get_cuda_device_id():
return (
- int(shared.cmd_opts.device_id)
- if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
+ int(shared.cmd_opts.device_id)
+ if shared.cmd_opts.device_id is not None and shared.cmd_opts.device_id.isdigit()
else 0
) or torch.cuda.current_device()
@@ -116,16 +116,19 @@ patch_module_list = [
torch.nn.LayerNorm,
]
+
+def manual_cast_forward(self, *args, **kwargs):
+ org_dtype = next(self.parameters()).dtype
+ self.to(dtype)
+ args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
+ kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
+ result = self.org_forward(*args, **kwargs)
+ self.to(org_dtype)
+ return result
+
+
@contextlib.contextmanager
def manual_autocast():
- def manual_cast_forward(self, *args, **kwargs):
- org_dtype = next(self.parameters()).dtype
- self.to(dtype)
- args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args]
- kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
- result = self.org_forward(*args, **kwargs)
- self.to(org_dtype)
- return result
for module_type in patch_module_list:
org_forward = module_type.forward
module_type.forward = manual_cast_forward