From c251e8db8d71e649e4350f13aad1a76ed98d35c3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 24 Jul 2023 23:18:16 +0300 Subject: Merge pull request #11957 from ljleb/pp-batch-list Add postprocess_batch_list script callback --- modules/scripts.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index f34240a0..5b4edcac 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -16,6 +16,11 @@ class PostprocessImageArgs: self.image = image +class PostprocessBatchListArgs: + def __init__(self, images): + self.images = images + + class Script: name = None """script's internal name derived from title""" @@ -156,6 +161,25 @@ class Script: pass + def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs): + """ + Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor. + This is useful when you want to update the entire batch instead of individual images. + + You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc. + If the number of images is different from the batch size when returning, + then the script has the responsibility to also update the following attributes in the processing object (p): + - p.prompts + - p.negative_prompts + - p.seeds + - p.subseeds + + **kwargs will have same items as process_batch, and also: + - batch_number - index of current batch, from 0 to number of batches-1 + """ + + pass + def postprocess_image(self, p, pp: PostprocessImageArgs, *args): """ Called for every image after it has been generated. @@ -536,6 +560,14 @@ class ScriptRunner: except Exception: errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True) + def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs): + for script in self.alwayson_scripts: + try: + script_args = p.script_args[script.args_from:script.args_to] + script.postprocess_batch_list(p, pp, *script_args, **kwargs) + except Exception: + errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True) + def postprocess_image(self, p, pp: PostprocessImageArgs): for script in self.alwayson_scripts: try: -- cgit v1.2.3