From 3078001439d25b66ef5627c9e3d431aa23bbed73 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Sun, 14 May 2023 01:49:41 +0000 Subject: Add/modify CFG callbacks Required by self-attn guidance extension https://github.com/ashen-sensored/sd_webui_SAG --- modules/script_callbacks.py | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) (limited to 'modules/script_callbacks.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 7d9dd736..e83c6ecf 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -53,6 +53,21 @@ class CFGDenoiserParams: class CFGDenoisedParams: + def __init__(self, x, sampling_step, total_sampling_steps, inner_model): + self.x = x + """Latent image representation in the process of being denoised""" + + self.sampling_step = sampling_step + """Current Sampling step number""" + + self.total_sampling_steps = total_sampling_steps + """Total number of sampling steps planned""" + + self.inner_model = inner_model + """Inner model reference that is being used for denoising""" + + +class AfterCFGCallbackParams: def __init__(self, x, sampling_step, total_sampling_steps): self.x = x """Latent image representation in the process of being denoised""" @@ -63,6 +78,9 @@ class CFGDenoisedParams: self.total_sampling_steps = total_sampling_steps """Total number of sampling steps planned""" + self.output_altered = False + """A flag for CFGDenoiser that indicates whether the output has been altered by the callback""" + class UiTrainTabParams: def __init__(self, txt2img_preview_params): @@ -87,6 +105,7 @@ callback_map = dict( callbacks_image_saved=[], callbacks_cfg_denoiser=[], callbacks_cfg_denoised=[], + callbacks_cfg_after_cfg=[], callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], @@ -186,6 +205,14 @@ def cfg_denoised_callback(params: CFGDenoisedParams): report_exception(c, 'cfg_denoised_callback') +def cfg_after_cfg_callback(params: AfterCFGCallbackParams): + for c in callback_map['callbacks_cfg_after_cfg']: + try: + c.callback(params) + except Exception: + report_exception(c, 'cfg_after_cfg_callback') + + def before_component_callback(component, **kwargs): for c in callback_map['callbacks_before_component']: try: @@ -332,6 +359,14 @@ def on_cfg_denoised(callback): add_callback(callback_map['callbacks_cfg_denoised'], callback) +def on_cfg_after_cfg(callback): + """register a function to be called in the kdiffussion cfg_denoiser method after cfg calculations has completed. + The callback is called with one argument: + - params: CFGDenoisedParams - parameters to be passed to the inner model and sampling state details. + """ + add_callback(callback_map['callbacks_cfg_after_cfg'], callback) + + def on_before_component(callback): """register a function to be called before a component is created. The callback is called with arguments: -- cgit v1.2.3