From 5f12b23b8bb7fca585a3a1e844881d06f171364e Mon Sep 17 00:00:00 2001 From: AlUlkesh <99896447+AlUlkesh@users.noreply.github.com> Date: Wed, 28 Dec 2022 22:18:19 +0100 Subject: Adding image numbers on grids New grid option in settings enables adding of image numbers on grids. This makes identifying the images, especially in larger batches, much easier. Revert "Adding image numbers on grids" This reverts commit 3530c283b4b1d3a3cab40efbffe4cf2697938b6f. Implements Callback for image grid loop Necessary to make "Add image's number to its picture in the grid" extension possible. --- modules/script_callbacks.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) (limited to 'modules/script_callbacks.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 8e22f875..0c854407 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -51,6 +51,11 @@ class UiTrainTabParams: self.txt2img_preview_params = txt2img_preview_params +class ImageGridLoopParams: + def __init__(self, img): + self.img = img + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callback_map = dict( callbacks_app_started=[], @@ -63,6 +68,7 @@ callback_map = dict( callbacks_cfg_denoiser=[], callbacks_before_component=[], callbacks_after_component=[], + callbacks_image_grid_loop=[], ) @@ -154,6 +160,12 @@ def after_component_callback(component, **kwargs): except Exception: report_exception(c, 'after_component_callback') +def image_grid_loop_callback(component, **kwargs): + for c in callback_map['callbacks_image_grid_loop']: + try: + c.callback(component, **kwargs) + except Exception: + report_exception(c, 'image_grid_loop') def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] @@ -255,3 +267,11 @@ def on_before_component(callback): def on_after_component(callback): """register a function to be called after a component is created. See on_before_component for more.""" add_callback(callback_map['callbacks_after_component'], callback) + + +def on_image_grid_loop(callback): + """register a function to be called inside the image grid loop. + The callback is called with one argument: + - params: ImageGridLoopParams - parameters to be used inside the image grid loop. + """ + add_callback(callback_map['callbacks_image_grid_loop'], callback) -- cgit v1.2.3 From e672cfb07418a1a3130d3bf21c14a0d3819f81fb Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 1 Jan 2023 18:37:37 +0300 Subject: rework of callback for #6094 --- modules/images.py | 10 ++++++---- modules/script_callbacks.py | 26 +++++++++++++++----------- 2 files changed, 21 insertions(+), 15 deletions(-) (limited to 'modules/script_callbacks.py') diff --git a/modules/images.py b/modules/images.py index 719aaf3b..f84fd485 100644 --- a/modules/images.py +++ b/modules/images.py @@ -39,12 +39,14 @@ def image_grid(imgs, batch_size=1, rows=None): cols = math.ceil(len(imgs) / rows) + params = script_callbacks.ImageGridLoopParams(imgs, cols, rows) + script_callbacks.image_grid_callback(params) + w, h = imgs[0].size - grid = Image.new('RGB', size=(cols * w, rows * h), color='black') + grid = Image.new('RGB', size=(params.cols * w, params.rows * h), color='black') - for i, img in enumerate(imgs): - script_callbacks.image_grid_loop_callback(img) - grid.paste(img, box=(i % cols * w, i // cols * h)) + for i, img in enumerate(params.imgs): + grid.paste(img, box=(i % params.cols * w, i // params.cols * h)) return grid diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 0c854407..de69fd9f 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -52,8 +52,10 @@ class UiTrainTabParams: class ImageGridLoopParams: - def __init__(self, img): - self.img = img + def __init__(self, imgs, cols, rows): + self.imgs = imgs + self.cols = cols + self.rows = rows ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) @@ -68,7 +70,7 @@ callback_map = dict( callbacks_cfg_denoiser=[], callbacks_before_component=[], callbacks_after_component=[], - callbacks_image_grid_loop=[], + callbacks_image_grid=[], ) @@ -160,12 +162,14 @@ def after_component_callback(component, **kwargs): except Exception: report_exception(c, 'after_component_callback') -def image_grid_loop_callback(component, **kwargs): - for c in callback_map['callbacks_image_grid_loop']: + +def image_grid_callback(params: ImageGridLoopParams): + for c in callback_map['callbacks_image_grid']: try: - c.callback(component, **kwargs) + c.callback(params) except Exception: - report_exception(c, 'image_grid_loop') + report_exception(c, 'image_grid') + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] @@ -269,9 +273,9 @@ def on_after_component(callback): add_callback(callback_map['callbacks_after_component'], callback) -def on_image_grid_loop(callback): - """register a function to be called inside the image grid loop. +def on_image_grid(callback): + """register a function to be called before making an image grid. The callback is called with one argument: - - params: ImageGridLoopParams - parameters to be used inside the image grid loop. + - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ - add_callback(callback_map['callbacks_image_grid_loop'], callback) + add_callback(callback_map['callbacks_image_grid'], callback) -- cgit v1.2.3 From 65ed4421e609dda3112f236c13e4db14caa71364 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 6 Jan 2023 13:55:50 +0300 Subject: add callback for when the script is unloaded --- modules/script_callbacks.py | 18 +++++++++++++++++- webui.py | 2 ++ 2 files changed, 19 insertions(+), 1 deletion(-) (limited to 'modules/script_callbacks.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index de69fd9f..608c5300 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_script_unloaded=[], ) @@ -171,6 +172,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def script_unloaded_callback(): + for c in reversed(callback_map['callbacks_script_unloaded']): + try: + c.callback() + except Exception: + report_exception(c, 'script_unloaded') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -202,7 +211,7 @@ def on_app_started(callback): def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is - passed as an argument""" + passed as an argument; this function is also called when the script is reloaded. """ add_callback(callback_map['callbacks_model_loaded'], callback) @@ -279,3 +288,10 @@ def on_image_grid(callback): - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified. """ add_callback(callback_map['callbacks_image_grid'], callback) + + +def on_script_unloaded(callback): + """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that + the script did should be reverted here""" + + add_callback(callback_map['callbacks_script_unloaded'], callback) diff --git a/webui.py b/webui.py index ff6eb6eb..733a06b5 100644 --- a/webui.py +++ b/webui.py @@ -187,12 +187,14 @@ def webui(): sd_samplers.set_samplers() + modules.script_callbacks.script_unloaded_callback() extensions.list_extensions() localization.list_localizations(cmd_opts.localizations_dir) modelloader.forbid_loaded_nonbuiltin_upscalers() modules.scripts.reload_scripts() + modules.script_callbacks.model_loaded_callback(shared.sd_model) modelloader.load_upscalers() for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: -- cgit v1.2.3 From 6c88eaed4f5efca54a882eb1f8f30f01f350332a Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Thu, 12 Jan 2023 13:50:09 -0800 Subject: Add script callback for fixing infotext parameters --- modules/generation_parameters_copypaste.py | 3 ++- modules/script_callbacks.py | 20 +++++++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) (limited to 'modules/script_callbacks.py') diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index 620aa606..593d99ef 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -7,7 +7,7 @@ from pathlib import Path import gradio as gr from modules.shared import script_path -from modules import shared, ui_tempdir +from modules import shared, ui_tempdir, script_callbacks import tempfile from PIL import Image @@ -298,6 +298,7 @@ def connect_paste(button, paste_fields, input_comp, jsfunc=None): prompt = file.read() params = parse_generation_parameters(prompt) + script_callbacks.infotext_pasted_callback(prompt, params) res = [] for output, key in paste_fields: diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 608c5300..a9e19236 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,7 +2,7 @@ import sys import traceback from collections import namedtuple import inspect -from typing import Optional +from typing import Optional, Dict, Any from fastapi import FastAPI from gradio import Blocks @@ -71,6 +71,7 @@ callback_map = dict( callbacks_before_component=[], callbacks_after_component=[], callbacks_image_grid=[], + callbacks_infotext_pasted=[], callbacks_script_unloaded=[], ) @@ -172,6 +173,14 @@ def image_grid_callback(params: ImageGridLoopParams): report_exception(c, 'image_grid') +def infotext_pasted_callback(infotext: str, params: Dict[str, Any]): + for c in callback_map['callbacks_infotext_pasted']: + try: + c.callback(infotext, params) + except Exception: + report_exception(c, 'infotext_pasted') + + def script_unloaded_callback(): for c in reversed(callback_map['callbacks_script_unloaded']): try: @@ -290,6 +299,15 @@ def on_image_grid(callback): add_callback(callback_map['callbacks_image_grid'], callback) +def on_infotext_pasted(callback): + """register a function to be called before applying an infotext. + The callback is called with two arguments: + - infotext: str - raw infotext. + - result: Dict[str, any] - parsed infotext parameters. + """ + add_callback(callback_map['callbacks_infotext_pasted'], callback) + + def on_script_unloaded(callback): """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that the script did should be reverted here""" -- cgit v1.2.3 From 855b9e3d1c5a1bd8c2d815d38a38bc7c410be5a8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 21 Jan 2023 16:15:53 +0300 Subject: Lora support! update readme to reflect some recent changes --- README.md | 14 +- extensions-builtin/Lora/extra_networks_lora.py | 20 +++ extensions-builtin/Lora/lora.py | 198 ++++++++++++++++++++++ extensions-builtin/Lora/scripts/lora_script.py | 30 ++++ extensions-builtin/Lora/ui_extra_networks_lora.py | 35 ++++ modules/extra_networks_hypernet.py | 2 +- modules/script_callbacks.py | 15 ++ modules/ui_extra_networks.py | 2 +- webui.py | 2 + 9 files changed, 314 insertions(+), 4 deletions(-) create mode 100644 extensions-builtin/Lora/extra_networks_lora.py create mode 100644 extensions-builtin/Lora/lora.py create mode 100644 extensions-builtin/Lora/scripts/lora_script.py create mode 100644 extensions-builtin/Lora/ui_extra_networks_lora.py (limited to 'modules/script_callbacks.py') diff --git a/README.md b/README.md index 1ac794e8..9c0cd1ef 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,7 @@ A browser interface based on Gradio library for Stable Diffusion. - Possible to change defaults/mix/max/step values for UI elements via text config - Tiling support, a checkbox to create images that can be tiled like textures - Progress bar and live image generation preview + - Can use a separate neural network to produce previews with almost none VRAM or compute requirement - Negative prompt, an extra text field that allows you to list what you don't want to see in generated image - Styles, a way to save part of prompt and easily apply them via dropdown later - Variations, a way to generate same image but with tiny differences @@ -75,13 +76,22 @@ A browser interface based on Gradio library for Stable Diffusion. - hypernetworks and embeddings options - Preprocessing images: cropping, mirroring, autotagging using BLIP or deepdanbooru (for anime) - Clip skip -- Use Hypernetworks -- Use VAEs +- Hypernetworks +- Loras (same as Hypernetworks but more pretty) +- A sparate UI where you can choose, with preview, which embeddings, hypernetworks or Loras to add to your prompt. +- Can select to load a different VAE from settings screen - Estimated completion time in progress bar - API - Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. - via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embeds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients)) - [Stable Diffusion 2.0](https://github.com/Stability-AI/stablediffusion) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20) for instructions +- [Alt-Diffusion](https://arxiv.org/abs/2211.06679) support - see [wiki](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#alt-diffusion) for instructions +- Now without any bad letters! +- Load checkpoints in safetensors format +- Eased resolution restriction: generated image's domension must be a multiple of 8 rather than 64 +- Now with a license! +- Reorder elements in the UI from settings screen +- ## Installation and Running Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs. diff --git a/extensions-builtin/Lora/extra_networks_lora.py b/extensions-builtin/Lora/extra_networks_lora.py new file mode 100644 index 00000000..8f2e753e --- /dev/null +++ b/extensions-builtin/Lora/extra_networks_lora.py @@ -0,0 +1,20 @@ +from modules import extra_networks +import lora + +class ExtraNetworkLora(extra_networks.ExtraNetwork): + def __init__(self): + super().__init__('lora') + + def activate(self, p, params_list): + names = [] + multipliers = [] + for params in params_list: + assert len(params.items) > 0 + + names.append(params.items[0]) + multipliers.append(float(params.items[1]) if len(params.items) > 1 else 1.0) + + lora.load_loras(names, multipliers) + + def deactivate(self, p): + pass diff --git a/extensions-builtin/Lora/lora.py b/extensions-builtin/Lora/lora.py new file mode 100644 index 00000000..7a3ad9a9 --- /dev/null +++ b/extensions-builtin/Lora/lora.py @@ -0,0 +1,198 @@ +import glob +import os +import re +import torch + +from modules import shared, devices, sd_models + +re_digits = re.compile(r"\d+") +re_unet_down_blocks = re.compile(r"lora_unet_down_blocks_(\d+)_attentions_(\d+)_(.+)") +re_unet_mid_blocks = re.compile(r"lora_unet_mid_block_attentions_(\d+)_(.+)") +re_unet_up_blocks = re.compile(r"lora_unet_up_blocks_(\d+)_attentions_(\d+)_(.+)") +re_text_block = re.compile(r"lora_te_text_model_encoder_layers_(\d+)_(.+)") + + +def convert_diffusers_name_to_compvis(key): + def match(match_list, regex): + r = re.match(regex, key) + if not r: + return False + + match_list.clear() + match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) + return True + + m = [] + + if match(m, re_unet_down_blocks): + return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_unet_mid_blocks): + return f"diffusion_model_middle_block_1_{m[1]}" + + if match(m, re_unet_up_blocks): + return f"diffusion_model_output_blocks_{m[0] * 3 + m[1]}_1_{m[2]}" + + if match(m, re_text_block): + return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" + + return key + + +class LoraOnDisk: + def __init__(self, name, filename): + self.name = name + self.filename = filename + + +class LoraModule: + def __init__(self, name): + self.name = name + self.multiplier = 1.0 + self.modules = {} + self.mtime = None + + +class LoraUpDownModule: + def __init__(self): + self.up = None + self.down = None + + +def assign_lora_names_to_compvis_modules(sd_model): + lora_layer_mapping = {} + + for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + for name, module in shared.sd_model.model.named_modules(): + lora_name = name.replace(".", "_") + lora_layer_mapping[lora_name] = module + module.lora_layer_name = lora_name + + sd_model.lora_layer_mapping = lora_layer_mapping + + +def load_lora(name, filename): + lora = LoraModule(name) + lora.mtime = os.path.getmtime(filename) + + sd = sd_models.read_state_dict(filename) + + keys_failed_to_match = [] + + for key_diffusers, weight in sd.items(): + fullkey = convert_diffusers_name_to_compvis(key_diffusers) + key, lora_key = fullkey.split(".", 1) + + sd_module = shared.sd_model.lora_layer_mapping.get(key, None) + if sd_module is None: + keys_failed_to_match.append(key_diffusers) + continue + + if type(sd_module) == torch.nn.Linear: + module = torch.nn.Linear(weight.shape[1], weight.shape[0], bias=False) + elif type(sd_module) == torch.nn.Conv2d: + module = torch.nn.Conv2d(weight.shape[1], weight.shape[0], (1, 1), bias=False) + else: + assert False, f'Lora layer {key_diffusers} matched a layer with unsupported type: {type(sd_module).__name__}' + + with torch.no_grad(): + module.weight.copy_(weight) + + module.to(device=devices.device, dtype=devices.dtype) + + lora_module = lora.modules.get(key, None) + if lora_module is None: + lora_module = LoraUpDownModule() + lora.modules[key] = lora_module + + if lora_key == "lora_up.weight": + lora_module.up = module + elif lora_key == "lora_down.weight": + lora_module.down = module + else: + assert False, f'Bad Lora layer name: {key_diffusers} - must end in lora_up.weight or lora_down.weight' + + if len(keys_failed_to_match) > 0: + print(f"Failed to match keys when loading Lora {filename}: {keys_failed_to_match}") + + return lora + + +def load_loras(names, multipliers=None): + already_loaded = {} + + for lora in loaded_loras: + if lora.name in names: + already_loaded[lora.name] = lora + + loaded_loras.clear() + + loras_on_disk = [available_loras.get(name, None) for name in names] + if any([x is None for x in loras_on_disk]): + list_available_loras() + + loras_on_disk = [available_loras.get(name, None) for name in names] + + for i, name in enumerate(names): + lora = already_loaded.get(name, None) + + lora_on_disk = loras_on_disk[i] + if lora_on_disk is not None: + if lora is None or os.path.getmtime(lora_on_disk.filename) > lora.mtime: + lora = load_lora(name, lora_on_disk.filename) + + if lora is None: + print(f"Couldn't find Lora with name {name}") + continue + + lora.multiplier = multipliers[i] if multipliers else 1.0 + loaded_loras.append(lora) + + +def lora_forward(module, input, res): + if len(loaded_loras) == 0: + return res + + lora_layer_name = getattr(module, 'lora_layer_name', None) + for lora in loaded_loras: + module = lora.modules.get(lora_layer_name, None) + if module is not None: + res = res + module.up(module.down(input)) * lora.multiplier + + return res + + +def lora_Linear_forward(self, input): + return lora_forward(self, input, torch.nn.Linear_forward_before_lora(self, input)) + + +def lora_Conv2d_forward(self, input): + return lora_forward(self, input, torch.nn.Conv2d_forward_before_lora(self, input)) + + +def list_available_loras(): + available_loras.clear() + + os.makedirs(lora_dir, exist_ok=True) + + candidates = glob.glob(os.path.join(lora_dir, '**/*.pt'), recursive=True) + glob.glob(os.path.join(lora_dir, '**/*.safetensors'), recursive=True) + + for filename in sorted(candidates): + if os.path.isdir(filename): + continue + + name = os.path.splitext(os.path.basename(filename))[0] + + available_loras[name] = LoraOnDisk(name, filename) + + +lora_dir = os.path.join(shared.models_path, "Lora") +available_loras = {} +loaded_loras = [] + +list_available_loras() + diff --git a/extensions-builtin/Lora/scripts/lora_script.py b/extensions-builtin/Lora/scripts/lora_script.py new file mode 100644 index 00000000..60b9eb64 --- /dev/null +++ b/extensions-builtin/Lora/scripts/lora_script.py @@ -0,0 +1,30 @@ +import torch + +import lora +import extra_networks_lora +import ui_extra_networks_lora +from modules import script_callbacks, ui_extra_networks, extra_networks + + +def unload(): + torch.nn.Linear.forward = torch.nn.Linear_forward_before_lora + torch.nn.Conv2d.forward = torch.nn.Conv2d_forward_before_lora + + +def before_ui(): + ui_extra_networks.register_page(ui_extra_networks_lora.ExtraNetworksPageLora()) + extra_networks.register_extra_network(extra_networks_lora.ExtraNetworkLora()) + + +if not hasattr(torch.nn, 'Linear_forward_before_lora'): + torch.nn.Linear_forward_before_lora = torch.nn.Linear.forward + +if not hasattr(torch.nn, 'Conv2d_forward_before_lora'): + torch.nn.Conv2d_forward_before_lora = torch.nn.Conv2d.forward + +torch.nn.Linear.forward = lora.lora_Linear_forward +torch.nn.Conv2d.forward = lora.lora_Conv2d_forward + +script_callbacks.on_model_loaded(lora.assign_lora_names_to_compvis_modules) +script_callbacks.on_script_unloaded(unload) +script_callbacks.on_before_ui(before_ui) diff --git a/extensions-builtin/Lora/ui_extra_networks_lora.py b/extensions-builtin/Lora/ui_extra_networks_lora.py new file mode 100644 index 00000000..65397890 --- /dev/null +++ b/extensions-builtin/Lora/ui_extra_networks_lora.py @@ -0,0 +1,35 @@ +import os +import lora + +from modules import shared, ui_extra_networks + + +class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage): + def __init__(self): + super().__init__('Lora') + + def refresh(self): + lora.list_available_loras() + + def list_items(self): + for name, lora_on_disk in lora.available_loras.items(): + path, ext = os.path.splitext(lora_on_disk.filename) + previews = [path + ".png", path + ".preview.png"] + + preview = None + for file in previews: + if os.path.isfile(file): + preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file)) + break + + yield { + "name": name, + "filename": path, + "preview": preview, + "prompt": f"", + "local_preview": path + ".png", + } + + def allowed_directories_for_previews(self): + return [lora.lora_dir] + diff --git a/modules/extra_networks_hypernet.py b/modules/extra_networks_hypernet.py index 6a0c4ba8..ff279a1f 100644 --- a/modules/extra_networks_hypernet.py +++ b/modules/extra_networks_hypernet.py @@ -17,5 +17,5 @@ class ExtraNetworkHypernet(extra_networks.ExtraNetwork): hypernetwork.load_hypernetworks(names, multipliers) - def deactivate(p, self): + def deactivate(self, p): pass diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index a9e19236..4bb45ec7 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -73,6 +73,7 @@ callback_map = dict( callbacks_image_grid=[], callbacks_infotext_pasted=[], callbacks_script_unloaded=[], + callbacks_before_ui=[], ) @@ -189,6 +190,14 @@ def script_unloaded_callback(): report_exception(c, 'script_unloaded') +def before_ui_callback(): + for c in reversed(callback_map['callbacks_before_ui']): + try: + c.callback() + except Exception: + report_exception(c, 'before_ui') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -313,3 +322,9 @@ def on_script_unloaded(callback): the script did should be reverted here""" add_callback(callback_map['callbacks_script_unloaded'], callback) + + +def on_before_ui(callback): + """register a function to be called before the UI is created.""" + + add_callback(callback_map['callbacks_before_ui'], callback) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 253e90f7..796e879c 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -10,7 +10,7 @@ extra_pages = [] def register_page(page): - """registers extra networks page for the UI; recommend doing it in on_app_started() callback for extensions""" + """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions""" extra_pages.append(page) diff --git a/webui.py b/webui.py index e8dd822a..88d04840 100644 --- a/webui.py +++ b/webui.py @@ -165,6 +165,8 @@ def webui(): if shared.opts.clean_temp_dir_at_start: ui_tempdir.cleanup_tmpdr() + modules.script_callbacks.before_ui_callback() + shared.demo = modules.ui.create_ui() app, local_url, share_url = shared.demo.launch( -- cgit v1.2.3