diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-11-12 07:00:22 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-11-12 07:00:22 +0000 |
commit | c62d17aee36b5f4ca24f9cfa7bf6d7aca0c923f8 (patch) | |
tree | 55f9909e52f61d09cd2f20597430f5db15ba4449 | |
parent | 526f0aa5569241aabf276a83af1a7216e825c6cc (diff) | |
download | stable-diffusion-webui-gfx803-c62d17aee36b5f4ca24f9cfa7bf6d7aca0c923f8.tar.gz stable-diffusion-webui-gfx803-c62d17aee36b5f4ca24f9cfa7bf6d7aca0c923f8.tar.bz2 stable-diffusion-webui-gfx803-c62d17aee36b5f4ca24f9cfa7bf6d7aca0c923f8.zip |
use the new devices.has_mps() function in register_buffer for DDIM/PLMS fix for OSX
-rw-r--r-- | modules/sd_hijack.py | 3 |
1 files changed, 1 insertions, 2 deletions
diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 75b2d22d..97979d05 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -418,8 +418,7 @@ def register_buffer(self, name, attr): if type(attr) == torch.Tensor:
if attr.device != devices.device:
- # would this not break cuda when torch adds has_mps() to main version?
- if getattr(torch, 'has_mps', False):
+ if devices.has_mps():
attr = attr.to(device="mps", dtype=torch.float32)
else:
attr = attr.to(devices.device)
|