From c7e0e28ccd5c5075cc6b9c637df02864bd468c2f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 12 Sep 2022 20:09:32 +0300 Subject: changes for #294 --- modules/devices.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'modules/devices.py') diff --git a/modules/devices.py b/modules/devices.py index a93a245b..e4430e1a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -31,3 +31,20 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") + + +device = get_optimal_device() +device_codeformer = cpu if has_mps else device + + +def randn(seed, shape): + # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. + if device.type == 'mps': + generator = torch.Generator(device=cpu) + generator.manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=cpu).to(device) + return noise + + torch.manual_seed(seed) + return torch.randn(shape, device=device) + -- cgit v1.2.3