diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-02 21:00:23 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-08-02 21:00:23 +0000 |
commit | 84b6fcd02ca6d6ab48c4b6be4bb8724b1c2e7014 (patch) | |
tree | 1be8f712b6f87a29e34204acc54c0b05bc06dfb8 /modules/devices.py | |
parent | ccb92339348f6973de39cde062982a51a4cd0818 (diff) | |
download | stable-diffusion-webui-gfx803-84b6fcd02ca6d6ab48c4b6be4bb8724b1c2e7014.tar.gz stable-diffusion-webui-gfx803-84b6fcd02ca6d6ab48c4b6be4bb8724b1c2e7014.tar.bz2 stable-diffusion-webui-gfx803-84b6fcd02ca6d6ab48c4b6be4bb8724b1c2e7014.zip |
add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards.
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 39 |
1 files changed, 37 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py index 57e51da3..b58776d8 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ import contextlib from functools import lru_cache import torch -from modules import errors +from modules import errors, rng_philox if sys.platform == "darwin": from modules import mac_specific @@ -90,23 +90,58 @@ def cond_cast_float(input): return input.float() if unet_needs_upcast else input +nv_rng = None + + def randn(seed, shape): from modules.shared import opts - torch.manual_seed(seed) + manual_seed(seed) + + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(shape), device=device) + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) + return torch.randn(shape, device=device) +def randn_like(x): + from modules.shared import opts + + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype) + + if opts.randn_source == "CPU" or x.device.type == 'mps': + return torch.randn_like(x, device=cpu).to(x.device) + + return torch.randn_like(x) + + def randn_without_seed(shape): from modules.shared import opts + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(shape), device=device) + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) + return torch.randn(shape, device=device) +def manual_seed(seed): + from modules.shared import opts + + if opts.randn_source == "NV": + global nv_rng + nv_rng = rng_philox.Generator(seed) + return + + torch.manual_seed(seed) + + def autocast(disable=False): from modules import shared |