diff options
author | AUTOMATIC <16777216c@gmail.com> | 2023-05-18 17:16:09 +0000 |
---|---|---|
committer | AUTOMATIC <16777216c@gmail.com> | 2023-05-18 17:16:09 +0000 |
commit | ff0e17174f8d93a71fdd5a4a80a4629bbf97f822 (patch) | |
tree | 3a4d30e009bfbe0ab86dc0dead8a53933e876394 /modules/processing.py | |
parent | 5ec2c294ee800fc360f6883340af8b30df850322 (diff) | |
download | stable-diffusion-webui-gfx803-ff0e17174f8d93a71fdd5a4a80a4629bbf97f822.tar.gz stable-diffusion-webui-gfx803-ff0e17174f8d93a71fdd5a4a80a4629bbf97f822.tar.bz2 stable-diffusion-webui-gfx803-ff0e17174f8d93a71fdd5a4a80a4629bbf97f822.zip |
rework hires prompts/sampler code to among other things support different extra networks in first/second pass
rework quoting for infotext items that have commas in them to use json (should be backwards compatible except for cases where it didn't work previously)
add some locals from processing function into the Processing class as fields
Diffstat (limited to 'modules/processing.py')
-rw-r--r-- | modules/processing.py | 261 |
1 files changed, 149 insertions, 112 deletions
diff --git a/modules/processing.py b/modules/processing.py index dd14c486..29a3743f 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -169,6 +169,16 @@ class StableDiffusionProcessing: self.is_hr_pass = False
self.sampler = None
+ self.prompts = None
+ self.negative_prompts = None
+ self.seeds = None
+ self.subseeds = None
+
+ self.step_multiplier = 1
+ self.cached_uc = [None, None]
+ self.cached_c = [None, None]
+ self.uc = None
+ self.c = None
@property
def sd_model(self):
@@ -271,11 +281,15 @@ class StableDiffusionProcessing: def init(self, all_prompts, all_seeds, all_subseeds):
pass
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts, hr_conditioning=None, hr_unconditional_conditioning=None):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
raise NotImplementedError()
def close(self):
self.sampler = None
+ self.c = None
+ self.uc = None
+ self.cached_c = [None, None]
+ self.cached_uc = [None, None]
def get_token_merging_ratio(self, for_hr=False):
if for_hr:
@@ -283,6 +297,52 @@ class StableDiffusionProcessing: return self.token_merging_ratio or opts.token_merging_ratio
+ def setup_prompts(self):
+ if type(self.prompt) == list:
+ self.all_prompts = self.prompt
+ else:
+ self.all_prompts = self.batch_size * self.n_iter * [self.prompt]
+
+ if type(self.negative_prompt) == list:
+ self.all_negative_prompts = self.negative_prompt
+ else:
+ self.all_negative_prompts = self.batch_size * self.n_iter * [self.negative_prompt]
+
+ self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts]
+ self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts]
+
+ def get_conds_with_caching(self, function, required_prompts, steps, cache):
+ """
+ Returns the result of calling function(shared.sd_model, required_prompts, steps)
+ using a cache to store the result if the same arguments have been used before.
+
+ cache is an array containing two elements. The first element is a tuple
+ representing the previously used arguments, or None if no arguments
+ have been used before. The second element is where the previously
+ computed result is stored.
+ """
+
+ if cache[0] is not None and (required_prompts, steps) == cache[0]:
+ return cache[1]
+
+ with devices.autocast():
+ cache[1] = function(shared.sd_model, required_prompts, steps)
+
+ cache[0] = (required_prompts, steps)
+ return cache[1]
+
+ def setup_conds(self):
+ sampler_config = sd_samplers.find_sampler_config(self.sampler_name)
+ self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
+
+ self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
+ self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.prompts, self.steps * self.step_multiplier, self.cached_c)
+
+ def parse_extra_network_prompts(self):
+ self.prompts, extra_network_data = extra_networks.parse_prompts(self.prompts)
+
+ return extra_network_data
+
class Processed:
def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
@@ -582,29 +642,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: comments = {}
- if type(p.prompt) == list:
- p.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.prompt]
- else:
- p.all_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.prompt, p.styles)]
-
- if type(p.negative_prompt) == list:
- p.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.negative_prompt]
- else:
- p.all_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.negative_prompt, p.styles)]
-
- if type(p) == StableDiffusionProcessingTxt2Img:
- if p.enable_hr and p.hr_prompt == '':
- p.all_hr_prompts, p.all_hr_negative_prompts = p.all_prompts, p.all_negative_prompts
- elif p.enable_hr and p.hr_prompt != '':
- if type(p.prompt) == list:
- p.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, p.styles) for x in p.hr_prompt]
- else:
- p.all_hr_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_styles_to_prompt(p.hr_prompt, p.styles)]
-
- if type(p.negative_prompt) == list:
- p.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, p.styles) for x in p.hr_negative_prompt]
- else:
- p.all_hr_negative_prompts = p.batch_size * p.n_iter * [shared.prompt_styles.apply_negative_styles_to_prompt(p.hr_negative_prompt, p.styles)]
+ p.setup_prompts()
if type(seed) == list:
p.all_seeds = seed
@@ -628,29 +666,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: infotexts = []
output_images = []
- cached_uc = [None, None]
- cached_c = [None, None]
-
- def get_conds_with_caching(function, required_prompts, steps, cache):
- """
- Returns the result of calling function(shared.sd_model, required_prompts, steps)
- using a cache to store the result if the same arguments have been used before.
-
- cache is an array containing two elements. The first element is a tuple
- representing the previously used arguments, or None if no arguments
- have been used before. The second element is where the previously
- computed result is stored.
- """
-
- if cache[0] is not None and (required_prompts, steps) == cache[0]:
- return cache[1]
-
- with devices.autocast():
- cache[1] = function(shared.sd_model, required_prompts, steps)
-
- cache[0] = (required_prompts, steps)
- return cache[1]
-
with torch.no_grad(), p.sd_model.ema_scope():
with devices.autocast():
p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -672,40 +687,25 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if state.interrupted:
break
- prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
-
- if type(p) == StableDiffusionProcessingTxt2Img:
- if p.enable_hr:
- if p.hr_prompt == '':
- hr_prompts, hr_negative_prompts = prompts, negative_prompts
- else:
- hr_prompts = p.all_hr_prompts[n * p.batch_size:(n + 1) * p.batch_size]
- hr_negative_prompts = p.all_hr_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
-
- seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
- subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
+ p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
+ p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size]
if p.scripts is not None:
- p.scripts.before_process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+ p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
- if len(prompts) == 0:
+ if len(p.prompts) == 0:
break
- prompts, extra_network_data = extra_networks.parse_prompts(prompts)
-
- if type(p) == StableDiffusionProcessingTxt2Img:
- if p.enable_hr and hr_prompts != prompts:
- _, hr_extra_network_data = extra_networks.parse_prompts(hr_prompts)
- extra_network_data.update(hr_extra_network_data)
-
+ extra_network_data = p.parse_extra_network_prompts()
if not p.disable_extra_networks:
with devices.autocast():
extra_networks.activate(p, extra_network_data)
if p.scripts is not None:
- p.scripts.process_batch(p, batch_number=n, prompts=prompts, seeds=seeds, subseeds=subseeds)
+ p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds)
# params.txt should be saved after scripts.process_batch, since the
# infotext could be modified by that callback
@@ -716,18 +716,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: processed = Processed(p, [], p.seed, "")
file.write(processed.infotext(p, 0))
- sampler_config = sd_samplers.find_sampler_config(p.sampler_name)
- step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1
- uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, p.steps * step_multiplier, cached_uc)
- c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, p.steps * step_multiplier, cached_c)
-
- if type(p) == StableDiffusionProcessingTxt2Img:
- if p.enable_hr:
- if prompts != hr_prompts:
- hr_uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, p.steps, cached_uc)
- hr_c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, p.steps, cached_c)
- else:
- hr_uc, hr_c = uc, c
+ p.setup_conds()
if len(model_hijack.comments) > 0:
for comment in model_hijack.comments:
@@ -736,15 +725,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.n_iter > 1:
shared.state.job = f"Batch {n+1} out of {p.n_iter}"
-
with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
- if type(p) == StableDiffusionProcessingTxt2Img:
- if p.enable_hr:
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, hr_conditioning=hr_c, hr_unconditional_conditioning=hr_uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
- else:
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
- else:
- samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
+ samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)
x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
for x in x_samples_ddim:
@@ -771,7 +753,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.restore_faces:
if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
- images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
+ images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")
devices.torch_gc()
@@ -788,13 +770,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.color_corrections is not None and i < len(p.color_corrections):
if opts.save and not p.do_not_save_samples and opts.save_images_before_color_correction:
image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
- images.save_image(image_without_cc, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
+ images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-color-correction")
image = apply_color_correction(p.color_corrections[i], image)
image = apply_overlay(image, p.paste_to, i, p.overlay_images)
if opts.samples_save and not p.do_not_save_samples:
- images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p)
+ images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p)
text = infotext(n, i)
infotexts.append(text)
@@ -807,10 +789,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
if opts.save_mask:
- images.save_image(image_mask, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
+ images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask")
if opts.save_mask_composite:
- images.save_image(image_mask_composite, p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
+ images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-mask-composite")
if opts.return_mask:
output_images.append(image_mask)
@@ -879,7 +861,7 @@ def old_hires_fix_first_pass_dimensions(width, height): class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
sampler = None
- def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler: str = '---', hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
+ def __init__(self, enable_hr: bool = False, denoising_strength: float = 0.75, firstphase_width: int = 0, firstphase_height: int = 0, hr_scale: float = 2.0, hr_upscaler: str = None, hr_second_pass_steps: int = 0, hr_resize_x: int = 0, hr_resize_y: int = 0, hr_sampler_name: str = None, hr_prompt: str = '', hr_negative_prompt: str = '', **kwargs):
super().__init__(**kwargs)
self.enable_hr = enable_hr
self.denoising_strength = denoising_strength
@@ -890,9 +872,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_resize_y = hr_resize_y
self.hr_upscale_to_x = hr_resize_x
self.hr_upscale_to_y = hr_resize_y
- self.hr_sampler = hr_sampler
- self.hr_prompt = hr_prompt if hr_prompt != '' else ''
- self.hr_negative_prompt = hr_negative_prompt if hr_negative_prompt != '' else ''
+ self.hr_sampler_name = hr_sampler_name
+ self.hr_prompt = hr_prompt
+ self.hr_negative_prompt = hr_negative_prompt
self.all_hr_prompts = None
self.all_hr_negative_prompts = None
@@ -906,14 +888,23 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.truncate_y = 0
self.applied_old_hires_behavior_to = None
+ self.hr_prompts = None
+ self.hr_negative_prompts = None
+ self.hr_extra_network_data = None
+
+ self.hr_c = None
+ self.hr_uc = None
+
def init(self, all_prompts, all_seeds, all_subseeds):
if self.enable_hr:
- if self.hr_sampler != '---':
- self.extra_generation_params["Hires sampler"] = self.hr_sampler
+ if self.hr_sampler_name is not None and self.hr_sampler_name != self.sampler_name:
+ self.extra_generation_params["Hires sampler"] = self.hr_sampler_name
+
+ if tuple(self.hr_prompt) != tuple(self.prompt):
+ self.extra_generation_params["Hires prompt"] = self.hr_prompt
- if self.hr_prompt != '':
- self.extra_generation_params["Hires prompt"] = f'({self.hr_prompt.replace(",", ";")})'
- self.extra_generation_params["Hires negative prompt"] = f'({self.hr_negative_prompt.replace(",", ";")})'
+ if tuple(self.hr_negative_prompt) != tuple(self.negative_prompt):
+ self.extra_generation_params["Hires negative prompt"] = self.hr_negative_prompt
if opts.use_old_hires_fix_width_height and self.applied_old_hires_behavior_to != (self.width, self.height):
self.hr_resize_x = self.width
@@ -975,7 +966,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): if self.hr_upscaler is not None:
self.extra_generation_params["Hires upscaler"] = self.hr_upscaler
- def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts, hr_conditioning=None, hr_unconditional_conditioning=None):
+ def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
latent_scale_mode = shared.latent_upscale_modes.get(self.hr_upscaler, None) if self.hr_upscaler is not None else shared.latent_upscale_modes.get(shared.latent_upscale_default_mode, "nearest")
@@ -1044,16 +1035,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): shared.state.nextjob()
- img2img_sampler_name = self.sampler_name
+ img2img_sampler_name = self.hr_sampler_name or self.sampler_name
if self.sampler_name in ['PLMS', 'UniPC']: # PLMS/UniPC do not support img2img so we just silently switch to DDIM
img2img_sampler_name = 'DDIM'
- if self.hr_sampler == '---':
- pass
- else:
- img2img_sampler_name = self.hr_sampler
-
self.sampler = sd_samplers.create_sampler(img2img_sampler_name, self.sd_model)
samples = samples[:, :, self.truncate_y//2:samples.shape[2]-(self.truncate_y+1)//2, self.truncate_x//2:samples.shape[3]-(self.truncate_x+1)//2]
@@ -1064,9 +1050,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): x = None
devices.torch_gc()
+ if not self.disable_extra_networks:
+ with devices.autocast():
+ extra_networks.activate(self, self.hr_extra_network_data)
+
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
- samples = self.sampler.sample_img2img(self, samples, noise, hr_conditioning, hr_unconditional_conditioning, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
+ samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())
@@ -1074,6 +1064,53 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples
+ def close(self):
+ self.hr_c = None
+ self.hr_uc = None
+
+ def setup_prompts(self):
+ super().setup_prompts()
+
+ if not self.enable_hr:
+ return
+
+ if self.hr_prompt == '':
+ self.hr_prompt = self.prompt
+
+ if self.hr_negative_prompt == '':
+ self.hr_negative_prompt = self.negative_prompt
+
+ if type(self.hr_prompt) == list:
+ self.all_hr_prompts = self.hr_prompt
+ else:
+ self.all_hr_prompts = self.batch_size * self.n_iter * [self.hr_prompt]
+
+ if type(self.hr_negative_prompt) == list:
+ self.all_hr_negative_prompts = self.hr_negative_prompt
+ else:
+ self.all_hr_negative_prompts = self.batch_size * self.n_iter * [self.hr_negative_prompt]
+
+ self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
+ self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
+
+ def setup_conds(self):
+ super().setup_conds()
+
+ if self.enable_hr:
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c)
+
+ def parse_extra_network_prompts(self):
+ res = super().parse_extra_network_prompts()
+
+ if self.enable_hr:
+ self.hr_prompts = self.all_hr_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
+ self.hr_negative_prompts = self.all_hr_negative_prompts[self.iteration * self.batch_size:(self.iteration + 1) * self.batch_size]
+
+ self.hr_prompts, self.hr_extra_network_data = extra_networks.parse_prompts(self.hr_prompts)
+
+ return res
+
class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
sampler = None
|