diff options
author | Zac Liu <liuguang@baai.ac.cn> | 2022-12-06 01:16:15 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-06 01:16:15 +0000 |
commit | 3ebf977a6e4f478ab918e44506974beee32da276 (patch) | |
tree | f68456207e5cd78718ec1e9c588ecdc22d568d81 /modules/sd_samplers.py | |
parent | 231fb72872191ffa8c446af1577c9003b3d19d4f (diff) | |
parent | 44c46f0ed395967cd3830dd481a2db759fda5b3b (diff) | |
download | stable-diffusion-webui-gfx803-3ebf977a6e4f478ab918e44506974beee32da276.tar.gz stable-diffusion-webui-gfx803-3ebf977a6e4f478ab918e44506974beee32da276.tar.bz2 stable-diffusion-webui-gfx803-3ebf977a6e4f478ab918e44506974beee32da276.zip |
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/sd_samplers.py')
-rw-r--r-- | modules/sd_samplers.py | 22 |
1 files changed, 19 insertions, 3 deletions
diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 5fefb227..4c123d3b 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -6,6 +6,7 @@ import tqdm from PIL import Image
import inspect
import k_diffusion.sampling
+import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from modules import prompt_parser, devices, processing, images
@@ -364,7 +365,23 @@ class TorchHijack: if noise.shape == x.shape:
return noise
- return torch.randn_like(x)
+ if x.device.type == 'mps':
+ return torch.randn_like(x, device=devices.cpu).to(x.device)
+ else:
+ return torch.randn_like(x)
+
+
+# MPS fix for randn in torchsde
+def torchsde_randn(size, dtype, device, seed):
+ if device.type == 'mps':
+ generator = torch.Generator(devices.cpu).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
+ else:
+ generator = torch.Generator(device).manual_seed(int(seed))
+ return torch.randn(size, dtype=dtype, device=device, generator=generator)
+
+
+torchsde._brownian.brownian_interval._randn = torchsde_randn
class KDiffusionSampler:
@@ -415,8 +432,7 @@ class KDiffusionSampler: self.model_wrap.step = 0
self.eta = p.eta or opts.eta_ancestral
- if self.sampler_noises is not None:
- k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
+ k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
extra_params_kwargs = {}
for param_name in self.extra_params:
|