From d1ba46b6e15d2f2917e7d392955c8c6f988527ba Mon Sep 17 00:00:00 2001 From: Robert Barron Date: Wed, 9 Aug 2023 07:46:30 -0700 Subject: allow first pass and hires pass to use a single prompt to do different prompt editing, hires is 1.0..2.0: relative time range is [1..2] absolute time range is [steps+1..steps+hire_steps], e.g. with 30 steps and 20 hires steps, '20' is 2/3rds through first pass, and 40 is halfway through hires pass --- modules/processing.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index 31745006..0750b299 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -319,12 +319,14 @@ class StableDiffusionProcessing: self.all_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_prompts] self.all_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_negative_prompts] - def cached_params(self, required_prompts, steps, extra_network_data): + def cached_params(self, required_prompts, steps, hires_steps, extra_network_data, use_old_scheduling): """Returns parameters that invalidate the cond cache if changed""" return ( required_prompts, steps, + hires_steps, + use_old_scheduling, opts.CLIP_stop_at_last_layers, shared.sd_model.sd_checkpoint_info, extra_network_data, @@ -334,7 +336,7 @@ class StableDiffusionProcessing: self.height, ) - def get_conds_with_caching(self, function, required_prompts, steps, caches, extra_network_data): + def get_conds_with_caching(self, function, required_prompts, steps, hires_steps, caches, extra_network_data): """ Returns the result of calling function(shared.sd_model, required_prompts, steps) using a cache to store the result if the same arguments have been used before. @@ -347,7 +349,7 @@ class StableDiffusionProcessing: caches is a list with items described above. """ - cached_params = self.cached_params(required_prompts, steps, extra_network_data) + cached_params = self.cached_params(required_prompts, steps, hires_steps, extra_network_data, shared.opts.use_old_scheduling) for cache in caches: if cache[0] is not None and cached_params == cache[0]: @@ -356,7 +358,7 @@ class StableDiffusionProcessing: cache = caches[0] with devices.autocast(): - cache[1] = function(shared.sd_model, required_prompts, steps) + cache[1] = function(shared.sd_model, required_prompts, steps, hires_steps, shared.opts.use_old_scheduling) cache[0] = cached_params return cache[1] @@ -367,8 +369,9 @@ class StableDiffusionProcessing: sampler_config = sd_samplers.find_sampler_config(self.sampler_name) self.step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 - self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data) - self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data) + self.firstpass_steps = self.steps * self.step_multiplier + self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.firstpass_steps, None, [self.cached_uc], self.extra_network_data) + self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.firstpass_steps, None, [self.cached_c], self.extra_network_data) def parse_extra_network_prompts(self): self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts) @@ -1225,8 +1228,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): hr_prompts = prompt_parser.SdConditioning(self.hr_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y) hr_negative_prompts = prompt_parser.SdConditioning(self.hr_negative_prompts, width=self.hr_upscale_to_x, height=self.hr_upscale_to_y, is_negative_prompt=True) - self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.steps * self.step_multiplier, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) - self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.steps * self.step_multiplier, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) + hires_steps = (self.hr_second_pass_steps or self.steps) * self.step_multiplier + self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, hr_negative_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_uc, self.cached_uc], self.hr_extra_network_data) + self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, hr_prompts, self.firstpass_steps, hires_steps, [self.cached_hr_c, self.cached_c], self.hr_extra_network_data) def setup_conds(self): super().setup_conds() -- cgit v1.2.3