From b78c5e87baaf8c88d039bf60082c3b5ae35ec4ff Mon Sep 17 00:00:00 2001 From: opparco Date: Sat, 11 Feb 2023 11:18:38 +0900 Subject: Add cfg_denoised_callback --- modules/script_callbacks.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) (limited to 'modules/script_callbacks.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 4bb45ec7..edd0e2a7 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -46,6 +46,18 @@ class CFGDenoiserParams: """Total number of sampling steps planned""" +class CFGDenoisedParams: + def __init__(self, x, sampling_step, total_sampling_steps): + self.x = x + """Latent image representation in the process of being denoised""" + + self.sampling_step = sampling_step + """Current Sampling step number""" + + self.total_sampling_steps = total_sampling_steps + """Total number of sampling steps planned""" + + class UiTrainTabParams: def __init__(self, txt2img_preview_params): self.txt2img_preview_params = txt2img_preview_params @@ -68,6 +80,7 @@ callback_map = dict( callbacks_before_image_saved=[], callbacks_image_saved=[], callbacks_cfg_denoiser=[], + callbacks_cfg_denoised=[], callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], @@ -150,6 +163,14 @@ def cfg_denoiser_callback(params: CFGDenoiserParams): report_exception(c, 'cfg_denoiser_callback') +def cfg_denoised_callback(params: CFGDenoisedParams): + for c in callback_map['callbacks_cfg_denoised']: + try: + c.callback(params) + except Exception: + report_exception(c, 'cfg_denoised_callback') + + def before_component_callback(component, **kwargs): for c in callback_map['callbacks_before_component']: try: @@ -283,6 +304,14 @@ def on_cfg_denoiser(callback): add_callback(callback_map['callbacks_cfg_denoiser'], callback) +def on_cfg_denoised(callback): + """register a function to be called in the kdiffussion cfg_denoiser method after building the inner model inputs. + The callback is called with one argument: + - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. + """ + add_callback(callback_map['callbacks_cfg_denoised'], callback) + + def on_before_component(callback): """register a function to be called before a component is created. The callback is called with arguments: -- cgit v1.2.3