diff options
author | thesved <2893181+thesved@users.noreply.github.com> | 2022-11-03 18:44:47 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-03 18:44:47 +0000 |
commit | 86b7fc6e5ed56327fa12b444ca2444b13eb98aa8 (patch) | |
tree | 6425c609836d92c8e8af64da2be11edb73844878 /modules/sd_hijack_inpainting.py | |
parent | c2465f67db2529d962e311b3a520bd5cd715058b (diff) | |
download | stable-diffusion-webui-gfx803-86b7fc6e5ed56327fa12b444ca2444b13eb98aa8.tar.gz stable-diffusion-webui-gfx803-86b7fc6e5ed56327fa12b444ca2444b13eb98aa8.tar.bz2 stable-diffusion-webui-gfx803-86b7fc6e5ed56327fa12b444ca2444b13eb98aa8.zip |
Make DDIM and PLMS work on Mac OS
Fix register_buffer error on Mac OS
Diffstat (limited to 'modules/sd_hijack_inpainting.py')
-rw-r--r-- | modules/sd_hijack_inpainting.py | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index fd92a335..202b42cf 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -1,4 +1,5 @@ import torch +import modules.devices as devices from einops import repeat from omegaconf import ListConfig @@ -314,6 +315,20 @@ class LatentInpaintDiffusion(LatentDiffusion): self.masked_image_key = masked_image_key assert self.masked_image_key in concat_keys self.concat_keys = concat_keys + + +# ================================================================================================= +# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here +# ================================================================================================= +def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + optimal_type = devices.get_optimal_device() + if attr.device != optimal_type: + if getattr(torch, 'has_mps', False): + attr = attr.to(device="mps", dtype=torch.float32) + else: + attr = attr.to(optimal_type) + setattr(self, name, attr) def should_hijack_inpainting(checkpoint_info): @@ -326,6 +341,8 @@ def do_inpainting_hijack(): ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim + ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms - ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
\ No newline at end of file + ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms + ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer |