From bb57f30c2de46cfca5419ad01738a41705f96cc3 Mon Sep 17 00:00:00 2001 From: MalumaDev Date: Fri, 14 Oct 2022 10:56:41 +0200 Subject: init --- modules/processing.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) (limited to 'modules/processing.py') diff --git a/modules/processing.py b/modules/processing.py index d5172f00..9a033759 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -316,11 +316,16 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration return f"{all_prompts[index]}{negative_prompt_text}\n{generation_params_text}".strip() -def process_images(p: StableDiffusionProcessing) -> Processed: +def process_images(p: StableDiffusionProcessing, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, + aesthetic_imgs=None,aesthetic_slerp=False) -> Processed: """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + aesthetic_lr = float(aesthetic_lr) + aesthetic_weight = float(aesthetic_weight) + aesthetic_steps = int(aesthetic_steps) + if type(p.prompt) == list: - assert(len(p.prompt) > 0) + assert (len(p.prompt) > 0) else: assert p.prompt is not None @@ -394,7 +399,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed: #uc = p.sd_model.get_learned_conditioning(len(prompts) * [p.negative_prompt]) #c = p.sd_model.get_learned_conditioning(prompts) with devices.autocast(): - uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], p.steps) + if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): + shared.sd_model.cond_stage_model.set_aesthetic_params(0, 0, 0) + uc = prompt_parser.get_learned_conditioning(shared.sd_model, len(prompts) * [p.negative_prompt], + p.steps) + if hasattr(shared.sd_model.cond_stage_model, "set_aesthetic_params"): + shared.sd_model.cond_stage_model.set_aesthetic_params(aesthetic_lr, aesthetic_weight, + aesthetic_steps, aesthetic_imgs,aesthetic_slerp) c = prompt_parser.get_multicond_learned_conditioning(shared.sd_model, prompts, p.steps) if len(model_hijack.comments) > 0: -- cgit v1.2.3