diff options
Diffstat (limited to 'modules/processing.py')
-rw-r--r-- | modules/processing.py | 32 |
1 files changed, 29 insertions, 3 deletions
diff --git a/modules/processing.py b/modules/processing.py index fae83788..b65c7beb 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -739,7 +739,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: del samples_ddim
- if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
+ if lowvram.is_enabled(shared.sd_model):
lowvram.send_everything_to_cpu()
devices.torch_gc()
@@ -894,6 +894,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.hr_negative_prompts = None
self.hr_extra_network_data = None
+ self.cached_hr_uc = [None, None]
+ self.cached_hr_c = [None, None]
self.hr_c = None
self.hr_uc = None
@@ -1056,6 +1058,9 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): with devices.autocast():
extra_networks.activate(self, self.hr_extra_network_data)
+ with devices.autocast():
+ self.calculate_hr_conds()
+
sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio(for_hr=True))
samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)
@@ -1067,6 +1072,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): return samples
def close(self):
+ self.cached_hr_uc = [None, None]
+ self.cached_hr_c = [None, None]
self.hr_c = None
self.hr_uc = None
@@ -1095,12 +1102,31 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing): self.all_hr_prompts = [shared.prompt_styles.apply_styles_to_prompt(x, self.styles) for x in self.all_hr_prompts]
self.all_hr_negative_prompts = [shared.prompt_styles.apply_negative_styles_to_prompt(x, self.styles) for x in self.all_hr_negative_prompts]
+ def calculate_hr_conds(self):
+ if self.hr_c is not None:
+ return
+
+ self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_hr_uc, self.hr_extra_network_data)
+ self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_hr_c, self.hr_extra_network_data)
+
def setup_conds(self):
super().setup_conds()
+ self.hr_uc = None
+ self.hr_c = None
+
if self.enable_hr:
- self.hr_uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, self.hr_negative_prompts, self.steps * self.step_multiplier, self.cached_uc, self.hr_extra_network_data)
- self.hr_c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, self.hr_prompts, self.steps * self.step_multiplier, self.cached_c, self.hr_extra_network_data)
+ if shared.opts.hires_fix_use_firstpass_conds:
+ self.calculate_hr_conds()
+
+ elif lowvram.is_enabled(shared.sd_model): # if in lowvram mode, we need to calculate conds right away, before the cond NN is unloaded
+ with devices.autocast():
+ extra_networks.activate(self, self.hr_extra_network_data)
+
+ self.calculate_hr_conds()
+
+ with devices.autocast():
+ extra_networks.activate(self, self.extra_network_data)
def parse_extra_network_prompts(self):
res = super().parse_extra_network_prompts()
|