aboutsummaryrefslogtreecommitdiffstats
path: root/modules/scripts.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-08-12 14:46:13 +0000
committerAUTOMATIC1111 <16777216c@gmail.com>2023-08-12 14:46:13 +0000
commitf0b72b81211881e083c84cff585380bb70d17271 (patch)
tree5da3afb6c761ed98d33a67701144b6b4fb2a3de8 /modules/scripts.py
parent6aa26a26d5beb317d708c4fa85c38056347ea5d3 (diff)
downloadstable-diffusion-webui-gfx803-f0b72b81211881e083c84cff585380bb70d17271.tar.gz
stable-diffusion-webui-gfx803-f0b72b81211881e083c84cff585380bb70d17271.tar.bz2
stable-diffusion-webui-gfx803-f0b72b81211881e083c84cff585380bb70d17271.zip
move seed, variation seed and variation seed strength to a single row, dump resize seed from UI
add a way for scripts to register a callback for before/after just a single component's creation
Diffstat (limited to 'modules/scripts.py')
-rw-r--r--modules/scripts.py77
1 files changed, 76 insertions, 1 deletions
diff --git a/modules/scripts.py b/modules/scripts.py
index 51da732a..66fbec0d 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -3,6 +3,7 @@ import re
import sys
import inspect
from collections import namedtuple
+from dataclasses import dataclass
import gradio as gr
@@ -21,6 +22,11 @@ class PostprocessBatchListArgs:
self.images = images
+@dataclass
+class OnComponent:
+ component: gr.blocks.Block
+
+
class Script:
name = None
"""script's internal name derived from title"""
@@ -35,6 +41,7 @@ class Script:
is_txt2img = False
is_img2img = False
+ tabname = None
group = None
"""A gr.Group component that has all script's UI inside it."""
@@ -55,6 +62,12 @@ class Script:
api_info = None
"""Generated value of type modules.api.models.ScriptInfo with information about the script for API"""
+ on_before_component_elem_id = []
+ """list of callbacks to be called before a component with an elem_id is created"""
+
+ on_after_component_elem_id = []
+ """list of callbacks to be called after a component with an elem_id is created"""
+
def title(self):
"""this function should return the title of the script. This is what will be displayed in the dropdown menu."""
@@ -215,6 +228,24 @@ class Script:
pass
+ def on_before_component(self, callback, *, elem_id):
+ """
+ Calls callback before a component is created. The callback function is called with a single argument of type OnComponent.
+
+ This function is an alternative to before_component in that it also cllows to run before a component is created, but
+ it doesn't require to be called for every created component - just for the one you need.
+ """
+
+ self.on_before_component_elem_id.append((elem_id, callback))
+
+ def on_after_component(self, callback, *, elem_id):
+ """
+ Calls callback after a component is created. The callback function is called with a single argument of type OnComponent.
+ """
+
+ self.on_after_component_elem_id.append((elem_id, callback))
+
+
def describe(self):
"""unused"""
return ""
@@ -236,6 +267,17 @@ class Script:
pass
+class ScriptBuiltin(Script):
+
+ def elem_id(self, item_id):
+ """helper function to generate id for a HTML element, constructs final id out of tab and user-supplied item_id"""
+
+ need_tabname = self.show(True) == self.show(False)
+ tabname = ('img2img' if self.is_img2img else 'txt2txt') + "_" if need_tabname else ""
+
+ return f'{tabname}{item_id}'
+
+
current_basedir = paths.script_path
@@ -354,10 +396,17 @@ class ScriptRunner:
self.selectable_scripts = []
self.alwayson_scripts = []
self.titles = []
+ self.title_map = {}
self.infotext_fields = []
self.paste_field_names = []
self.inputs = [None]
+ self.on_before_component_elem_id = {}
+ """dict of callbacks to be called before an element is created; key=elem_id, value=list of callbacks"""
+
+ self.on_after_component_elem_id = {}
+ """dict of callbacks to be called after an element is created; key=elem_id, value=list of callbacks"""
+
def initialize_scripts(self, is_img2img):
from modules import scripts_auto_postprocessing
@@ -372,6 +421,7 @@ class ScriptRunner:
script.filename = script_data.path
script.is_txt2img = not is_img2img
script.is_img2img = is_img2img
+ script.tabname = "img2img" if is_img2img else "txt2img"
visibility = script.show(script.is_img2img)
@@ -446,6 +496,8 @@ class ScriptRunner:
self.inputs = [None]
def setup_ui(self):
+ all_titles = [wrap_call(script.title, script.filename, "title") or script.filename for script in self.scripts]
+ self.title_map = {title.lower(): script for title, script in zip(all_titles, self.scripts)}
self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]
self.setup_ui_for_section(None)
@@ -492,6 +544,13 @@ class ScriptRunner:
self.infotext_fields.append((dropdown, lambda x: gr.update(value=x.get('Script', 'None'))))
self.infotext_fields.extend([(script.group, onload_script_visibility) for script in self.selectable_scripts])
+ for script in self.scripts:
+ for elem_id, callback in script.on_before_component_elem_id:
+ self.on_before_component_elem_id.get(elem_id, []).append((callback, script))
+
+ for elem_id, callback in script.on_after_component_elem_id:
+ self.on_after_component_elem_id.get(elem_id, []).append((callback, script))
+
return self.inputs
def run(self, p, *args):
@@ -585,6 +644,13 @@ class ScriptRunner:
errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)
def before_component(self, component, **kwargs):
+ for callbacks in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
+ for callback, script in callbacks:
+ try:
+ callback(OnComponent(component=component))
+ except Exception:
+ errors.report(f"Error running on_before_component: {script.filename}", exc_info=True)
+
for script in self.scripts:
try:
script.before_component(component, **kwargs)
@@ -592,12 +658,22 @@ class ScriptRunner:
errors.report(f"Error running before_component: {script.filename}", exc_info=True)
def after_component(self, component, **kwargs):
+ for callbacks in self.on_after_component_elem_id.get(component.elem_id, []):
+ for callback, script in callbacks:
+ try:
+ callback(OnComponent(component=component))
+ except Exception:
+ errors.report(f"Error running on_after_component: {script.filename}", exc_info=True)
+
for script in self.scripts:
try:
script.after_component(component, **kwargs)
except Exception:
errors.report(f"Error running after_component: {script.filename}", exc_info=True)
+ def script(self, title):
+ return self.title_map.get(title.lower())
+
def reload_sources(self, cache):
for si, script in list(enumerate(self.scripts)):
args_from = script.args_from
@@ -616,7 +692,6 @@ class ScriptRunner:
self.scripts[si].args_from = args_from
self.scripts[si].args_to = args_to
-
def before_hr(self, p):
for script in self.alwayson_scripts:
try: