From 893191cab24cc3511135495d6d2c8d81f5ec63a3 Mon Sep 17 00:00:00 2001 From: Tong Zeng Date: Thu, 10 Nov 2022 10:34:03 +0800 Subject: fix a bug in list_files_with_name --- modules/scripts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index 637b2329..22d8908b 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -140,7 +140,7 @@ def list_files_with_name(filename): continue path = os.path.join(dirpath, filename) - if os.path.isfile(filename): + if os.path.isfile(path): res.append(path) return res -- cgit v1.2.3 From a1a376331c9ecbbee77b86daeaba44587cc56557 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 12 Nov 2022 10:56:06 +0300 Subject: make existing script loading and new preload code use same code for loading modules limit extension preload scripts to just one file named preload.py --- modules/extensions.py | 21 --------------------- modules/script_loading.py | 34 ++++++++++++++++++++++++++++++++++ modules/scripts.py | 46 +++++++++++++++++----------------------------- modules/shared.py | 5 ++--- 4 files changed, 53 insertions(+), 53 deletions(-) create mode 100644 modules/script_loading.py (limited to 'modules/scripts.py') diff --git a/modules/extensions.py b/modules/extensions.py index 544f3580..94ce479a 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,7 +1,6 @@ import os import sys import traceback -from importlib.machinery import SourceFileLoader import git @@ -85,23 +84,3 @@ def list_extensions(): extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) extensions.append(extension) - -def preload_extensions(parser): - if not os.path.isdir(extensions_dir): - return - - for dirname in sorted(os.listdir(extensions_dir)): - path = os.path.join(extensions_dir, dirname) - if not os.path.isdir(path): - continue - for file in os.listdir(path): - if "preload.py" in file: - full_file = os.path.join(path, file) - print(f"Got preload file: {full_file}") - - try: - ext = SourceFileLoader("preload", full_file).load_module() - parser = ext.preload(parser) - except Exception as e: - print(f"Exception preloading script: {e}") - return parser \ No newline at end of file diff --git a/modules/script_loading.py b/modules/script_loading.py new file mode 100644 index 00000000..f93f0951 --- /dev/null +++ b/modules/script_loading.py @@ -0,0 +1,34 @@ +import os +import sys +import traceback +from types import ModuleType + + +def load_module(path): + with open(path, "r", encoding="utf8") as file: + text = file.read() + + compiled = compile(text, path, 'exec') + module = ModuleType(os.path.basename(path)) + exec(compiled, module.__dict__) + + return module + + +def preload_extensions(extensions_dir, parser): + if not os.path.isdir(extensions_dir): + return + + for dirname in sorted(os.listdir(extensions_dir)): + preload_script = os.path.join(extensions_dir, dirname, "preload.py") + if not os.path.isfile(preload_script): + continue + + try: + module = load_module(preload_script) + if hasattr(module, 'preload'): + module.preload(parser) + + except Exception: + print(f"Error running preload() for {preload_script}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) diff --git a/modules/scripts.py b/modules/scripts.py index 22d8908b..986b1914 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -6,7 +6,7 @@ from collections import namedtuple import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks, extensions +from modules import shared, paths, script_callbacks, extensions, script_loading AlwaysVisible = object() @@ -161,13 +161,7 @@ def load_scripts(): sys.path = [scriptfile.basedir] + sys.path current_basedir = scriptfile.basedir - with open(scriptfile.path, "r", encoding="utf8") as file: - text = file.read() - - from types import ModuleType - compiled = compile(text, scriptfile.path, 'exec') - module = ModuleType(scriptfile.filename) - exec(compiled, module.__dict__) + module = script_loading.load_module(scriptfile.path) for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): @@ -328,27 +322,21 @@ class ScriptRunner: def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): - with open(script.filename, "r", encoding="utf8") as file: - args_from = script.args_from - args_to = script.args_to - filename = script.filename - text = file.read() - - from types import ModuleType - - module = cache.get(filename, None) - if module is None: - compiled = compile(text, filename, 'exec') - module = ModuleType(script.filename) - exec(compiled, module.__dict__) - cache[filename] = module - - for key, script_class in module.__dict__.items(): - if type(script_class) == type and issubclass(script_class, Script): - self.scripts[si] = script_class() - self.scripts[si].filename = filename - self.scripts[si].args_from = args_from - self.scripts[si].args_to = args_to + args_from = script.args_from + args_to = script.args_to + filename = script.filename + + module = cache.get(filename, None) + if module is None: + module = script_loading.load_module(script.filename) + cache[filename] = module + + for key, script_class in module.__dict__.items(): + if type(script_class) == type and issubclass(script_class, Script): + self.scripts[si] = script_class() + self.scripts[si].filename = filename + self.scripts[si].args_from = args_from + self.scripts[si].args_to = args_to scripts_txt2img = ScriptRunner() diff --git a/modules/shared.py b/modules/shared.py index 17132e42..6936cbe0 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -3,7 +3,6 @@ import datetime import json import os import sys -from collections import OrderedDict import time import gradio as gr @@ -15,7 +14,7 @@ import modules.memmon import modules.sd_models import modules.styles import modules.devices as devices -from modules import sd_samplers, sd_models, localization, sd_vae, extensions +from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading from modules.hypernetworks import hypernetwork from modules.paths import models_path, script_path, sd_path @@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) -extensions.preload_extensions(parser) +script_loading.preload_extensions(extensions.extensions_dir, parser) cmd_opts = parser.parse_args() -- cgit v1.2.3 From 3596af07493ab7981ef92074f979eeee8fa624c4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 19 Nov 2022 19:10:17 +0300 Subject: Add API for scripts to add elements anywhere in UI. --- modules/script_callbacks.py | 35 +++++++++++++++++++++++ modules/scripts.py | 69 +++++++++++++++++++++++++++++++++++++++++++-- modules/ui.py | 12 ++++++-- 3 files changed, 111 insertions(+), 5 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index f19e164c..8e22f875 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -61,6 +61,8 @@ callback_map = dict( callbacks_before_image_saved=[], callbacks_image_saved=[], callbacks_cfg_denoiser=[], + callbacks_before_component=[], + callbacks_after_component=[], ) @@ -137,6 +139,22 @@ def cfg_denoiser_callback(params: CFGDenoiserParams): report_exception(c, 'cfg_denoiser_callback') +def before_component_callback(component, **kwargs): + for c in callback_map['callbacks_before_component']: + try: + c.callback(component, **kwargs) + except Exception: + report_exception(c, 'before_component_callback') + + +def after_component_callback(component, **kwargs): + for c in callback_map['callbacks_after_component']: + try: + c.callback(component, **kwargs) + except Exception: + report_exception(c, 'after_component_callback') + + def add_callback(callbacks, fun): stack = [x for x in inspect.stack() if x.filename != __file__] filename = stack[0].filename if len(stack) > 0 else 'unknown file' @@ -220,3 +238,20 @@ def on_cfg_denoiser(callback): - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details. """ add_callback(callback_map['callbacks_cfg_denoiser'], callback) + + +def on_before_component(callback): + """register a function to be called before a component is created. + The callback is called with arguments: + - component - gradio component that is about to be created. + - **kwargs - args to gradio.components.IOComponent.__init__ function + + Use elem_id/label fields of kwargs to figure out which component it is. + This can be useful to inject your own components somewhere in the middle of vanilla UI. + """ + add_callback(callback_map['callbacks_before_component'], callback) + + +def on_after_component(callback): + """register a function to be called after a component is created. See on_before_component for more.""" + add_callback(callback_map['callbacks_after_component'], callback) diff --git a/modules/scripts.py b/modules/scripts.py index 986b1914..b934d881 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -17,6 +17,9 @@ class Script: args_to = None alwayson = False + is_txt2img = False + is_img2img = False + """A gr.Group component that has all script's UI inside it""" group = None @@ -93,6 +96,23 @@ class Script: pass + def before_component(self, component, **kwargs): + """ + Called before a component is created. + Use elem_id/label fields of kwargs to figure out which component it is. + This can be useful to inject your own components somewhere in the middle of vanilla UI. + You can return created components in the ui() function to add them to the list of arguments for your processing functions + """ + + pass + + def after_component(self, component, **kwargs): + """ + Called after a component is created. Same as above. + """ + + pass + def describe(self): """unused""" return "" @@ -195,12 +215,18 @@ class ScriptRunner: self.titles = [] self.infotext_fields = [] - def setup_ui(self, is_img2img): + def initialize_scripts(self, is_img2img): + self.scripts.clear() + self.alwayson_scripts.clear() + self.selectable_scripts.clear() + for script_class, path, basedir in scripts_data: script = script_class() script.filename = path + script.is_txt2img = not is_img2img + script.is_img2img = is_img2img - visibility = script.show(is_img2img) + visibility = script.show(script.is_img2img) if visibility == AlwaysVisible: self.scripts.append(script) @@ -211,6 +237,7 @@ class ScriptRunner: self.scripts.append(script) self.selectable_scripts.append(script) + def setup_ui(self): self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts] inputs = [None] @@ -220,7 +247,7 @@ class ScriptRunner: script.args_from = len(inputs) script.args_to = len(inputs) - controls = wrap_call(script.ui, script.filename, "ui", is_img2img) + controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img) if controls is None: return @@ -320,6 +347,22 @@ class ScriptRunner: print(f"Error running postprocess: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) + def before_component(self, component, **kwargs): + for script in self.scripts: + try: + script.before_component(component, **kwargs) + except Exception: + print(f"Error running before_component: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + def after_component(self, component, **kwargs): + for script in self.scripts: + try: + script.after_component(component, **kwargs) + except Exception: + print(f"Error running after_component: {script.filename}", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): args_from = script.args_from @@ -341,6 +384,7 @@ class ScriptRunner: scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() +scripts_current: ScriptRunner = None def reload_script_body_only(): @@ -357,3 +401,22 @@ def reload_scripts(): scripts_txt2img = ScriptRunner() scripts_img2img = ScriptRunner() + +def IOComponent_init(self, *args, **kwargs): + if scripts_current is not None: + scripts_current.before_component(self, **kwargs) + + script_callbacks.before_component_callback(self, **kwargs) + + res = original_IOComponent_init(self, *args, **kwargs) + + script_callbacks.after_component_callback(self, **kwargs) + + if scripts_current is not None: + scripts_current.after_component(self, **kwargs) + + return res + + +original_IOComponent_init = gr.components.IOComponent.__init__ +gr.components.IOComponent.__init__ = IOComponent_init diff --git a/modules/ui.py b/modules/ui.py index bb090c62..a5953fce 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -695,6 +695,9 @@ def create_ui(wrap_gradio_gpu_call): parameters_copypaste.reset() + modules.scripts.scripts_current = modules.scripts.scripts_txt2img + modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) dummy_component = gr.Label(visible=False) @@ -737,7 +740,7 @@ def create_ui(wrap_gradio_gpu_call): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() with gr.Group(): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False) + custom_inputs = modules.scripts.scripts_txt2img.setup_ui() txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples) parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt) @@ -846,6 +849,9 @@ def create_ui(wrap_gradio_gpu_call): token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter]) + modules.scripts.scripts_current = modules.scripts.scripts_img2img + modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + with gr.Blocks(analytics_enabled=False) as img2img_interface: img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True) @@ -916,7 +922,7 @@ def create_ui(wrap_gradio_gpu_call): seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs() with gr.Group(): - custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True) + custom_inputs = modules.scripts.scripts_img2img.setup_ui() img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples) parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt) @@ -1065,6 +1071,8 @@ def create_ui(wrap_gradio_gpu_call): parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields) + modules.scripts.scripts_current = None + with gr.Blocks(analytics_enabled=False) as extras_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): -- cgit v1.2.3