From 0fddb4a1c06a6e2122add7eee3b001a6d473baee Mon Sep 17 00:00:00 2001 From: brkirch Date: Wed, 30 Nov 2022 08:02:39 -0500 Subject: Rework MPS randn fix, add randn_like fix torch.manual_seed() already sets a CPU generator, so there is no reason to create a CPU generator manually. torch.randn_like also needs a MPS fix for k-diffusion, but a torch hijack with randn_like already exists so it can also be used for that. --- modules/sd_samplers.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) (limited to 'modules/sd_samplers.py') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 8b11f569..4c123d3b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -365,7 +365,10 @@ class TorchHijack: if noise.shape == x.shape: return noise - return torch.randn_like(x) + if x.device.type == 'mps': + return torch.randn_like(x, device=devices.cpu).to(x.device) + else: + return torch.randn_like(x) # MPS fix for randn in torchsde @@ -429,8 +432,7 @@ class KDiffusionSampler: self.model_wrap.step = 0 self.eta = p.eta or opts.eta_ancestral - if self.sampler_noises is not None: - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises) + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) extra_params_kwargs = {} for param_name in self.extra_params: -- cgit v1.2.3