diff options
author | Spaceginner <ivan.demian2009@gmail.com> | 2023-01-27 12:35:54 +0000 |
---|---|---|
committer | Spaceginner <ivan.demian2009@gmail.com> | 2023-01-27 12:35:54 +0000 |
commit | 56c83e453a2ac333a0888ab3835ad4c82feacc25 (patch) | |
tree | bf7090e3b8faf0ab02e3fe5bd43ac1cde2dc62dc /modules/processing.py | |
parent | 9ecf1e827c5966e11495a0c066a127defbba9bcc (diff) | |
parent | 63391419c11c1749a3d83dade19235a836c509f9 (diff) | |
download | stable-diffusion-webui-gfx803-56c83e453a2ac333a0888ab3835ad4c82feacc25.tar.gz stable-diffusion-webui-gfx803-56c83e453a2ac333a0888ab3835ad4c82feacc25.tar.bz2 stable-diffusion-webui-gfx803-56c83e453a2ac333a0888ab3835ad4c82feacc25.zip |
Merge remote-tracking branch 'origin/master'
Diffstat (limited to 'modules/processing.py')
-rw-r--r-- | modules/processing.py | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/modules/processing.py b/modules/processing.py index 9e5a2f38..262806a1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -13,7 +13,7 @@ from skimage import exposure from typing import Any, Dict, List, Optional
import modules.sd_hijack
-from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx
+from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks, extra_networks, sd_vae_approx, scripts
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -172,7 +172,7 @@ class StableDiffusionProcessing: midas_in = torch.from_numpy(transformed["midas_in"][None, ...]).to(device=shared.device)
midas_in = repeat(midas_in, "1 ... -> n ...", n=self.batch_size)
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_unet) if devices.unet_needs_upcast else source_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image.to(devices.dtype_vae) if devices.unet_needs_upcast else source_image))
conditioning_image = conditioning_image.float() if devices.unet_needs_upcast else conditioning_image
conditioning = torch.nn.functional.interpolate(
self.sd_model.depth_model(midas_in),
@@ -185,7 +185,12 @@ class StableDiffusionProcessing: conditioning = 2. * (conditioning - depth_min) / (depth_max - depth_min) - 1.
return conditioning
- def inpainting_image_conditioning(self, source_image, latent_image, image_mask = None):
+ def edit_image_conditioning(self, source_image):
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(source_image))
+
+ return conditioning_image
+
+ def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
self.is_using_inpainting_conditioning = True
# Handle the different mask inputs
@@ -212,7 +217,7 @@ class StableDiffusionProcessing: )
# Encode the new masked image using first stage of network.
- conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_unet) if devices.unet_needs_upcast else conditioning_image))
+ conditioning_image = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(conditioning_image.to(devices.dtype_vae) if devices.unet_needs_upcast else conditioning_image))
# Create the concatenated conditioning tensor to be fed to `c_concat`
conditioning_mask = torch.nn.functional.interpolate(conditioning_mask, size=latent_image.shape[-2:])
@@ -228,6 +233,9 @@ class StableDiffusionProcessing: if isinstance(self.sd_model, LatentDepth2ImageDiffusion):
return self.depth2img_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image)
+ if self.sd_model.cond_stage_key == "edit":
+ return self.edit_image_conditioning(source_image)
+
if self.sampler.conditioning_key in {'hybrid', 'concat'}:
return self.inpainting_image_conditioning(source_image.float() if devices.unet_needs_upcast else source_image, latent_image, image_mask=image_mask)
@@ -409,7 +417,7 @@ def create_random_tensors(shape, seeds, subseeds=None, subseed_strength=0.0, see def decode_first_stage(model, x):
with devices.autocast(disable=x.dtype == devices.dtype_vae):
- x = model.decode_first_stage(x)
+ x = model.decode_first_stage(x.to(devices.dtype_vae) if devices.unet_needs_upcast else x)
return x
@@ -650,6 +658,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image = Image.fromarray(x_sample)
+ if p.scripts is not None:
+ pp = scripts.PostprocessImageArgs(image)
+ p.scripts.postprocess_image(p, pp)
+ image = pp.image
+
if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
@@ -993,7 +1006,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing): image = torch.from_numpy(batch_images)
image = 2. * image - 1.
- image = image.to(device=shared.device, dtype=devices.dtype_unet if devices.unet_needs_upcast else None)
+ image = image.to(device=shared.device, dtype=devices.dtype_vae if devices.unet_needs_upcast else None)
self.init_latent = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(image))
|