From 9324cdaa3199d65c182858785dd1eca42b192b8e Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Sun, 16 Oct 2022 17:53:56 +0200 Subject: ui fix, re organization of the code --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 3aa21ec1..8e4ee435 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -20,7 +20,7 @@ checkpoints_loaded = collections.OrderedDict() try: # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start. - from transformers import logging + from transformers import logging, CLIPModel logging.set_verbosity_error() except Exception: @@ -196,6 +196,9 @@ def load_model(): sd_hijack.model_hijack.hijack(sd_model) + if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path: + shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path) + sd_model.eval() print(f"Model loaded.") -- cgit v1.2.3 From 8e7097d06a6a261580d34375c9d2a9e4ffc63ffa Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Wed, 19 Oct 2022 13:47:45 -0700 Subject: Added support for RunwayML inpainting model --- modules/processing.py | 34 ++++++- modules/sd_hijack_inpainting.py | 208 ++++++++++++++++++++++++++++++++++++++++ modules/sd_models.py | 16 +++- modules/sd_samplers.py | 50 +++++++--- 4 files changed, 293 insertions(+), 15 deletions(-) create mode 100644 modules/sd_hijack_inpainting.py (limited to 'modules/sd_models.py') diff --git a/modules/processing.py b/modules/processing.py index bcb0c32c..a6c308f9 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -546,7 +546,16 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not self.enable_hr: x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning) + + # The "masked-image" in this case will just be all zeros since the entire image is masked. + image_conditioning = torch.zeros(x.shape[0], 3, self.height, self.width, device=x.device) + image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning)) + + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) + + samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=image_conditioning) return samples x = create_random_tensors([opt_C, self.firstphase_height // opt_f, self.firstphase_width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) @@ -714,10 +723,31 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): elif self.inpainting_fill == 3: self.init_latent = self.init_latent * self.mask + if self.image_mask is not None: + conditioning_mask = np.array(self.image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + conditioning_mask = conditioning_mask.to(image.device) + conditioning_image = image * (1.0 - conditioning_mask) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self) - samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning) + samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning) if self.mask is not None: samples = samples * self.nmask + self.init_latent * self.mask diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py new file mode 100644 index 00000000..7e5670d6 --- /dev/null +++ b/modules/sd_hijack_inpainting.py @@ -0,0 +1,208 @@ +import torch +import numpy as np + +from tqdm import tqdm +from einops import rearrange, repeat +from omegaconf import ListConfig + +from types import MethodType + +import ldm.models.diffusion.ddpm +import ldm.models.diffusion.ddim + +from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.models.diffusion.ddim import DDIMSampler, noise_like + +# ================================================================================================= +# Monkey patch DDIMSampler methods from RunwayML repo directly. +# Adapted from: +# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py +# ================================================================================================= +@torch.no_grad() +def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = elf.inpainting_fill == 2: + self.init_latent = self.init_latent * self.mask + create_random_tensors(self.init_latent.shape[1:], all_seeds[0:self.init_latent.shape[0]]) * self.nmask + elif self.inpainting_fill == 3: + self.init_latent = self.init_latent * self.mask + + if self.image_mask is not None: + conditioning_mask = np.array(self.image_mask.convert("L")) + conditioning_mask = conditioning_mask.astype(np.float32) / 255.0 + conditioning_mask = torch.from_numpy(conditioning_mask[None, None]) + + # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0 + conditioning_mask = torch.round(conditioning_mask) + else: + conditioning_mask = torch.ones(1, 1, *image.shape[-2:]) + + # Create another latent image, this time with a masked version of the original input. + conditioning_mask = conditioning_mask.to(image.device) + conditioning_image = image * (1.0 - conditioning_mask) + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image)) + + # Create the concatenated conditioning tensor to be fed to `c_concat` + conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=self.init_latent.shape[-2:]) + conditioning_mask = conditioning_mask.expand(conditioning_image.shape[0], -1, -1, -1) + self.image_conditioning = torch.cat([conditioning_mask, conditioning_image], dim=1) + self.image_conditioning = self.image_conditioning.to(shared.device).type(self.sd_model.dtype) + + def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength): + x = create_random_tensors([opctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + +@torch.no_grad() +def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + +# ================================================================================================= +# Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config. +# Adapted from: +# https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddpm.py +# ================================================================================================= + +@torch.no_grad() +def get_unconditional_conditioning(self, batch_size, null_label=None): + if null_label is not None: + xc = null_label + if isinstance(xc, ListConfig): + xc = list(xc) + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + if hasattr(xc, "to"): + xc = xc.to(self.device) + c = self.get_learned_conditioning(xc) + else: + # todo: get null label from cond_stage_model + raise NotImplementedError() + c = repeat(c, "1 ... -> b ...", b=batch_size).to(self.device) + return c + +class LatentInpaintDiffusion(LatentDiffusion): + def __init__( + self, + concat_keys=("mask", "masked_image"), + masked_image_key="masked_image", + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.masked_image_key = masked_image_key + assert self.masked_image_key in concat_keys + self.concat_keys = concat_keys + +def should_hijack_inpainting(checkpoint_info): + return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml") + +def do_inpainting_hijack(): + ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning + ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion + ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim + ldm.models.diffusion.ddim.DDIMSampler.sample = sample \ No newline at end of file diff --git a/modules/sd_models.py b/modules/sd_models.py index eae22e87..47836d25 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -9,6 +9,7 @@ from ldm.util import instantiate_from_config from modules import shared, modelloader, devices from modules.paths import models_path +from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) @@ -211,6 +212,19 @@ def load_model(): print(f"Loading config from: {checkpoint_info.config}") sd_config = OmegaConf.load(checkpoint_info.config) + + if should_hijack_inpainting(checkpoint_info): + do_inpainting_hijack() + + # Hardcoded config for now... + sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" + sd_config.model.params.use_ema = False + sd_config.model.params.conditioning_key = "hybrid" + sd_config.model.params.unet_config.params.in_channels = 9 + + # Create a "fake" config with a different name so that we know to unload it when switching models. + checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) @@ -234,7 +248,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_model_checkpoint == checkpoint_info.filename: return - if sd_model.sd_checkpoint_info.config != checkpoint_info.config: + if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() shared.sd_model = load_model() return shared.sd_model diff --git a/modules/sd_samplers.py b/modules/sd_samplers.py index b58e810b..9d3cf289 100644 --- a/modules/sd_samplers.py +++ b/modules/sd_samplers.py @@ -136,9 +136,15 @@ class VanillaStableDiffusionSampler: if self.stop_at is not None and self.step > self.stop_at: raise InterruptedException + # Have to unwrap the inpainting conditioning here to perform pre-preocessing + image_conditioning = None + if isinstance(cond, dict): + image_conditioning = cond["c_concat"][0] + cond = cond["c_crossattn"][0] + unconditional_conditioning = unconditional_conditioning["c_crossattn"][0] conds_list, tensor = prompt_parser.reconstruct_multicond_batch(cond, self.step) - unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) + unconditional_conditioning = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, self.step) assert all([len(conds) == 1 for conds in conds_list]), 'composition via AND is not supported for DDIM/PLMS samplers' cond = tensor @@ -157,6 +163,10 @@ class VanillaStableDiffusionSampler: img_orig = self.sampler.model.q_sample(self.init_latent, ts) x_dec = img_orig * self.mask + self.nmask * x_dec + if image_conditioning is not None: + cond = {"c_concat": [image_conditioning], "c_crossattn": [cond]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + res = self.orig_p_sample_ddim(x_dec, cond, ts, unconditional_conditioning=unconditional_conditioning, *args, **kwargs) if self.mask is not None: @@ -182,7 +192,7 @@ class VanillaStableDiffusionSampler: self.mask = p.mask if hasattr(p, 'mask') else None self.nmask = p.nmask if hasattr(p, 'nmask') else None - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = setup_img2img_steps(p, steps) self.initialize(p) @@ -202,7 +212,7 @@ class VanillaStableDiffusionSampler: return samples - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): self.initialize(p) self.init_latent = None @@ -210,6 +220,11 @@ class VanillaStableDiffusionSampler: steps = steps or p.steps + # Wrap the conditioning models with additional image conditioning for inpainting model + if image_conditioning is not None: + conditioning = {"c_concat": [image_conditioning], "c_crossattn": [conditioning]} + unconditional_conditioning = {"c_concat": [image_conditioning], "c_crossattn": [unconditional_conditioning]} + # existing code fails with certain step counts, like 9 try: samples_ddim = self.launch_sampling(steps, lambda: self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)[0]) @@ -228,7 +243,7 @@ class CFGDenoiser(torch.nn.Module): self.init_latent = None self.step = 0 - def forward(self, x, sigma, uncond, cond, cond_scale): + def forward(self, x, sigma, uncond, cond, cond_scale, image_cond): if state.interrupted or state.skipped: raise InterruptedException @@ -239,28 +254,29 @@ class CFGDenoiser(torch.nn.Module): repeats = [len(conds_list[i]) for i in range(batch_size)] x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x]) + image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond]) sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma]) if tensor.shape[1] == uncond.shape[1]: cond_in = torch.cat([tensor, uncond]) if shared.batch_cond_uncond: - x_out = self.inner_model(x_in, sigma_in, cond=cond_in) + x_out = self.inner_model(x_in, sigma_in, cond={"c_crossattn": [cond_in], "c_concat": [image_cond_in]}) else: x_out = torch.zeros_like(x_in) for batch_offset in range(0, x_out.shape[0], batch_size): a = batch_offset b = a + batch_size - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=cond_in[a:b]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [cond_in[a:b]], "c_concat": [image_cond_in[a:b]]}) else: x_out = torch.zeros_like(x_in) batch_size = batch_size*2 if shared.batch_cond_uncond else batch_size for batch_offset in range(0, tensor.shape[0], batch_size): a = batch_offset b = min(a + batch_size, tensor.shape[0]) - x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond=tensor[a:b]) + x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], cond={"c_crossattn": [tensor[a:b]], "c_concat": [image_cond_in[a:b]]}) - x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond=uncond) + x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], cond={"c_crossattn": [uncond], "c_concat": [image_cond_in[-uncond.shape[0]:]]}) denoised_uncond = x_out[-uncond.shape[0]:] denoised = torch.clone(denoised_uncond) @@ -361,7 +377,7 @@ class KDiffusionSampler: return extra_params_kwargs - def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None): + def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None): steps, t_enc = setup_img2img_steps(p, steps) if p.sampler_noise_scheduler_override: @@ -389,11 +405,16 @@ class KDiffusionSampler: self.model_wrap_cfg.init_latent = x - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, xi, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples - def sample(self, p, x, conditioning, unconditional_conditioning, steps=None): + def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning = None): steps = steps or p.steps if p.sampler_noise_scheduler_override: @@ -414,7 +435,12 @@ class KDiffusionSampler: else: extra_params_kwargs['sigmas'] = sigmas - samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)) + samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={ + 'cond': conditioning, + 'image_cond': image_conditioning, + 'uncond': unconditional_conditioning, + 'cond_scale': p.cfg_scale + }, disable=False, callback=self.callback_state, **extra_params_kwargs)) return samples -- cgit v1.2.3 From 708c3a7bd8ce68cbe1aa7c268e5a4b1980affc9f Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 13:28:43 -0700 Subject: Added PLMS hijack and made sure to always replace methods --- modules/sd_hijack_inpainting.py | 163 ++++++++++++++++++++++++++++++++++++++-- modules/sd_models.py | 3 +- 2 files changed, 157 insertions(+), 9 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_hijack_inpainting.py b/modules/sd_hijack_inpainting.py index d4d28d2e..43938071 100644 --- a/modules/sd_hijack_inpainting.py +++ b/modules/sd_hijack_inpainting.py @@ -1,16 +1,14 @@ import torch -import numpy as np -from tqdm import tqdm -from einops import rearrange, repeat +from einops import repeat from omegaconf import ListConfig -from types import MethodType - import ldm.models.diffusion.ddpm import ldm.models.diffusion.ddim +import ldm.models.diffusion.plms from ldm.models.diffusion.ddpm import LatentDiffusion +from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.ddim import DDIMSampler, noise_like # ================================================================================================= @@ -19,7 +17,7 @@ from ldm.models.diffusion.ddim import DDIMSampler, noise_like # https://github.com/runwayml/stable-diffusion/blob/main/ldm/models/diffusion/ddim.py # ================================================================================================= @torch.no_grad() -def sample(self, +def sample_ddim(self, S, batch_size, shape, @@ -132,6 +130,153 @@ def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=F return x_prev, pred_x0 +# ================================================================================================= +# Monkey patch PLMSSampler methods. +# This one was not actually patched correctly in the RunwayML repo, but we can replicate the changes. +# Adapted from: +# https://github.com/CompVis/stable-diffusion/blob/main/ldm/models/diffusion/plms.py +# ================================================================================================= +@torch.no_grad() +def sample_plms(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + ctmp = conditioning[list(conditioning.keys())[0]] + while isinstance(ctmp, list): + ctmp = ctmp[0] + cbs = ctmp.shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + print(f'Data shape for PLMS sampling is {size}') + + samples, intermediates = self.plms_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + +@torch.no_grad() +def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): + b, *_, device = *x.shape, x.device + + def get_model_output(x, t): + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + + if isinstance(c, dict): + assert isinstance(unconditional_conditioning, dict) + c_in = dict() + for k in c: + if isinstance(c[k], list): + c_in[k] = [ + torch.cat([unconditional_conditioning[k][i], c[k][i]]) + for i in range(len(c[k])) + ] + else: + c_in[k] = torch.cat([unconditional_conditioning[k], c[k]]) + else: + c_in = torch.cat([unconditional_conditioning, c]) + + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + return e_t + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + + def get_x_prev_and_pred_x0(e_t, index): + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 + + e_t = get_model_output(x, t) + if len(old_eps) == 0: + # Pseudo Improved Euler (2nd order) + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) + e_t_next = get_model_output(x_prev, t_next) + e_t_prime = (e_t + e_t_next) / 2 + elif len(old_eps) == 1: + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (3 * e_t - old_eps[-1]) / 2 + elif len(old_eps) == 2: + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 + elif len(old_eps) >= 3: + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 + + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) + + return x_prev, pred_x0, e_t + # ================================================================================================= # Monkey patch LatentInpaintDiffusion to load the checkpoint with a proper config. # Adapted from: @@ -175,5 +320,9 @@ def should_hijack_inpainting(checkpoint_info): def do_inpainting_hijack(): ldm.models.diffusion.ddpm.get_unconditional_conditioning = get_unconditional_conditioning ldm.models.diffusion.ddpm.LatentInpaintDiffusion = LatentInpaintDiffusion + ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim - ldm.models.diffusion.ddim.DDIMSampler.sample = sample \ No newline at end of file + ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim + + ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms + ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms \ No newline at end of file diff --git a/modules/sd_models.py b/modules/sd_models.py index 47836d25..7072db08 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -214,8 +214,6 @@ def load_model(): sd_config = OmegaConf.load(checkpoint_info.config) if should_hijack_inpainting(checkpoint_info): - do_inpainting_hijack() - # Hardcoded config for now... sd_config.model.target = "ldm.models.diffusion.ddpm.LatentInpaintDiffusion" sd_config.model.params.use_ema = False @@ -225,6 +223,7 @@ def load_model(): # Create a "fake" config with a different name so that we know to unload it when switching models. checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml")) + do_inpainting_hijack() sd_model = instantiate_from_config(sd_config.model) load_model_weights(sd_model, checkpoint_info) -- cgit v1.2.3 From 49533eed9e3aad19e9868ee140708baec4fd44be Mon Sep 17 00:00:00 2001 From: random_thoughtss Date: Thu, 20 Oct 2022 16:01:27 -0700 Subject: XY grid correctly re-assignes model when config changes --- modules/sd_models.py | 6 +++--- scripts/xy_grid.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 7072db08..fea84630 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -204,9 +204,9 @@ def load_model_weights(model, checkpoint_info): model.sd_checkpoint_info = checkpoint_info -def load_model(): +def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack - checkpoint_info = select_checkpoint() + checkpoint_info = checkpoint_info or select_checkpoint() if checkpoint_info.config != shared.cmd_opts.config: print(f"Loading config from: {checkpoint_info.config}") @@ -249,7 +249,7 @@ def reload_model_weights(sd_model, info=None): if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): checkpoints_loaded.clear() - shared.sd_model = load_model() + shared.sd_model = load_model(checkpoint_info) return shared.sd_model if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: diff --git a/scripts/xy_grid.py b/scripts/xy_grid.py index 5cca168a..eff0c942 100644 --- a/scripts/xy_grid.py +++ b/scripts/xy_grid.py @@ -89,6 +89,7 @@ def apply_checkpoint(p, x, xs): if info is None: raise RuntimeError(f"Unknown checkpoint: {x}") modules.sd_models.reload_model_weights(shared.sd_model, info) + p.sd_model = shared.sd_model def confirm_checkpoints(p, xs): -- cgit v1.2.3 From df5706409386cc2e88718bd9101045587c39f8bb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 16:10:51 +0300 Subject: do not load aesthetic clip model until it's needed add refresh button for aesthetic embeddings add aesthetic params to images' infotext --- modules/aesthetic_clip.py | 40 +++++++++++++++++++---- modules/generation_parameters_copypaste.py | 18 +++++++++-- modules/img2img.py | 5 +-- modules/processing.py | 4 +-- modules/sd_models.py | 3 -- modules/txt2img.py | 4 +-- modules/ui.py | 52 ++++++++++++++++++++---------- style.css | 2 +- 8 files changed, 89 insertions(+), 39 deletions(-) (limited to 'modules/sd_models.py') diff --git a/modules/aesthetic_clip.py b/modules/aesthetic_clip.py index 34efa931..8c828541 100644 --- a/modules/aesthetic_clip.py +++ b/modules/aesthetic_clip.py @@ -40,6 +40,8 @@ def iter_to_batched(iterable, n=1): def create_ui(): + import modules.ui + with gr.Group(): with gr.Accordion("Open for Clip Aesthetic!", open=False): with gr.Row(): @@ -55,6 +57,8 @@ def create_ui(): label="Aesthetic imgs embedding", value="None") + modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings") + with gr.Row(): aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs', placeholder="This text is used to rotate the feature space of the imgs embs", @@ -66,11 +70,21 @@ def create_ui(): return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative +aesthetic_clip_model = None + + +def aesthetic_clip(): + global aesthetic_clip_model + + if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path: + aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path) + aesthetic_clip_model.cpu() + + return aesthetic_clip_model + + def generate_imgs_embd(name, folder, batch_size): - # clipModel = CLIPModel.from_pretrained( - # shared.sd_model.cond_stage_model.clipModel.name_or_path - # ) - model = shared.clip_model.to(device) + model = aesthetic_clip().to(device) processor = CLIPProcessor.from_pretrained(model.name_or_path) with torch.no_grad(): @@ -91,7 +105,7 @@ def generate_imgs_embd(name, folder, batch_size): path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt") torch.save(embs, path) - model = model.cpu() + model.cpu() del processor del embs gc.collect() @@ -132,7 +146,7 @@ class AestheticCLIP: self.image_embs = None self.load_image_embs(None) - def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, + def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None, aesthetic_slerp=True, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False): @@ -145,6 +159,18 @@ class AestheticCLIP: self.aesthetic_steps = aesthetic_steps self.load_image_embs(image_embs_name) + if self.image_embs_name is not None: + p.extra_generation_params.update({ + "Aesthetic LR": aesthetic_lr, + "Aesthetic weight": aesthetic_weight, + "Aesthetic steps": aesthetic_steps, + "Aesthetic embedding": self.image_embs_name, + "Aesthetic slerp": aesthetic_slerp, + "Aesthetic text": aesthetic_imgs_text, + "Aesthetic text negative": aesthetic_text_negative, + "Aesthetic slerp angle": aesthetic_slerp_angle, + }) + def set_skip(self, skip): self.skip = skip @@ -168,7 +194,7 @@ class AestheticCLIP: tokens = torch.asarray(remade_batch_tokens).to(device) - model = copy.deepcopy(shared.clip_model).to(device) + model = copy.deepcopy(aesthetic_clip()).to(device) model.requires_grad_(True) if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0: text_embs_2 = model.get_text_features( diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 0f041449..f73647da 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -4,13 +4,22 @@ import gradio as gr from modules.shared import script_path from modules import shared -re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)" +re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)' re_param = re.compile(re_param_code) re_params = re.compile(r"^(?:" + re_param_code + "){3,}$") re_imagesize = re.compile(r"^(\d+)x(\d+)$") type_of_gr_update = type(gr.update()) +def quote(text): + if ',' not in str(text): + return text + + text = str(text) + text = text.replace('\\', '\\\\') + text = text.replace('"', '\\"') + return f'"{text}"' + def parse_generation_parameters(x: str): """parses generation parameters string, the one you see in text field under the picture in UI: ``` @@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None): else: try: valtype = type(output.value) - val = valtype(v) + + if valtype == bool and v == "False": + val = False + else: + val = valtype(v) + res.append(gr.update(value=val)) except Exception: res.append(gr.update()) diff --git a/modules/img2img.py b/modules/img2img.py index bc7c66bc..eea5199b 100644 --- a/modules/img2img.py +++ b/modules/img2img.py @@ -109,10 +109,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro inpainting_mask_invert=inpainting_mask_invert, ) - shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), - aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, - aesthetic_slerp_angle, - aesthetic_text_negative) + shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) if shared.cmd_opts.enable_console_prompts: print(f"\nimg2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/processing.py b/modules/processing.py index d1deffa9..f0852cd5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -12,7 +12,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional import modules.sd_hijack -from modules import devices, prompt_parser, masking, sd_samplers, lowvram +from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste from modules.sd_hijack import model_hijack from modules.shared import opts, cmd_opts, state import modules.shared as shared @@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration generation_params.update(p.extra_generation_params) - generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None]) + generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None]) negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else "" diff --git a/modules/sd_models.py b/modules/sd_models.py index 05a1df28..b1c91b0d 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -234,9 +234,6 @@ def load_model(checkpoint_info=None): sd_hijack.model_hijack.hijack(sd_model) - if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path: - shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path) - sd_model.eval() print(f"Model loaded.") diff --git a/modules/txt2img.py b/modules/txt2img.py index 32ed1d8d..1761cfa2 100644 --- a/modules/txt2img.py +++ b/modules/txt2img.py @@ -36,9 +36,7 @@ def txt2img(prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: firstphase_height=firstphase_height if enable_hr else None, ) - shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), - aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, - aesthetic_text_negative) + shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative) if cmd_opts.enable_console_prompts: print(f"\ntxt2img: {prompt}", file=shared.progress_print_out) diff --git a/modules/ui.py b/modules/ui.py index 381ca925..0d020de6 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -597,27 +597,29 @@ def apply_setting(key, value): return value -def create_ui(wrap_gradio_gpu_call): - import modules.img2img - import modules.txt2img +def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): + def refresh(): + refresh_method() + args = refreshed_args() if callable(refreshed_args) else refreshed_args - def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id): - def refresh(): - refresh_method() - args = refreshed_args() if callable(refreshed_args) else refreshed_args + for k, v in args.items(): + setattr(refresh_component, k, v) - for k, v in args.items(): - setattr(refresh_component, k, v) + return gr.update(**(args or {})) - return gr.update(**(args or {})) + refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) + refresh_button.click( + fn=refresh, + inputs=[], + outputs=[refresh_component] + ) + return refresh_button + + +def create_ui(wrap_gradio_gpu_call): + import modules.img2img + import modules.txt2img - refresh_button = gr.Button(value=refresh_symbol, elem_id=elem_id) - refresh_button.click( - fn = refresh, - inputs = [], - outputs = [refresh_component] - ) - return refresh_button with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) @@ -802,6 +804,14 @@ def create_ui(wrap_gradio_gpu_call): (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)), (firstphase_width, "First pass size-1"), (firstphase_height, "First pass size-2"), + (aesthetic_lr, "Aesthetic LR"), + (aesthetic_weight, "Aesthetic weight"), + (aesthetic_steps, "Aesthetic steps"), + (aesthetic_imgs, "Aesthetic embedding"), + (aesthetic_slerp, "Aesthetic slerp"), + (aesthetic_imgs_text, "Aesthetic text"), + (aesthetic_text_negative, "Aesthetic text negative"), + (aesthetic_slerp_angle, "Aesthetic slerp angle"), ] txt2img_preview_params = [ @@ -1077,6 +1087,14 @@ def create_ui(wrap_gradio_gpu_call): (seed_resize_from_w, "Seed resize from-1"), (seed_resize_from_h, "Seed resize from-2"), (denoising_strength, "Denoising strength"), + (aesthetic_lr_im, "Aesthetic LR"), + (aesthetic_weight_im, "Aesthetic weight"), + (aesthetic_steps_im, "Aesthetic steps"), + (aesthetic_imgs_im, "Aesthetic embedding"), + (aesthetic_slerp_im, "Aesthetic slerp"), + (aesthetic_imgs_text_im, "Aesthetic text"), + (aesthetic_text_negative_im, "Aesthetic text negative"), + (aesthetic_slerp_angle_im, "Aesthetic slerp angle"), ] token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter]) diff --git a/style.css b/style.css index 26ae36a5..5d2bacc9 100644 --- a/style.css +++ b/style.css @@ -477,7 +477,7 @@ input[type="range"]{ padding: 0; } -#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization{ +#refresh_sd_model_checkpoint, #refresh_sd_hypernetwork, #refresh_train_hypernetwork_name, #refresh_train_embedding_name, #refresh_localization, #refresh_aesthetic_embeddings{ max-width: 2.5em; min-width: 2.5em; height: 2.4em; -- cgit v1.2.3 From ac0aa2b18efeeb9220a5994c8dd54c7cdda7cc40 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 21 Oct 2022 17:35:51 +0300 Subject: loading SD VAE, see PR #3303 --- modules/sd_models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) (limited to 'modules/sd_models.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index b1c91b0d..d99dbce8 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -155,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd +vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + + def load_model_weights(model, checkpoint_info): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash @@ -186,7 +189,7 @@ def load_model_weights(model, checkpoint_info): if os.path.exists(vae_file): print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) - vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"} + vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} model.first_stage_model.load_state_dict(vae_dict) model.first_stage_model.to(devices.dtype_vae) -- cgit v1.2.3