diff options
Diffstat (limited to 'modules/processing.py')
-rw-r--r-- | modules/processing.py | 73 |
1 files changed, 68 insertions, 5 deletions
diff --git a/modules/processing.py b/modules/processing.py index 59717b4c..0a46174c 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -247,7 +247,7 @@ class StableDiffusionProcessing: def init(self, all_prompts, all_seeds, all_subseeds):
pass
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts, hr_conditioning=None, hr_unconditional_conditioning=None):
raise NotImplementedError()
def close(self):
@@ -527,6 +527,20 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: else:
p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
+ if type(p) == StableDiffusionProcessingTxt2Img:
+ if p.enable_hr and p.hr_prompt == '':
+ p.all_hr_prompts, p.all_hr_negative_prompts = p.all_prompts, p.all_negative_prompts
+ elif p.enable_hr and p.hr_prompt != '':
+ if type(p.prompt) == list:
+ p.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.hr_prompt]
+ else:
+ p.all_hr_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.hr_prompt, p.styles)]
+
+ if type(p.negative_prompt) == list:
+ p.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.hr_negative_prompt]
+ else:
+ p.all_hr_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.hr_negative_prompt, p.styles)]
+
if type(seed) == list:
p.all_seeds = seed
else:
@@ -595,6 +609,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+
+ if type(p) == StableDiffusionProcessingTxt2Img:
+ if p.enable_hr:
+ if p.hr_prompt == '':
+ hr_prompts, hr_negative_prompts = prompts, negative_prompts
+ else:
+ hr_prompts = p.all_hr_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ hr_negative_prompts = p.all_hr_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+
seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
@@ -606,6 +629,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: prompts, extra_network_data = extra_networks.parse_prompts(prompts)
+ if type(p) == StableDiffusionProcessingTxt2Img:
+ if p.enable_hr and hr_prompts != prompts:
+ _, hr_extra_network_data = extra_networks.parse_prompts(hr_prompts)
+ extra_network_data.update(hr_extra_network_data)
+
+
if not p.disable_extra_networks:
with devices.autocast():
extra_networks.activate(p, extra_network_data)
@@ -625,6 +654,14 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps, cached_uc)
c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps, cached_c)
+ if type(p) == StableDiffusionProcessingTxt2Img:
+ if p.enable_hr:
+ if prompts != hr_prompts:
+ hr_uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, p.steps, cached_uc)
+ hr_c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, p.steps, cached_c)
+ else:
+ hr_uc, hr_c = uc, c
+
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
comments[comment] = 1
@@ -632,8 +669,15 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
+
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
+ if type(p) == StableDiffusionProcessingTxt2Img:
+ if p.enable_hr:
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, hr_conditioning=hr_c, hr_unconditional_conditioning=hr_uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
+ else:
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
+ else:
+ samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
for x in x_samples_ddim:
@@ -741,7 +785,7 @@ def old_hires_fix_first_pass_dimensions(width, height): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler: str = '---', hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
@@ -752,6 +796,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_resize_y = hr_resize_y
self.hr_upscale_to_x = hr_resize_x
self.hr_upscale_to_y = hr_resize_y
+ self.hr_sampler = hr_sampler
+ self.hr_prompt = hr_prompt if hr_prompt != '' else ''
+ self.hr_negative_prompt = hr_negative_prompt if hr_negative_prompt != '' else ''
+ self.all_hr_prompts = None
+ self.all_hr_negative_prompts = None
if firstphase_width != 0 or firstphase_height != 0:
self.hr_upscale_to_x = self.width
@@ -765,6 +814,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
+ if self.hr_sampler != '---':
+ self.extra_generation_params["Hires sampler"] = self.hr_sampler
+
+ if self.hr_prompt != '':
+ self.extra_generation_params["Hires prompt"] = f'({self.hr_prompt.replace(",", ";")})'
+ self.extra_generation_params["Hires negative prompt"] = f'({self.hr_negative_prompt.replace(",", ";")})'
+
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
self.hr_resize_x = self.width
self.hr_resize_y = self.height
@@ -825,7 +881,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.hr_upscaler is not None:
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts, hr_conditioning=None, hr_unconditional_conditioning=None):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
@@ -893,8 +949,15 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob()
img2img_sampler_name = self.sampler_name
+
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
img2img_sampler_name = 'DDIM'
+
+ if self.hr_sampler == '---':
+ pass
+ else:
+ img2img_sampler_name = self.hr_sampler
+
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
@@ -905,7 +968,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None
devices.torch_gc()
- samples = self.sampler.sample_img2img(self, samples, noise, conditioning, unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+ samples = self.sampler.sample_img2img(self, samples, noise, hr_conditioning, hr_unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
return samples
|