From 8a34671fe91e142bce9e5556cca2258b3be9dd6e Mon Sep 17 00:00:00 2001 From: MrCheeze Date: Fri, 24 Mar 2023 22:48:16 -0400 Subject: Add support for the Variations models (unclip-h and unclip-l) --- modules/processing.py | 41 +++++++++++++++++++++++++++++------------ 1 file changed, 29 insertions(+), 12 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 59717b4c..1451811c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -78,21 +78,27 @@ def apply_overlay(image, paste_loc, index, overlays): def txt2img_image_conditioning(sd_model, x, width, height): - if sd_model.model.conditioning_key not in {'hybrid', 'concat'}: - # Dummy zero conditioning if we're not using inpainting model. - # Still takes up a bit of memory, but no encoder call. - # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. - return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) + if sd_model.model.conditioning_key in {'hybrid', 'concat'}: # Inpainting models - # The "masked-image" in this case will just be all zeros since the entire image is masked. - image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) - image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) + # The "masked-image" in this case will just be all zeros since the entire image is masked. + image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device) + image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning)) - # Add the fake full 1s mask to the first dimension. - image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) - image_conditioning = image_conditioning.to(x.dtype) + # Add the fake full 1s mask to the first dimension. + image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0) + image_conditioning = image_conditioning.to(x.dtype) - return image_conditioning + return image_conditioning + + elif sd_model.model.conditioning_key == "crossattn-adm": # UnCLIP models + + return x.new_zeros(x.shape[0], 2*sd_model.noise_augmentor.time_embed.dim, dtype=x.dtype, device=x.device) + + else: + # Dummy zero conditioning if we're not using inpainting or unclip models. + # Still takes up a bit of memory, but no encoder call. + # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size. + return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device) class StableDiffusionProcessing: @@ -190,6 +196,14 @@ class StableDiffusionProcessing: return conditioning_image + def unclip_image_conditioning(self, source_image): + c_adm = self.sd_model.embedder(source_image) + if self.sd_model.noise_augmentor is not None: + noise_level = 0 # TODO: Allow other noise levels? + c_adm, noise_level_emb = self.sd_model.noise_augmentor(c_adm, noise_level=repeat(torch.tensor([noise_level]).to(c_adm.device), '1 -> b', b=c_adm.shape[0])) + c_adm = torch.cat((c_adm, noise_level_emb), 1) + return c_adm + def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None): self.is_using_inpainting_conditioning = True @@ -241,6 +255,9 @@ class StableDiffusionProcessing: if self.sampler.conditioning_key in {'hybrid', 'concat'}: return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + if self.sampler.conditioning_key == "crossattn-adm": + return self.unclip_image_conditioning(source_image) + # Dummy zero conditioning if we're not using inpainting or depth model. return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) -- cgit v1.2.3