diff options
author | KohakuBlueleaf <apolloyeh0123@gmail.com> | 2023-10-28 08:52:35 +0000 |
---|---|---|
committer | KohakuBlueleaf <apolloyeh0123@gmail.com> | 2023-10-28 08:52:35 +0000 |
commit | ddc2a3499b8cd120b4a42358bcd33137ce1d1e75 (patch) | |
tree | 8e2fcdbc2bcd53629c64335285bca8ab53b5d5ab | |
parent | d4d3134f6d2d232c7bcfa80900a362921e644976 (diff) | |
download | stable-diffusion-webui-gfx803-ddc2a3499b8cd120b4a42358bcd33137ce1d1e75.tar.gz stable-diffusion-webui-gfx803-ddc2a3499b8cd120b4a42358bcd33137ce1d1e75.tar.bz2 stable-diffusion-webui-gfx803-ddc2a3499b8cd120b4a42358bcd33137ce1d1e75.zip |
Add MPS manual cast
-rw-r--r-- | modules/devices.py | 6 |
1 files changed, 5 insertions, 1 deletions
diff --git a/modules/devices.py b/modules/devices.py index c05f2b35..d7c905c2 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -121,6 +121,8 @@ def manual_autocast(): def manual_cast_forward(self, *args, **kwargs): org_dtype = next(self.parameters()).dtype self.to(dtype) + args = [arg.to(dtype) if isinstance(arg, torch.Tensor) else arg for arg in args] + kwargs = {k: v.to(dtype) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()} result = self.org_forward(*args, **kwargs) self.to(org_dtype) return result @@ -136,7 +138,6 @@ def manual_autocast(): def autocast(disable=False): - print(fp8, dtype, shared.cmd_opts.precision, device) if disable: return contextlib.nullcontext() @@ -146,6 +147,9 @@ def autocast(disable=False): if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()): return manual_autocast() + if has_mps() and shared.cmd_opts.precision != "full": + return manual_autocast() + if dtype == torch.float32 or shared.cmd_opts.precision == "full": return contextlib.nullcontext() |