diff options
author | w-e-w <40751091+w-e-w@users.noreply.github.com> | 2023-08-08 02:39:34 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-08 02:39:34 +0000 |
commit | f17c8c2eff63210f5e96e1e2b049b46ba9cfa389 (patch) | |
tree | 701056aec9ae11bc45df9b39b176a54fa4d34e19 /modules/devices.py | |
parent | c75bda867be5345bf959daf23bdc19eadc90841a (diff) | |
parent | 01997f45ba089af24b03a5f614147bb0f9d8d824 (diff) | |
download | stable-diffusion-webui-gfx803-f17c8c2eff63210f5e96e1e2b049b46ba9cfa389.tar.gz stable-diffusion-webui-gfx803-f17c8c2eff63210f5e96e1e2b049b46ba9cfa389.tar.bz2 stable-diffusion-webui-gfx803-f17c8c2eff63210f5e96e1e2b049b46ba9cfa389.zip |
Merge branch 'dev' into auro-autolaunch
Diffstat (limited to 'modules/devices.py')
-rw-r--r-- | modules/devices.py | 83 |
1 files changed, 75 insertions, 8 deletions
diff --git a/modules/devices.py b/modules/devices.py index 57e51da3..00a00b18 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ import contextlib from functools import lru_cache import torch -from modules import errors +from modules import errors, rng_philox if sys.platform == "darwin": from modules import mac_specific @@ -71,14 +71,17 @@ def enable_tf32(): torch.backends.cudnn.allow_tf32 = True - errors.run(enable_tf32, "Enabling TF32") -cpu = torch.device("cpu") -device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None -dtype = torch.float16 -dtype_vae = torch.float16 -dtype_unet = torch.float16 +cpu: torch.device = torch.device("cpu") +device: torch.device = None +device_interrogate: torch.device = None +device_gfpgan: torch.device = None +device_esrgan: torch.device = None +device_codeformer: torch.device = None +dtype: torch.dtype = torch.float16 +dtype_vae: torch.dtype = torch.float16 +dtype_unet: torch.dtype = torch.float16 unet_needs_upcast = False @@ -90,23 +93,87 @@ def cond_cast_float(input): return input.float() if unet_needs_upcast else input +nv_rng = None + + def randn(seed, shape): + """Generate a tensor with random numbers from a normal distribution using seed. + + Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed.""" + from modules.shared import opts - torch.manual_seed(seed) + manual_seed(seed) + + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(shape), device=device) + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) + return torch.randn(shape, device=device) +def randn_local(seed, shape): + """Generate a tensor with random numbers from a normal distribution using seed. + + Does not change the global random number generator. You can only generate the seed's first tensor using this function.""" + + from modules.shared import opts + + if opts.randn_source == "NV": + rng = rng_philox.Generator(seed) + return torch.asarray(rng.randn(shape), device=device) + + local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device + local_generator = torch.Generator(local_device).manual_seed(int(seed)) + return torch.randn(shape, device=local_device, generator=local_generator).to(device) + + +def randn_like(x): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + + from modules.shared import opts + + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype) + + if opts.randn_source == "CPU" or x.device.type == 'mps': + return torch.randn_like(x, device=cpu).to(x.device) + + return torch.randn_like(x) + + def randn_without_seed(shape): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + from modules.shared import opts + if opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(shape), device=device) + if opts.randn_source == "CPU" or device.type == 'mps': return torch.randn(shape, device=cpu).to(device) + return torch.randn(shape, device=device) +def manual_seed(seed): + """Set up a global random number generator using the specified seed.""" + from modules.shared import opts + + if opts.randn_source == "NV": + global nv_rng + nv_rng = rng_philox.Generator(seed) + return + + torch.manual_seed(seed) + + def autocast(disable=False): from modules import shared |