diff options
Diffstat (limited to 'modules/scripts.py')
-rw-r--r-- | modules/scripts.py | 73 |
1 files changed, 59 insertions, 14 deletions
diff --git a/modules/scripts.py b/modules/scripts.py index 722f8685..6e9dc0c0 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -1,16 +1,21 @@ import os
+import re
import sys
import traceback
from collections import namedtuple
import gradio as gr
-from modules.processing import StableDiffusionProcessing
-from modules import shared, paths, script_callbacks, extensions, script_loading
+from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing
AlwaysVisible = object()
+class PostprocessImageArgs:
+ def __init__(self, image):
+ self.image = image
+
+
class Script:
filename = None
args_from = None
@@ -64,7 +69,7 @@ class Script: args contains all values returned by components from ui()
"""
- raise NotImplementedError()
+ pass
def process(self, p, *args):
"""
@@ -99,6 +104,13 @@ class Script: pass
+ def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
+ """
+ Called for every image after it has been generated.
+ """
+
+ pass
+
def postprocess(self, p, processed, *args):
"""
This function is called after processing ends for AlwaysVisible scripts.
@@ -128,6 +140,15 @@ class Script: """unused"""
return ""
+ def elem_id(self, item_id):
+ """helper function to generate id for a HTML element, constructs final id out of script name, tab and user-supplied item_id"""
+
+ need_tabname = self.show(True) == self.show(False)
+ tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
+ title = re.sub(r'[^a-z_0-9]', '', re.sub(r'\s', '_', self.title().lower()))
+
+ return f'script_{tabname}{title}_{item_id}'
+
current_basedir = paths.script_path
@@ -140,9 +161,11 @@ def basedir(): return current_basedir
-scripts_data = []
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])
-ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir"])
+
+scripts_data = []
+postprocessing_scripts_data = []
+ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])
def list_scripts(scriptdirname, extension):
@@ -180,23 +203,31 @@ def list_files_with_name(filename): def load_scripts():
global current_basedir
scripts_data.clear()
+ postprocessing_scripts_data.clear()
script_callbacks.clear_callbacks()
scripts_list = list_scripts("scripts", ".py")
syspath = sys.path
+ def register_scripts_from_module(module):
+ for key, script_class in module.__dict__.items():
+ if type(script_class) != type:
+ continue
+
+ if issubclass(script_class, Script):
+ scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+ elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
+ postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
+
for scriptfile in sorted(scripts_list):
try:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path
current_basedir = scriptfile.basedir
- module = script_loading.load_module(scriptfile.path)
-
- for key, script_class in module.__dict__.items():
- if type(script_class) == type and issubclass(script_class, Script):
- scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir))
+ script_module = script_loading.load_module(scriptfile.path)
+ register_scripts_from_module(script_module)
except Exception:
print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
@@ -227,11 +258,15 @@ class ScriptRunner: self.infotext_fields = []
def initialize_scripts(self, is_img2img):
+ from modules import scripts_auto_postprocessing
+
self.scripts.clear()
self.alwayson_scripts.clear()
self.selectable_scripts.clear()
- for script_class, path, basedir in scripts_data:
+ auto_processing_scripts = scripts_auto_postprocessing.create_auto_preprocessing_script_data()
+
+ for script_class, path, basedir, script_module in auto_processing_scripts + scripts_data:
script = script_class()
script.filename = path
script.is_txt2img = not is_img2img
@@ -280,7 +315,6 @@ class ScriptRunner: script.group = group
dropdown = gr.Dropdown(label="Script", elem_id="script_list", choices=["None"] + self.titles, value="None", type="index")
- dropdown.save_to_config = True
inputs[0] = dropdown
for script in self.selectable_scripts:
@@ -313,7 +347,7 @@ class ScriptRunner: return inputs
- def run(self, p: StableDiffusionProcessing, *args):
+ def run(self, p, *args):
script_index = args[0]
if script_index == 0:
@@ -367,6 +401,15 @@ class ScriptRunner: print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
print(traceback.format_exc(), file=sys.stderr)
+ def postprocess_image(self, p, pp: PostprocessImageArgs):
+ for script in self.alwayson_scripts:
+ try:
+ script_args = p.script_args[script.args_from:script.args_to]
+ script.postprocess_image(p, pp, *script_args)
+ except Exception:
+ print(f"Error running postprocess_batch: {script.filename}", file=sys.stderr)
+ print(traceback.format_exc(), file=sys.stderr)
+
def before_component(self, component, **kwargs):
for script in self.scripts:
try:
@@ -404,6 +447,7 @@ class ScriptRunner: scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
scripts_current: ScriptRunner = None
@@ -414,12 +458,13 @@ def reload_script_body_only(): def reload_scripts():
- global scripts_txt2img, scripts_img2img
+ global scripts_txt2img, scripts_img2img, scripts_postproc
load_scripts()
scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
+ scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
def IOComponent_init(self, *args, **kwargs):
|