diff options
Diffstat (limited to 'modules/sd_samplers_common.py')
-rw-r--r-- | modules/sd_samplers_common.py | 24 |
1 files changed, 12 insertions, 12 deletions
diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index b6ad6830..35c4d657 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,5 +1,5 @@ import inspect
-from collections import namedtuple, deque
+from collections import namedtuple
import numpy as np
import torch
from PIL import Image
@@ -161,10 +161,15 @@ def apply_refiner(sampler): class TorchHijack:
- def __init__(self, sampler_noises):
- # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
- # implementation.
- self.sampler_noises = deque(sampler_noises)
+ """This is here to replace torch.randn_like of k-diffusion.
+
+ k-diffusion has random_sampler argument for most samplers, but not for all, so
+ this is needed to properly replace every use of torch.randn_like.
+
+ We need to replace to make images generated in batches to be same as images generated individually."""
+
+ def __init__(self, p):
+ self.rng = p.rng
def __getattr__(self, item):
if item == 'randn_like':
@@ -176,12 +181,7 @@ class TorchHijack: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def randn_like(self, x):
- if self.sampler_noises:
- noise = self.sampler_noises.popleft()
- if noise.shape == x.shape:
- return noise
-
- return devices.randn_like(x)
+ return self.rng.next()
class Sampler:
@@ -248,7 +248,7 @@ class Sampler: self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)
- k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])
+ k_diffusion.sampling.torch = TorchHijack(p)
extra_params_kwargs = {}
for param_name in self.extra_params:
|