From fca42949a3593c5a2f646e30cc99be2c02566aa2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 3 Aug 2023 07:18:55 +0300 Subject: rework torchsde._brownian.brownian_interval replacement to use device.randn_local and respect the NV setting. --- modules/sd_samplers_common.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 763829f1..5deda761 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,10 +2,8 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd - +from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared from modules.shared import opts, state -import modules.shared as shared SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -85,11 +83,13 @@ class InterruptedException(BaseException): pass -if opts.randn_source == "CPU": +def replace_torchsde_browinan(): import torchsde._brownian.brownian_interval def torchsde_randn(size, dtype, device, seed): - generator = torch.Generator(devices.cpu).manual_seed(int(seed)) - return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device) + return devices.randn_local(seed, size).to(device=device, dtype=dtype) torchsde._brownian.brownian_interval._randn = torchsde_randn + + +replace_torchsde_browinan() -- cgit v1.2.3 From 75336dfc84cae280036bc52a6805eb10d9ae30ba Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 4 Aug 2023 13:38:52 +0800 Subject: add TAESD for i2i and t2i --- modules/processing.py | 13 +++++------ modules/sd_samplers_common.py | 38 ++++++++++++++++++++++++++----- modules/sd_vae_approx.py | 2 +- modules/sd_vae_taesd.py | 52 ++++++++++++++++++++++++++++++++++++------- modules/shared.py | 2 ++ 5 files changed, 86 insertions(+), 21 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/processing.py b/modules/processing.py index 8f34c8b4..099d86b7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -573,9 +573,10 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): def decode_first_stage(model, x): - x = model.decode_first_stage(x.to(devices.dtype_vae)) - - return x + from modules.sd_samplers_common import samples_to_images_tensor, approximation_indexes + x = x.to(devices.dtype_vae) + approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0) + return samples_to_images_tensor(x, approx_index, model) def get_fixed_seed(seed): @@ -1344,10 +1345,8 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") image = torch.from_numpy(batch_images) - image = 2. * image - 1. - image = image.to(shared.device, dtype=devices.dtype_vae) - - self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image)) + from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes + self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) devices.torch_gc() if self.resize_mode == 3: diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 5deda761..5a45e8eb 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -23,19 +23,29 @@ def setup_img2img_steps(p, steps=None): approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3} -def single_sample_to_image(sample, approximation=None): +def samples_to_images_tensor(sample, approximation=None, model=None): + '''latents -> images [-1, 1]''' if approximation is None: approximation = approximation_indexes.get(opts.show_progress_type, 0) if approximation == 2: - x_sample = sd_vae_approx.cheap_approximation(sample) * 0.5 + 0.5 + x_sample = sd_vae_approx.cheap_approximation(sample) elif approximation == 1: - x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() * 0.5 + 0.5 + x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach() elif approximation == 3: x_sample = sample * 1.5 - x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() + x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach() + x_sample = x_sample * 2 - 1 else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 + if model is None: + model = shared.sd_model + x_sample = model.decode_first_stage(sample) + + return x_sample + + +def single_sample_to_image(sample, approximation=None): + x_sample = samples_to_images_tensor(sample.unsqueeze(0), approximation)[0] * 0.5 + 0.5 x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) @@ -52,6 +62,24 @@ def samples_to_image_grid(samples, approximation=None): return images.image_grid([single_sample_to_image(sample, approximation) for sample in samples]) +def images_tensor_to_samples(image, approximation=None, model=None): + '''image[0, 1] -> latent''' + if approximation is None: + approximation = approximation_indexes.get(opts.sd_vae_encode_method, 0) + + if approximation == 3: + image = image.to(devices.device, devices.dtype) + x_latent = sd_vae_taesd.encoder_model()(image) / 1.5 + else: + if model is None: + model = shared.sd_model + image = image.to(shared.device, dtype=devices.dtype_vae) + image = image * 2 - 1 + x_latent = model.get_first_stage_encoding(model.encode_first_stage(image)) + + return x_latent + + def store_latent(decoded): state.current_latent = decoded diff --git a/modules/sd_vae_approx.py b/modules/sd_vae_approx.py index 86bd658a..3965e223 100644 --- a/modules/sd_vae_approx.py +++ b/modules/sd_vae_approx.py @@ -81,6 +81,6 @@ def cheap_approximation(sample): coefs = torch.tensor(coeffs).to(sample.device) - x_sample = torch.einsum("lxy,lr -> rxy", sample, coefs) + x_sample = torch.einsum("...lxy,lr -> ...rxy", sample, coefs) return x_sample diff --git a/modules/sd_vae_taesd.py b/modules/sd_vae_taesd.py index 5bf7c76e..808eb362 100644 --- a/modules/sd_vae_taesd.py +++ b/modules/sd_vae_taesd.py @@ -44,7 +44,17 @@ def decoder(): ) -class TAESD(nn.Module): +def encoder(): + return nn.Sequential( + conv(3, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 64, stride=2, bias=False), Block(64, 64), Block(64, 64), Block(64, 64), + conv(64, 4), + ) + + +class TAESDDecoder(nn.Module): latent_magnitude = 3 latent_shift = 0.5 @@ -55,21 +65,28 @@ class TAESD(nn.Module): self.decoder.load_state_dict( torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) - @staticmethod - def unscale_latents(x): - """[0, 1] -> raw latents""" - return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude) + +class TAESDEncoder(nn.Module): + latent_magnitude = 3 + latent_shift = 0.5 + + def __init__(self, encoder_path="taesd_encoder.pth"): + """Initialize pretrained TAESD on the given device from the given checkpoints.""" + super().__init__() + self.encoder = encoder() + self.encoder.load_state_dict( + torch.load(encoder_path, map_location='cpu' if devices.device.type != 'cuda' else None)) def download_model(model_path, model_url): if not os.path.exists(model_path): os.makedirs(os.path.dirname(model_path), exist_ok=True) - print(f'Downloading TAESD decoder to: {model_path}') + print(f'Downloading TAESD model to: {model_path}') torch.hub.download_url_to_file(model_url, model_path) -def model(): +def decoder_model(): model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth" loaded_model = sd_vae_taesd_models.get(model_name) @@ -78,7 +95,7 @@ def model(): download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) if os.path.exists(model_path): - loaded_model = TAESD(model_path) + loaded_model = TAESDDecoder(model_path) loaded_model.eval() loaded_model.to(devices.device, devices.dtype) sd_vae_taesd_models[model_name] = loaded_model @@ -86,3 +103,22 @@ def model(): raise FileNotFoundError('TAESD model not found') return loaded_model.decoder + + +def encoder_model(): + model_name = "taesdxl_encoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_encoder.pth" + loaded_model = sd_vae_taesd_models.get(model_name) + + if loaded_model is None: + model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name) + download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name) + + if os.path.exists(model_path): + loaded_model = TAESDEncoder(model_path) + loaded_model.eval() + loaded_model.to(devices.device, devices.dtype) + sd_vae_taesd_models[model_name] = loaded_model + else: + raise FileNotFoundError('TAESD model not found') + + return loaded_model.encoder diff --git a/modules/shared.py b/modules/shared.py index cec030f7..61ba9347 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -430,6 +430,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "auto_vae_precision": OptionInfo(True, "Automaticlly revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), + "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to encode image to latent (use in img2img or inpaint mask)"), + "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}).info("method to decode latent to image"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { -- cgit v1.2.3 From c134a480164bef017cd4b33fae57a31a86556beb Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 4 Aug 2023 13:40:20 +0800 Subject: Fix code style --- modules/sd_samplers_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 5a45e8eb..d444cac1 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,7 +2,7 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared +from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared from modules.shared import opts, state SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -40,7 +40,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None): if model is None: model = shared.sd_model x_sample = model.decode_first_stage(sample) - + return x_sample -- cgit v1.2.3 From f0c1063a707a4a43823b0ed00e2a8eeb22a9ed0a Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 4 Aug 2023 09:09:09 +0300 Subject: resolve some of circular import issues for kohaku --- modules/hypernetworks/hypernetwork.py | 5 ++--- modules/processing.py | 7 +------ modules/sd_hijack.py | 6 +++--- modules/sd_samplers_common.py | 10 ++++++++-- modules/textual_inversion/textual_inversion.py | 4 +++- 5 files changed, 17 insertions(+), 15 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/hypernetworks/hypernetwork.py b/modules/hypernetworks/hypernetwork.py index c4821d21..70f1cbd2 100644 --- a/modules/hypernetworks/hypernetwork.py +++ b/modules/hypernetworks/hypernetwork.py @@ -10,7 +10,7 @@ import torch import tqdm from einops import rearrange, repeat from ldm.util import default -from modules import devices, processing, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors +from modules import devices, sd_models, shared, sd_samplers, hashes, sd_hijack_checkpoint, errors from modules.textual_inversion import textual_inversion, logging from modules.textual_inversion.learn_schedule import LearnRateScheduler from torch import einsum @@ -469,8 +469,7 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None, def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_hypernetwork_every, template_filename, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): - # images allows training previews to have infotext. Importing it at the top causes a circular import problem. - from modules import images + from modules import images, processing save_hypernetwork_every = save_hypernetwork_every or 0 create_image_every = create_image_every or 0 diff --git a/modules/processing.py b/modules/processing.py index 8f34c8b4..8086a2b0 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -30,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType +decode_first_stage = sd_samplers_common.decode_first_stage # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -572,12 +573,6 @@ def decode_latent_batch(model, batch, target_device=None, check_for_nans=False): return samples -def decode_first_stage(model, x): - x = model.decode_first_stage(x.to(devices.dtype_vae)) - - return x - - def get_fixed_seed(seed): if seed is None or seed == '' or seed == -1: return int(random.randrange(4294967294)) diff --git a/modules/sd_hijack.py b/modules/sd_hijack.py index cfa5f0eb..609fd56c 100644 --- a/modules/sd_hijack.py +++ b/modules/sd_hijack.py @@ -2,7 +2,6 @@ import torch from torch.nn.functional import silu from types import MethodType -import modules.textual_inversion.textual_inversion from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet from modules.hypernetworks import hypernetwork from modules.shared import cmd_opts @@ -164,12 +163,13 @@ class StableDiffusionModelHijack: clip = None optimization_method = None - embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() - def __init__(self): + import modules.textual_inversion.textual_inversion + self.extra_generation_params = {} self.comments = [] + self.embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase() self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir) def apply_optimizations(self, option=None): diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 5deda761..b3d344e7 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,7 +2,7 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared +from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared from modules.shared import opts, state SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -35,7 +35,7 @@ def single_sample_to_image(sample, approximation=None): x_sample = sample * 1.5 x_sample = sd_vae_taesd.model()(x_sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach() else: - x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 + x_sample = decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0] * 0.5 + 0.5 x_sample = torch.clamp(x_sample, min=0.0, max=1.0) x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2) @@ -44,6 +44,12 @@ def single_sample_to_image(sample, approximation=None): return Image.fromarray(x_sample) +def decode_first_stage(model, x): + x = model.decode_first_stage(x.to(devices.dtype_vae)) + + return x + + def sample_to_image(samples, index=0, approximation=None): return single_sample_to_image(samples[index], approximation) diff --git a/modules/textual_inversion/textual_inversion.py b/modules/textual_inversion/textual_inversion.py index 4713bc2d..aa79dc09 100644 --- a/modules/textual_inversion/textual_inversion.py +++ b/modules/textual_inversion/textual_inversion.py @@ -13,7 +13,7 @@ import numpy as np from PIL import Image, PngImagePlugin from torch.utils.tensorboard import SummaryWriter -from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes +from modules import shared, devices, sd_hijack, sd_models, images, sd_samplers, sd_hijack_checkpoint, errors, hashes import modules.textual_inversion.dataset from modules.textual_inversion.learn_schedule import LearnRateScheduler @@ -387,6 +387,8 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, varsize, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, use_weight, create_image_every, save_embedding_every, template_filename, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height): + from modules import processing + save_embedding_every = save_embedding_every or 0 create_image_every = create_image_every or 0 template_file = textual_inversion_templates.get(template_filename, None) -- cgit v1.2.3 From 1f6bfdea80f58f292aeebb9a001689a118d71c01 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 4 Aug 2023 14:38:52 +0800 Subject: move the modified decode into smapler_common --- modules/sd_samplers_common.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 2cfa4ac6..7269514f 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -55,9 +55,9 @@ def single_sample_to_image(sample, approximation=None): def decode_first_stage(model, x): - x = model.decode_first_stage(x.to(devices.dtype_vae)) - - return x + x = x.to(devices.dtype_vae) + approx_index = approximation_indexes.get(opts.sd_vae_decode_method, 0) + return samples_to_images_tensor(x, approx_index, model) def sample_to_image(samples, index=0, approximation=None): -- cgit v1.2.3 From 094c416a801b16c7d8e1944e2e9fae2c9e98bf12 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 4 Aug 2023 17:53:16 +0800 Subject: change all encode --- modules/processing.py | 14 ++++++-------- modules/sd_samplers_common.py | 2 +- run.ps1 | 1 + run_local.ps1 | 3 +++ update.ps1 | 1 + 5 files changed, 12 insertions(+), 9 deletions(-) create mode 100644 run.ps1 create mode 100644 run_local.ps1 create mode 100644 update.ps1 (limited to 'modules/sd_samplers_common.py') diff --git a/modules/processing.py b/modules/processing.py index aae39866..544667a4 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,6 +16,7 @@ from typing import Any, Dict, List import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors from modules.sd_hijack import model_hijack +from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.paths as paths @@ -30,7 +31,6 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType -decode_first_stage = sd_samplers_common.decode_first_stage # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -84,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) + image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method)) # Add the fake full 1s mask to the first dimension. image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) @@ -203,7 +203,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -216,7 +216,7 @@ class StableDiffusionProcessing: return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = self.sd_model.encode_first_stage(source_image).mode() + conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) return conditioning_image @@ -255,7 +255,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + conditioning_image = images_tensor_to_samples(conditioning_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -1099,9 +1099,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): decoded_samples = torch.from_numpy(np.array(batch_images)) decoded_samples = decoded_samples.to(shared.device) - decoded_samples = 2. * decoded_samples - 1. - samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) + samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method)) image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) @@ -1339,7 +1338,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") image = torch.from_numpy(batch_images) - from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) devices.torch_gc() diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 7269514f..42a29fc9 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -75,7 +75,7 @@ def images_tensor_to_samples(image, approximation=None, model=None): if approximation == 3: image = image.to(devices.device, devices.dtype) - x_latent = sd_vae_taesd.encoder_model()(image) / 1.5 + x_latent = sd_vae_taesd.encoder_model()(image) else: if model is None: model = shared.sd_model diff --git a/run.ps1 b/run.ps1 new file mode 100644 index 00000000..82c1660b --- /dev/null +++ b/run.ps1 @@ -0,0 +1 @@ +.\venv\Scripts\accelerate-launch.exe --num_cpu_threads_per_process=6 --api .\launch.py --listen --port 17415 --xformers --opt-channelslast \ No newline at end of file diff --git a/run_local.ps1 b/run_local.ps1 new file mode 100644 index 00000000..e2ac43db --- /dev/null +++ b/run_local.ps1 @@ -0,0 +1,3 @@ +.\venv\Scripts\Activate.ps1 +python .\launch.py --xformers --opt-channelslast --api +. $PSCommandPath \ No newline at end of file diff --git a/update.ps1 b/update.ps1 new file mode 100644 index 00000000..9960bead --- /dev/null +++ b/update.ps1 @@ -0,0 +1 @@ +git stash push && git pull --rebase && git stash pop \ No newline at end of file -- cgit v1.2.3 From 6346d8eeaa17ba0f7e41618908519f6e9bfe07e0 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 4 Aug 2023 17:53:30 +0800 Subject: Revert "change all encode" This reverts commit 094c416a801b16c7d8e1944e2e9fae2c9e98bf12. --- modules/processing.py | 14 ++++++++------ modules/sd_samplers_common.py | 2 +- run.ps1 | 1 - run_local.ps1 | 3 --- update.ps1 | 1 - 5 files changed, 9 insertions(+), 12 deletions(-) delete mode 100644 run.ps1 delete mode 100644 run_local.ps1 delete mode 100644 update.ps1 (limited to 'modules/sd_samplers_common.py') diff --git a/modules/processing.py b/modules/processing.py index 544667a4..aae39866 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -16,7 +16,6 @@ from typing import Any, Dict, List import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors from modules.sd_hijack import model_hijack -from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.shared import opts, cmd_opts, state import modules.shared as shared import modules.paths as paths @@ -31,6 +30,7 @@ from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion from einops import repeat, rearrange from blendmodes.blend import blendLayers, BlendType +decode_first_stage = sd_samplers_common.decode_first_stage # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -84,7 +84,7 @@ def txt2img_image_conditioning(sd_model, x, width, height): # The "masked-image" in this case will just be all zeros since the entire image is masked. image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = images_tensor_to_samples(image_conditioning, approximation_indexes.get(opts.sd_vae_encode_method)) + image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) # Add the fake full 1s mask to the first dimension. image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) @@ -203,7 +203,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) - conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) conditioning = torch.nn.functional.interpolate( self.sd_model.depth_model(midas_in), size=conditioning_image.shape[2:], @@ -216,7 +216,7 @@ class StableDiffusionProcessing: return conditioning def edit_image_conditioning(self, source_image): - conditioning_image = images_tensor_to_samples(source_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) + conditioning_image = self.sd_model.encode_first_stage(source_image).mode() return conditioning_image @@ -255,7 +255,7 @@ class StableDiffusionProcessing: ) # Encode the new masked image using first stage of network. - conditioning_image = images_tensor_to_samples(conditioning_image*0.5+0.5, approximation_indexes.get(opts.sd_vae_encode_method)) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) # Create the concatenated conditioning tensor to be fed to `c_concat` conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:]) @@ -1099,8 +1099,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): decoded_samples = torch.from_numpy(np.array(batch_images)) decoded_samples = decoded_samples.to(shared.device) + decoded_samples = 2. * decoded_samples - 1. - samples = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method)) + samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples)) image_conditioning = self.img2img_image_conditioning(decoded_samples, samples) @@ -1338,6 +1339,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") image = torch.from_numpy(batch_images) + from modules.sd_samplers_common import images_tensor_to_samples, approximation_indexes self.init_latent = images_tensor_to_samples(image, approximation_indexes.get(opts.sd_vae_encode_method), self.sd_model) devices.torch_gc() diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 42a29fc9..7269514f 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -75,7 +75,7 @@ def images_tensor_to_samples(image, approximation=None, model=None): if approximation == 3: image = image.to(devices.device, devices.dtype) - x_latent = sd_vae_taesd.encoder_model()(image) + x_latent = sd_vae_taesd.encoder_model()(image) / 1.5 else: if model is None: model = shared.sd_model diff --git a/run.ps1 b/run.ps1 deleted file mode 100644 index 82c1660b..00000000 --- a/run.ps1 +++ /dev/null @@ -1 +0,0 @@ -.\venv\Scripts\accelerate-launch.exe --num_cpu_threads_per_process=6 --api .\launch.py --listen --port 17415 --xformers --opt-channelslast \ No newline at end of file diff --git a/run_local.ps1 b/run_local.ps1 deleted file mode 100644 index e2ac43db..00000000 --- a/run_local.ps1 +++ /dev/null @@ -1,3 +0,0 @@ -.\venv\Scripts\Activate.ps1 -python .\launch.py --xformers --opt-channelslast --api -. $PSCommandPath \ No newline at end of file diff --git a/update.ps1 b/update.ps1 deleted file mode 100644 index 9960bead..00000000 --- a/update.ps1 +++ /dev/null @@ -1 +0,0 @@ -git stash push && git pull --rebase && git stash pop \ No newline at end of file -- cgit v1.2.3 From 073342c8878adc208be1eaab2705ba865d7b3ea1 Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Fri, 4 Aug 2023 17:55:52 +0800 Subject: remove noneed scale --- modules/sd_samplers_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 7269514f..42a29fc9 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -75,7 +75,7 @@ def images_tensor_to_samples(image, approximation=None, model=None): if approximation == 3: image = image.to(devices.device, devices.dtype) - x_latent = sd_vae_taesd.encoder_model()(image) / 1.5 + x_latent = sd_vae_taesd.encoder_model()(image) else: if model is None: model = shared.sd_model -- cgit v1.2.3 From aa42c0ff8e51c1dde50a313aa2c12b357b287b50 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 07:41:17 +0300 Subject: repair broken live previews if using VAE with half --- modules/sd_samplers_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 42a29fc9..39586b40 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -39,7 +39,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None): else: if model is None: model = shared.sd_model - x_sample = model.decode_first_stage(sample) + x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) return x_sample -- cgit v1.2.3 From f1975b0213f5be400889ec04b3891d1cb571fe20 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 17:01:07 +0300 Subject: initial refiner support --- modules/processing.py | 4 ++++ modules/sd_models.py | 18 +++++++++++++++++- modules/sd_samplers_common.py | 19 ++++++++++++++++++- modules/sd_samplers_compvis.py | 12 +++++++++++- modules/sd_samplers_kdiffusion.py | 30 ++++++++++++++++++++++++------ modules/shared.py | 2 ++ 6 files changed, 76 insertions(+), 9 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/processing.py b/modules/processing.py index 31745006..f4748d6d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -666,6 +666,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed: stored_opts = {k: opts.data[k] for k in p.override_settings.keys()} try: + # after running refiner, the refiner model is not unloaded - webui swaps back to main model here + if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint: + sd_models.reload_model_weights() + # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None: p.override_settings.pop('sd_model_checkpoint', None) diff --git a/modules/sd_models.py b/modules/sd_models.py index f6051604..981aa93d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -289,11 +289,27 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer): return res +class SkipWritingToConfig: + """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight.""" + + skip = False + previous = None + + def __enter__(self): + self.previous = SkipWritingToConfig.skip + SkipWritingToConfig.skip = True + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + SkipWritingToConfig.skip = self.previous + + def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer): sd_model_hash = checkpoint_info.calculate_shorthash() timer.record("calculate hash") - shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title + if not SkipWritingToConfig.skip: + shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title if state_dict is None: state_dict = get_checkpoint_state_dict(checkpoint_info, timer) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 39586b40..3f3e83e3 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -2,7 +2,7 @@ from collections import namedtuple import numpy as np import torch from PIL import Image -from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared +from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models from modules.shared import opts, state SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -127,3 +127,20 @@ def replace_torchsde_browinan(): replace_torchsde_browinan() + + +def apply_refiner(sampler): + completed_ratio = sampler.step / sampler.steps + if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint: + refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) + if refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') + + with sd_models.SkipWritingToConfig(): + sd_models.reload_model_weights(info=refiner_checkpoint_info) + + devices.torch_gc() + + sampler.update_inner_model() + + sampler.p.setup_conds() diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 4a8396f9..5df926d3 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -19,7 +19,8 @@ samplers_data_compvis = [ class VanillaStableDiffusionSampler: def __init__(self, constructor, sd_model): - self.sampler = constructor(sd_model) + self.p = None + self.sampler = constructor(shared.sd_model) self.is_ddim = hasattr(self.sampler, 'p_sample_ddim') self.is_plms = hasattr(self.sampler, 'p_sample_plms') self.is_unipc = isinstance(self.sampler, modules.models.diffusion.uni_pc.UniPCSampler) @@ -32,6 +33,7 @@ class VanillaStableDiffusionSampler: self.nmask = None self.init_latent = None self.sampler_noises = None + self.steps = None self.step = 0 self.stop_at = None self.eta = None @@ -44,6 +46,7 @@ class VanillaStableDiffusionSampler: return 0 def launch_sampling(self, steps, func): + self.steps = steps state.sampling_steps = steps state.sampling_step = 0 @@ -61,10 +64,15 @@ class VanillaStableDiffusionSampler: return res + def update_inner_model(self): + self.sampler.model = shared.sd_model + def before_sample(self, x, ts, cond, unconditional_conditioning): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + sd_samplers_common.apply_refiner(self) + if self.stop_at is not None and self.step > self.stop_at: raise sd_samplers_common.InterruptedException @@ -134,6 +142,8 @@ class VanillaStableDiffusionSampler: self.update_step(x) def initialize(self, p): + self.p = p + if self.is_ddim: self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim else: diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index db71a549..be1bd35e 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -2,7 +2,7 @@ from collections import deque import torch import inspect import k_diffusion.sampling -from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra +from modules import prompt_parser, devices, sd_samplers_common, sd_samplers_extra, sd_models from modules.processing import StableDiffusionProcessing from modules.shared import opts, state @@ -87,15 +87,25 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, model): + def __init__(self): super().__init__() - self.inner_model = model + self.model_wrap = None self.mask = None self.nmask = None self.init_latent = None + self.steps = None self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False + self.p = None + + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + + return self.model_wrap def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -113,10 +123,15 @@ class CFGDenoiser(torch.nn.Module): return denoised + def update_inner_model(self): + self.model_wrap = None + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + sd_samplers_common.apply_refiner(self) + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 @@ -267,13 +282,13 @@ class TorchHijack: class KDiffusionSampler: def __init__(self, funcname, sd_model): - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) + self.p = None self.funcname = funcname self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser(self.model_wrap) + self.model_wrap_cfg = CFGDenoiser() + self.model_wrap = self.model_wrap_cfg.inner_model self.sampler_noises = None self.stop_at = None self.eta = None @@ -305,6 +320,7 @@ class KDiffusionSampler: shared.total_tqdm.update() def launch_sampling(self, steps, func): + self.model_wrap_cfg.steps = steps state.sampling_steps = steps state.sampling_step = 0 @@ -324,6 +340,8 @@ class KDiffusionSampler: return p.steps def initialize(self, p: StableDiffusionProcessing): + self.p = p + self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 diff --git a/modules/shared.py b/modules/shared.py index 078e8135..ed8395dc 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -461,6 +461,8 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), + "sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"), + "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { -- cgit v1.2.3 From 5a0db84b6c7322082c7532df11a29a95a59a612b Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 6 Aug 2023 17:53:33 +0300 Subject: add infotext add proper support for recalculating conds in k-diffusion samplers remove support for compvis samplers --- modules/generation_parameters_copypaste.py | 2 ++ modules/processing.py | 10 ++++++++++ modules/sd_samplers_common.py | 29 ++++++++++++++++++++--------- modules/sd_samplers_compvis.py | 2 -- modules/sd_samplers_kdiffusion.py | 24 ++++++++++++++++-------- 5 files changed, 48 insertions(+), 19 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index e71c9601..6711ca16 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -344,6 +344,8 @@ infotext_to_setting_name_mapping = [ ('Pad conds', 'pad_cond_uncond'), ('VAE Encoder', 'sd_vae_encode_method'), ('VAE Decoder', 'sd_vae_decode_method'), + ('Refiner', 'sd_refiner_checkpoint'), + ('Refiner switch at', 'sd_refiner_switch_at'), ] diff --git a/modules/processing.py b/modules/processing.py index f4748d6d..ec66fd8e 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -370,6 +370,9 @@ class StableDiffusionProcessing: self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) + def get_conds(self): + return self.c, self.uc + def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) @@ -1251,6 +1254,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast(): extra_networks.activate(self, self.extra_network_data) + def get_conds(self): + if self.is_hr_pass: + return self.hr_c, self.hr_uc + + return super().get_conds() + + def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts() diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 3f3e83e3..92bf0ca1 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -131,16 +131,27 @@ replace_torchsde_browinan() def apply_refiner(sampler): completed_ratio = sampler.step / sampler.steps - if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint: - refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) - if refiner_checkpoint_info is None: - raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') - with sd_models.SkipWritingToConfig(): - sd_models.reload_model_weights(info=refiner_checkpoint_info) + if completed_ratio <= shared.opts.sd_refiner_switch_at: + return False + + if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint: + return False + + refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) + if refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') + + sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title + sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at + + with sd_models.SkipWritingToConfig(): + sd_models.reload_model_weights(info=refiner_checkpoint_info) + + devices.torch_gc() + sampler.p.setup_conds() + sampler.update_inner_model() - devices.torch_gc() + return True - sampler.update_inner_model() - sampler.p.setup_conds() diff --git a/modules/sd_samplers_compvis.py b/modules/sd_samplers_compvis.py index 5df926d3..2eeec18a 100644 --- a/modules/sd_samplers_compvis.py +++ b/modules/sd_samplers_compvis.py @@ -71,8 +71,6 @@ class VanillaStableDiffusionSampler: if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException - sd_samplers_common.apply_refiner(self) - if self.stop_at is not None and self.step > self.stop_at: raise sd_samplers_common.InterruptedException diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3aee4e3a..46da0a97 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -87,8 +87,9 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self): + def __init__(self, sampler): super().__init__() + self.sampler = sampler self.model_wrap = None self.mask = None self.nmask = None @@ -126,11 +127,17 @@ class CFGDenoiser(torch.nn.Module): def update_inner_model(self): self.model_wrap = None + c, uc = self.p.get_conds() + self.sampler.sampler_extra_args['cond'] = c + self.sampler.sampler_extra_args['uncond'] = uc + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException - sd_samplers_common.apply_refiner(self) + if sd_samplers_common.apply_refiner(self): + cond = self.sampler.sampler_extra_args['cond'] + uncond = self.sampler.sampler_extra_args['uncond'] # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. @@ -282,12 +289,12 @@ class TorchHijack: class KDiffusionSampler: def __init__(self, funcname, sd_model): - self.p = None self.funcname = funcname self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = CFGDenoiser() + self.sampler_extra_args = {} + self.model_wrap_cfg = CFGDenoiser(self) self.model_wrap = self.model_wrap_cfg.inner_model self.sampler_noises = None self.stop_at = None @@ -476,7 +483,7 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x self.last_latent = x - extra_args = { + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, @@ -484,7 +491,7 @@ class KDiffusionSampler: 's_min_uncond': self.s_min_uncond } - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True @@ -514,13 +521,14 @@ class KDiffusionSampler: extra_params_kwargs['noise_sampler'] = noise_sampler self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True -- cgit v1.2.3 From 8285a149d8c488ae6c7a566eb85fb5e825145464 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 19:20:11 +0300 Subject: add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit when you can run both old and new implementations to compare them) --- modules/sd_samplers.py | 3 +- modules/sd_samplers_cfg_denoiser.py | 50 ++++------- modules/sd_samplers_common.py | 140 ++++++++++++++++++++++++++++++- modules/sd_samplers_kdiffusion.py | 152 ++++------------------------------ modules/sd_samplers_timesteps.py | 147 ++++++++++++++++++++++++++++++++ modules/sd_samplers_timesteps_impl.py | 135 ++++++++++++++++++++++++++++++ 6 files changed, 455 insertions(+), 172 deletions(-) create mode 100644 modules/sd_samplers_timesteps.py create mode 100644 modules/sd_samplers_timesteps_impl.py (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index bea2684c..fe206894 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -1,4 +1,4 @@ -from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared +from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_timesteps, shared # imports for functions that previously were here and are used by other modules from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # noqa: F401 @@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image # all_samplers = [ *sd_samplers_kdiffusion.samplers_data_k_diffusion, *sd_samplers_compvis.samplers_data_compvis, + *sd_samplers_timesteps.samplers_data_timesteps, ] all_samplers_map = {x.name: x for x in all_samplers} diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index 33a49783..166a00c7 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -39,7 +39,7 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, model): + def __init__(self, model, sampler): super().__init__() self.inner_model = model self.mask = None @@ -48,6 +48,7 @@ class CFGDenoiser(torch.nn.Module): self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False + self.sampler = sampler def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -65,6 +66,9 @@ class CFGDenoiser(torch.nn.Module): return denoised + def get_pred_x0(self, x_in, x_out, sigma): + return x_out + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException @@ -78,6 +82,9 @@ class CFGDenoiser(torch.nn.Module): assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)" + if self.mask is not None: + x = self.init_latent * self.mask + self.nmask * x + batch_size = len(conds_list) repeats = [len(conds_list[i]) for i in range(batch_size)] @@ -170,11 +177,6 @@ class CFGDenoiser(torch.nn.Module): devices.test_for_nans(x_out, "unet") - if opts.live_preview_content == "Prompt": - sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes])) - elif opts.live_preview_content == "Negative prompt": - sd_samplers_common.store_latent(x_out[-uncond.shape[0]:]) - if is_edit_model: denoised = self.combine_denoised_for_edit_model(x_out, cond_scale) elif skip_uncond: @@ -182,8 +184,16 @@ class CFGDenoiser(torch.nn.Module): else: denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale) - if self.mask is not None: - denoised = self.init_latent * self.mask + self.nmask * denoised + self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma) + + if opts.live_preview_content == "Prompt": + preview = self.sampler.last_latent + elif opts.live_preview_content == "Negative prompt": + preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma) + else: + preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma) + + sd_samplers_common.store_latent(preview) after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps) cfg_after_cfg_callback(after_cfg_callback_params) @@ -192,27 +202,3 @@ class CFGDenoiser(torch.nn.Module): self.step += 1 return denoised - -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - return devices.randn_like(x) - diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 39586b40..adda963b 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,9 +1,11 @@ -from collections import namedtuple +import inspect +from collections import namedtuple, deque import numpy as np import torch from PIL import Image from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared from modules.shared import opts, state +import k_diffusion.sampling SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) @@ -127,3 +129,139 @@ def replace_torchsde_browinan(): replace_torchsde_browinan() + + +class TorchHijack: + def __init__(self, sampler_noises): + # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based + # implementation. + self.sampler_noises = deque(sampler_noises) + + def __getattr__(self, item): + if item == 'randn_like': + return self.randn_like + + if hasattr(torch, item): + return getattr(torch, item) + + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") + + def randn_like(self, x): + if self.sampler_noises: + noise = self.sampler_noises.popleft() + if noise.shape == x.shape: + return noise + + return devices.randn_like(x) + + +class Sampler: + def __init__(self, funcname): + self.funcname = funcname + self.func = funcname + self.extra_params = [] + self.sampler_noises = None + self.stop_at = None + self.eta = None + self.config = None # set by the function calling the constructor + self.last_latent = None + self.s_min_uncond = None + self.s_churn = 0.0 + self.s_tmin = 0.0 + self.s_tmax = float('inf') + self.s_noise = 1.0 + + self.eta_option_field = 'eta_ancestral' + self.eta_infotext_field = 'Eta' + + self.conditioning_key = shared.sd_model.model.conditioning_key + + self.model_wrap = None + self.model_wrap_cfg = None + + def callback_state(self, d): + step = d['i'] + + if self.stop_at is not None and step > self.stop_at: + raise InterruptedException + + state.sampling_step = step + shared.total_tqdm.update() + + def launch_sampling(self, steps, func): + state.sampling_steps = steps + state.sampling_step = 0 + + try: + return func() + except RecursionError: + print( + 'Encountered RecursionError during sampling, returning last latent. ' + 'rho >5 with a polyexponential scheduler may cause this error. ' + 'You should try to use a smaller rho value instead.' + ) + return self.last_latent + except InterruptedException: + return self.last_latent + + def number_of_needed_noises(self, p): + return p.steps + + def initialize(self, p) -> dict: + self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None + self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None + self.model_wrap_cfg.step = 0 + self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) + self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) + self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) + + k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) + + extra_params_kwargs = {} + for param_name in self.extra_params: + if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: + extra_params_kwargs[param_name] = getattr(p, param_name) + + if 'eta' in inspect.signature(self.func).parameters: + if self.eta != 1.0: + p.extra_generation_params[self.eta_infotext_field] = self.eta + + extra_params_kwargs['eta'] = self.eta + + if len(self.extra_params) > 0: + s_churn = getattr(opts, 's_churn', p.s_churn) + s_tmin = getattr(opts, 's_tmin', p.s_tmin) + s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf + s_noise = getattr(opts, 's_noise', p.s_noise) + + if s_churn != self.s_churn: + extra_params_kwargs['s_churn'] = s_churn + p.s_churn = s_churn + p.extra_generation_params['Sigma churn'] = s_churn + if s_tmin != self.s_tmin: + extra_params_kwargs['s_tmin'] = s_tmin + p.s_tmin = s_tmin + p.extra_generation_params['Sigma tmin'] = s_tmin + if s_tmax != self.s_tmax: + extra_params_kwargs['s_tmax'] = s_tmax + p.s_tmax = s_tmax + p.extra_generation_params['Sigma tmax'] = s_tmax + if s_noise != self.s_noise: + extra_params_kwargs['s_noise'] = s_noise + p.s_noise = s_noise + p.extra_generation_params['Sigma noise'] = s_noise + + return extra_params_kwargs + + def create_noise_sampler(self, x, sigmas, p): + """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" + if shared.opts.no_dpmpp_sde_batch_determinism: + return None + + from k_diffusion.sampling import BrownianTreeNoiseSampler + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] + return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) + + + diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 9c9b46d1..3a2e01b7 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -4,8 +4,7 @@ import inspect import k_diffusion.sampling from modules import devices, sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser -from modules.processing import StableDiffusionProcessing -from modules.shared import opts, state +from modules.shared import opts import modules.shared as shared samplers_k_diffusion = [ @@ -54,133 +53,17 @@ k_diffusion_scheduler = { } -class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) - - def __getattr__(self, item): - if item == 'randn_like': - return self.randn_like - - if hasattr(torch, item): - return getattr(torch, item) - - raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") - - def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise +class KDiffusionSampler(sd_samplers_common.Sampler): + def __init__(self, funcname, sd_model): - return devices.randn_like(x) + super().__init__(funcname) + self.extra_params = sampler_extra_params.get(funcname, []) + self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) -class KDiffusionSampler: - def __init__(self, funcname, sd_model): denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.funcname = funcname - self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - self.extra_params = sampler_extra_params.get(funcname, []) - self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap) - self.sampler_noises = None - self.stop_at = None - self.eta = None - self.config = None # set by the function calling the constructor - self.last_latent = None - self.s_min_uncond = None - - # NOTE: These are also defined in the StableDiffusionProcessing class. - # They should have been here to begin with but we're going to - # leave that class __init__ signature alone. - self.s_churn = 0.0 - self.s_tmin = 0.0 - self.s_tmax = float('inf') - self.s_noise = 1.0 - - self.conditioning_key = sd_model.model.conditioning_key - - def callback_state(self, d): - step = d['i'] - latent = d["denoised"] - if opts.live_preview_content == "Combined": - sd_samplers_common.store_latent(latent) - self.last_latent = latent - - if self.stop_at is not None and step > self.stop_at: - raise sd_samplers_common.InterruptedException - - state.sampling_step = step - shared.total_tqdm.update() - - def launch_sampling(self, steps, func): - state.sampling_steps = steps - state.sampling_step = 0 - - try: - return func() - except RecursionError: - print( - 'Encountered RecursionError during sampling, returning last latent. ' - 'rho >5 with a polyexponential scheduler may cause this error. ' - 'You should try to use a smaller rho value instead.' - ) - return self.last_latent - except sd_samplers_common.InterruptedException: - return self.last_latent - - def number_of_needed_noises(self, p): - return p.steps - - def initialize(self, p: StableDiffusionProcessing): - self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None - self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None - self.model_wrap_cfg.step = 0 - self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None) - self.eta = p.eta if p.eta is not None else opts.eta_ancestral - self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) - - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) - - extra_params_kwargs = {} - for param_name in self.extra_params: - if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters: - extra_params_kwargs[param_name] = getattr(p, param_name) - - if 'eta' in inspect.signature(self.func).parameters: - if self.eta != 1.0: - p.extra_generation_params["Eta"] = self.eta - - extra_params_kwargs['eta'] = self.eta - - if len(self.extra_params) > 0: - s_churn = getattr(opts, 's_churn', p.s_churn) - s_tmin = getattr(opts, 's_tmin', p.s_tmin) - s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf - s_noise = getattr(opts, 's_noise', p.s_noise) - - if s_churn != self.s_churn: - extra_params_kwargs['s_churn'] = s_churn - p.s_churn = s_churn - p.extra_generation_params['Sigma churn'] = s_churn - if s_tmin != self.s_tmin: - extra_params_kwargs['s_tmin'] = s_tmin - p.s_tmin = s_tmin - p.extra_generation_params['Sigma tmin'] = s_tmin - if s_tmax != self.s_tmax: - extra_params_kwargs['s_tmax'] = s_tmax - p.s_tmax = s_tmax - p.extra_generation_params['Sigma tmax'] = s_tmax - if s_noise != self.s_noise: - extra_params_kwargs['s_noise'] = s_noise - p.s_noise = s_noise - p.extra_generation_params['Sigma noise'] = s_noise - - return extra_params_kwargs + self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self) def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -232,22 +115,12 @@ class KDiffusionSampler: return sigmas - def create_noise_sampler(self, x, sigmas, p): - """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes""" - if shared.opts.no_dpmpp_sde_batch_determinism: - return None - - from k_diffusion.sampling import BrownianTreeNoiseSampler - sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() - current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] - return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) sigmas = self.get_sigmas(p, steps) - sigma_sched = sigmas[steps - t_enc - 1:] + xi = x + noise * sigma_sched[0] extra_params_kwargs = self.initialize(p) @@ -296,12 +169,14 @@ class KDiffusionSampler: extra_params_kwargs = self.initialize(p) parameters = inspect.signature(self.func).parameters + if 'n' in parameters: + extra_params_kwargs['n'] = steps + if 'sigma_min' in parameters: extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item() extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item() - if 'n' in parameters: - extra_params_kwargs['n'] = steps - else: + + if 'sigmas' in parameters: extra_params_kwargs['sigmas'] = sigmas if self.config.options.get('brownian_noise', False): @@ -322,3 +197,4 @@ class KDiffusionSampler: return samples + diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py new file mode 100644 index 00000000..8560d009 --- /dev/null +++ b/modules/sd_samplers_timesteps.py @@ -0,0 +1,147 @@ +import torch +import inspect +from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl +from modules.sd_samplers_cfg_denoiser import CFGDenoiser + +from modules.shared import opts +import modules.shared as shared + +samplers_timesteps = [ + ('k_DDIM', sd_samplers_timesteps_impl.ddim, ['k_ddim'], {}), + ('k_PLMS', sd_samplers_timesteps_impl.plms, ['k_plms'], {}), + ('k_UniPC', sd_samplers_timesteps_impl.unipc, ['k_unipc'], {}), +] + + +samplers_data_timesteps = [ + sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options) + for label, funcname, aliases, options in samplers_timesteps +] + + +class CompVisTimestepsDenoiser(torch.nn.Module): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inner_model = model + + def forward(self, input, timesteps, **kwargs): + return self.inner_model.apply_model(input, timesteps, **kwargs) + + +class CompVisTimestepsVDenoiser(torch.nn.Module): + def __init__(self, model, *args, **kwargs): + super().__init__(*args, **kwargs) + self.inner_model = model + + def predict_eps_from_z_and_v(self, x_t, t, v): + return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t + + def forward(self, input, timesteps, **kwargs): + model_output = self.inner_model.apply_model(input, timesteps, **kwargs) + e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output) + return e_t + + +class CFGDenoiserTimesteps(CFGDenoiser): + + def __init__(self, model, sampler): + super().__init__(model, sampler) + + self.alphas = model.inner_model.alphas_cumprod + + def get_pred_x0(self, x_in, x_out, sigma): + ts = int(sigma.item()) + + s_in = x_in.new_ones([x_in.shape[0]]) + a_t = self.alphas[ts].item() * s_in + sqrt_one_minus_at = (1 - a_t).sqrt() + + pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt() + + return pred_x0 + + +class CompVisSampler(sd_samplers_common.Sampler): + def __init__(self, funcname, sd_model): + super().__init__(funcname) + + self.eta_option_field = 'eta_ddim' + self.eta_infotext_field = 'Eta DDIM' + + denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser + self.model_wrap = denoiser(sd_model) + self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self) + + def get_timesteps(self, p, steps): + discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) + if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma: + discard_next_to_last_sigma = True + p.extra_generation_params["Discard penultimate sigma"] = True + + steps += 1 if discard_next_to_last_sigma else 0 + + timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999) + + return timesteps + + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps) + + timesteps = self.get_timesteps(p, steps) + timesteps_sched = timesteps[:t_enc] + + alphas_cumprod = shared.sd_model.alphas_cumprod + sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]]) + sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]]) + + xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'timesteps' in parameters: + extra_params_kwargs['timesteps'] = timesteps_sched + if 'is_img2img' in parameters: + extra_params_kwargs['is_img2img'] = True + + self.model_wrap_cfg.init_latent = x + self.last_latent = x + extra_args = { + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + } + + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + steps = steps or p.steps + timesteps = self.get_timesteps(p, steps) + + extra_params_kwargs = self.initialize(p) + parameters = inspect.signature(self.func).parameters + + if 'timesteps' in parameters: + extra_params_kwargs['timesteps'] = timesteps + + self.last_latent = x + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale, + 's_min_uncond': self.s_min_uncond + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + + if self.model_wrap_cfg.padded_cond_uncond: + p.extra_generation_params["Pad conds"] = True + + return samples + diff --git a/modules/sd_samplers_timesteps_impl.py b/modules/sd_samplers_timesteps_impl.py new file mode 100644 index 00000000..48d7e649 --- /dev/null +++ b/modules/sd_samplers_timesteps_impl.py @@ -0,0 +1,135 @@ +import torch +import tqdm +import k_diffusion.sampling +import numpy as np + +from modules import shared +from modules.models.diffusion.uni_pc import uni_pc + + +@torch.no_grad() +def ddim(model, x, timesteps, extra_args=None, callback=None, disable=None, eta=0.0): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas = alphas_cumprod[timesteps] + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy())) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + for i in tqdm.trange(len(timesteps) - 1, disable=disable): + index = len(timesteps) - 1 - i + + e_t = model(x, timesteps[index].item() * s_in, **extra_args) + + a_t = alphas[index].item() * s_in + a_prev = alphas_prev[index].item() * s_in + sigma_t = sigmas[index].item() * s_in + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in + + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t + noise = sigma_t * k_diffusion.sampling.torch.randn_like(x) + x = a_prev.sqrt() * pred_x0 + dir_xt + noise + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0}) + + return x + + +@torch.no_grad() +def plms(model, x, timesteps, extra_args=None, callback=None, disable=None): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + alphas = alphas_cumprod[timesteps] + alphas_prev = alphas_cumprod[torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))].to(torch.float64) + sqrt_one_minus_alphas = torch.sqrt(1 - alphas) + + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + old_eps = [] + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = alphas[index].item() * s_in + a_prev = alphas_prev[index].item() * s_in + sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_in + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + + # direction pointing to x_t + dir_xt = (1. - a_prev).sqrt() * e_t + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + return x_prev, pred_x0 + + for i in tqdm.trange(len(timesteps) - 1, disable=disable): + index = len(timesteps) - 1 - i + ts = timesteps[index].item() * s_in + t_next = timesteps[max(index - 1, 0)].item() * s_in + + e_t = model(x, ts, **extra_args) + + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = model(x_prev, t_next, **extra_args) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + else: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + old_eps.append(e_t) + if len(old_eps) >= 4: + old_eps.pop(0) + + x = x_prev + + if callback is not None: + callback({'x': x, 'i': i, 'sigma': 0, 'sigma_hat': 0, 'denoised': pred_x0}) + + return x + + +class UniPCCFG(uni_pc.UniPC): + def __init__(self, cfg_model, extra_args, callback, *args, **kwargs): + super().__init__(None, *args, **kwargs) + + def after_update(x, model_x): + callback({'x': x, 'i': self.index, 'sigma': 0, 'sigma_hat': 0, 'denoised': model_x}) + self.index += 1 + + self.cfg_model = cfg_model + self.extra_args = extra_args + self.callback = callback + self.index = 0 + self.after_update = after_update + + def get_model_input_time(self, t_continuous): + return (t_continuous - 1. / self.noise_schedule.total_N) * 1000. + + def model(self, x, t): + t_input = self.get_model_input_time(t) + + res = self.cfg_model(x, t_input, **self.extra_args) + + return res + + +def unipc(model, x, timesteps, extra_args=None, callback=None, disable=None, is_img2img=False): + alphas_cumprod = model.inner_model.inner_model.alphas_cumprod + + ns = uni_pc.NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod) + t_start = timesteps[-1] / 1000 + 1 / 1000 if is_img2img else None # this is likely off by a bit - if someone wants to fix it please by all means + unipc_sampler = UniPCCFG(model, extra_args, callback, ns, predict_x0=True, thresholding=False, variant=shared.opts.uni_pc_variant) + x = unipc_sampler.sample(x, steps=len(timesteps), t_start=t_start, skip_type=shared.opts.uni_pc_skip_type, method="multistep", order=shared.opts.uni_pc_order, lower_order_final=shared.opts.uni_pc_lower_order_final) + + return x -- cgit v1.2.3 From f8ff8c0638997fd0aef217db1505598846f14782 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 22:09:40 +0300 Subject: merge errors --- modules/sd_samplers_cfg_denoiser.py | 23 +++++++++++++++++++++-- modules/sd_samplers_common.py | 6 +++++- modules/sd_samplers_kdiffusion.py | 17 ++++++++++++----- modules/sd_samplers_timesteps.py | 27 +++++++++++++++++---------- 4 files changed, 55 insertions(+), 18 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index d826222c..a532e013 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -38,16 +38,24 @@ class CFGDenoiser(torch.nn.Module): negative prompt. """ - def __init__(self, model, sampler): + def __init__(self, sampler): super().__init__() - self.inner_model = model + self.model_wrap = None self.mask = None self.nmask = None self.init_latent = None + self.steps = None self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False self.sampler = sampler + self.model_wrap = None + self.p = None + + @property + def inner_model(self): + raise NotImplementedError() + def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] @@ -68,10 +76,21 @@ class CFGDenoiser(torch.nn.Module): def get_pred_x0(self, x_in, x_out, sigma): return x_out + def update_inner_model(self): + self.model_wrap = None + + c, uc = self.p.get_conds() + self.sampler.sampler_extra_args['cond'] = c + self.sampler.sampler_extra_args['uncond'] = uc + def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond): if state.interrupted or state.skipped: raise sd_samplers_common.InterruptedException + if sd_samplers_common.apply_refiner(self): + cond = self.sampler.sampler_extra_args['cond'] + uncond = self.sampler.sampler_extra_args['uncond'] + # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling, # so is_edit_model is set to False to support AND composition. is_edit_model = shared.sd_model.cond_stage_key == "edit" and self.image_cfg_scale is not None and self.image_cfg_scale != 1.0 diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 15f27970..fa3614ff 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -202,8 +202,9 @@ class Sampler: self.conditioning_key = shared.sd_model.model.conditioning_key - self.model_wrap = None + self.p = None self.model_wrap_cfg = None + self.sampler_extra_args = None def callback_state(self, d): step = d['i'] @@ -215,6 +216,7 @@ class Sampler: shared.total_tqdm.update() def launch_sampling(self, steps, func): + self.model_wrap_cfg.steps = steps state.sampling_steps = steps state.sampling_step = 0 @@ -234,6 +236,8 @@ class Sampler: return p.steps def initialize(self, p) -> dict: + self.p = p + self.model_wrap_cfg.p = p self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None self.model_wrap_cfg.step = 0 diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index 3ff4b634..95a43cef 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -52,17 +52,24 @@ k_diffusion_scheduler = { } +class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = k_diffusion.external.CompVisVDenoiser if shared.sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser + self.model_wrap = denoiser(shared.sd_model, quantize=shared.opts.enable_quantization) + + return self.model_wrap + + class KDiffusionSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): - super().__init__(funcname) - self.extra_params = sampler_extra_params.get(funcname, []) self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) - denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser - self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization) - self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self) + self.model_wrap_cfg = CFGDenoiserKDiffusion(self) + self.model_wrap = self.model_wrap_cfg.inner_model def get_sigmas(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index d89d0efb..965e61c6 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -44,10 +44,10 @@ class CompVisTimestepsVDenoiser(torch.nn.Module): class CFGDenoiserTimesteps(CFGDenoiser): - def __init__(self, model, sampler): - super().__init__(model, sampler) + def __init__(self, sampler): + super().__init__(sampler) - self.alphas = model.inner_model.alphas_cumprod + self.alphas = shared.sd_model.alphas_cumprod def get_pred_x0(self, x_in, x_out, sigma): ts = int(sigma.item()) @@ -60,6 +60,14 @@ class CFGDenoiserTimesteps(CFGDenoiser): return pred_x0 + @property + def inner_model(self): + if self.model_wrap is None: + denoiser = CompVisTimestepsVDenoiser if shared.sd_model.parameterization == "v" else CompVisTimestepsDenoiser + self.model_wrap = denoiser(shared.sd_model) + + return self.model_wrap + class CompVisSampler(sd_samplers_common.Sampler): def __init__(self, funcname, sd_model): @@ -68,9 +76,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.eta_option_field = 'eta_ddim' self.eta_infotext_field = 'Eta DDIM' - denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser - self.model_wrap = denoiser(sd_model) - self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self) + self.model_wrap_cfg = CFGDenoiserTimesteps(self) def get_timesteps(self, p, steps): discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False) @@ -106,7 +112,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.model_wrap_cfg.init_latent = x self.last_latent = x - extra_args = { + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, @@ -114,7 +120,7 @@ class CompVisSampler(sd_samplers_common.Sampler): 's_min_uncond': self.s_min_uncond } - samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True @@ -132,13 +138,14 @@ class CompVisSampler(sd_samplers_common.Sampler): extra_params_kwargs['timesteps'] = timesteps self.last_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + self.sampler_extra_args = { 'cond': conditioning, 'image_cond': image_conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale, 's_min_uncond': self.s_min_uncond - }, disable=False, callback=self.callback_state, **extra_params_kwargs)) + } + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs)) if self.model_wrap_cfg.padded_cond_uncond: p.extra_generation_params["Pad conds"] = True -- cgit v1.2.3 From 1aefb5025929818b2a96cbb6148fcc2db7b947ec Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 8 Aug 2023 22:17:25 +0300 Subject: add None refiner option --- modules/sd_samplers_common.py | 3 +++ modules/shared.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index fa3614ff..b6ad6830 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -137,6 +137,9 @@ def apply_refiner(sampler): if completed_ratio <= shared.opts.sd_refiner_switch_at: return False + if shared.opts.sd_refiner_checkpoint == "None": + return False + if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint: return False diff --git a/modules/shared.py b/modules/shared.py index 2fd29904..9935d2a7 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -462,7 +462,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}).link("wiki", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#clip-skip").info("ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer"), "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), - "sd_refiner_checkpoint": OptionInfo(None, "Refiner checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"), + "sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + list_checkpoint_tiles()}, refresh=refresh_checkpoints).info("switch to another model in the middle of generation"), "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}).info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"), })) -- cgit v1.2.3 From 0d5dc9a6e7f6362e423a06bf0e75dd5854025394 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 9 Aug 2023 08:43:31 +0300 Subject: rework RNG to use generators instead of generating noises beforehand --- modules/devices.py | 81 +------------------- modules/processing.py | 89 +++------------------- modules/rng.py | 171 ++++++++++++++++++++++++++++++++++++++++++ modules/sd_samplers_common.py | 24 +++--- modules/shared.py | 2 +- 5 files changed, 196 insertions(+), 171 deletions(-) create mode 100644 modules/rng.py (limited to 'modules/sd_samplers_common.py') diff --git a/modules/devices.py b/modules/devices.py index 00a00b18..ce59dc53 100644 --- a/modules/devices.py +++ b/modules/devices.py @@ -3,7 +3,7 @@ import contextlib from functools import lru_cache import torch -from modules import errors, rng_philox +from modules import errors if sys.platform == "darwin": from modules import mac_specific @@ -96,84 +96,6 @@ def cond_cast_float(input): nv_rng = None -def randn(seed, shape): - """Generate a tensor with random numbers from a normal distribution using seed. - - Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed.""" - - from modules.shared import opts - - manual_seed(seed) - - if opts.randn_source == "NV": - return torch.asarray(nv_rng.randn(shape), device=device) - - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) - - return torch.randn(shape, device=device) - - -def randn_local(seed, shape): - """Generate a tensor with random numbers from a normal distribution using seed. - - Does not change the global random number generator. You can only generate the seed's first tensor using this function.""" - - from modules.shared import opts - - if opts.randn_source == "NV": - rng = rng_philox.Generator(seed) - return torch.asarray(rng.randn(shape), device=device) - - local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device - local_generator = torch.Generator(local_device).manual_seed(int(seed)) - return torch.randn(shape, device=local_device, generator=local_generator).to(device) - - -def randn_like(x): - """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. - - Use either randn() or manual_seed() to initialize the generator.""" - - from modules.shared import opts - - if opts.randn_source == "NV": - return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype) - - if opts.randn_source == "CPU" or x.device.type == 'mps': - return torch.randn_like(x, device=cpu).to(x.device) - - return torch.randn_like(x) - - -def randn_without_seed(shape): - """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. - - Use either randn() or manual_seed() to initialize the generator.""" - - from modules.shared import opts - - if opts.randn_source == "NV": - return torch.asarray(nv_rng.randn(shape), device=device) - - if opts.randn_source == "CPU" or device.type == 'mps': - return torch.randn(shape, device=cpu).to(device) - - return torch.randn(shape, device=device) - - -def manual_seed(seed): - """Set up a global random number generator using the specified seed.""" - from modules.shared import opts - - if opts.randn_source == "NV": - global nv_rng - nv_rng = rng_philox.Generator(seed) - return - - torch.manual_seed(seed) - - def autocast(disable=False): from modules import shared @@ -236,3 +158,4 @@ def first_time_calculation(): x = torch.zeros((1, 1, 3, 3)).to(device, dtype) conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype) conv2d(x) + diff --git a/modules/processing.py b/modules/processing.py index aa72b132..2df5e8c7 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -14,7 +14,7 @@ from skimage import exposure from typing import Any, Dict, List import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng from modules.sd_hijack import model_hijack from modules.sd_samplers_common import images_tensor_to_samples, decode_first_stage, approximation_indexes from modules.shared import opts, cmd_opts, state @@ -186,6 +186,7 @@ class StableDiffusionProcessing: self.cached_c = StableDiffusionProcessing.cached_c self.uc = None self.c = None + self.rng: rng.ImageRNG = None self.user = None @@ -475,82 +476,9 @@ class Processed: return self.token_merging_ratio_hr if for_hr else self.token_merging_ratio -# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 -def slerp(val, low, high): - low_norm = low/torch.norm(low, dim=1, keepdim=True) - high_norm = high/torch.norm(high, dim=1, keepdim=True) - dot = (low_norm*high_norm).sum(1) - - if dot.mean() > 0.9995: - return low * val + high * (1 - val) - - omega = torch.acos(dot) - so = torch.sin(omega) - res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high - return res - - def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0, p=None): - eta_noise_seed_delta = opts.eta_noise_seed_delta or 0 - xs = [] - - # if we have multiple seeds, this means we are working with batch size>1; this then - # enables the generation of additional tensors with noise that the sampler will use during its processing. - # Using those pre-generated tensors instead of simple torch.randn allows a batch with seeds [100, 101] to - # produce the same images as with two batches [100], [101]. - if p is not None and p.sampler is not None and (len(seeds) > 1 and opts.enable_batch_seeds or eta_noise_seed_delta > 0): - sampler_noises = [[] for _ in range(p.sampler.number_of_needed_noises(p))] - else: - sampler_noises = None - - for i, seed in enumerate(seeds): - noise_shape = shape if seed_resize_from_h <= 0 or seed_resize_from_w <= 0 else (shape[0], seed_resize_from_h//8, seed_resize_from_w//8) - - subnoise = None - if subseeds is not None and subseed_strength != 0: - subseed = 0 if i >= len(subseeds) else subseeds[i] - - subnoise = devices.randn(subseed, noise_shape) - - # randn results depend on device; gpu and cpu get different results for same seed; - # the way I see it, it's better to do this on CPU, so that everyone gets same result; - # but the original script had it like this, so I do not dare change it for now because - # it will break everyone's seeds. - noise = devices.randn(seed, noise_shape) - - if subnoise is not None: - noise = slerp(subseed_strength, noise, subnoise) - - if noise_shape != shape: - x = devices.randn(seed, shape) - dx = (shape[2] - noise_shape[2]) // 2 - dy = (shape[1] - noise_shape[1]) // 2 - w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx - h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy - tx = 0 if dx < 0 else dx - ty = 0 if dy < 0 else dy - dx = max(-dx, 0) - dy = max(-dy, 0) - - x[:, ty:ty+h, tx:tx+w] = noise[:, dy:dy+h, dx:dx+w] - noise = x - - if sampler_noises is not None: - cnt = p.sampler.number_of_needed_noises(p) - - if eta_noise_seed_delta > 0: - devices.manual_seed(seed + eta_noise_seed_delta) - - for j in range(cnt): - sampler_noises[j].append(devices.randn_without_seed(tuple(noise_shape))) - - xs.append(noise) - - if sampler_noises is not None: - p.sampler.sampler_noises = [torch.stack(n).to(shared.device) for n in sampler_noises] - - x = torch.stack(xs).to(shared.device) - return x + g = rng.ImageRNG(shape, seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=seed_resize_from_h, seed_resize_from_w=seed_resize_from_w) + return g.next() class DecodedSamples(list): @@ -769,6 +697,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) + if p.scripts is not None: p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) @@ -1072,7 +1002,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = self.rng.next() samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x)) del x @@ -1160,7 +1090,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2] - noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=self) + self.rng = rng.ImageRNG(samples.shape[1:], self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w) + noise = self.rng.next() # GC now before running the next img2img to prevent running out of memory devices.torch_gc() @@ -1418,7 +1349,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): self.image_conditioning = self.img2img_image_conditioning(image, self.init_latent, image_mask) def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): - x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) + x = self.rng.next() if self.initial_noise_multiplier != 1.0: self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier diff --git a/modules/rng.py b/modules/rng.py new file mode 100644 index 00000000..2d7baea5 --- /dev/null +++ b/modules/rng.py @@ -0,0 +1,171 @@ +import torch + +from modules import devices, rng_philox, shared + + +def randn(seed, shape, generator=None): + """Generate a tensor with random numbers from a normal distribution using seed. + + Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed.""" + + manual_seed(seed) + + if shared.opts.randn_source == "NV": + return torch.asarray((generator or nv_rng).randn(shape), device=devices.device) + + if shared.opts.randn_source == "CPU" or devices.device.type == 'mps': + return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device) + + return torch.randn(shape, device=devices.device, generator=generator) + + +def randn_local(seed, shape): + """Generate a tensor with random numbers from a normal distribution using seed. + + Does not change the global random number generator. You can only generate the seed's first tensor using this function.""" + + if shared.opts.randn_source == "NV": + rng = rng_philox.Generator(seed) + return torch.asarray(rng.randn(shape), device=devices.device) + + local_device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device + local_generator = torch.Generator(local_device).manual_seed(int(seed)) + return torch.randn(shape, device=local_device, generator=local_generator).to(devices.device) + + +def randn_like(x): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + + if shared.opts.randn_source == "NV": + return torch.asarray(nv_rng.randn(x.shape), device=x.device, dtype=x.dtype) + + if shared.opts.randn_source == "CPU" or x.device.type == 'mps': + return torch.randn_like(x, device=devices.cpu).to(x.device) + + return torch.randn_like(x) + + +def randn_without_seed(shape, generator=None): + """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator. + + Use either randn() or manual_seed() to initialize the generator.""" + + if shared.opts.randn_source == "NV": + return torch.asarray((generator or nv_rng).randn(shape), device=devices.device) + + if shared.opts.randn_source == "CPU" or devices.device.type == 'mps': + return torch.randn(shape, device=devices.cpu, generator=generator).to(devices.device) + + return torch.randn(shape, device=devices.device, generator=generator) + + +def manual_seed(seed): + """Set up a global random number generator using the specified seed.""" + from modules.shared import opts + + if opts.randn_source == "NV": + global nv_rng + nv_rng = rng_philox.Generator(seed) + return + + torch.manual_seed(seed) + + +def create_generator(seed): + if shared.opts.randn_source == "NV": + return rng_philox.Generator(seed) + + device = devices.cpu if shared.opts.randn_source == "CPU" or devices.device.type == 'mps' else devices.device + generator = torch.Generator(device).manual_seed(int(seed)) + return generator + + +# from https://discuss.pytorch.org/t/help-regarding-slerp-function-for-generative-model-sampling/32475/3 +def slerp(val, low, high): + low_norm = low/torch.norm(low, dim=1, keepdim=True) + high_norm = high/torch.norm(high, dim=1, keepdim=True) + dot = (low_norm*high_norm).sum(1) + + if dot.mean() > 0.9995: + return low * val + high * (1 - val) + + omega = torch.acos(dot) + so = torch.sin(omega) + res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high + return res + + +class ImageRNG: + def __init__(self, shape, seeds, subseeds=None, subseed_strength=0.0, seed_resize_from_h=0, seed_resize_from_w=0): + self.shape = shape + self.seeds = seeds + self.subseeds = subseeds + self.subseed_strength = subseed_strength + self.seed_resize_from_h = seed_resize_from_h + self.seed_resize_from_w = seed_resize_from_w + + self.generators = [create_generator(seed) for seed in seeds] + + self.is_first = True + + def first(self): + noise_shape = self.shape if self.seed_resize_from_h <= 0 or self.seed_resize_from_w <= 0 else (self.shape[0], self.seed_resize_from_h // 8, self.seed_resize_from_w // 8) + + xs = [] + + for i, (seed, generator) in enumerate(zip(self.seeds, self.generators)): + subnoise = None + if self.subseeds is not None and self.subseed_strength != 0: + subseed = 0 if i >= len(self.subseeds) else self.subseeds[i] + subnoise = randn(subseed, noise_shape) + + if noise_shape != self.shape: + noise = randn(seed, noise_shape) + else: + noise = randn(seed, self.shape, generator=generator) + + if subnoise is not None: + noise = slerp(self.subseed_strength, noise, subnoise) + + if noise_shape != self.shape: + x = randn(seed, self.shape, generator=generator) + dx = (self.shape[2] - noise_shape[2]) // 2 + dy = (self.shape[1] - noise_shape[1]) // 2 + w = noise_shape[2] if dx >= 0 else noise_shape[2] + 2 * dx + h = noise_shape[1] if dy >= 0 else noise_shape[1] + 2 * dy + tx = 0 if dx < 0 else dx + ty = 0 if dy < 0 else dy + dx = max(-dx, 0) + dy = max(-dy, 0) + + x[:, ty:ty + h, tx:tx + w] = noise[:, dy:dy + h, dx:dx + w] + noise = x + + xs.append(noise) + + eta_noise_seed_delta = shared.opts.eta_noise_seed_delta or 0 + if eta_noise_seed_delta: + self.generators = [create_generator(seed + eta_noise_seed_delta) for seed in self.seeds] + + return torch.stack(xs).to(shared.device) + + def next(self): + if self.is_first: + self.is_first = False + return self.first() + + xs = [] + for generator in self.generators: + x = randn_without_seed(self.shape, generator=generator) + xs.append(x) + + return torch.stack(xs).to(shared.device) + + +devices.randn = randn +devices.randn_local = randn_local +devices.randn_like = randn_like +devices.randn_without_seed = randn_without_seed +devices.manual_seed = manual_seed diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index adda963b..97bc0804 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -1,5 +1,5 @@ import inspect -from collections import namedtuple, deque +from collections import namedtuple import numpy as np import torch from PIL import Image @@ -132,10 +132,15 @@ replace_torchsde_browinan() class TorchHijack: - def __init__(self, sampler_noises): - # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based - # implementation. - self.sampler_noises = deque(sampler_noises) + """This is here to replace torch.randn_like of k-diffusion. + + k-diffusion has random_sampler argument for most samplers, but not for all, so + this is needed to properly replace every use of torch.randn_like. + + We need to replace to make images generated in batches to be same as images generated individually.""" + + def __init__(self, p): + self.rng = p.rng def __getattr__(self, item): if item == 'randn_like': @@ -147,12 +152,7 @@ class TorchHijack: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'") def randn_like(self, x): - if self.sampler_noises: - noise = self.sampler_noises.popleft() - if noise.shape == x.shape: - return noise - - return devices.randn_like(x) + return self.rng.next() class Sampler: @@ -215,7 +215,7 @@ class Sampler: self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0) self.s_min_uncond = getattr(p, 's_min_uncond', 0.0) - k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else []) + k_diffusion.sampling.torch = TorchHijack(p) extra_params_kwargs = {} for param_name in self.extra_params: diff --git a/modules/shared.py b/modules/shared.py index e34847ce..e9b980a4 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -16,7 +16,7 @@ import modules.interrogate import modules.memmon import modules.styles import modules.devices as devices -from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args +from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args, rng # noqa: F401 from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir # noqa: F401 from ldm.models.diffusion.ddpm import LatentDiffusion from typing import Optional -- cgit v1.2.3 From 64311faa6848d641cc452115e4e1eb47d2a7b519 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 12 Aug 2023 12:39:59 +0300 Subject: put refiner into main UI, into the new accordions section add VAE from main model into infotext, not from refiner model option to make scripts UI without gr.Group fix inconsistencies with refiner when usings samplers that do more denoising than steps --- modules/processing.py | 22 ++++++++----- modules/processing_scripts/refiner.py | 55 +++++++++++++++++++++++++++++++++ modules/scripts.py | 24 ++++++++++----- modules/sd_models.py | 3 ++ modules/sd_samplers_cfg_denoiser.py | 6 +++- modules/sd_samplers_common.py | 40 ++++++++++++++---------- modules/sd_samplers_kdiffusion.py | 3 +- modules/shared_items.py | 4 +-- modules/shared_options.py | 2 -- modules/ui.py | 58 ++++++++++++++++++++--------------- modules/ui_components.py | 18 ++++++++--- style.css | 32 +++++++++++-------- 12 files changed, 188 insertions(+), 79 deletions(-) create mode 100644 modules/processing_scripts/refiner.py (limited to 'modules/sd_samplers_common.py') diff --git a/modules/processing.py b/modules/processing.py index 131c4c3c..5996cbac 100755 --- a/modules/processing.py +++ b/modules/processing.py @@ -373,9 +373,10 @@ class StableDiffusionProcessing: negative_prompts = prompt_parser.SdConditioning(self.negative_prompts, width=self.width, height=self.height, is_negative_prompt=True) sampler_config = sd_samplers.find_sampler_config(self.sampler_name) - self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) + total_steps = sampler_config.total_steps(self.steps) if sampler_config else self.steps + self.step_multiplier = total_steps // self.steps + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, total_steps, [self.cached_uc], self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, total_steps, [self.cached_c], self.extra_network_data) def get_conds(self): return self.c, self.uc @@ -579,8 +580,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Size": f"{p.width}x{p.height}", "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra), - "VAE hash": sd_vae.get_loaded_vae_hash() if opts.add_model_hash_to_info else None, - "VAE": sd_vae.get_loaded_vae_name() if opts.add_model_name_to_info else None, + "VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None, + "VAE": p.loaded_vae_name if opts.add_model_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), @@ -669,6 +670,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.tiling is None: p.tiling = opts.tiling + p.loaded_vae_name = sd_vae.get_loaded_vae_name() + p.loaded_vae_hash = sd_vae.get_loaded_vae_hash() + modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.clear_comments() @@ -1188,8 +1192,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y) hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True) - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) + sampler_config = sd_samplers.find_sampler_config(self.hr_sampler_name or self.sampler_name) + steps = self.hr_second_pass_steps or self.steps + total_steps = sampler_config.total_steps(steps) if sampler_config else steps + + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, total_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) def setup_conds(self): super().setup_conds() diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py new file mode 100644 index 00000000..5a82991a --- /dev/null +++ b/modules/processing_scripts/refiner.py @@ -0,0 +1,55 @@ +import gradio as gr + +from modules import scripts, sd_models +from modules.ui_common import create_refresh_button +from modules.ui_components import InputAccordion + + +class ScriptRefiner(scripts.Script): + section = "accordions" + create_group = False + + def __init__(self): + pass + + def title(self): + return "Refiner" + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + with InputAccordion(False, label="Refiner", elem_id=self.elem_id("enable")) as enable_refiner: + with gr.Row(): + refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") + create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) + + refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") + + def lookup_checkpoint(title): + info = sd_models.get_closet_checkpoint_match(title) + return None if info is None else info.title + + self.infotext_fields = [ + (enable_refiner, lambda d: 'Refiner' in d), + (refiner_checkpoint, lambda d: lookup_checkpoint(d.get('Refiner'))), + (refiner_switch_at, 'Refiner switch at'), + ] + + return enable_refiner, refiner_checkpoint, refiner_switch_at + + def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at): + # the actual implementation is in sd_samplers_common.py, apply_refiner + + p.refiner_checkpoint_info = None + p.refiner_switch_at = None + + if not enable_refiner or refiner_checkpoint in (None, "", "None"): + return + + refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint) + if refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}') + + p.refiner_checkpoint_info = refiner_checkpoint_info + p.refiner_switch_at = refiner_switch_at diff --git a/modules/scripts.py b/modules/scripts.py index f7d060aa..51da732a 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -37,7 +37,10 @@ class Script: is_img2img = False group = None - """A gr.Group component that has all script's UI inside it""" + """A gr.Group component that has all script's UI inside it.""" + + create_group = True + """If False, for alwayson scripts, a group component will not be created.""" infotext_fields = None """if set in ui(), this is a list of pairs of gradio component + text; the text will be used when @@ -232,6 +235,7 @@ class Script: """ pass + current_basedir = paths.script_path @@ -250,7 +254,7 @@ postprocessing_scripts_data = [] ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"]) -def list_scripts(scriptdirname, extension): +def list_scripts(scriptdirname, extension, *, include_extensions=True): scripts_list = [] basedir = os.path.join(paths.script_path, scriptdirname) @@ -258,8 +262,9 @@ def list_scripts(scriptdirname, extension): for filename in sorted(os.listdir(basedir)): scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) - for ext in extensions.active(): - scripts_list += ext.list_files(scriptdirname, extension) + if include_extensions: + for ext in extensions.active(): + scripts_list += ext.list_files(scriptdirname, extension) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] @@ -288,7 +293,7 @@ def load_scripts(): postprocessing_scripts_data.clear() script_callbacks.clear_callbacks() - scripts_list = list_scripts("scripts", ".py") + scripts_list = list_scripts("scripts", ".py") + list_scripts("modules/processing_scripts", ".py", include_extensions=False) syspath = sys.path @@ -429,10 +434,13 @@ class ScriptRunner: if script.alwayson and script.section != section: continue - with gr.Group(visible=script.alwayson) as group: - self.create_script_ui(script) + if script.create_group: + with gr.Group(visible=script.alwayson) as group: + self.create_script_ui(script) - script.group = group + script.group = group + else: + self.create_script_ui(script) def prepare_ui(self): self.inputs = [None] diff --git a/modules/sd_models.py b/modules/sd_models.py index a178adca..f6fbdcd6 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -147,6 +147,9 @@ re_strip_checksum = re.compile(r"\s*\[[^]]+]\s*$") def get_closet_checkpoint_match(search_string): + if not search_string: + return None + checkpoint_info = checkpoint_aliases.get(search_string, None) if checkpoint_info is not None: return checkpoint_info diff --git a/modules/sd_samplers_cfg_denoiser.py b/modules/sd_samplers_cfg_denoiser.py index a532e013..113425b2 100644 --- a/modules/sd_samplers_cfg_denoiser.py +++ b/modules/sd_samplers_cfg_denoiser.py @@ -45,6 +45,11 @@ class CFGDenoiser(torch.nn.Module): self.nmask = None self.init_latent = None self.steps = None + """number of steps as specified by user in UI""" + + self.total_steps = None + """expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler""" + self.step = 0 self.image_cfg_scale = None self.padded_cond_uncond = False @@ -56,7 +61,6 @@ class CFGDenoiser(torch.nn.Module): def inner_model(self): raise NotImplementedError() - def combine_denoised(self, x_out, conds_list, uncond, cond_scale): denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 35c4d657..85f3c7e0 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -7,7 +7,16 @@ from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, s from modules.shared import opts, state import k_diffusion.sampling -SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) + +SamplerDataTuple = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options']) + + +class SamplerData(SamplerDataTuple): + def total_steps(self, steps): + if self.options.get("second_order", False): + steps = steps * 2 + + return steps def setup_img2img_steps(p, steps=None): @@ -131,31 +140,26 @@ def replace_torchsde_browinan(): replace_torchsde_browinan() -def apply_refiner(sampler): - completed_ratio = sampler.step / sampler.steps +def apply_refiner(cfg_denoiser): + completed_ratio = cfg_denoiser.step / cfg_denoiser.total_steps + refiner_switch_at = cfg_denoiser.p.refiner_switch_at + refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info - if completed_ratio <= shared.opts.sd_refiner_switch_at: + if refiner_switch_at is not None and completed_ratio <= refiner_switch_at: return False - if shared.opts.sd_refiner_checkpoint == "None": + if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info: return False - if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint: - return False - - refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint) - if refiner_checkpoint_info is None: - raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}') - - sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title - sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at + cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title + cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at with sd_models.SkipWritingToConfig(): sd_models.reload_model_weights(info=refiner_checkpoint_info) devices.torch_gc() - sampler.p.setup_conds() - sampler.update_inner_model() + cfg_denoiser.p.setup_conds() + cfg_denoiser.update_inner_model() return True @@ -192,7 +196,7 @@ class Sampler: self.sampler_noises = None self.stop_at = None self.eta = None - self.config = None # set by the function calling the constructor + self.config: SamplerData = None # set by the function calling the constructor self.last_latent = None self.s_min_uncond = None self.s_churn = 0.0 @@ -208,6 +212,7 @@ class Sampler: self.p = None self.model_wrap_cfg = None self.sampler_extra_args = None + self.options = {} def callback_state(self, d): step = d['i'] @@ -220,6 +225,7 @@ class Sampler: def launch_sampling(self, steps, func): self.model_wrap_cfg.steps = steps + self.model_wrap_cfg.total_steps = self.config.total_steps(steps) state.sampling_steps = steps state.sampling_step = 0 diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index d10fe12e..1f8e9c4b 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -64,9 +64,10 @@ class CFGDenoiserKDiffusion(sd_samplers_cfg_denoiser.CFGDenoiser): class KDiffusionSampler(sd_samplers_common.Sampler): - def __init__(self, funcname, sd_model): + def __init__(self, funcname, sd_model, options=None): super().__init__(funcname) + self.options = options or {} self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname) self.model_wrap_cfg = CFGDenoiserKDiffusion(self) diff --git a/modules/shared_items.py b/modules/shared_items.py index e4ec40a8..754166d2 100644 --- a/modules/shared_items.py +++ b/modules/shared_items.py @@ -69,8 +69,8 @@ def reload_hypernetworks(): ui_reorder_categories_builtin_items = [ "inpaint", "sampler", + "accordions", "checkboxes", - "hires_fix", "dimensions", "cfg", "seed", @@ -86,7 +86,7 @@ def ui_reorder_categories(): sections = {} for script in scripts.scripts_txt2img.scripts + scripts.scripts_img2img.scripts: - if isinstance(script.section, str): + if isinstance(script.section, str) and script.section not in ui_reorder_categories_builtin_items: sections[script.section] = 1 yield from sections diff --git a/modules/shared_options.py b/modules/shared_options.py index 1e5b64ea..9ae51f18 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -140,8 +140,6 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), { "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"), "randn_source": OptionInfo("GPU", "Random number generator source.", gr.Radio, {"choices": ["GPU", "CPU", "NV"]}).info("changes seeds drastically; use CPU to produce the same picture across different videocard vendors; use NV to produce same picture as on NVidia videocards"), "tiling": OptionInfo(False, "Tiling", infotext='Tiling').info("produce a tileable picture"), - "sd_refiner_checkpoint": OptionInfo("None", "Refiner checkpoint", gr.Dropdown, lambda: {"choices": ["None"] + shared_items.list_checkpoint_tiles()}, refresh=shared_items.refresh_checkpoints, infotext="Refiner").info("switch to another model in the middle of generation"), - "sd_refiner_switch_at": OptionInfo(1.0, "Refiner switch at", gr.Slider, {"minimum": 0.01, "maximum": 1.0, "step": 0.01}, infotext='Refiner switch at').info("fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation"), })) options_templates.update(options_section(('sdxl', "Stable Diffusion XL"), { diff --git a/modules/ui.py b/modules/ui.py index 05292734..3321b94d 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -438,35 +438,38 @@ def create_ui(): with FormRow(elem_classes="checkboxes-row", variant="compact"): pass - elif category == "hires_fix": - with InputAccordion(False, label="Hires. fix") as enable_hr: - with enable_hr.extra(): - hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0) + elif category == "accordions": + with gr.Row(elem_id="txt2img_accordions", elem_classes="accordions"): + with InputAccordion(False, label="Hires. fix", elem_id="txt2img_hr") as enable_hr: + with enable_hr.extra(): + hr_final_resolution = FormHTML(value="", elem_id="txtimg_hr_finalres", label="Upscaled resolution", interactive=False, min_width=0) - with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): - hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) - hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") - denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") + with FormRow(elem_id="txt2img_hires_fix_row1", variant="compact"): + hr_upscaler = gr.Dropdown(label="Upscaler", elem_id="txt2img_hr_upscaler", choices=[*shared.latent_upscale_modes, *[x.name for x in shared.sd_upscalers]], value=shared.latent_upscale_default_mode) + hr_second_pass_steps = gr.Slider(minimum=0, maximum=150, step=1, label='Hires steps', value=0, elem_id="txt2img_hires_steps") + denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.7, elem_id="txt2img_denoising_strength") - with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): - hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") - hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") - hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") + with FormRow(elem_id="txt2img_hires_fix_row2", variant="compact"): + hr_scale = gr.Slider(minimum=1.0, maximum=4.0, step=0.05, label="Upscale by", value=2.0, elem_id="txt2img_hr_scale") + hr_resize_x = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize width to", value=0, elem_id="txt2img_hr_resize_x") + hr_resize_y = gr.Slider(minimum=0, maximum=2048, step=8, label="Resize height to", value=0, elem_id="txt2img_hr_resize_y") - with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: + with FormRow(elem_id="txt2img_hires_fix_row3", variant="compact", visible=opts.hires_fix_show_sampler) as hr_sampler_container: - hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") - create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") + hr_checkpoint_name = gr.Dropdown(label='Hires checkpoint', elem_id="hr_checkpoint", choices=["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True), value="Use same checkpoint") + create_refresh_button(hr_checkpoint_name, modules.sd_models.list_models, lambda: {"choices": ["Use same checkpoint"] + modules.sd_models.checkpoint_tiles(use_short=True)}, "hr_checkpoint_refresh") - hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") + hr_sampler_name = gr.Dropdown(label='Hires sampling method', elem_id="hr_sampler", choices=["Use same sampler"] + sd_samplers.visible_sampler_names(), value="Use same sampler") - with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: - with gr.Column(scale=80): - with gr.Row(): - hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"]) - with gr.Column(scale=80): - with gr.Row(): - hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"]) + with FormRow(elem_id="txt2img_hires_fix_row4", variant="compact", visible=opts.hires_fix_show_prompts) as hr_prompts_container: + with gr.Column(scale=80): + with gr.Row(): + hr_prompt = gr.Textbox(label="Hires prompt", elem_id="hires_prompt", show_label=False, lines=3, placeholder="Prompt for hires fix pass.\nLeave empty to use the same prompt as in first pass.", elem_classes=["prompt"]) + with gr.Column(scale=80): + with gr.Row(): + hr_negative_prompt = gr.Textbox(label="Hires negative prompt", elem_id="hires_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt for hires fix pass.\nLeave empty to use the same negative prompt as in first pass.", elem_classes=["prompt"]) + + scripts.scripts_txt2img.setup_ui_for_section(category) elif category == "batch": if not opts.dimensions_and_batch_together: @@ -482,7 +485,7 @@ def create_ui(): with FormGroup(elem_id="txt2img_script_container"): custom_inputs = scripts.scripts_txt2img.setup_ui() - else: + if category not in {"accordions"}: scripts.scripts_txt2img.setup_ui_for_section(category) hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] @@ -794,6 +797,10 @@ def create_ui(): with FormRow(elem_classes="checkboxes-row", variant="compact"): pass + elif category == "accordions": + with gr.Row(elem_id="img2img_accordions", elem_classes="accordions"): + scripts.scripts_img2img.setup_ui_for_section(category) + elif category == "batch": if not opts.dimensions_and_batch_together: with FormRow(elem_id="img2img_column_batch"): @@ -836,7 +843,8 @@ def create_ui(): inputs=[], outputs=[inpaint_controls, mask_alpha], ) - else: + + if category not in {"accordions"}: scripts.scripts_img2img.setup_ui_for_section(category) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) diff --git a/modules/ui_components.py b/modules/ui_components.py index bfe2fbd9..d08b2b99 100644 --- a/modules/ui_components.py +++ b/modules/ui_components.py @@ -87,13 +87,23 @@ class InputAccordion(gr.Checkbox): self.accordion_id = f"input-accordion-{InputAccordion.global_index}" InputAccordion.global_index += 1 - kwargs['elem_id'] = self.accordion_id + "-checkbox" - kwargs['visible'] = False - super().__init__(value, **kwargs) + kwargs_checkbox = { + **kwargs, + "elem_id": f"{self.accordion_id}-checkbox", + "visible": False, + } + super().__init__(value, **kwargs_checkbox) self.change(fn=None, _js='function(checked){ inputAccordionChecked("' + self.accordion_id + '", checked); }', inputs=[self]) - self.accordion = gr.Accordion(kwargs.get('label', 'Accordion'), open=value, elem_id=self.accordion_id, elem_classes=['input-accordion']) + kwargs_accordion = { + **kwargs, + "elem_id": self.accordion_id, + "label": kwargs.get('label', 'Accordion'), + "elem_classes": ['input-accordion'], + "open": value, + } + self.accordion = gr.Accordion(**kwargs_accordion) def extra(self): """Allows you to put something into the label of the accordion. diff --git a/style.css b/style.css index 4cdce87c..260b1056 100644 --- a/style.css +++ b/style.css @@ -166,16 +166,6 @@ a{ color: var(--button-secondary-text-color-hover); } -.checkboxes-row{ - margin-bottom: 0.5em; - margin-left: 0em; -} -.checkboxes-row > div{ - flex: 0; - white-space: nowrap; - min-width: auto !important; -} - button.custom-button{ border-radius: var(--button-large-radius); padding: var(--button-large-padding); @@ -352,7 +342,7 @@ div.block.gradio-accordion { } div.dimensions-tools{ - min-width: 0 !important; + min-width: 1.6em !important; max-width: fit-content; flex-direction: column; place-content: center; @@ -1012,10 +1002,28 @@ div.block.gradio-box.popup-dialog > div:last-child, .popup-dialog > div:last-chi } div.block.input-accordion{ - margin-bottom: 0.4em; + } .input-accordion-extra{ flex: 0 0 auto !important; margin: 0 0.5em 0 auto; } + +div.accordions > div.input-accordion{ + min-width: fit-content !important; +} + +div.accordions > div.gradio-accordion .label-wrap span{ + white-space: nowrap; + margin-right: 0.25em; +} + +div.accordions{ + gap: 0.5em; +} + +div.accordions > div.input-accordion.input-accordion-open{ + flex: 1 auto; +} + -- cgit v1.2.3 From b293ed30610c040e621e1840d63047ae298f0650 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 12 Aug 2023 12:54:32 +0300 Subject: make it possible to use hires fix together with refiner --- modules/processing.py | 6 ++++++ modules/processing_scripts/refiner.py | 2 +- modules/sd_samplers_common.py | 3 +++ 3 files changed, 10 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/processing.py b/modules/processing.py index 5996cbac..6ad105d7 100755 --- a/modules/processing.py +++ b/modules/processing.py @@ -1200,6 +1200,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, total_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) def setup_conds(self): + if self.is_hr_pass: + # if we are in hr pass right now, the call is being made from the refiner, and we don't need to setup firstpass cons or switch model + self.hr_c = None + self.calculate_hr_conds() + return + super().setup_conds() self.hr_uc = None diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index 5a82991a..773ec5d0 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -24,7 +24,7 @@ class ScriptRefiner(scripts.Script): refiner_checkpoint = gr.Dropdown(label='Checkpoint', elem_id=self.elem_id("checkpoint"), choices=sd_models.checkpoint_tiles(), value='', tooltip="switch to another model in the middle of generation") create_refresh_button(refiner_checkpoint, sd_models.list_models, lambda: {"choices": sd_models.checkpoint_tiles()}, self.elem_id("checkpoint_refresh")) - refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the swtch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") + refiner_switch_at = gr.Slider(value=0.8, label="Switch at", minimum=0.01, maximum=1.0, step=0.01, elem_id=self.elem_id("switch_at"), tooltip="fraction of sampling steps when the switch to refiner model should happen; 1=never, 0.5=switch in the middle of generation") def lookup_checkpoint(title): info = sd_models.get_closet_checkpoint_match(title) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 85f3c7e0..40c7aae0 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -151,6 +151,9 @@ def apply_refiner(cfg_denoiser): if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info: return False + if getattr(cfg_denoiser.p, "enable_hr", False) and not cfg_denoiser.p.is_hr_pass: + return False + cfg_denoiser.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title cfg_denoiser.p.extra_generation_params['Refiner switch at'] = refiner_switch_at -- cgit v1.2.3 From fa9370b7411166c19e8e386400dc4e6082f47b2d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 13 Aug 2023 06:07:30 +0300 Subject: add refiner to StableDiffusionProcessing class write out correct model name in infotext, rather than the refiner model --- modules/processing.py | 38 +++++++++++++++++++++++++++-------- modules/processing_scripts/refiner.py | 16 +++++---------- modules/sd_samplers_common.py | 2 +- 3 files changed, 36 insertions(+), 20 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/processing.py b/modules/processing.py index 6ad105d7..b47ddaa8 100755 --- a/modules/processing.py +++ b/modules/processing.py @@ -111,7 +111,7 @@ class StableDiffusionProcessing: cached_uc = [None, None] cached_c = [None, None] - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, script_args: list = None): + def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, refiner_checkpoint: str = None, refiner_switch_at: float = None, script_args: list = None): if sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) @@ -153,10 +153,14 @@ class StableDiffusionProcessing: self.s_noise = s_noise if s_noise is not None else opts.s_noise self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} self.override_settings_restore_afterwards = override_settings_restore_afterwards + self.refiner_checkpoint = refiner_checkpoint + self.refiner_switch_at = refiner_switch_at + self.is_using_inpainting_conditioning = False self.disable_extra_networks = False self.token_merging_ratio = 0 self.token_merging_ratio_hr = 0 + self.refiner_checkpoint_info = None if not seed_enable_extras: self.subseed = -1 @@ -191,6 +195,11 @@ class StableDiffusionProcessing: self.user = None + self.sd_model_name = None + self.sd_model_hash = None + self.sd_vae_name = None + self.sd_vae_hash = None + @property def sd_model(self): return shared.sd_model @@ -408,7 +417,10 @@ class Processed: self.batch_size = p.batch_size self.restore_faces = p.restore_faces self.face_restoration_model = opts.face_restoration_model if p.restore_faces else None - self.sd_model_hash = shared.sd_model.sd_model_hash + self.sd_model_name = p.sd_model_name + self.sd_model_hash = p.sd_model_hash + self.sd_vae_name = p.sd_vae_name + self.sd_vae_hash = p.sd_vae_hash self.seed_resize_from_w = p.seed_resize_from_w self.seed_resize_from_h = p.seed_resize_from_h self.denoising_strength = getattr(p, 'denoising_strength', None) @@ -459,7 +471,10 @@ class Processed: "batch_size": self.batch_size, "restore_faces": self.restore_faces, "face_restoration_model": self.face_restoration_model, + "sd_model_name": self.sd_model_name, "sd_model_hash": self.sd_model_hash, + "sd_vae_name": self.sd_vae_name, + "sd_vae_hash": self.sd_vae_hash, "seed_resize_from_w": self.seed_resize_from_w, "seed_resize_from_h": self.seed_resize_from_h, "denoising_strength": self.denoising_strength, @@ -578,10 +593,10 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index], "Face restoration": opts.face_restoration_model if p.restore_faces else None, "Size": f"{p.width}x{p.height}", - "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash), - "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra), - "VAE hash": p.loaded_vae_hash if opts.add_model_hash_to_info else None, - "VAE": p.loaded_vae_name if opts.add_model_name_to_info else None, + "Model hash": p.sd_model_hash if opts.add_model_hash_to_info else None, + "Model": p.sd_model_name if opts.add_model_name_to_info else None, + "VAE hash": p.sd_vae_hash if opts.add_model_hash_to_info else None, + "VAE": p.sd_vae_name if opts.add_model_name_to_info else None, "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])), "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength), "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"), @@ -670,8 +685,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.tiling is None: p.tiling = opts.tiling - p.loaded_vae_name = sd_vae.get_loaded_vae_name() - p.loaded_vae_hash = sd_vae.get_loaded_vae_hash() + if p.refiner_checkpoint not in (None, "", "None"): + p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint) + if p.refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}') + + p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra + p.sd_model_hash = shared.sd_model.sd_model_hash + p.sd_vae_name = sd_vae.get_loaded_vae_name() + p.sd_vae_hash = sd_vae.get_loaded_vae_hash() modules.sd_hijack.model_hijack.apply_circular(p.tiling) modules.sd_hijack.model_hijack.clear_comments() diff --git a/modules/processing_scripts/refiner.py b/modules/processing_scripts/refiner.py index 773ec5d0..7b946d05 100644 --- a/modules/processing_scripts/refiner.py +++ b/modules/processing_scripts/refiner.py @@ -41,15 +41,9 @@ class ScriptRefiner(scripts.Script): def before_process(self, p, enable_refiner, refiner_checkpoint, refiner_switch_at): # the actual implementation is in sd_samplers_common.py, apply_refiner - p.refiner_checkpoint_info = None - p.refiner_switch_at = None - if not enable_refiner or refiner_checkpoint in (None, "", "None"): - return - - refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(refiner_checkpoint) - if refiner_checkpoint_info is None: - raise Exception(f'Could not find checkpoint with name {refiner_checkpoint}') - - p.refiner_checkpoint_info = refiner_checkpoint_info - p.refiner_switch_at = refiner_switch_at + p.refiner_checkpoint_info = None + p.refiner_switch_at = None + else: + p.refiner_checkpoint = refiner_checkpoint + p.refiner_switch_at = refiner_switch_at diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 40c7aae0..380cdd5f 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -145,7 +145,7 @@ def apply_refiner(cfg_denoiser): refiner_switch_at = cfg_denoiser.p.refiner_switch_at refiner_checkpoint_info = cfg_denoiser.p.refiner_checkpoint_info - if refiner_switch_at is not None and completed_ratio <= refiner_switch_at: + if refiner_switch_at is not None and completed_ratio < refiner_switch_at: return False if refiner_checkpoint_info is None or shared.sd_model.sd_checkpoint_info == refiner_checkpoint_info: -- cgit v1.2.3 From 599f61a1e0bddf463dd3c6adb84509b3d9db1941 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 13 Aug 2023 08:24:16 +0300 Subject: use dataclass for StableDiffusionProcessing --- modules/processing.py | 318 +++++++++++++++++++++++------------------- modules/sd_samplers_common.py | 5 +- 2 files changed, 176 insertions(+), 147 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/processing.py b/modules/processing.py index b47ddaa8..007a4e05 100755 --- a/modules/processing.py +++ b/modules/processing.py @@ -1,9 +1,11 @@ +from __future__ import annotations import json import logging import math import os import sys import hashlib +from dataclasses import dataclass, field import torch import numpy as np @@ -11,7 +13,7 @@ from PIL import Image, ImageOps import random import cv2 from skimage import exposure -from typing import Any, Dict, List +from typing import Any import modules.sd_hijack from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet, errors, rng @@ -104,106 +106,126 @@ def txt2img_image_conditioning(sd_model, x, width, height): return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) +@dataclass(repr=False) class StableDiffusionProcessing: - """ - The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing - """ + sd_model: object = None + outpath_samples: str = None + outpath_grids: str = None + prompt: str = "" + prompt_for_display: str = None + negative_prompt: str = "" + styles: list[str] = field(default_factory=list) + seed: int = -1 + subseed: int = -1 + subseed_strength: float = 0 + seed_resize_from_h: int = -1 + seed_resize_from_w: int = -1 + seed_enable_extras: bool = True + sampler_name: str = None + batch_size: int = 1 + n_iter: int = 1 + steps: int = 50 + cfg_scale: float = 7.0 + width: int = 512 + height: int = 512 + restore_faces: bool = None + tiling: bool = None + do_not_save_samples: bool = False + do_not_save_grid: bool = False + extra_generation_params: dict[str, Any] = None + overlay_images: list = None + eta: float = None + do_not_reload_embeddings: bool = False + denoising_strength: float = 0 + ddim_discretize: str = None + s_min_uncond: float = None + s_churn: float = None + s_tmax: float = None + s_tmin: float = None + s_noise: float = None + override_settings: dict[str, Any] = None + override_settings_restore_afterwards: bool = True + sampler_index: int = None + refiner_checkpoint: str = None + refiner_switch_at: float = None + token_merging_ratio = 0 + token_merging_ratio_hr = 0 + disable_extra_networks: bool = False + + script_args: list = None + cached_uc = [None, None] cached_c = [None, None] - def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt: str = "", styles: List[str] = None, seed: int = -1, subseed: int = -1, subseed_strength: float = 0, seed_resize_from_h: int = -1, seed_resize_from_w: int = -1, seed_enable_extras: bool = True, sampler_name: str = None, batch_size: int = 1, n_iter: int = 1, steps: int = 50, cfg_scale: float = 7.0, width: int = 512, height: int = 512, restore_faces: bool = None, tiling: bool = None, do_not_save_samples: bool = False, do_not_save_grid: bool = False, extra_generation_params: Dict[Any, Any] = None, overlay_images: Any = None, negative_prompt: str = None, eta: float = None, do_not_reload_embeddings: bool = False, denoising_strength: float = 0, ddim_discretize: str = None, s_min_uncond: float = 0.0, s_churn: float = 0.0, s_tmax: float = None, s_tmin: float = 0.0, s_noise: float = None, override_settings: Dict[str, Any] = None, override_settings_restore_afterwards: bool = True, sampler_index: int = None, refiner_checkpoint: str = None, refiner_switch_at: float = None, script_args: list = None): - if sampler_index is not None: + sampler: sd_samplers_common.Sampler | None = field(default=None, init=False) + is_using_inpainting_conditioning: bool = field(default=False, init=False) + paste_to: tuple | None = field(default=None, init=False) + + is_hr_pass: bool = field(default=False, init=False) + + c: tuple = field(default=None, init=False) + uc: tuple = field(default=None, init=False) + + rng: rng.ImageRNG | None = field(default=None, init=False) + step_multiplier: int = field(default=1, init=False) + color_corrections: list = field(default=None, init=False) + + scripts: list = field(default=None, init=False) + all_prompts: list = field(default=None, init=False) + all_negative_prompts: list = field(default=None, init=False) + all_seeds: list = field(default=None, init=False) + all_subseeds: list = field(default=None, init=False) + iteration: int = field(default=0, init=False) + main_prompt: str = field(default=None, init=False) + main_negative_prompt: str = field(default=None, init=False) + + prompts: list = field(default=None, init=False) + negative_prompts: list = field(default=None, init=False) + seeds: list = field(default=None, init=False) + subseeds: list = field(default=None, init=False) + extra_network_data: dict = field(default=None, init=False) + + user: str = field(default=None, init=False) + + sd_model_name: str = field(default=None, init=False) + sd_model_hash: str = field(default=None, init=False) + sd_vae_name: str = field(default=None, init=False) + sd_vae_hash: str = field(default=None, init=False) + + def __post_init__(self): + if self.sampler_index is not None: print("sampler_index argument for StableDiffusionProcessing does not do anything; use sampler_name", file=sys.stderr) - self.outpath_samples: str = outpath_samples - self.outpath_grids: str = outpath_grids - self.prompt: str = prompt - self.prompt_for_display: str = None - self.negative_prompt: str = (negative_prompt or "") - self.styles: list = styles or [] - self.seed: int = seed - self.subseed: int = subseed - self.subseed_strength: float = subseed_strength - self.seed_resize_from_h: int = seed_resize_from_h - self.seed_resize_from_w: int = seed_resize_from_w - self.sampler_name: str = sampler_name - self.batch_size: int = batch_size - self.n_iter: int = n_iter - self.steps: int = steps - self.cfg_scale: float = cfg_scale - self.width: int = width - self.height: int = height - self.restore_faces: bool = restore_faces - self.tiling: bool = tiling - self.do_not_save_samples: bool = do_not_save_samples - self.do_not_save_grid: bool = do_not_save_grid - self.extra_generation_params: dict = extra_generation_params or {} - self.overlay_images = overlay_images - self.eta = eta - self.do_not_reload_embeddings = do_not_reload_embeddings - self.paste_to = None - self.color_corrections = None - self.denoising_strength: float = denoising_strength self.sampler_noise_scheduler_override = None - self.ddim_discretize = ddim_discretize or opts.ddim_discretize - self.s_min_uncond = s_min_uncond or opts.s_min_uncond - self.s_churn = s_churn or opts.s_churn - self.s_tmin = s_tmin or opts.s_tmin - self.s_tmax = (s_tmax if s_tmax is not None else opts.s_tmax) or float('inf') - self.s_noise = s_noise if s_noise is not None else opts.s_noise - self.override_settings = {k: v for k, v in (override_settings or {}).items() if k not in shared.restricted_opts} - self.override_settings_restore_afterwards = override_settings_restore_afterwards - self.refiner_checkpoint = refiner_checkpoint - self.refiner_switch_at = refiner_switch_at - - self.is_using_inpainting_conditioning = False - self.disable_extra_networks = False - self.token_merging_ratio = 0 - self.token_merging_ratio_hr = 0 + self.s_min_uncond = self.s_min_uncond if self.s_min_uncond is not None else opts.s_min_uncond + self.s_churn = self.s_churn if self.s_churn is not None else opts.s_churn + self.s_tmin = self.s_tmin if self.s_tmin is not None else opts.s_tmin + self.s_tmax = (self.s_tmax if self.s_tmax is not None else opts.s_tmax) or float('inf') + self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise + + self.extra_generation_params = self.extra_generation_params or {} + self.override_settings = self.override_settings or {} + self.script_args = self.script_args or {} + self.refiner_checkpoint_info = None - if not seed_enable_extras: + if not self.seed_enable_extras: self.subseed = -1 self.subseed_strength = 0 self.seed_resize_from_h = 0 self.seed_resize_from_w = 0 - self.scripts = None - self.script_args = script_args - self.all_prompts = None - self.all_negative_prompts = None - self.all_seeds = None - self.all_subseeds = None - self.iteration = 0 - self.is_hr_pass = False - self.sampler = None - self.main_prompt = None - self.main_negative_prompt = None - - self.prompts = None - self.negative_prompts = None - self.extra_network_data = None - self.seeds = None - self.subseeds = None - - self.step_multiplier = 1 self.cached_uc = StableDiffusionProcessing.cached_uc self.cached_c = StableDiffusionProcessing.cached_c - self.uc = None - self.c = None - self.rng: rng.ImageRNG = None - - self.user = None - - self.sd_model_name = None - self.sd_model_hash = None - self.sd_vae_name = None - self.sd_vae_hash = None @property def sd_model(self): return shared.sd_model + @sd_model.setter + def sd_model(self, value): + pass + def txt2img_image_conditioning(self, x, width=None, height=None): self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'} @@ -932,49 +954,51 @@ def old_hires_fix_first_pass_dimensions(width, height): return width, height +@dataclass(repr=False) class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): - sampler = None + enable_hr: bool = False + denoising_strength: float = 0.75 + firstphase_width: int = 0 + firstphase_height: int = 0 + hr_scale: float = 2.0 + hr_upscaler: str = None + hr_second_pass_steps: int = 0 + hr_resize_x: int = 0 + hr_resize_y: int = 0 + hr_checkpoint_name: str = None + hr_sampler_name: str = None + hr_prompt: str = '' + hr_negative_prompt: str = '' + cached_hr_uc = [None, None] cached_hr_c = [None, None] - def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_checkpoint_name: str = None, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs): - super().__init__(**kwargs) - self.enable_hr = enable_hr - self.denoising_strength = denoising_strength - self.hr_scale = hr_scale - self.hr_upscaler = hr_upscaler - self.hr_second_pass_steps = hr_second_pass_steps - self.hr_resize_x = hr_resize_x - self.hr_resize_y = hr_resize_y - self.hr_upscale_to_x = hr_resize_x - self.hr_upscale_to_y = hr_resize_y - self.hr_checkpoint_name = hr_checkpoint_name - self.hr_checkpoint_info = None - self.hr_sampler_name = hr_sampler_name - self.hr_prompt = hr_prompt - self.hr_negative_prompt = hr_negative_prompt - self.all_hr_prompts = None - self.all_hr_negative_prompts = None - self.latent_scale_mode = None - - if firstphase_width != 0 or firstphase_height != 0: + hr_checkpoint_info: dict = field(default=None, init=False) + hr_upscale_to_x: int = field(default=0, init=False) + hr_upscale_to_y: int = field(default=0, init=False) + truncate_x: int = field(default=0, init=False) + truncate_y: int = field(default=0, init=False) + applied_old_hires_behavior_to: tuple = field(default=None, init=False) + latent_scale_mode: dict = field(default=None, init=False) + hr_c: tuple | None = field(default=None, init=False) + hr_uc: tuple | None = field(default=None, init=False) + all_hr_prompts: list = field(default=None, init=False) + all_hr_negative_prompts: list = field(default=None, init=False) + hr_prompts: list = field(default=None, init=False) + hr_negative_prompts: list = field(default=None, init=False) + hr_extra_network_data: list = field(default=None, init=False) + + def __post_init__(self): + super().__post_init__() + + if self.firstphase_width != 0 or self.firstphase_height != 0: self.hr_upscale_to_x = self.width self.hr_upscale_to_y = self.height - self.width = firstphase_width - self.height = firstphase_height - - self.truncate_x = 0 - self.truncate_y = 0 - self.applied_old_hires_behavior_to = None - - self.hr_prompts = None - self.hr_negative_prompts = None - self.hr_extra_network_data = None + self.width = self.firstphase_width + self.height = self.firstphase_height self.cached_hr_uc = StableDiffusionProcessingTxt2Img.cached_hr_uc self.cached_hr_c = StableDiffusionProcessingTxt2Img.cached_hr_c - self.hr_c = None - self.hr_uc = None def calculate_target_resolution(self): if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height): @@ -1252,7 +1276,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return super().get_conds() - def parse_extra_network_prompts(self): res = super().parse_extra_network_prompts() @@ -1265,32 +1288,37 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return res +@dataclass(repr=False) class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): - sampler = None - - def __init__(self, init_images: list = None, resize_mode: int = 0, denoising_strength: float = 0.75, image_cfg_scale: float = None, mask: Any = None, mask_blur: int = None, mask_blur_x: int = 4, mask_blur_y: int = 4, inpainting_fill: int = 0, inpaint_full_res: bool = True, inpaint_full_res_padding: int = 0, inpainting_mask_invert: int = 0, initial_noise_multiplier: float = None, **kwargs): - super().__init__(**kwargs) - - self.init_images = init_images - self.resize_mode: int = resize_mode - self.denoising_strength: float = denoising_strength - self.image_cfg_scale: float = image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None - self.init_latent = None - self.image_mask = mask - self.latent_mask = None - self.mask_for_overlay = None - self.mask_blur_x = mask_blur_x - self.mask_blur_y = mask_blur_y - if mask_blur is not None: - self.mask_blur = mask_blur - self.inpainting_fill = inpainting_fill - self.inpaint_full_res = inpaint_full_res - self.inpaint_full_res_padding = inpaint_full_res_padding - self.inpainting_mask_invert = inpainting_mask_invert - self.initial_noise_multiplier = opts.initial_noise_multiplier if initial_noise_multiplier is None else initial_noise_multiplier + init_images: list = None + resize_mode: int = 0 + denoising_strength: float = 0.75 + image_cfg_scale: float = None + mask: Any = None + mask_blur_x: int = 4 + mask_blur_y: int = 4 + mask_blur: int = None + inpainting_fill: int = 0 + inpaint_full_res: bool = True + inpaint_full_res_padding: int = 0 + inpainting_mask_invert: int = 0 + initial_noise_multiplier: float = None + latent_mask: Image = None + + image_mask: Any = field(default=None, init=False) + + nmask: torch.Tensor = field(default=None, init=False) + image_conditioning: torch.Tensor = field(default=None, init=False) + init_img_hash: str = field(default=None, init=False) + mask_for_overlay: Image = field(default=None, init=False) + init_latent: torch.Tensor = field(default=None, init=False) + + def __post_init__(self): + super().__post_init__() + + self.image_mask = self.mask self.mask = None - self.nmask = None - self.image_conditioning = None + self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier @property def mask_blur(self): @@ -1300,15 +1328,13 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): @mask_blur.setter def mask_blur(self, value): - self.mask_blur_x = value - self.mask_blur_y = value - - @mask_blur.deleter - def mask_blur(self): - del self.mask_blur_x - del self.mask_blur_y + if isinstance(value, int): + self.mask_blur_x = value + self.mask_blur_y = value def init(self, all_prompts, all_seeds, all_subseeds): + self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None + self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) crop_region = None diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 380cdd5f..09d1e11e 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -305,5 +305,8 @@ class Sampler: current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size] return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds) + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + raise NotImplementedError() - + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): + raise NotImplementedError() -- cgit v1.2.3 From 822597db49218de17e105e62075096284dfcfd41 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sun, 13 Aug 2023 04:16:48 -0400 Subject: Encode batches separately Significantly reduces VRAM. This makes encoding more inline with how decoding currently functions. --- modules/sd_samplers_common.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 09d1e11e..f9d034ca 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -92,7 +92,15 @@ def images_tensor_to_samples(image, approximation=None, model=None): model = shared.sd_model image = image.to(shared.device, dtype=devices.dtype_vae) image = image * 2 - 1 - x_latent = model.get_first_stage_encoding(model.encode_first_stage(image)) + if len(image) > 1: + x_latent = torch.stack([ + model.get_first_stage_encoding( + model.encode_first_stage(torch.unsqueeze(img, 0)) + )[0] + for img in image + ]) + else: + x_latent = model.get_first_stage_encoding(model.encode_first_stage(image)) return x_latent -- cgit v1.2.3 From d1a70c3f0534b88665d55bdac1e6b48a63f7f035 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sun, 13 Aug 2023 08:22:24 -0400 Subject: Add s_noise param to more samplers --- modules/sd_samplers_common.py | 8 ++++---- modules/sd_samplers_kdiffusion.py | 6 ++++++ 2 files changed, 10 insertions(+), 4 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 09d1e11e..d2fb21f4 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -276,19 +276,19 @@ class Sampler: s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf s_noise = getattr(opts, 's_noise', p.s_noise) - if s_churn != self.s_churn: + if 's_churn' in extra_params_kwargs and s_churn != self.s_churn: extra_params_kwargs['s_churn'] = s_churn p.s_churn = s_churn p.extra_generation_params['Sigma churn'] = s_churn - if s_tmin != self.s_tmin: + if 's_tmin' in extra_params_kwargs and s_tmin != self.s_tmin: extra_params_kwargs['s_tmin'] = s_tmin p.s_tmin = s_tmin p.extra_generation_params['Sigma tmin'] = s_tmin - if s_tmax != self.s_tmax: + if 's_tmax' in extra_params_kwargs and s_tmax != self.s_tmax: extra_params_kwargs['s_tmax'] = s_tmax p.s_tmax = s_tmax p.extra_generation_params['Sigma tmax'] = s_tmax - if s_noise != self.s_noise: + if 's_noise' in extra_params_kwargs and s_noise != self.s_noise: extra_params_kwargs['s_noise'] = s_noise p.s_noise = s_noise p.extra_generation_params['Sigma noise'] = s_noise diff --git a/modules/sd_samplers_kdiffusion.py b/modules/sd_samplers_kdiffusion.py index a48a563f..9f5dfd6d 100644 --- a/modules/sd_samplers_kdiffusion.py +++ b/modules/sd_samplers_kdiffusion.py @@ -45,6 +45,12 @@ sampler_extra_params = { 'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'], 'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'], + 'sample_dpm_fast': ['s_noise'], + 'sample_dpm_2_ancestral': ['s_noise'], + 'sample_dpmpp_2s_ancestral': ['s_noise'], + 'sample_dpmpp_sde': ['s_noise'], + 'sample_dpmpp_2m_sde': ['s_noise'], + 'sample_dpmpp_3m_sde': ['s_noise'], } k_diffusion_samplers_map = {x.name: x for x in samplers_data_k_diffusion} -- cgit v1.2.3 From 45be87afc6b173ebda94f427b8d396a2842ddf3c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 14 Aug 2023 21:48:05 +0300 Subject: correctly add Eta DDIM to infotext when it's 1.0 and do not add it when it's 0.0. --- modules/sd_samplers_common.py | 3 ++- modules/sd_samplers_timesteps.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 07fc4434..8886de7e 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -217,6 +217,7 @@ class Sampler: self.eta_option_field = 'eta_ancestral' self.eta_infotext_field = 'Eta' + self.eta_default = 1.0 self.conditioning_key = shared.sd_model.model.conditioning_key @@ -273,7 +274,7 @@ class Sampler: extra_params_kwargs[param_name] = getattr(p, param_name) if 'eta' in inspect.signature(self.func).parameters: - if self.eta != 1.0: + if self.eta != self.eta_default: p.extra_generation_params[self.eta_infotext_field] = self.eta extra_params_kwargs['eta'] = self.eta diff --git a/modules/sd_samplers_timesteps.py b/modules/sd_samplers_timesteps.py index c1f534ed..66e83ff7 100644 --- a/modules/sd_samplers_timesteps.py +++ b/modules/sd_samplers_timesteps.py @@ -76,6 +76,7 @@ class CompVisSampler(sd_samplers_common.Sampler): self.eta_option_field = 'eta_ddim' self.eta_infotext_field = 'Eta DDIM' + self.eta_default = 0.0 self.model_wrap_cfg = CFGDenoiserTimesteps(self) -- cgit v1.2.3 From d9ddc5d4cd9e707c018d245572310ca6f6cdf1ef Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Wed, 16 Aug 2023 11:21:12 +0800 Subject: Remove wrong scale --- modules/sd_samplers_common.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 8886de7e..7dc79ea8 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -44,8 +44,7 @@ def samples_to_images_tensor(sample, approximation=None, model=None): elif approximation == 1: x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype)).detach() elif approximation == 3: - x_sample = sample * 1.5 - x_sample = sd_vae_taesd.decoder_model()(x_sample.to(devices.device, devices.dtype)).detach() + x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach() x_sample = x_sample * 2 - 1 else: if model is None: -- cgit v1.2.3 From 3003b10e0ab84abb6298285e707b3618ceec51dc Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Thu, 17 Aug 2023 18:10:55 -0400 Subject: Attempt to resolve NaN issue with unstable VAEs in fp32 mk2 --- modules/sd_samplers_common.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 7dc79ea8..67deefa5 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -49,7 +49,8 @@ def samples_to_images_tensor(sample, approximation=None, model=None): else: if model is None: model = shared.sd_model - x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) + with devices.without_autocast(): # fixes an issue with unstable VAEs that are flaky even in fp32 + x_sample = model.decode_first_stage(sample.to(model.first_stage_model.dtype)) return x_sample -- cgit v1.2.3 From 3ce5fb8e5c3a6577172e89e964d570f8b31a998b Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Thu, 17 Aug 2023 20:03:26 -0400 Subject: Add option for faster live interrupt --- modules/sd_samplers_common.py | 2 +- modules/shared_options.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 7dc79ea8..f0bc8e6a 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -36,7 +36,7 @@ approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": def samples_to_images_tensor(sample, approximation=None, model=None): '''latents -> images [-1, 1]''' - if approximation is None: + if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt): approximation = approximation_indexes.get(opts.show_progress_type, 0) if approximation == 2: diff --git a/modules/shared_options.py b/modules/shared_options.py index 79cbb92e..3e4bcaef 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -281,6 +281,7 @@ options_templates.update(options_section(('ui', "Live previews"), { "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), + "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), })) options_templates.update(options_section(('sampler-params', "Sampler parameters"), { -- cgit v1.2.3 From 953c3eab7b3b952f7e96d728413a531d7fb521a2 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 21 Aug 2023 15:54:30 +0300 Subject: forbid Full live preview method for medvram and add a setting to undo the forbidding --- modules/sd_samplers_common.py | 7 ++++++- modules/shared_options.py | 1 + 2 files changed, 7 insertions(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index feb1a9db..60fa161c 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -35,10 +35,15 @@ approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": def samples_to_images_tensor(sample, approximation=None, model=None): - '''latents -> images [-1, 1]''' + """Transforms 4-channel latent space images into 3-channel RGB image tensors, with values in range [-1, 1].""" + if approximation is None or (shared.state.interrupted and opts.live_preview_fast_interrupt): approximation = approximation_indexes.get(opts.show_progress_type, 0) + from modules import lowvram + if approximation == 0 and lowvram.is_enabled(shared.sd_model) and not shared.opts.live_preview_allow_lowvram_full: + approximation = 1 + if approximation == 2: x_sample = sd_vae_approx.cheap_approximation(sample) elif approximation == 1: diff --git a/modules/shared_options.py b/modules/shared_options.py index 095cf479..88f6b334 100644 --- a/modules/shared_options.py +++ b/modules/shared_options.py @@ -287,6 +287,7 @@ options_templates.update(options_section(('ui', "Live previews"), { "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"), "show_progress_every_n_steps": OptionInfo(10, "Live preview display period", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}).info("in sampling steps - show new live preview image every N sampling steps; -1 = only show after completion of batch"), "show_progress_type": OptionInfo("Approx NN", "Live preview method", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap", "TAESD"]}).info("Full = slow but pretty; Approx NN and TAESD = fast but low quality; Approx cheap = super fast but terrible otherwise"), + "live_preview_allow_lowvram_full": OptionInfo(False, "Allow Full live preview method with lowvram/medvram").info("If not, Approx NN will be used instead; Full live preview method is very detrimental to speed if lowvram/medvram optimizations are enabled"), "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}), "live_preview_refresh_period": OptionInfo(1000, "Progressbar and preview update period").info("in milliseconds"), "live_preview_fast_interrupt": OptionInfo(False, "Return image with chosen live preview method on interrupt").info("makes interrupts faster"), -- cgit v1.2.3 From d7c9c6142071359470347564913feb5b7a6c5c13 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 22 Aug 2023 09:55:20 +0300 Subject: attemped solution to the uncommon hanging problem that is seemingly caused by live previews working on the tensor as denoising --- modules/sd_samplers_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 60fa161c..64845ea4 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -111,7 +111,7 @@ def images_tensor_to_samples(image, approximation=None, model=None): def store_latent(decoded): - state.current_latent = decoded + state.current_latent = decoded.clone() if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if not shared.parallel_processing_allowed: -- cgit v1.2.3 From a459075d26eecc38d6d58116e38f453450191460 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 22 Aug 2023 10:41:10 +0300 Subject: actual solution to the uncommon hanging problem that is seemingly caused by multiple progress requests working on same tensor --- modules/progress.py | 45 +++++++++++++++++++++---------------------- modules/sd_samplers_common.py | 2 +- modules/shared_state.py | 2 +- 3 files changed, 24 insertions(+), 25 deletions(-) (limited to 'modules/sd_samplers_common.py') diff --git a/modules/progress.py b/modules/progress.py index a25a0113..69921de7 100644 --- a/modules/progress.py +++ b/modules/progress.py @@ -95,31 +95,30 @@ def progressapi(req: ProgressRequest): predicted_duration = elapsed_since_start / progress if progress > 0 else None eta = predicted_duration - elapsed_since_start if predicted_duration is not None else None + live_preview = None id_live_preview = req.id_live_preview - shared.state.set_current_image() - if opts.live_previews_enable and req.live_preview and shared.state.id_live_preview != req.id_live_preview: - image = shared.state.current_image - if image is not None: - buffered = io.BytesIO() - - if opts.live_previews_image_format == "png": - # using optimize for large images takes an enormous amount of time - if max(*image.size) <= 256: - save_kwargs = {"optimize": True} + + if opts.live_previews_enable and req.live_preview: + shared.state.set_current_image() + if shared.state.id_live_preview != req.id_live_preview: + image = shared.state.current_image + if image is not None: + buffered = io.BytesIO() + + if opts.live_previews_image_format == "png": + # using optimize for large images takes an enormous amount of time + if max(*image.size) <= 256: + save_kwargs = {"optimize": True} + else: + save_kwargs = {"optimize": False, "compress_level": 1} + else: - save_kwargs = {"optimize": False, "compress_level": 1} - - else: - save_kwargs = {} - - image.save(buffered, format=opts.live_previews_image_format, **save_kwargs) - base64_image = base64.b64encode(buffered.getvalue()).decode('ascii') - live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}" - id_live_preview = shared.state.id_live_preview - else: - live_preview = None - else: - live_preview = None + save_kwargs = {} + + image.save(buffered, format=opts.live_previews_image_format, **save_kwargs) + base64_image = base64.b64encode(buffered.getvalue()).decode('ascii') + live_preview = f"data:image/{opts.live_previews_image_format};base64,{base64_image}" + id_live_preview = shared.state.id_live_preview return ProgressResponse(active=active, queued=queued, completed=completed, progress=progress, eta=eta, live_preview=live_preview, id_live_preview=id_live_preview, textinfo=shared.state.textinfo) diff --git a/modules/sd_samplers_common.py b/modules/sd_samplers_common.py index 64845ea4..60fa161c 100644 --- a/modules/sd_samplers_common.py +++ b/modules/sd_samplers_common.py @@ -111,7 +111,7 @@ def images_tensor_to_samples(image, approximation=None, model=None): def store_latent(decoded): - state.current_latent = decoded.clone() + state.current_latent = decoded if opts.live_previews_enable and opts.show_progress_every_n_steps > 0 and shared.state.sampling_step % opts.show_progress_every_n_steps == 0: if not shared.parallel_processing_allowed: diff --git a/modules/shared_state.py b/modules/shared_state.py index 3dc9c788..d272ee5b 100644 --- a/modules/shared_state.py +++ b/modules/shared_state.py @@ -128,7 +128,7 @@ class State: devices.torch_gc() def set_current_image(self): - """sets self.current_image from self.current_latent if enough sampling steps have been made after the last call to this""" + """if enough sampling steps have been made after the last call to this, sets self.current_image from self.current_latent, and modifies self.id_live_preview accordingly""" if not shared.parallel_processing_allowed: return -- cgit v1.2.3