diff options
author | AUTOMATIC <16777216c@gmail.com> | 2022-09-12 17:09:32 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2022-09-12 17:09:32 +0000 |
commit | c7e0e28ccd5c5075cc6b9c637df02864bd468c2f (patch) | |
tree | e9065f13fcb5dd200df19607e81b283005518eab /modules/devices.py | |
parent | 11e03b9abdb4dbf38151bbf290b77122ff20bddb (diff) | |
download | stable-diffusion-webui-gfx803-c7e0e28ccd5c5075cc6b9c637df02864bd468c2f.tar.gz stable-diffusion-webui-gfx803-c7e0e28ccd5c5075cc6b9c637df02864bd468c2f.tar.bz2 stable-diffusion-webui-gfx803-c7e0e28ccd5c5075cc6b9c637df02864bd468c2f.zip |
changes for #294
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/modules/devices.py b/modules/devices.py index a93a245b..e4430e1a 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -31,3 +31,20 @@ def enable_tf32(): errors.run(enable_tf32, "Enabling TF32") + + +device = get_optimal_device() +device_codeformer = cpu if has_mps else device + + +def randn(seed, shape): + # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used. + if device.type == 'mps': + generator = torch.Generator(device=cpu) + generator.manual_seed(seed) + noise = torch.randn(shape, generator=generator, device=cpu).to(device) + return noise + + torch.manual_seed(seed) + return torch.randn(shape, device=device) + |