diff options
Diffstat (limited to 'modules/scripts.py')
-rw-r--r-- | modules/scripts.py | 77 |
1 files changed, 76 insertions, 1 deletions
diff --git a/modules/scripts.py b/modules/scripts.py index 51da732a..66fbec0d 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -3,6 +3,7 @@ import re import sys
import inspect
from collections import namedtuple
+from dataclasses import dataclass
import gradio as gr
@@ -21,6 +22,11 @@ class PostprocessBatchListArgs: self.images = images
+@dataclass
+class OnComponent:
+ component: gr.blocks.Block
+
+
class Script:
name = None
"""script's internal name derived from title"""
@@ -35,6 +41,7 @@ class Script: is_txt2img = False
is_img2img = False
+ tabname = None
group = None
"""A gr.Group component that has all script's UI inside it."""
@@ -55,6 +62,12 @@ class Script: api_info = None
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
+ on_before_component_elem_id = []
+ """list of callbacks to be called before a component with an elem_id is created"""
+
+ on_after_component_elem_id = []
+ """list of callbacks to be called after a component with an elem_id is created"""
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -215,6 +228,24 @@ class Script: pass
+ def on_before_component(self, callback, *, elem_id):
+ """
+ Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
+
+ This function is an alternative to before_component in that it also cllows to run before a component is created, but
+ it doesn't require to be called for every created component - just for the one you need.
+ """
+
+ self.on_before_component_elem_id.append((elem_id, callback))
+
+ def on_after_component(self, callback, *, elem_id):
+ """
+ Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
+ """
+
+ self.on_after_component_elem_id.append((elem_id, callback))
+
+
def describe(self):
"""unused"""
return ""
@@ -236,6 +267,17 @@ class Script: pass
+class ScriptBuiltin(Script):
+
+ def elem_id(self, item_id):
+ """helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
+
+ need_tabname = self.show(True) == self.show(False)
+ tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
+
+ return f'{tabname}{item_id}'
+
+
current_basedir = paths.script_path
@@ -354,10 +396,17 @@ class ScriptRunner: self.selectable_scripts = []
self.alwayson_scripts = []
self.titles = []
+ self.title_map = {}
self.infotext_fields = []
self.paste_field_names = []
self.inputs = [None]
+ self.on_before_component_elem_id = {}
+ """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
+
+ self.on_after_component_elem_id = {}
+ """dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""
+
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
@@ -372,6 +421,7 @@ class ScriptRunner: script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
+ script.tabname = "img2img" if is_img2img else "txt2img"
visibility = script.show(script.is_img2img)
@@ -446,6 +496,8 @@ class ScriptRunner: self.inputs = [None]
def setup_ui(self):
+ all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
+ self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
self.setup_ui_for_section(None)
@@ -492,6 +544,13 @@ class ScriptRunner: self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
+ for script in self.scripts:
+ for elem_id, callback in script.on_before_component_elem_id:
+ self.on_before_component_elem_id.get(elem_id, []).append((callback, script))
+
+ for elem_id, callback in script.on_after_component_elem_id:
+ self.on_after_component_elem_id.get(elem_id, []).append((callback, script))
+
return self.inputs
def run(self, p, *args):
@@ -585,6 +644,13 @@ class ScriptRunner: errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
+ for callbacks in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
+ for callback, script in callbacks:
+ try:
+ callback(OnComponent(component=component))
+ except Exception:
+ errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
+
for script in self.scripts:
try:
script.before_component(component, **kwargs)
@@ -592,12 +658,22 @@ class ScriptRunner: errors.report(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
+ for callbacks in self.on_after_component_elem_id.get(component.elem_id, []):
+ for callback, script in callbacks:
+ try:
+ callback(OnComponent(component=component))
+ except Exception:
+ errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
+
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
+ def script(self, title):
+ return self.title_map.get(title.lower())
+
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
args_from = script.args_from
@@ -616,7 +692,6 @@ class ScriptRunner: self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
-
def before_hr(self, p):
for script in self.alwayson_scripts:
try:
|