From 1ed4f0e22807f3afef925210182cbbee51f0cb2c Mon Sep 17 00:00:00 2001 From: Jay Smith Date: Thu, 8 Dec 2022 18:14:35 -0600 Subject: Depth2img model support --- modules/processing.py | 38 ++++++++++++++++++++++++++++++++++---- 1 file changed, 34 insertions(+), 4 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 3d2c4dc9..0417ffc5 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -21,7 +21,10 @@ import modules.face_restoration import modules.images as images import modules.styles import logging +from ldm.data.util import AddMiDaS +from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion +from einops import repeat, rearrange # some of those options should not be changed at all because they would break the model, so I removed them from options. opt_C = 4 @@ -150,11 +153,26 @@ class StableDiffusionProcessing(): return image_conditioning - def img2img_image_conditioning(self, source_image, latent_image, image_mask = None): - if self.sampler.conditioning_key not in {'hybrid', 'concat'}: - # Dummy zero conditioning if we're not using inpainting model. - return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) + def depth2img_image_conditioning(self, source_image): + # Use the AddMiDaS helper to Format our source image to suit the MiDaS model + transformer = AddMiDaS(model_type="dpt_hybrid") + transformed = transformer({"jpg": rearrange(source_image[0], "c h w -> h w c")}) + midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device) + midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size) + + conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image)) + conditioning = torch.nn.functional.interpolate( + self.sd_model.depth_model(midas_in), + size=conditioning_image.shape[2:], + mode="bicubic", + align_corners=False, + ) + + (depth_min, depth_max) = torch.aminmax(conditioning) + conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1. + return conditioning + def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None): self.is_using_inpainting_conditioning = True # Handle the different mask inputs @@ -191,6 +209,18 @@ class StableDiffusionProcessing(): return image_conditioning + def img2img_image_conditioning(self, source_image, latent_image, image_mask=None): + # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely + # identify itself with a field common to all models. The conditioning_key is also hybrid. + if isinstance(self.sd_model, LatentDepth2ImageDiffusion): + return self.depth2img_image_conditioning(source_image) + + if self.sampler.conditioning_key in {'hybrid', 'concat'}: + return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask) + + # Dummy zero conditioning if we're not using inpainting or depth model. + return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1) + def init(self, all_prompts, all_seeds, all_subseeds): pass -- cgit v1.2.3