aboutsummaryrefslogtreecommitdiffstats
path: root/modules/script_callbacks.py
diff options
context:
space:
mode:
authorrandom-thoughtss <116161560+random-thoughtss@users.noreply.github.com>2022-11-03 22:55:54 +0000
committerGitHub <noreply@github.com>2022-11-03 22:55:54 +0000
commit243253ff4a8ae944ba142abe9c1e78a92dd14ebe (patch)
treec40402e18a29ca9a9b167a2f9e47dab39dce0943 /modules/script_callbacks.py
parentd9e4e4d7a09d4aee8ce249a3c8e91ce165b10fa5 (diff)
parent20a860b525cb7a319a42994f75a94bbca9a54d89 (diff)
downloadstable-diffusion-webui-gfx803-243253ff4a8ae944ba142abe9c1e78a92dd14ebe.tar.gz
stable-diffusion-webui-gfx803-243253ff4a8ae944ba142abe9c1e78a92dd14ebe.tar.bz2
stable-diffusion-webui-gfx803-243253ff4a8ae944ba142abe9c1e78a92dd14ebe.zip
Merge branch 'AUTOMATIC1111:master' into master
Diffstat (limited to 'modules/script_callbacks.py')
-rw-r--r--modules/script_callbacks.py56
1 files changed, 55 insertions, 1 deletions
diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py
index 6ea58d61..c28e220e 100644
--- a/modules/script_callbacks.py
+++ b/modules/script_callbacks.py
@@ -2,7 +2,10 @@ import sys
import traceback
from collections import namedtuple
import inspect
+from typing import Optional
+from fastapi import FastAPI
+from gradio import Blocks
def report_exception(c, job):
print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
@@ -24,12 +27,32 @@ class ImageSaveParams:
"""dictionary with parameters for image's PNG info data; infotext will have the key 'parameters'"""
+class CFGDenoiserParams:
+ def __init__(self, x, image_cond, sigma, sampling_step, total_sampling_steps):
+ self.x = x
+ """Latent image representation in the process of being denoised"""
+
+ self.image_cond = image_cond
+ """Conditioning image"""
+
+ self.sigma = sigma
+ """Current sigma noise step value"""
+
+ self.sampling_step = sampling_step
+ """Current Sampling step number"""
+
+ self.total_sampling_steps = total_sampling_steps
+ """Total number of sampling steps planned"""
+
+
ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
+callbacks_app_started = []
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
callbacks_before_image_saved = []
callbacks_image_saved = []
+callbacks_cfg_denoiser = []
def clear_callbacks():
@@ -38,6 +61,14 @@ def clear_callbacks():
callbacks_ui_settings.clear()
callbacks_before_image_saved.clear()
callbacks_image_saved.clear()
+ callbacks_cfg_denoiser.clear()
+
+def app_started_callback(demo: Optional[Blocks], app: FastAPI):
+ for c in callbacks_app_started:
+ try:
+ c.callback(demo, app)
+ except Exception:
+ report_exception(c, 'app_started_callback')
def model_loaded_callback(sd_model):
@@ -69,7 +100,7 @@ def ui_settings_callback():
def before_image_saved_callback(params: ImageSaveParams):
- for c in callbacks_image_saved:
+ for c in callbacks_before_image_saved:
try:
c.callback(params)
except Exception:
@@ -84,6 +115,14 @@ def image_saved_callback(params: ImageSaveParams):
report_exception(c, 'image_saved_callback')
+def cfg_denoiser_callback(params: CFGDenoiserParams):
+ for c in callbacks_cfg_denoiser:
+ try:
+ c.callback(params)
+ except Exception:
+ report_exception(c, 'cfg_denoiser_callback')
+
+
def add_callback(callbacks, fun):
stack = [x for x in inspect.stack() if x.filename != __file__]
filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -91,6 +130,12 @@ def add_callback(callbacks, fun):
callbacks.append(ScriptCallback(filename, fun))
+def on_app_started(callback):
+ """register a function to be called when the webui started, the gradio `Block` component and
+ fastapi `FastAPI` object are passed as the arguments"""
+ add_callback(callbacks_app_started, callback)
+
+
def on_model_loaded(callback):
"""register a function to be called when the stable diffusion model is created; the model is
passed as an argument"""
@@ -130,3 +175,12 @@ def on_image_saved(callback):
- params: ImageSaveParams - parameters the image was saved with. Changing fields in this object does nothing.
"""
add_callback(callbacks_image_saved, callback)
+
+
+def on_cfg_denoiser(callback):
+ """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs.
+ The callback is called with one argument:
+ - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
+ """
+ add_callback(callbacks_cfg_denoiser, callback)
+