From b7f45e67dcf48914d2f34d4ace977a431a5aa12e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 11 Feb 2024 12:56:53 +0300 Subject: add before_token_counter callback and use it for prompt comments --- modules/script_callbacks.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) (limited to 'modules/script_callbacks.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a54cb3eb..08bc5256 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -1,3 +1,4 @@ +import dataclasses import inspect import os from collections import namedtuple @@ -106,6 +107,15 @@ class ImageGridLoopParams: self.rows = rows +@dataclasses.dataclass +class BeforeTokenCounterParams: + prompt: str + steps: int + styles: list + + is_positive: bool = True + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callback_map = dict( callbacks_app_started=[], @@ -128,6 +138,7 @@ callback_map = dict( callbacks_on_reload=[], callbacks_list_optimizers=[], callbacks_list_unets=[], + callbacks_before_token_counter=[], ) @@ -309,6 +320,14 @@ def list_unets_callback(): return res +def before_token_counter_callback(params: BeforeTokenCounterParams): + for c in callback_map['callbacks_before_token_counter']: + try: + c.callback(params) + except Exception: + report_exception(c, 'before_token_counter') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if stack else 'unknown file' @@ -483,3 +502,10 @@ def on_list_unets(callback): The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it.""" add_callback(callback_map['callbacks_list_unets'], callback) + + +def on_before_token_counter(callback): + """register a function to be called when UI is counting tokens for a prompt. + The function will be called with one argument of type BeforeTokenCounterParams, and should modify its fields if necessary.""" + + add_callback(callback_map['callbacks_before_token_counter'], callback) -- cgit v1.2.3