aboutsummaryrefslogtreecommitdiffstats
path: root/modules/devices.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-02 21:00:23 +0000
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-02 21:00:23 +0000
commit84b6fcd02ca6d6ab48c4b6be4bb8724b1c2e7014 (patch)
tree1be8f712b6f87a29e34204acc54c0b05bc06dfb8 /modules/devices.py
parentccb92339348f6973de39cde062982a51a4cd0818 (diff)
downloadstable-diffusion-webui-gfx803-84b6fcd02ca6d6ab48c4b6be4bb8724b1c2e7014.tar.gz
stable-diffusion-webui-gfx803-84b6fcd02ca6d6ab48c4b6be4bb8724b1c2e7014.tar.bz2
stable-diffusion-webui-gfx803-84b6fcd02ca6d6ab48c4b6be4bb8724b1c2e7014.zip
add NV option for Random number generator source setting, which allows to generate same pictures on CPU/AMD/Mac as on NVidia videocards.
Diffstat (limited to 'modules/devices.py')
-rw-r--r--modules/devices.py39
1 files changed, 37 insertions, 2 deletions
diff --git a/modules/devices.py b/modules/devices.py
index 57e51da3..b58776d8 100644
--- a/modules/devices.py
+++ b/modules/devices.py
@@ -3,7 +3,7 @@ import contextlib
from functools import lru_cache
import torch
-from modules import errors
+from modules import errors, rng_philox
if sys.platform == "darwin":
from modules import mac_specific
@@ -90,23 +90,58 @@ def cond_cast_float(input):
return input.float() if unet_needs_upcast else input
+nv_rng = None
+
+
def randn(seed, shape):
from modules.shared import opts
- torch.manual_seed(seed)
+ manual_seed(seed)
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def randn_like(x):
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype)
+
+ if opts.randn_source == "CPU" or x.device.type == 'mps':
+ return torch.randn_like(x, device=cpu).to(x.device)
+
+ return torch.randn_like(x)
+
+
def randn_without_seed(shape):
from modules.shared import opts
+ if opts.randn_source == "NV":
+ return torch.asarray(nv_rng.randn(shape), device=device)
+
if opts.randn_source == "CPU" or device.type == 'mps':
return torch.randn(shape, device=cpu).to(device)
+
return torch.randn(shape, device=device)
+def manual_seed(seed):
+ from modules.shared import opts
+
+ if opts.randn_source == "NV":
+ global nv_rng
+ nv_rng = rng_philox.Generator(seed)
+ return
+
+ torch.manual_seed(seed)
+
+
def autocast(disable=False):
from modules import shared