diff options
author | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-11-04 11:47:28 +0000 |
---|---|---|
committer | Muhammad Rizqi Nur <rizqinur2010@gmail.com> | 2022-11-04 11:47:28 +0000 |
commit | 3277f90e933485d2590a55998480d02f9499be5c (patch) | |
tree | ae1c4774a503ff0401a1589703f924fcbc59ee91 /modules/processing.py | |
parent | 31a98d0dc0a97640afa0611eb261ef9c3ba10532 (diff) | |
parent | 81973091bc07c706d056809d89221bafcd01b38a (diff) | |
download | stable-diffusion-webui-gfx803-3277f90e933485d2590a55998480d02f9499be5c.tar.gz stable-diffusion-webui-gfx803-3277f90e933485d2590a55998480d02f9499be5c.tar.bz2 stable-diffusion-webui-gfx803-3277f90e933485d2590a55998480d02f9499be5c.zip |
Merge branch 'master' into gradient-clipping
Diffstat (limited to 'modules/processing.py')
-rw-r--r-- | modules/processing.py | 38 |
1 files changed, 20 insertions, 18 deletions
diff --git a/modules/processing.py b/modules/processing.py index 3a364b5f..03c9143d 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -134,11 +134,7 @@ class StableDiffusionProcessing(): # Dummy zero conditioning if we're not using inpainting model.
# Still takes up a bit of memory, but no encoder call.
# Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
- return torch.zeros(
- x.shape[0], 5, 1, 1,
- dtype=x.dtype,
- device=x.device
- )
+ return x.new_zeros(x.shape[0], 5, 1, 1)
height = height or self.height
width = width or self.width
@@ -156,11 +152,7 @@ class StableDiffusionProcessing(): def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
# Dummy zero conditioning if we're not using inpainting model.
- return torch.zeros(
- latent_image.shape[0], 5, 1, 1,
- dtype=latent_image.dtype,
- device=latent_image.device
- )
+ return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)
# Handle the different mask inputs
if image_mask is not None:
@@ -174,11 +166,11 @@ class StableDiffusionProcessing(): # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
conditioning_mask = torch.round(conditioning_mask)
else:
- conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
+ conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])
# Create another latent image, this time with a masked version of the original input.
# Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
- conditioning_mask = conditioning_mask.to(source_image.device)
+ conditioning_mask = conditioning_mask.to(source_image.device).to(source_image.dtype)
conditioning_image = torch.lerp(
source_image,
source_image * (1.0 - conditioning_mask),
@@ -426,13 +418,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: try:
for k, v in p.override_settings.items():
- opts.data[k] = v # we don't call onchange for simplicity which makes changing model, hypernet impossible
+ setattr(opts, k, v) # we don't call onchange for simplicity which makes changing model, hypernet impossible
res = process_images_inner(p)
finally:
for k, v in stored_opts.items():
- opts.data[k] = v
+ setattr(opts, k, v)
return res
@@ -509,6 +501,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if len(prompts) == 0:
break
+ if p.scripts is not None:
+ p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+
with devices.autocast():
uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps)
c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps)
@@ -673,10 +668,17 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): images.save_image(image, self.outpath_samples, "", seeds[index], prompts[index], opts.samples_format, suffix="-before-highres-fix")
if opts.use_scale_latent_for_hires_fix:
- samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
-
for i in range(samples.shape[0]):
save_intermediate(samples, i)
+
+ samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
+
+ # Avoid making the inpainting conditioning unless necessary as
+ # this does need some extra compute to decode / encode the image again.
+ if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
+ image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
+ else:
+ image_conditioning = self.txt2img_image_conditioning(samples)
else:
decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
@@ -700,14 +702,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))
+ image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)
+
shared.state.nextjob()
self.sampler = sd_samplers.create_sampler_with_index(sd_samplers.samplers, self.sampler_index, self.sd_model)
noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
- image_conditioning = self.txt2img_image_conditioning(x)
-
# GC now before running the next img2img to prevent running out of memory
x = None
devices.torch_gc()
|