diff options
author | KohakuBlueleaf <apolloyeh0123@gmail.com> | 2024-01-09 14:39:39 +0000 |
---|---|---|
committer | KohakuBlueleaf <apolloyeh0123@gmail.com> | 2024-01-09 14:39:39 +0000 |
commit | 42e6df723c68af775b73c9fa4f43f99345348689 (patch) | |
tree | 26f55dcda9cba2d1522001ad25d336e17a50e7bb | |
parent | 209c26a1cb9e4be357ab3c5e7613caf3cbc26183 (diff) | |
download | stable-diffusion-webui-gfx803-42e6df723c68af775b73c9fa4f43f99345348689.tar.gz stable-diffusion-webui-gfx803-42e6df723c68af775b73c9fa4f43f99345348689.tar.bz2 stable-diffusion-webui-gfx803-42e6df723c68af775b73c9fa4f43f99345348689.zip |
Fix bugs when arg dtype doesn't match
-rw-r--r-- | modules/devices.py | 25 |
1 files changed, 10 insertions, 15 deletions
diff --git a/modules/devices.py b/modules/devices.py index 6edfb127..e0574052 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -134,24 +134,19 @@ patch_module_list = [ def manual_cast_forward(target_dtype): def forward_wrapper(self, *args, **kwargs): + if any( + isinstance(arg, torch.Tensor) and arg.dtype != target_dtype + for arg in args + ): + args = [arg.to(target_dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(target_dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} + org_dtype = torch_utils.get_param(self).dtype - if not target_dtype == org_dtype == dtype_inference: + if org_dtype != target_dtype: self.to(target_dtype) - args = [ - arg.to(target_dtype) - if isinstance(arg, torch.Tensor) - else arg - for arg in args - ] - kwargs = { - k: v.to(target_dtype) - if isinstance(v, torch.Tensor) - else v - for k, v in kwargs.items() - } - result = self.org_forward(*args, **kwargs) - self.to(org_dtype) + if org_dtype != target_dtype: + self.to(org_dtype) if target_dtype != dtype_inference: if isinstance(result, tuple): |