aboutsummaryrefslogtreecommitdiffstats
path: root/scripts/batch.py
diff options
context:
space:
mode:
authorAUTOMATIC <16777216c@gmail.com>2022-09-10 05:45:55 +0000
committerAUTOMATIC <16777216c@gmail.com>2022-09-10 05:45:55 +0000
commit22faf30c0b485d1907ed15fb037b623b3d6a5016 (patch)
tree063bead1f5188c10f9c26439f48e829c75d76e65 /scripts/batch.py
parent6f678ec79ce6319607a79ac2b75bbe987611603e (diff)
downloadstable-diffusion-webui-gfx803-22faf30c0b485d1907ed15fb037b623b3d6a5016.tar.gz
stable-diffusion-webui-gfx803-22faf30c0b485d1907ed15fb037b623b3d6a5016.tar.bz2
stable-diffusion-webui-gfx803-22faf30c0b485d1907ed15fb037b623b3d6a5016.zip
add script for batch file processing
Diffstat (limited to 'scripts/batch.py')
-rw-r--r--scripts/batch.py59
1 files changed, 59 insertions, 0 deletions
diff --git a/scripts/batch.py b/scripts/batch.py
new file mode 100644
index 00000000..1af4a7bc
--- /dev/null
+++ b/scripts/batch.py
@@ -0,0 +1,59 @@
+import math
+import os
+import sys
+import traceback
+
+import modules.scripts as scripts
+import gradio as gr
+
+from modules.processing import Processed, process_images
+from PIL import Image
+from modules.shared import opts, cmd_opts, state
+
+
+class Script(scripts.Script):
+ def title(self):
+ return "Batch processing"
+
+ def show(self, is_img2img):
+ return is_img2img
+
+ def ui(self, is_img2img):
+ input_dir = gr.Textbox(label="Input directory", lines=1)
+ output_dir = gr.Textbox(label="Output directory", lines=1)
+
+ return [input_dir, output_dir]
+
+ def run(self, p, input_dir, output_dir):
+ images = [file for file in [os.path.join(input_dir, x) for x in os.listdir(input_dir)] if os.path.isfile(file)]
+
+ batch_count = math.ceil(len(images) / p.batch_size)
+ print(f"Will process {len(images)} images in {batch_count} batches.")
+
+ p.batch_count = 1
+ p.do_not_save_grid = True
+ p.do_not_save_samples = True
+
+ state.job_count = batch_count
+
+ for batch_no in range(batch_count):
+ batch_images = []
+ for path in images[batch_no*p.batch_size:(batch_no+1)*p.batch_size]:
+ try:
+ img = Image.open(path)
+ batch_images.append((img, path))
+ except:
+ print(f"Error processing {path}:", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
+ if len(batch_images) == 0:
+ continue
+
+ state.job = f"{batch_no} out of {batch_count}: {batch_images[0][1]}"
+ p.init_images = [x[0] for x in batch_images]
+ proc = process_images(p)
+ for image, (_, path) in zip(proc.images, batch_images):
+ filename = os.path.basename(path)
+ image.save(os.path.join(output_dir, filename))
+
+ return Processed(p, [], p.seed, "")