aboutsummaryrefslogtreecommitdiffstats
path: root/modules/devices.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-03 04:18:55 +0000
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-03 04:18:55 +0000
commitfca42949a3593c5a2f646e30cc99be2c02566aa2 (patch)
tree6759c76a2c7e713a258c78f91111f2439dcfb9f6 /modules/devices.py
parent84b6fcd02ca6d6ab48c4b6be4bb8724b1c2e7014 (diff)
downloadstable-diffusion-webui-gfx803-fca42949a3593c5a2f646e30cc99be2c02566aa2.tar.gz
stable-diffusion-webui-gfx803-fca42949a3593c5a2f646e30cc99be2c02566aa2.tar.bz2
stable-diffusion-webui-gfx803-fca42949a3593c5a2f646e30cc99be2c02566aa2.zip
rework torchsde._brownian.brownian_interval replacement to use device.randn_local and respect the NV setting.
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py44
1 files changed, 38 insertions, 6 deletions
diff --git a/modules/devices.py b/modules/devices.py
index b58776d8..00a00b18 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -71,14 +71,17 @@ def enable_tf32():
torch.backends.cudnn.allow_tf32 = True
-
errors.run(enable_tf32, "Enabling TF32")
-cpu = torch.device("cpu")
-device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
-dtype = torch.float16
-dtype_vae = torch.float16
-dtype_unet = torch.float16
+cpu: torch.device = torch.device("cpu")
+device: torch.device = None
+device_interrogate: torch.device = None
+device_gfpgan: torch.device = None
+device_esrgan: torch.device = None
+device_codeformer: torch.device = None
+dtype: torch.dtype = torch.float16
+dtype_vae: torch.dtype = torch.float16
+dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False
@@ -94,6 +97,10 @@ nv_rng = None
def randn(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""
+
from modules.shared import opts
manual_seed(seed)
@@ -107,7 +114,27 @@ def randn(seed, shape):
return torch.randn(shape, device=device)
+def randn_local(seed, shape):
+ """Generate a tensor with random numbers from a normal distribution using seed.
+
+ Does not change the global random number generator. You can only generate the seed's first tensor using this function."""
+
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ rng = rng_philox.Generator(seed)
+ return torch.asarray(rng.randn(shape), device=device)
+
+ local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
+ local_generator = torch.Generator(local_device).manual_seed(int(seed))
+ return torch.randn(shape, device=local_device, generator=local_generator).to(device)
+
+
def randn_like(x):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
from modules.shared import opts
if opts.randn_source == "NV":
@@ -120,6 +147,10 @@ def randn_like(x):
def randn_without_seed(shape):
+ """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.
+
+ Use either randn() or manual_seed() to initialize the generator."""
+
from modules.shared import opts
if opts.randn_source == "NV":
@@ -132,6 +163,7 @@ def randn_without_seed(shape):
def manual_seed(seed):
+ """Set up a global random number generator using the specified seed."""
from modules.shared import opts
if opts.randn_source == "NV":