diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-01-04 14:58:07 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-01-04 14:58:07 +0000 |
commit | 525cea924562afd676f55470095268a0f6fca59e (patch) | |
tree | 535c3c054cc111197649f0b56e41f488efce4478 /modules/processing.py | |
parent | 184e670126f5fc50ba56fa0fedcf0cf60e45ed7e (diff) | |
download | stable-diffusion-webui-gfx803-525cea924562afd676f55470095268a0f6fca59e.tar.gz stable-diffusion-webui-gfx803-525cea924562afd676f55470095268a0f6fca59e.tar.bz2 stable-diffusion-webui-gfx803-525cea924562afd676f55470095268a0f6fca59e.zip |
use shared function from processing for creating dummy mask when training inpainting model
Diffstat (limited to 'modules/processing.py')
-rw-r--r-- | modules/processing.py | 39 |
1 files changed, 20 insertions, 19 deletions
diff --git a/modules/processing.py b/modules/processing.py index c03e77e7..c7264aff 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -76,6 +76,24 @@ def apply_overlay(image, paste_loc, index, overlays): return image
+def txt2img_image_conditioning(sd_model, x, width, height):
+ if sd_model.model.conditioning_key not in {'hybrid', 'concat'}:
+ # Dummy zero conditioning if we're not using inpainting model.
+ # Still takes up a bit of memory, but no encoder call.
+ # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
+ return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
+
+ # The "masked-image" in this case will just be all zeros since the entire image is masked.
+ image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
+ image_conditioning = sd_model.get_first_stage_encoding(sd_model.encode_first_stage(image_conditioning))
+
+ # Add the fake full 1s mask to the first dimension.
+ image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
+ image_conditioning = image_conditioning.to(x.dtype)
+
+ return image_conditioning
+
+
class StableDiffusionProcessing():
"""
The first set of paramaters: sd_models -> do_not_reload_embeddings represent the minimum required to create a StableDiffusionProcessing
@@ -139,26 +157,9 @@ class StableDiffusionProcessing(): self.iteration = 0
def txt2img_image_conditioning(self, x, width=None, height=None):
- if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
- # Dummy zero conditioning if we're not using inpainting model.
- # Still takes up a bit of memory, but no encoder call.
- # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return x.new_zeros(x.shape[0], 5, 1, 1)
+ self.is_using_inpainting_conditioning = self.sd_model.model.conditioning_key in {'hybrid', 'concat'}
- self.is_using_inpainting_conditioning = True
-
- height = height or self.height
- width = width or self.width
-
- # The "masked-image" in this case will just be all zeros since the entire image is masked.
- image_conditioning = torch.zeros(x.shape[0], 3, height, width, device=x.device)
- image_conditioning = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image_conditioning))
-
- # Add the fake full 1s mask to the first dimension.
- image_conditioning = torch.nn.functional.pad(image_conditioning, (0, 0, 0, 0, 1, 0), value=1.0)
- image_conditioning = image_conditioning.to(x.dtype)
-
- return image_conditioning
+ return txt2img_image_conditioning(self.sd_model, x, width or self.width, height or self.height)
def depth2img_image_conditioning(self, source_image):
# Use the AddMiDaS helper to Format our source image to suit the MiDaS model
|