From a2a1a2f7270a865175f64475229838a8d64509ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 09:02:25 +0300 Subject: add ability to create extensions that add localizations --- modules/ui.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py index 76ca9b07..23643c22 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1563,11 +1563,10 @@ def create_ui(wrap_gradio_gpu_call): shared.state.need_restart = True restart_gradio.click( - fn=request_restart, + _js='restart_reload', inputs=[], outputs=[], - _js='restart_reload' ) if column is not None: -- cgit v1.2.3 From e5b4e3f820cd09e751f1d168ab05d606d078a0d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 10:12:53 +0300 Subject: add tags to extensions, and ability to filter out tags list changed Settings keys in UI do not print VRAM/etc stats everywhere but in calls that use GPU --- modules/ui.py | 25 ++++++++++++---------- modules/ui_extensions.py | 55 ++++++++++++++++++++++++++++++++++++++---------- style.css | 5 +++++ webui.py | 2 +- 4 files changed, 64 insertions(+), 23 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py index 23643c22..c946ad59 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -174,9 +174,9 @@ def save_pil_to_file(pil_image, dir=None): gr.processing_utils.save_pil_to_file = save_pil_to_file -def wrap_gradio_call(func, extra_outputs=None): +def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def f(*args, extra_outputs_array=extra_outputs, **kwargs): - run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled + run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats if run_memmon: shared.mem_mon.monitor() t = time.perf_counter() @@ -203,11 +203,18 @@ def wrap_gradio_call(func, extra_outputs=None): res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + shared.state.skipped = False + shared.state.interrupted = False + shared.state.job_count = 0 + + if not add_stats: + return tuple(res) + elapsed = time.perf_counter() - t elapsed_m = int(elapsed // 60) elapsed_s = elapsed % 60 elapsed_text = f"{elapsed_s:.2f}s" - if (elapsed_m > 0): + if elapsed_m > 0: elapsed_text = f"{elapsed_m}m "+elapsed_text if run_memmon: @@ -225,10 +232,6 @@ def wrap_gradio_call(func, extra_outputs=None): # last item is always HTML res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" - shared.state.skipped = False - shared.state.interrupted = False - shared.state.job_count = 0 - return tuple(res) return f @@ -1436,7 +1439,7 @@ def create_ui(wrap_gradio_gpu_call): opts.reorder() def run_settings(*args): - changed = 0 + changed = [] for key, value, comp in zip(opts.data_labels.keys(), args, components): assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" @@ -1454,12 +1457,12 @@ def create_ui(wrap_gradio_gpu_call): if opts.data_labels[key].onchange is not None: opts.data_labels[key].onchange() - changed += 1 + changed.append(key) try: opts.save(shared.config_filename) except RuntimeError: - return opts.dumpjson(), f'{changed} settings changed without save.' - return opts.dumpjson(), f'{changed} settings changed.' + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' def run_settings_single(value, key): if not opts.same_type(value, opts.data_labels[key].default): diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 8e0d41d5..02ab9643 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -140,13 +140,15 @@ def install_extension_from_url(dirname, url): shutil.rmtree(tmpdir, True) -def install_extension_from_index(url): +def install_extension_from_index(url, hide_tags): ext_table, message = install_extension_from_url(None, url) - return refresh_available_extensions_from_data(), ext_table, message + code, _ = refresh_available_extensions_from_data(hide_tags) + return code, ext_table, message -def refresh_available_extensions(url): + +def refresh_available_extensions(url, hide_tags): global available_extensions import urllib.request @@ -155,13 +157,25 @@ def refresh_available_extensions(url): available_extensions = json.loads(text) - return url, refresh_available_extensions_from_data(), '' + code, tags = refresh_available_extensions_from_data(hide_tags) + + return url, code, gr.CheckboxGroup.update(choices=tags), '' + + +def refresh_available_extensions_for_tags(hide_tags): + code, _ = refresh_available_extensions_from_data(hide_tags) + return code, '' -def refresh_available_extensions_from_data(): + +def refresh_available_extensions_from_data(hide_tags): extlist = available_extensions["extensions"] installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} + tags = available_extensions.get("tags", {}) + tags_to_hide = set(hide_tags) + hidden = 0 + code = f""" @@ -178,17 +192,24 @@ def refresh_available_extensions_from_data(): name = ext.get("name", "noname") url = ext.get("url", None) description = ext.get("description", "") + extension_tags = ext.get("tags", []) if url is None: continue + if len([x for x in extension_tags if x in tags_to_hide]) > 0: + hidden += 1 + continue + existing = installed_extension_urls.get(normalize_git_url(url), None) install_code = f"""""" + tags_text = ", ".join([f"{x}" for x in extension_tags]) + code += f""" - + @@ -199,7 +220,10 @@ def refresh_available_extensions_from_data():
{html.escape(name)}{html.escape(name)}
{tags_text}
{html.escape(description)} {install_code}
""" - return code + if hidden > 0: + code += f"

Extension hidden: {hidden}

" + + return code, list(tags) def create_ui(): @@ -238,21 +262,30 @@ def create_ui(): extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) + with gr.Row(): + hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"]) + install_result = gr.HTML() available_extensions_table = gr.HTML() refresh_available_extensions_button.click( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]), - inputs=[available_extensions_index], - outputs=[available_extensions_index, available_extensions_table, install_result], + fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), + inputs=[available_extensions_index, hide_tags], + outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result], ) install_extension_button.click( fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), - inputs=[extension_to_install], + inputs=[extension_to_install, hide_tags], outputs=[available_extensions_table, extensions_table, install_result], ) + hide_tags.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags], + outputs=[available_extensions_table, install_result] + ) + with gr.TabItem("Install from URL"): install_url = gr.Text(label="URL for extension's git repository") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") diff --git a/style.css b/style.css index a0382a8c..e2b71f25 100644 --- a/style.css +++ b/style.css @@ -563,6 +563,11 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h opacity: 0.5; } +.extension-tag{ + font-weight: bold; + font-size: 95%; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running diff --git a/webui.py b/webui.py index 4342a962..f4f1d74d 100644 --- a/webui.py +++ b/webui.py @@ -57,7 +57,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return res - return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) + return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) def initialize(): -- cgit v1.2.3 From 32c0eab89538ba3900bf499291720f80ae4b43e5 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 14:39:41 +0300 Subject: load all settings in one call instead of one by one when the page loads --- modules/ui.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py index c946ad59..34c31ef1 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1141,7 +1141,7 @@ def create_ui(wrap_gradio_gpu_call): outputs=[html, generation_info, html2], ) - with gr.Blocks() as modelmerger_interface: + with gr.Blocks(analytics_enabled=False) as modelmerger_interface: with gr.Row().style(equal_height=False): with gr.Column(variant='panel'): gr.HTML(value="

A merger of the two checkpoints will be generated in your checkpoint directory.

") @@ -1161,7 +1161,7 @@ def create_ui(wrap_gradio_gpu_call): sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings() - with gr.Blocks() as train_interface: + with gr.Blocks(analytics_enabled=False) as train_interface: with gr.Row().style(equal_height=False): gr.HTML(value="

See wiki for detailed explanation.

") @@ -1420,15 +1420,14 @@ def create_ui(wrap_gradio_gpu_call): if info.refresh is not None: if is_quicksettings: - res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: with gr.Row(variant="compact"): - res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) create_refresh_button(res, info.refresh, info.component_args, "refresh_" + key) else: - res = comp(label=info.label, value=fun, elem_id=elem_id, **(args or {})) - + res = comp(label=info.label, value=fun(), elem_id=elem_id, **(args or {})) return res @@ -1639,6 +1638,17 @@ def create_ui(wrap_gradio_gpu_call): outputs=[component, text_settings], ) + component_keys = [k for k in opts.data_labels.keys() if k in component_dict] + + def get_settings_values(): + return [getattr(opts, key) for key in component_keys] + + demo.load( + fn=get_settings_values, + inputs=[], + outputs=[component_dict[k] for k in component_keys], + ) + def modelmerger(*args): try: results = modules.extras.run_modelmerger(*args) -- cgit v1.2.3 From c5334fc56b3d44976425da2e6d0a303ae96836a1 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 8 Nov 2022 08:35:01 +0300 Subject: fix javascript duplication bug after pressing the restart UI button --- modules/ui.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/ui.py b/modules/ui.py index 34c31ef1..67cf1d6a 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1752,7 +1752,7 @@ def create_ui(wrap_gradio_gpu_call): return demo -def load_javascript(raw_response): +def reload_javascript(): with open(os.path.join(script_path, "script.js"), "r", encoding="utf8") as jsfile: javascript = f'' @@ -1768,7 +1768,7 @@ def load_javascript(raw_response): javascript += f"\n" def template_response(*args, **kwargs): - res = raw_response(*args, **kwargs) + res = shared.GradioTemplateResponseOriginal(*args, **kwargs) res.body = res.body.replace( b'', f'{javascript}'.encode("utf8")) res.init_headers() @@ -1777,4 +1777,5 @@ def load_javascript(raw_response): gradio.routes.templates.TemplateResponse = template_response -reload_javascript = partial(load_javascript, gradio.routes.templates.TemplateResponse) +if not hasattr(shared, 'GradioTemplateResponseOriginal'): + shared.GradioTemplateResponseOriginal = gradio.routes.templates.TemplateResponse -- cgit v1.2.3 From 1610b3258458025025e9c4faae57d290e4519745 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Tue, 8 Nov 2022 08:38:10 +0300 Subject: add callback for creating a tab in train UI --- modules/script_callbacks.py | 27 +++++++++++++++++++++++++-- modules/ui.py | 4 ++++ 2 files changed, 29 insertions(+), 2 deletions(-) (limited to 'modules/ui.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 74dfb880..f19e164c 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -7,6 +7,7 @@ from typing import Optional from fastapi import FastAPI from gradio import Blocks + def report_exception(c, job): print(f"Error executing callback {job} for {c.script}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) @@ -45,15 +46,21 @@ class CFGDenoiserParams: """Total number of sampling steps planned""" +class UiTrainTabParams: + def __init__(self, txt2img_preview_params): + self.txt2img_preview_params = txt2img_preview_params + + ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"]) callback_map = dict( callbacks_app_started=[], callbacks_model_loaded=[], callbacks_ui_tabs=[], + callbacks_ui_train_tabs=[], callbacks_ui_settings=[], callbacks_before_image_saved=[], callbacks_image_saved=[], - callbacks_cfg_denoiser=[] + callbacks_cfg_denoiser=[], ) @@ -61,6 +68,7 @@ def clear_callbacks(): for callback_list in callback_map.values(): callback_list.clear() + def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callback_map['callbacks_app_started']: try: @@ -79,7 +87,7 @@ def model_loaded_callback(sd_model): def ui_tabs_callback(): res = [] - + for c in callback_map['callbacks_ui_tabs']: try: res += c.callback() or [] @@ -89,6 +97,14 @@ def ui_tabs_callback(): return res +def ui_train_tabs_callback(params: UiTrainTabParams): + for c in callback_map['callbacks_ui_train_tabs']: + try: + c.callback(params) + except Exception: + report_exception(c, 'callbacks_ui_train_tabs') + + def ui_settings_callback(): for c in callback_map['callbacks_ui_settings']: try: @@ -169,6 +185,13 @@ def on_ui_tabs(callback): add_callback(callback_map['callbacks_ui_tabs'], callback) +def on_ui_train_tabs(callback): + """register a function to be called when the UI is creating new tabs for the train tab. + Create your new tabs with gr.Tab. + """ + add_callback(callback_map['callbacks_ui_train_tabs'], callback) + + def on_ui_settings(callback): """register a function to be called before UI settings are populated; add your settings by using shared.opts.add_option(shared.OptionInfo(...)) """ diff --git a/modules/ui.py b/modules/ui.py index 67cf1d6a..7ea1177f 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1270,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call): train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary') train_embedding = gr.Button(value="Train Embedding", variant='primary') + params = script_callbacks.UiTrainTabParams(txt2img_preview_params) + + script_callbacks.ui_train_tabs_callback(params) + with gr.Column(): progressbar = gr.HTML(elem_id="ti_progressbar") ti_output = gr.Text(elem_id="ti_output", value="", show_label=False) -- cgit v1.2.3