diff options
author | brkirch <brkirch@users.noreply.github.com> | 2022-11-30 13:02:39 +0000 |
---|---|---|
committer | brkirch <brkirch@users.noreply.github.com> | 2022-11-30 15:33:42 +0000 |
commit | 0fddb4a1c06a6e2122add7eee3b001a6d473baee (patch) | |
tree | 1e8673eb008616320d85f3a11c6e2453d78d9c1f /modules/devices.py | |
parent | 4d5f1691dda971ec7b461dd880426300fd54ccee (diff) | |
download | stable-diffusion-webui-gfx803-0fddb4a1c06a6e2122add7eee3b001a6d473baee.tar.gz stable-diffusion-webui-gfx803-0fddb4a1c06a6e2122add7eee3b001a6d473baee.tar.bz2 stable-diffusion-webui-gfx803-0fddb4a1c06a6e2122add7eee3b001a6d473baee.zip |
Rework MPS randn fix, add randn_like fix
torch.manual_seed() already sets a CPU generator, so there is no reason to create a CPU generator manually. torch.randn_like also needs a MPS fix for k-diffusion, but a torch hijack with randn_like already exists so it can also be used for that.
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 15 |
1 files changed, 3 insertions, 12 deletions
diff --git a/modules/devices.py b/modules/devices.py index f00079c6..046460fa 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -66,24 +66,15 @@ dtype_vae = torch.float16 def randn(seed, shape): - # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. - if device.type == 'mps': - generator = torch.Generator(device=cpu) - generator.manual_seed(seed) - noise = torch.randn(shape, generator=generator, device=cpu).to(device) - return noise - torch.manual_seed(seed) + if device.type == 'mps': + return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) def randn_without_seed(shape): - # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. if device.type == 'mps': - generator = torch.Generator(device=cpu) - noise = torch.randn(shape, generator=generator, device=cpu).to(device) - return noise - + return torch.randn(shape, device=cpu).to(device) return torch.randn(shape, device=device) |