diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-30 12:12:09 +0000 |
---|---|---|
committer | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-30 12:12:09 +0000 |
commit | a64fbe89288802f8b5ec8ca7bcab5aaf2c7bfea5 (patch) | |
tree | ea36cfc5843df7e3bb933e06d9397321b5d25d17 | |
parent | eec540b22798ddcf8a03d947519c36635d77d722 (diff) | |
download | stable-diffusion-webui-gfx803-a64fbe89288802f8b5ec8ca7bcab5aaf2c7bfea5.tar.gz stable-diffusion-webui-gfx803-a64fbe89288802f8b5ec8ca7bcab5aaf2c7bfea5.tar.bz2 stable-diffusion-webui-gfx803-a64fbe89288802f8b5ec8ca7bcab5aaf2c7bfea5.zip |
make it possible to use checkpoints of different types (SD1, SDXL) in first and second pass of hires fix
-rw-r--r-- | modules/processing.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/modules/processing.py b/modules/processing.py index 6fb14516..c4da208f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -1060,16 +1060,21 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if not self.enable_hr:
return samples
+ if self.latent_scale_mode is None:
+ decoded_samples = decode_first_stage(self.sd_model, samples)
+ else:
+ decoded_samples = None
+
current = shared.sd_model.sd_checkpoint_info
try:
if self.hr_checkpoint_info is not None:
sd_models.reload_model_weights(info=self.hr_checkpoint_info)
- return self.sample_hr_pass(samples, seeds, subseeds, subseed_strength, prompts)
+ return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)
finally:
sd_models.reload_model_weights(info=current)
- def sample_hr_pass(self, samples, seeds, subseeds, subseed_strength, prompts):
+ def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
self.is_hr_pass = True
target_width = self.hr_upscale_to_x
@@ -1100,7 +1105,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): else:
image_conditioning = self.txt2img_image_conditioning(samples)
else:
- decoded_samples = decode_first_stage(self.sd_model, samples)
lowres_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
batch_images = []
|