diff options
author | AUTOMATIC1111 <16777216c@gmail.com> | 2023-07-24 20:18:16 +0000 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-07-24 20:18:16 +0000 |
commit | f7c0a963f15b165c8518faad87ab10323baa173f (patch) | |
tree | 1f889b348ce267bde81ba36d9e3ee8e2f46bbf97 | |
parent | f451994053140622ef5e394bc02ac166fb74e56f (diff) | |
parent | 5b06607476d1ef2c9d16fe8b21c786b2ca13b95c (diff) | |
download | stable-diffusion-webui-gfx803-f7c0a963f15b165c8518faad87ab10323baa173f.tar.gz stable-diffusion-webui-gfx803-f7c0a963f15b165c8518faad87ab10323baa173f.tar.bz2 stable-diffusion-webui-gfx803-f7c0a963f15b165c8518faad87ab10323baa173f.zip |
Merge pull request #11957 from ljleb/pp-batch-list
Add postprocess_batch_list script callback
-rw-r--r-- | modules/processing.py | 26 | ||||
-rw-r--r-- | modules/scripts.py | 32 |
2 files changed, 57 insertions, 1 deletions
diff --git a/modules/processing.py b/modules/processing.py index a74a5302..6dc178e1 100644 --- a/modules/processing.py +++ b/modules/processing.py @@ -717,7 +717,27 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))]
def infotext(iteration=0, position_in_batch=0, use_main_prompt=False):
- return create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
+ all_prompts = p.all_prompts[:]
+ all_negative_prompts = p.all_negative_prompts[:]
+ all_seeds = p.all_seeds[:]
+ all_subseeds = p.all_subseeds[:]
+
+ # apply changes to generation data
+ all_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.prompts
+ all_negative_prompts[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.negative_prompts
+ all_seeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.seeds
+ all_subseeds[iteration * p.batch_size:(iteration + 1) * p.batch_size] = p.subseeds
+
+ # update p.all_negative_prompts in case extensions changed the size of the batch
+ # create_infotext below uses it
+ old_negative_prompts = p.all_negative_prompts
+ p.all_negative_prompts = all_negative_prompts
+
+ try:
+ return create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration, position_in_batch, use_main_prompt)
+ finally:
+ # restore p.all_negative_prompts in case extensions changed the size of the batch
+ p.all_negative_prompts = old_negative_prompts
if os.path.exists(cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings:
model_hijack.embedding_db.load_textual_inversion_embeddings()
@@ -806,6 +826,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed: if p.scripts is not None:
p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)
+ postprocess_batch_list_args = scripts.PostprocessBatchListArgs(list(x_samples_ddim))
+ p.scripts.postprocess_batch_list(p, postprocess_batch_list_args, batch_number=n)
+ x_samples_ddim = postprocess_batch_list_args.images
+
for i, x_sample in enumerate(x_samples_ddim):
p.batch_index = i
diff --git a/modules/scripts.py b/modules/scripts.py index f34240a0..5b4edcac 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -16,6 +16,11 @@ class PostprocessImageArgs: self.image = image
+class PostprocessBatchListArgs:
+ def __init__(self, images):
+ self.images = images
+
+
class Script:
name = None
"""script's internal name derived from title"""
@@ -156,6 +161,25 @@ class Script: pass
+ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, *args, **kwargs):
+ """
+ Same as postprocess_batch(), but receives batch images as a list of 3D tensors instead of a 4D tensor.
+ This is useful when you want to update the entire batch instead of individual images.
+
+ You can modify the postprocessing object (pp) to update the images in the batch, remove images, add images, etc.
+ If the number of images is different from the batch size when returning,
+ then the script has the responsibility to also update the following attributes in the processing object (p):
+ - p.prompts
+ - p.negative_prompts
+ - p.seeds
+ - p.subseeds
+
+ **kwargs will have same items as process_batch, and also:
+ - batch_number - index of current batch, from 0 to number of batches-1
+ """
+
+ pass
+
def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
"""
Called for every image after it has been generated.
@@ -536,6 +560,14 @@ class ScriptRunner: except Exception:
errors.report(f"Error running postprocess_batch: {script.filename}", exc_info=True)
+ def postprocess_batch_list(self, p, pp: PostprocessBatchListArgs, **kwargs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_batch_list(p, pp, *script_args, **kwargs)
+ except Exception:
+ errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)
+
def postprocess_image(self, p, pp: PostprocessImageArgs):
for script in self.alwayson_scripts:
try:
|