From 8285a149d8c488ae6c7a566eb85fb5e825145464 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 19:20:11 +0300 Subject: add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit when you can run both old and new implementations to compare them) --- modules/sd_samplers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers.py') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index bea2684c..fe206894 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,4 +1,4 @@ -from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared +from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_timesteps, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 @@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, *sd_samplers_compvis.samplers_data_compvis, + *sd_samplers_timesteps.samplers_data_timesteps, ] all_samplers_map = {x.name: x for x in all_samplers} -- cgit v1.2.3 From a8a256f9b5b445206818bfc8a363ed5a1ba50c86 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 21:07:18 +0300 Subject: REMOVE --- modules/processing.py | 3 - modules/sd_hijack.py | 4 +- modules/sd_hijack_inpainting.py | 95 --------------- modules/sd_samplers.py | 7 +- modules/sd_samplers_cfg_denoiser.py | 1 - modules/sd_samplers_compvis.py | 224 ------------------------------------ modules/sd_samplers_kdiffusion.py | 3 +- modules/sd_samplers_timesteps.py | 6 +- 8 files changed, 7 insertions(+), 336 deletions(-) delete mode 100644 modules/sd_hijack_inpainting.py delete mode 100644 modules/sd_samplers_compvis.py (limited to 'modules/sd_samplers.py') diff --git a/modules/processing.py b/modules/processing.py index 31745006..61ba5f11 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1112,9 +1112,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): img2img_sampler_name = self.hr_sampler_name or self.sampler_name - if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM - img2img_sampler_name = 'DDIM' - self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model) if self.latent_scale_mode is not None: diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index 9ad98199..46652fbd 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -5,7 +5,7 @@ from types import MethodType from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts -from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting +from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr import ldm.modules.attention import ldm.modules.diffusionmodules.model @@ -34,8 +34,6 @@ ldm.modules.diffusionmodules.model.print = shared.ldm_print ldm.util.print = shared.ldm_print ldm.models.diffusion.ddpm.print = shared.ldm_print -sd_hijack_inpainting.do_inpainting_hijack() - optimizers = [] current_optimizer: sd_hijack_optimizations.SdOptimization = None diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py deleted file mode 100644 index 2d44b856..00000000 --- a/modules/sd_hijack_inpainting.py +++ /dev/null @@ -1,95 +0,0 @@ -import torch - -import ldm.models.diffusion.ddpm -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms - -from ldm.models.diffusion.ddim import noise_like -from ldm.models.diffusion.sampling_util import norm_thresholding - - -@torch.no_grad() -def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, - temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, - unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None, dynamic_threshold=None): - b, *_, device = *x.shape, x.device - - def get_model_output(x, t): - if unconditional_conditioning is None or unconditional_guidance_scale == 1.: - e_t = self.model.apply_model(x, t, c) - else: - x_in = torch.cat([x] * 2) - t_in = torch.cat([t] * 2) - - if isinstance(c, dict): - assert isinstance(unconditional_conditioning, dict) - c_in = {} - for k in c: - if isinstance(c[k], list): - c_in[k] = [ - torch.cat([unconditional_conditioning[k][i], c[k][i]]) - for i in range(len(c[k])) - ] - else: - c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) - else: - c_in = torch.cat([unconditional_conditioning, c]) - - e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) - e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) - - if score_corrector is not None: - assert self.model.parameterization == "eps" - e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) - - return e_t - - alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas - alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev - sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas - sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas - - def get_x_prev_and_pred_x0(e_t, index): - # select parameters corresponding to the currently considered timestep - a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) - a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) - sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) - sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) - - # current prediction for x_0 - pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() - if quantize_denoised: - pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) - if dynamic_threshold is not None: - pred_x0 = norm_thresholding(pred_x0, dynamic_threshold) - # direction pointing to x_t - dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t - noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature - if noise_dropout > 0.: - noise = torch.nn.functional.dropout(noise, p=noise_dropout) - x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise - return x_prev, pred_x0 - - e_t = get_model_output(x, t) - if len(old_eps) == 0: - # Pseudo Improved Euler (2nd order) - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) - e_t_next = get_model_output(x_prev, t_next) - e_t_prime = (e_t + e_t_next) / 2 - elif len(old_eps) == 1: - # 2nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (3 * e_t - old_eps[-1]) / 2 - elif len(old_eps) == 2: - # 3nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 - elif len(old_eps) >= 3: - # 4nd order Pseudo Linear Multistep (Adams-Bashforth) - e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 - - x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) - - return x_prev, pred_x0, e_t - - -def do_inpainting_hijack(): - ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index fe206894..05dbe2b5 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,11 +1,10 @@ -from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_timesteps, shared +from modules import sd_samplers_kdiffusion, sd_samplers_timesteps, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, - *sd_samplers_compvis.samplers_data_compvis, *sd_samplers_timesteps.samplers_data_timesteps, ] all_samplers_map = {x.name: x for x in all_samplers} @@ -42,10 +41,8 @@ def set_samplers(): global samplers, samplers_for_img2img hidden = set(shared.opts.hide_samplers) - hidden_img2img = set(shared.opts.hide_samplers + ['PLMS', 'UniPC']) - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden_img2img] + samplers_for_img2img = [x for x in all_samplers if x.name not in hidden] samplers_map.clear() for sampler in all_samplers: diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 166a00c7..d826222c 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -1,4 +1,3 @@ -from collections import deque import torch from modules import prompt_parser, devices, sd_samplers_common diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py deleted file mode 100644 index 4a8396f9..00000000 --- a/modules/sd_samplers_compvis.py +++ /dev/null @@ -1,224 +0,0 @@ -import math -import ldm.models.diffusion.ddim -import ldm.models.diffusion.plms - -import numpy as np -import torch - -from modules.shared import state -from modules import sd_samplers_common, prompt_parser, shared -import modules.models.diffusion.uni_pc - - -samplers_data_compvis = [ - sd_samplers_common.SamplerData('DDIM', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.ddim.DDIMSampler, model), [], {"default_eta_is_0": True, "uses_ensd": True, "no_sdxl": True}), - sd_samplers_common.SamplerData('PLMS', lambda model: VanillaStableDiffusionSampler(ldm.models.diffusion.plms.PLMSSampler, model), [], {"no_sdxl": True}), - sd_samplers_common.SamplerData('UniPC', lambda model: VanillaStableDiffusionSampler(modules.models.diffusion.uni_pc.UniPCSampler, model), [], {"no_sdxl": True}), -] - - -class VanillaStableDiffusionSampler: - def __init__(self, constructor, sd_model): - self.sampler = constructor(sd_model) - self.is_ddim = hasattr(self.sampler, 'p_sample_ddim') - self.is_plms = hasattr(self.sampler, 'p_sample_plms') - self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler) - self.orig_p_sample_ddim = None - if self.is_plms: - self.orig_p_sample_ddim = self.sampler.p_sample_plms - elif self.is_ddim: - self.orig_p_sample_ddim = self.sampler.p_sample_ddim - self.mask = None - self.nmask = None - self.init_latent = None - self.sampler_noises = None - self.step = 0 - self.stop_at = None - self.eta = None - self.config = None - self.last_latent = None - - self.conditioning_key = sd_model.model.conditioning_key - - def number_of_needed_noises(self, p): - return 0 - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except sd_samplers_common.InterruptedException: - return self.last_latent - - def p_sample_ddim_hook(self, x_dec, cond, ts, unconditional_conditioning, *args, **kwargs): - x_dec, ts, cond, unconditional_conditioning = self.before_sample(x_dec, ts, cond, unconditional_conditioning) - - res = self.orig_p_sample_ddim(x_dec, cond, ts, *args, unconditional_conditioning=unconditional_conditioning, **kwargs) - - x_dec, ts, cond, unconditional_conditioning, res = self.after_sample(x_dec, ts, cond, unconditional_conditioning, res) - - return res - - def before_sample(self, x, ts, cond, unconditional_conditioning): - if state.interrupted or state.skipped: - raise sd_samplers_common.InterruptedException - - if self.stop_at is not None and self.step > self.stop_at: - raise sd_samplers_common.InterruptedException - - # Have to unwrap the inpainting conditioning here to perform pre-processing - image_conditioning = None - uc_image_conditioning = None - if isinstance(cond, dict): - if self.conditioning_key == "crossattn-adm": - image_conditioning = cond["c_adm"] - uc_image_conditioning = unconditional_conditioning["c_adm"] - else: - image_conditioning = cond["c_concat"][0] - cond = cond["c_crossattn"][0] - unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] - - conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) - - assert all(len(conds) == 1 for conds in conds_list), 'composition via AND is not supported for DDIM/PLMS samplers' - cond = tensor - - # for DDIM, shapes must match, we can't just process cond and uncond independently; - # filling unconditional_conditioning with repeats of the last vector to match length is - # not 100% correct but should work well enough - if unconditional_conditioning.shape[1] < cond.shape[1]: - last_vector = unconditional_conditioning[:, -1:] - last_vector_repeated = last_vector.repeat([1, cond.shape[1] - unconditional_conditioning.shape[1], 1]) - unconditional_conditioning = torch.hstack([unconditional_conditioning, last_vector_repeated]) - elif unconditional_conditioning.shape[1] > cond.shape[1]: - unconditional_conditioning = unconditional_conditioning[:, :cond.shape[1]] - - if self.mask is not None: - img_orig = self.sampler.model.q_sample(self.init_latent, ts) - x = img_orig * self.mask + self.nmask * x - - # Wrap the image conditioning back up since the DDIM code can accept the dict directly. - # Note that they need to be lists because it just concatenates them later. - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - cond = {"c_adm": image_conditioning, "c_crossattn": [cond]} - unconditional_conditioning = {"c_adm": uc_image_conditioning, "c_crossattn": [unconditional_conditioning]} - else: - cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - return x, ts, cond, unconditional_conditioning - - def update_step(self, last_latent): - if self.mask is not None: - self.last_latent = self.init_latent * self.mask + self.nmask * last_latent - else: - self.last_latent = last_latent - - sd_samplers_common.store_latent(self.last_latent) - - self.step += 1 - state.sampling_step = self.step - shared.total_tqdm.update() - - def after_sample(self, x, ts, cond, uncond, res): - if not self.is_unipc: - self.update_step(res[1]) - - return x, ts, cond, uncond, res - - def unipc_after_update(self, x, model_x): - self.update_step(x) - - def initialize(self, p): - if self.is_ddim: - self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim - else: - self.eta = 0.0 - - if self.eta != 0.0: - p.extra_generation_params["Eta DDIM"] = self.eta - - if self.is_unipc: - keys = [ - ('UniPC variant', 'uni_pc_variant'), - ('UniPC skip type', 'uni_pc_skip_type'), - ('UniPC order', 'uni_pc_order'), - ('UniPC lower order final', 'uni_pc_lower_order_final'), - ] - - for name, key in keys: - v = getattr(shared.opts, key) - if v != shared.opts.get_default(key): - p.extra_generation_params[name] = v - - for fieldname in ['p_sample_ddim', 'p_sample_plms']: - if hasattr(self.sampler, fieldname): - setattr(self.sampler, fieldname, self.p_sample_ddim_hook) - if self.is_unipc: - self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx)) - - self.mask = p.mask if hasattr(p, 'mask') else None - self.nmask = p.nmask if hasattr(p, 'nmask') else None - - - def adjust_steps_if_invalid(self, p, num_steps): - if ((self.config.name == 'DDIM') and p.ddim_discretize == 'uniform') or (self.config.name == 'PLMS') or (self.config.name == 'UniPC'): - if self.config.name == 'UniPC' and num_steps < shared.opts.uni_pc_order: - num_steps = shared.opts.uni_pc_order - valid_step = 999 / (1000 // num_steps) - if valid_step == math.floor(valid_step): - return int(valid_step) + 1 - - return num_steps - - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) - steps = self.adjust_steps_if_invalid(p, steps) - self.initialize(p) - - self.sampler.make_schedule(ddim_num_steps=steps, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False) - x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise) - - self.init_latent = x - self.last_latent = x - self.step = 0 - - # Wrap the conditioning models with additional image conditioning for inpainting model - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - conditioning = {"c_adm": image_conditioning, "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_adm": torch.zeros_like(image_conditioning), "c_crossattn": [unconditional_conditioning]} - else: - conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} - unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} - - samples = self.launch_sampling(t_enc + 1, lambda: self.sampler.decode(x1, conditioning, t_enc, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning)) - - return samples - - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): - self.initialize(p) - - self.init_latent = None - self.last_latent = x - self.step = 0 - - steps = self.adjust_steps_if_invalid(p, steps or p.steps) - - # Wrap the conditioning models with additional image conditioning for inpainting model - # dummy_for_plms is needed because PLMS code checks the first item in the dict to have the right shape - if image_conditioning is not None: - if self.conditioning_key == "crossattn-adm": - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_adm": image_conditioning} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_adm": torch.zeros_like(image_conditioning)} - else: - conditioning = {"dummy_for_plms": np.zeros((conditioning.shape[0],)), "c_crossattn": [conditioning], "c_concat": [image_conditioning]} - unconditional_conditioning = {"c_crossattn": [unconditional_conditioning], "c_concat": [image_conditioning]} - - samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) - - return samples_ddim diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3a2e01b7..27a73486 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -1,8 +1,7 @@ -from collections import deque import torch import inspect import k_diffusion.sampling -from modules import devices, sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser +from modules import sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser from modules.shared import opts import modules.shared as shared diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index 8560d009..d89d0efb 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -7,9 +7,9 @@ from modules.shared import opts import modules.shared as shared samplers_timesteps = [ - ('k_DDIM', sd_samplers_timesteps_impl.ddim, ['k_ddim'], {}), - ('k_PLMS', sd_samplers_timesteps_impl.plms, ['k_plms'], {}), - ('k_UniPC', sd_samplers_timesteps_impl.unipc, ['k_unipc'], {}), + ('DDIM', sd_samplers_timesteps_impl.ddim, ['ddim'], {}), + ('PLMS', sd_samplers_timesteps_impl.plms, ['plms'], {}), + ('UniPC', sd_samplers_timesteps_impl.unipc, ['unipc'], {}), ] -- cgit v1.2.3 From 70c63c1208d33bf02e15c4e310bac83f12ee8625 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 21:28:34 +0300 Subject: pass samplers from UI by name, make it possible to use a sampler from infotext even if it's hidden in the dropdown --- modules/img2img.py | 6 +++--- modules/sd_samplers.py | 13 +++++++++---- modules/txt2img.py | 8 ++++---- modules/ui.py | 29 ++++++++++++++--------------- 4 files changed, 30 insertions(+), 26 deletions(-) (limited to 'modules/sd_samplers.py') diff --git a/modules/img2img.py b/modules/img2img.py index d8e1c534..e06ac1d6 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -6,7 +6,7 @@ import numpy as np from PIL import Image, ImageOps, ImageFilter, ImageEnhance, UnidentifiedImageError import gradio as gr -from modules import sd_samplers, images as imgutil +from modules import images as imgutil from modules.generation_parameters_copypaste import create_override_settings_dict, parse_generation_parameters from modules.processing import Processed, StableDiffusionProcessingImg2Img, process_images from modules.shared import opts, state @@ -116,7 +116,7 @@ def process_batch(p, input_dir, output_dir, inpaint_mask_dir, args, to_scale=Fal process_images(p) -def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_index: int, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): +def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_styles, init_img, sketch, init_img_with_mask, inpaint_color_sketch, inpaint_color_sketch_orig, init_img_inpaint, init_mask_inpaint, steps: int, sampler_name: str, mask_blur: int, mask_alpha: float, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, image_cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, selected_scale_tab: int, height: int, width: int, scale_by: float, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, img2img_batch_inpaint_mask_dir: str, override_settings_texts, img2img_batch_use_png_info: bool, img2img_batch_png_info_props: list, img2img_batch_png_info_dir: str, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) is_batch = mode == 5 @@ -172,7 +172,7 @@ def img2img(id_task: str, mode: int, prompt: str, negative_prompt: str, prompt_s seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w, seed_enable_extras=seed_enable_extras, - sampler_name=sd_samplers.samplers_for_img2img[sampler_index].name, + sampler_name=sampler_name, batch_size=batch_size, n_iter=n_iter, steps=steps, diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index 05dbe2b5..45faae62 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -12,6 +12,7 @@ all_samplers_map = {x.name: x for x in all_samplers} samplers = [] samplers_for_img2img = [] samplers_map = {} +samplers_hidden = {} def find_sampler_config(name): @@ -38,11 +39,11 @@ def create_sampler(name, model): def set_samplers(): - global samplers, samplers_for_img2img + global samplers, samplers_for_img2img, samplers_hidden - hidden = set(shared.opts.hide_samplers) - samplers = [x for x in all_samplers if x.name not in hidden] - samplers_for_img2img = [x for x in all_samplers if x.name not in hidden] + samplers_hidden = set(shared.opts.hide_samplers) + samplers = all_samplers + samplers_for_img2img = all_samplers samplers_map.clear() for sampler in all_samplers: @@ -51,4 +52,8 @@ def set_samplers(): samplers_map[alias.lower()] = sampler.name +def visible_sampler_names(): + return [x.name for x in samplers if x.name not in samplers_hidden] + + set_samplers() diff --git a/modules/txt2img.py b/modules/txt2img.py index 935ed418..8fa389b5 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -1,7 +1,7 @@ from contextlib import closing import modules.scripts -from modules import sd_samplers, processing +from modules import processing from modules.generation_parameters_copypaste import create_override_settings_dict from modules.shared import opts, cmd_opts import modules.shared as shared @@ -9,7 +9,7 @@ from modules.ui import plaintext_to_html import gradio as gr -def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_index: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_index: int, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): +def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, steps: int, sampler_name: str, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, enable_hr: bool, denoising_strength: float, hr_scale: float, hr_upscaler: str, hr_second_pass_steps: int, hr_resize_x: int, hr_resize_y: int, hr_checkpoint_name: str, hr_sampler_name: str, hr_prompt: str, hr_negative_prompt, override_settings_texts, request: gr.Request, *args): override_settings = create_override_settings_dict(override_settings_texts) p = processing.StableDiffusionProcessingTxt2Img( @@ -25,7 +25,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w, seed_enable_extras=seed_enable_extras, - sampler_name=sd_samplers.samplers[sampler_index].name, + sampler_name=sampler_name, batch_size=batch_size, n_iter=n_iter, steps=steps, @@ -42,7 +42,7 @@ def txt2img(id_task: str, prompt: str, negative_prompt: str, prompt_styles, step hr_resize_x=hr_resize_x, hr_resize_y=hr_resize_y, hr_checkpoint_name=None if hr_checkpoint_name == 'Use same checkpoint' else hr_checkpoint_name, - hr_sampler_name=sd_samplers.samplers_for_img2img[hr_sampler_index - 1].name if hr_sampler_index != 0 else None, + hr_sampler_name=hr_sampler_name, hr_prompt=hr_prompt, hr_negative_prompt=hr_negative_prompt, override_settings=override_settings, diff --git a/modules/ui.py b/modules/ui.py index 5150dae4..e3753e97 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -13,7 +13,7 @@ from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call from modules import gradio_extensons # noqa: F401 -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts, sd_samplers from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path from modules.ui_common import create_refresh_button @@ -29,7 +29,6 @@ import modules.shared as shared import modules.images from modules import prompt_parser from modules.sd_hijack import model_hijack -from modules.sd_samplers import samplers, samplers_for_img2img from modules.generation_parameters_copypaste import image_from_url_text create_setting_component = ui_settings.create_setting_component @@ -360,14 +359,14 @@ def create_output_panel(tabname, outdir): def create_sampler_and_steps_selection(choices, tabname): if opts.samplers_in_dropdown: with FormRow(elem_id=f"sampler_selection_{tabname}"): - sampler_index = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + sampler_name = gr.Dropdown(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0]) steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) else: with FormGroup(elem_id=f"sampler_selection_{tabname}"): steps = gr.Slider(minimum=1, maximum=150, step=1, elem_id=f"{tabname}_steps", label="Sampling steps", value=20) - sampler_index = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=[x.name for x in choices], value=choices[0].name, type="index") + sampler_name = gr.Radio(label='Sampling method', elem_id=f"{tabname}_sampling", choices=choices, value=choices[0]) - return steps, sampler_index + return steps, sampler_name def ordered_ui_categories(): @@ -414,7 +413,7 @@ def create_ui(): for category in ordered_ui_categories(): if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers, "txt2img") + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "txt2img") elif category == "dimensions": with FormRow(): @@ -460,7 +459,7 @@ def create_ui(): hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") - hr_sampler_index = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + [x.name for x in samplers_for_img2img], value="Use same sampler", type="index") + hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: with gr.Column(scale=80): @@ -520,7 +519,7 @@ def create_ui(): toprow.negative_prompt, toprow.ui_styles.dropdown, steps, - sampler_index, + sampler_name, restore_faces, tiling, batch_count, @@ -538,7 +537,7 @@ def create_ui(): hr_resize_x, hr_resize_y, hr_checkpoint_name, - hr_sampler_index, + hr_sampler_name, hr_prompt, hr_negative_prompt, override_settings, @@ -583,7 +582,7 @@ def create_ui(): (toprow.prompt, "Prompt"), (toprow.negative_prompt, "Negative prompt"), (steps, "Steps"), - (sampler_index, "Sampler"), + (sampler_name, "Sampler"), (restore_faces, "Face restoration"), (cfg_scale, "CFG scale"), (seed, "Seed"), @@ -605,7 +604,7 @@ def create_ui(): (hr_resize_x, "Hires resize-1"), (hr_resize_y, "Hires resize-2"), (hr_checkpoint_name, "Hires checkpoint"), - (hr_sampler_index, "Hires sampler"), + (hr_sampler_name, "Hires sampler"), (hr_sampler_container, lambda d: gr.update(visible=True) if d.get("Hires sampler", "Use same sampler") != "Use same sampler" or d.get("Hires checkpoint", "Use same checkpoint") != "Use same checkpoint" else gr.update()), (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), @@ -621,7 +620,7 @@ def create_ui(): toprow.prompt, toprow.negative_prompt, steps, - sampler_index, + sampler_name, cfg_scale, seed, width, @@ -744,7 +743,7 @@ def create_ui(): for category in ordered_ui_categories(): if category == "sampler": - steps, sampler_index = create_sampler_and_steps_selection(samplers_for_img2img, "img2img") + steps, sampler_name = create_sampler_and_steps_selection(sd_samplers.visible_sampler_names(), "img2img") elif category == "dimensions": with FormRow(): @@ -876,7 +875,7 @@ def create_ui(): init_img_inpaint, init_mask_inpaint, steps, - sampler_index, + sampler_name, mask_blur, mask_alpha, inpainting_fill, @@ -972,7 +971,7 @@ def create_ui(): (toprow.prompt, "Prompt"), (toprow.negative_prompt, "Negative prompt"), (steps, "Steps"), - (sampler_index, "Sampler"), + (sampler_name, "Sampler"), (restore_faces, "Face restoration"), (cfg_scale, "CFG scale"), (image_cfg_scale, "Image CFG scale"), -- cgit v1.2.3