From 0d5dc9a6e7f6362e423a06bf0e75dd5854025394 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 9 Aug 2023 08:43:31 +0300 Subject: rework RNG to use generators instead of generating noises beforehand --- modules/rng.py | 171 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 171 insertions(+) create mode 100644 modules/rng.py (limited to 'modules/rng.py') diff --git a/modules/rng.py b/modules/rng.py new file mode 100644 index 00000000..2d7baea5 --- /dev/null +++ b/modules/rng.py @@ -0,0 +1,171 @@ +import torch + +from modules import devices, rng_philox, shared + + +def randn(seed, shape, generator=None): + """Generate a tensor with random numbers from a normal distribution using seed. + + Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed.""" + + manual_seed(seed) + + if shared.opts.randn_source == "NV": + return torch.asarray((generator or nv_rng).randn(shape), device=devices.device) + + if shared.opts.randn_source == "CPU" or devices.device.type == 'mps': + return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device) + + return torch.randn(shape, device=devices.device, generator=generator) + + +def randn_local(seed, shape): + """Generate a tensor with random numbers from a normal distribution using seed. + + Does not change the global random number generator. You can only generate the seed's first tensor using this function.""" + + if shared.opts.randn_source == "NV": + rng = rng_philox.Generator(seed) + return torch.asarray(rng.randn(shape), device=devices.device) + + local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device + local_generator = torch.Generator(local_device).manual_seed(int(seed)) + return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device) + + +def randn_like(x): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + + if shared.opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype) + + if shared.opts.randn_source == "CPU" or x.device.type == 'mps': + return torch.randn_like(x, device=devices.cpu).to(x.device) + + return torch.randn_like(x) + + +def randn_without_seed(shape, generator=None): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + + if shared.opts.randn_source == "NV": + return torch.asarray((generator or nv_rng).randn(shape), device=devices.device) + + if shared.opts.randn_source == "CPU" or devices.device.type == 'mps': + return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device) + + return torch.randn(shape, device=devices.device, generator=generator) + + +def manual_seed(seed): + """Set up a global random number generator using the specified seed.""" + from modules.shared import opts + + if opts.randn_source == "NV": + global nv_rng + nv_rng = rng_philox.Generator(seed) + return + + torch.manual_seed(seed) + + +def create_generator(seed): + if shared.opts.randn_source == "NV": + return rng_philox.Generator(seed) + + device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device + generator = torch.Generator(device).manual_seed(int(seed)) + return generator + + +# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 +def slerp(val, low, high): + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + dot = (low_norm*high_norm).sum(1) + + if dot.mean() > 0.9995: + return low * val + high * (1 - val) + + omega = torch.acos(dot) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res + + +class ImageRNG: + def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0): + self.shape = shape + self.seeds = seeds + self.subseeds = subseeds + self.subseed_strength = subseed_strength + self.seed_resize_from_h = seed_resize_from_h + self.seed_resize_from_w = seed_resize_from_w + + self.generators = [create_generator(seed) for seed in seeds] + + self.is_first = True + + def first(self): + noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8) + + xs = [] + + for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)): + subnoise = None + if self.subseeds is not None and self.subseed_strength != 0: + subseed = 0 if i >= len(self.subseeds) else self.subseeds[i] + subnoise = randn(subseed, noise_shape) + + if noise_shape != self.shape: + noise = randn(seed, noise_shape) + else: + noise = randn(seed, self.shape, generator=generator) + + if subnoise is not None: + noise = slerp(self.subseed_strength, noise, subnoise) + + if noise_shape != self.shape: + x = randn(seed, self.shape, generator=generator) + dx = (self.shape[2] - noise_shape[2]) // 2 + dy = (self.shape[1] - noise_shape[1]) // 2 + w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx + h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy + tx = 0 if dx < 0 else dx + ty = 0 if dy < 0 else dy + dx = max(-dx, 0) + dy = max(-dy, 0) + + x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w] + noise = x + + xs.append(noise) + + eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0 + if eta_noise_seed_delta: + self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds] + + return torch.stack(xs).to(shared.device) + + def next(self): + if self.is_first: + self.is_first = False + return self.first() + + xs = [] + for generator in self.generators: + x = randn_without_seed(self.shape, generator=generator) + xs.append(x) + + return torch.stack(xs).to(shared.device) + + +devices.randn = randn +devices.randn_local = randn_local +devices.randn_like = randn_like +devices.randn_without_seed = randn_without_seed +devices.manual_seed = manual_seed -- cgit v1.2.3