From aefe1325df60a925b3a75a2cb58bf74e8ca86df4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 29 Jul 2023 08:11:59 +0300 Subject: split the new sampler into a different file --- modules/sd_samplers_kdiffusion.py | 75 +++------------------------------------ 1 file changed, 4 insertions(+), 71 deletions(-) (limited to 'modules/sd_samplers_kdiffusion.py') diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index a54673eb..e0da3425 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra from modules.shared import opts, state import modules.shared as shared @@ -30,81 +30,14 @@ samplers_k_diffusion = [ ('DPM++ 2M Karras', 'sample_dpmpp_2m', ['k_dpmpp_2m_ka'], {'scheduler': 'karras'}), ('DPM++ SDE Karras', 'sample_dpmpp_sde', ['k_dpmpp_sde_ka'], {'scheduler': 'karras', "second_order": True, "brownian_noise": True}), ('DPM++ 2M SDE Karras', 'sample_dpmpp_2m_sde', ['k_dpmpp_2m_sde_ka'], {'scheduler': 'karras', "brownian_noise": True}), - ('Restart (new)', 'restart_sampler', ['restart'], {'scheduler': 'karras', "second_order": True}), + ('Restart', sd_samplers_extra.restart_sampler, ['restart'], {'scheduler': 'karras'}), ] -@torch.no_grad() -def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=None, s_noise=1., restart_list = None): - """Implements restart sampling in Restart Sampling for Improving Generative Processes (2023)""" - '''Restart_list format: {min_sigma: [ restart_steps, restart_times, max_sigma]}''' - '''If restart_list is None: will choose restart_list automatically, otherwise will use the given restart_list''' - from tqdm.auto import trange - extra_args = {} if extra_args is None else extra_args - s_in = x.new_ones([x.shape[0]]) - step_id = 0 - from k_diffusion.sampling import to_d, get_sigmas_karras - def heun_step(x, old_sigma, new_sigma, second_order = True): - nonlocal step_id - denoised = model(x, old_sigma * s_in, **extra_args) - d = to_d(x, old_sigma, denoised) - if callback is not None: - callback({'x': x, 'i': step_id, 'sigma': new_sigma, 'sigma_hat': old_sigma, 'denoised': denoised}) - dt = new_sigma - old_sigma - if new_sigma == 0 or not second_order: - # Euler method - x = x + d * dt - else: - # Heun's method - x_2 = x + d * dt - denoised_2 = model(x_2, new_sigma * s_in, **extra_args) - d_2 = to_d(x_2, new_sigma, denoised_2) - d_prime = (d + d_2) / 2 - x = x + d_prime * dt - step_id += 1 - return x - steps = sigmas.shape[0] - 1 - if restart_list is None: - if steps >= 20: - restart_steps = 9 - restart_times = 1 - if steps >= 36: - restart_steps = steps // 4 - restart_times = 2 - sigmas = get_sigmas_karras(steps - restart_steps * restart_times, sigmas[-2].item(), sigmas[0].item(), device=sigmas.device) - restart_list = {0.1: [restart_steps + 1, restart_times, 2]} - else: - restart_list = dict() - temp_list = dict() - for key, value in restart_list.items(): - temp_list[int(torch.argmin(abs(sigmas - key), dim=0))] = value - restart_list = temp_list - step_list = [] - for i in range(len(sigmas) - 1): - step_list.append((sigmas[i], sigmas[i + 1])) - if i + 1 in restart_list: - restart_steps, restart_times, restart_max = restart_list[i + 1] - min_idx = i + 1 - max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0)) - if max_idx < min_idx: - sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx].item(), sigmas[max_idx].item(), device=sigmas.device)[:-1] - while restart_times > 0: - restart_times -= 1 - step_list.extend([(old_sigma, new_sigma) for (old_sigma, new_sigma) in zip(sigma_restart[:-1], sigma_restart[1:])]) - last_sigma = None - for i in trange(len(step_list), disable=disable): - if last_sigma is None: - last_sigma = step_list[i][0] - elif last_sigma < step_list[i][0]: - x = x + k_diffusion.sampling.torch.randn_like(x) * s_noise * (step_list[i][0] ** 2 - last_sigma ** 2) ** 0.5 - x = heun_step(x, step_list[i][0], step_list[i][1]) - last_sigma = step_list[i][1] - return x - samplers_data_k_diffusion = [ sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: KDiffusionSampler(funcname, model), aliases, options) for label, funcname, aliases, options in samplers_k_diffusion - if (hasattr(k_diffusion.sampling, funcname) or funcname == 'restart_sampler') + if callable(funcname) or hasattr(k_diffusion.sampling, funcname) ] sampler_extra_params = { @@ -339,7 +272,7 @@ class KDiffusionSampler: self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) self.funcname = funcname - self.func = getattr(k_diffusion.sampling, self.funcname) if funcname != "restart_sampler" else restart_sampler + self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) self.model_wrap_cfg = CFGDenoiser(self.model_wrap) self.sampler_noises = None -- cgit v1.2.3