From c62d17aee36b5f4ca24f9cfa7bf6d7aca0c923f8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 12 Nov 2022 10:00:22 +0300 Subject: use the new devices.has_mps() function in register_buffer for DDIM/PLMS fix for OSX --- modules/sd_hijack.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/sd_hijack.py') diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 75b2d22d..97979d05 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -418,8 +418,7 @@ def register_buffer(self, name, attr): if type(attr) == torch.Tensor: if attr.device != devices.device: - # would this not break cuda when torch adds has_mps() to main version? - if getattr(torch, 'has_mps', False): + if devices.has_mps(): attr = attr.to(device="mps", dtype=torch.float32) else: attr = attr.to(devices.device) -- cgit v1.2.3