From 5ef7590324891ec7263c767d178a51827a6f9b33 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 11:38:59 +0300 Subject: always show extra networks tabs in the UI --- modules/ui_extra_networks.py | 58 ++++++++++++++++++-------------------------- 1 file changed, 23 insertions(+), 35 deletions(-) (limited to 'modules/ui_extra_networks.py') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 6c73998f..0eb02873 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -355,7 +355,7 @@ def pages_in_preferred_order(pages): return sorted(pages, key=lambda x: tab_scores[x.name]) -def create_ui(container, button, tabname): +def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] @@ -363,48 +363,35 @@ def create_ui(container, button, tabname): ui.stored_extra_pages = pages_in_preferred_order(extra_pages.copy()) ui.tabname = tabname - with gr.Tabs(elem_id=tabname+"_extra_tabs"): - for page in ui.stored_extra_pages: - with gr.Tab(page.title, id=page.id_page): - elem_id = f"{tabname}_{page.id_page}_cards_html" - page_elem = gr.HTML('Loading...', elem_id=elem_id) - ui.pages.append(page_elem) + related_tabs = [] - page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) + for page in ui.stored_extra_pages: + with gr.Tab(page.title, id=page.id_page) as tab: + elem_id = f"{tabname}_{page.id_page}_cards_html" + page_elem = gr.HTML('Loading...', elem_id=elem_id) + ui.pages.append(page_elem) - editor = page.create_user_metadata_editor(ui, tabname) - editor.create_ui() - ui.user_metadata_editors.append(editor) + page_elem.change(fn=lambda: None, _js='function(){applyExtraNetworkFilter(' + quote_js(tabname) + '); return []}', inputs=[], outputs=[]) - gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", placeholder="Search...", visible=False) - gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", multiselect=False, visible=False, show_label=False, interactive=True) - ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder") - button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh") + editor = page.create_user_metadata_editor(ui, tabname) + editor.create_ui() + ui.user_metadata_editors.append(editor) - ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) - ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - - def toggle_visibility(is_visible): - is_visible = not is_visible - - return is_visible, gr.update(visible=is_visible), gr.update(variant=("secondary-down" if is_visible else "secondary")) + related_tabs.append(tab) - def fill_tabs(is_empty): - """Creates HTML for extra networks' tabs when the extra networks button is clicked for the first time.""" + edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) + dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True) + button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) + button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) - if not ui.pages_contents: - refresh() - - if is_empty: - return True, *ui.pages_contents - - return True, *[gr.update() for _ in ui.pages_contents] + ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) + ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) - state_visible = gr.State(value=False) - button.click(fn=toggle_visibility, inputs=[state_visible], outputs=[state_visible, container, button], show_progress=False) + for tab in unrelated_tabs: + tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False) - state_empty = gr.State(value=True) - button.click(fn=fill_tabs, inputs=[state_empty], outputs=[state_empty, *ui.pages], show_progress=False) + for tab in related_tabs: + tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False) def refresh(): for pg in ui.stored_extra_pages: @@ -414,6 +401,7 @@ def create_ui(container, button, tabname): return ui.pages_contents + interface.load(fn=refresh, inputs=[], outputs=[*ui.pages]) button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui -- cgit v1.2.3 From 57d61de25cb6de2e317ae23580971e98c70f542e Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sun, 16 Jul 2023 11:52:29 +0300 Subject: fix unneded reload from disk --- modules/ui_extra_networks.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) (limited to 'modules/ui_extra_networks.py') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 0eb02873..c11f1d5b 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -393,6 +393,12 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): for tab in related_tabs: tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False) + def pages_html(): + if not ui.pages_contents: + return refresh() + + return ui.pages_contents + def refresh(): for pg in ui.stored_extra_pages: pg.refresh() @@ -401,7 +407,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): return ui.pages_contents - interface.load(fn=refresh, inputs=[], outputs=[*ui.pages]) + interface.load(fn=pages_html, inputs=[], outputs=[*ui.pages]) button_refresh.click(fn=refresh, inputs=[], outputs=ui.pages) return ui -- cgit v1.2.3 From 543ea5730b8c2eea271739cab74bd962b45a4fea Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Mon, 17 Jul 2023 16:15:52 +0900 Subject: fix extra search button --- javascript/extraNetworks.js | 2 +- modules/ui_extra_networks.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/ui_extra_networks.py') diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 1835717b..2361144a 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -177,7 +177,7 @@ function saveCardPreview(event, tabname, filename) { } function extraNetworksSearchButton(tabs_id, event) { - var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > div > textarea'); + var searchTextarea = gradioApp().querySelector("#" + tabs_id + ' > label > textarea'); var button = event.target; var text = button.classList.contains("search-all") ? "" : button.textContent.trim(); diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index c11f1d5b..b913cb3e 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -163,7 +163,7 @@ class ExtraNetworksPage: subdirs = {"": 1, **subdirs} subdirs_html = "".join([f""" - """ for subdir in subdirs]) -- cgit v1.2.3 From c278e60131d34b58069c91d441e60a5d87f14a22 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Wed, 19 Jul 2023 04:58:30 +0900 Subject: add dropdown extra_sort_order lable --- modules/ui_extra_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/ui_extra_networks.py') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index b913cb3e..7387d01e 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -380,7 +380,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): related_tabs.append(tab) edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) - dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True) + dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) -- cgit v1.2.3 From b73c405013f63afa82d6358f9ce544931e5e9bc6 Mon Sep 17 00:00:00 2001 From: Littleor Date: Wed, 26 Jul 2023 11:02:34 +0800 Subject: fix: error rendering name and description in extra network ui --- modules/ui_extra_networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/ui_extra_networks.py') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 49612298..48537bc1 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -253,8 +253,8 @@ class ExtraNetworksPage: "prompt": item.get("prompt", None), "tabname": quote_js(tabname), "local_preview": quote_js(item["local_preview"]), - "name": item["name"], - "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), + "name": html.escape(item["name"]), + "description": html.escape(item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), -- cgit v1.2.3 From 187323a606bc0e4913240e5f20c51c9789234654 Mon Sep 17 00:00:00 2001 From: Littleor Date: Wed, 26 Jul 2023 17:23:57 +0800 Subject: fix: extra network ui description allow HTML tags --- modules/ui_extra_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'modules/ui_extra_networks.py') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 48537bc1..f2752f10 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -254,7 +254,7 @@ class ExtraNetworksPage: "tabname": quote_js(tabname), "local_preview": quote_js(item["local_preview"]), "name": html.escape(item["name"]), - "description": html.escape(item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), + "description": (item.get("description") or "" if shared.opts.extra_networks_card_show_desc else ""), "card_clicked": onclick, "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {quote_js(tabname)}, {quote_js(item["local_preview"])})""") + '"', "search_term": item.get("search_term", ""), -- cgit v1.2.3 From 45601766409e531d2b4ee512bf1433600f140183 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 4 Aug 2023 22:05:40 +0300 Subject: added VAE selection to checkpoint user metadata --- modules/extra_networks.py | 19 +++++++ modules/sd_vae.py | 13 ++++- modules/ui_extra_networks.py | 13 +---- modules/ui_extra_networks_checkpoints.py | 3 ++ .../ui_extra_networks_checkpoints_user_metadata.py | 60 ++++++++++++++++++++++ 5 files changed, 96 insertions(+), 12 deletions(-) create mode 100644 modules/ui_extra_networks_checkpoints_user_metadata.py (limited to 'modules/ui_extra_networks.py') diff --git a/modules/extra_networks.py b/modules/extra_networks.py index 6ae07e91..fa28ac75 100644 --- a/modules/extra_networks.py +++ b/modules/extra_networks.py @@ -1,3 +1,5 @@ +import json +import os import re from collections import defaultdict @@ -177,3 +179,20 @@ def parse_prompts(prompts): return res, extra_data + +def get_user_metadata(filename): + if filename is None: + return {} + + basename, ext = os.path.splitext(filename) + metadata_filename = basename + '.json' + + metadata = {} + try: + if os.path.isfile(metadata_filename): + with open(metadata_filename, "r", encoding="utf8") as file: + metadata = json.load(file) + except Exception as e: + errors.display(e, f"reading extra network user metadata from {metadata_filename}") + + return metadata diff --git a/modules/sd_vae.py b/modules/sd_vae.py index 84271db0..0bd5e19b 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,6 +1,6 @@ import os import collections -from modules import paths, shared, devices, script_callbacks, sd_models +from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks import glob from copy import deepcopy @@ -16,6 +16,7 @@ checkpoint_info = None checkpoints_loaded = collections.OrderedDict() + def get_base_vae(model): if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: return base_vae @@ -100,6 +101,16 @@ def resolve_vae(checkpoint_file): if shared.cmd_opts.vae_path is not None: return shared.cmd_opts.vae_path, 'from commandline argument' + metadata = extra_networks.get_user_metadata(checkpoint_file) + vae_metadata = metadata.get("vae", None) + if vae_metadata is not None and vae_metadata != "Automatic": + if vae_metadata == "None": + return None, None + + vae_from_metadata = vae_dict.get(vae_metadata, None) + if vae_from_metadata is not None: + return vae_from_metadata, "from user metadata" + is_automatic = shared.opts.sd_vae in {"Automatic", "auto"} # "auto" for people with old config vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file) diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index f2752f10..c6390db7 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -2,7 +2,7 @@ import os.path import urllib.parse from pathlib import Path -from modules import shared, ui_extra_networks_user_metadata, errors +from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo from modules.ui import up_down_symbol import gradio as gr @@ -101,16 +101,7 @@ class ExtraNetworksPage: def read_user_metadata(self, item): filename = item.get("filename", None) - basename, ext = os.path.splitext(filename) - metadata_filename = basename + '.json' - - metadata = {} - try: - if os.path.isfile(metadata_filename): - with open(metadata_filename, "r", encoding="utf8") as file: - metadata = json.load(file) - except Exception as e: - errors.display(e, f"reading extra network user metadata from {metadata_filename}") + metadata = extra_networks.get_user_metadata(filename) desc = metadata.get("description", None) if desc is not None: diff --git a/modules/ui_extra_networks_checkpoints.py b/modules/ui_extra_networks_checkpoints.py index 891d8f2c..77885022 100644 --- a/modules/ui_extra_networks_checkpoints.py +++ b/modules/ui_extra_networks_checkpoints.py @@ -3,6 +3,7 @@ import os from modules import shared, ui_extra_networks, sd_models from modules.ui_extra_networks import quote_js +from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): @@ -34,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage): def allowed_directories_for_previews(self): return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None] + def create_user_metadata_editor(self, ui, tabname): + return CheckpointUserMetadataEditor(ui, tabname, self) diff --git a/modules/ui_extra_networks_checkpoints_user_metadata.py b/modules/ui_extra_networks_checkpoints_user_metadata.py new file mode 100644 index 00000000..2c69aab8 --- /dev/null +++ b/modules/ui_extra_networks_checkpoints_user_metadata.py @@ -0,0 +1,60 @@ +import gradio as gr + +from modules import ui_extra_networks_user_metadata, sd_vae +from modules.ui_common import create_refresh_button + + +class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor): + def __init__(self, ui, tabname, page): + super().__init__(ui, tabname, page) + + self.select_vae = None + + def save_user_metadata(self, name, desc, notes, vae): + user_metadata = self.get_user_metadata(name) + user_metadata["description"] = desc + user_metadata["notes"] = notes + user_metadata["vae"] = vae + + self.write_user_metadata(name, user_metadata) + + def put_values_into_components(self, name): + user_metadata = self.get_user_metadata(name) + values = super().put_values_into_components(name) + + return [ + *values[0:5], + user_metadata.get('vae', ''), + ] + + def create_editor(self): + self.create_default_editor_elems() + + with gr.Row(): + self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae") + create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae") + + self.edit_notes = gr.TextArea(label='Notes', lines=4) + + self.create_default_buttons() + + viewed_components = [ + self.edit_name, + self.edit_description, + self.html_filedata, + self.html_preview, + self.edit_notes, + self.select_vae, + ] + + self.button_edit\ + .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\ + .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box]) + + edited_components = [ + self.edit_description, + self.edit_notes, + self.select_vae, + ] + + self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components) -- cgit v1.2.3 From c74c708ed8c422bf7ca1f388a3ee772c7d1e4ddd Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 5 Aug 2023 09:15:18 +0300 Subject: add checkbox to show/hide dirs for extra networks --- javascript/extraNetworks.js | 29 +++++++++++++++++++++++++++++ modules/ui_extra_networks.py | 5 +++-- style.css | 5 ++++- 3 files changed, 36 insertions(+), 3 deletions(-) (limited to 'modules/ui_extra_networks.py') diff --git a/javascript/extraNetworks.js b/javascript/extraNetworks.js index 44d02349..897ebeba 100644 --- a/javascript/extraNetworks.js +++ b/javascript/extraNetworks.js @@ -1,3 +1,20 @@ +function toggleCss(key, css, enable) { + var style = document.getElementById(key); + if (enable && !style) { + style = document.createElement('style'); + style.id = key; + style.type = 'text/css'; + document.head.appendChild(style); + } + if (style && !enable) { + document.head.removeChild(style); + } + if (style) { + style.innerHTML == ''; + style.appendChild(document.createTextNode(css)); + } +} + function setupExtraNetworksForTab(tabname) { gradioApp().querySelector('#' + tabname + '_extra_tabs').classList.add('extra-networks'); @@ -7,12 +24,15 @@ function setupExtraNetworksForTab(tabname) { var sort = gradioApp().getElementById(tabname + '_extra_sort'); var sortOrder = gradioApp().getElementById(tabname + '_extra_sortorder'); var refresh = gradioApp().getElementById(tabname + '_extra_refresh'); + var showDirsDiv = gradioApp().getElementById(tabname + '_extra_show_dirs'); + var showDirs = gradioApp().querySelector('#' + tabname + '_extra_show_dirs input'); sort.dataset.sortkey = 'sortDefault'; tabs.appendChild(searchDiv); tabs.appendChild(sort); tabs.appendChild(sortOrder); tabs.appendChild(refresh); + tabs.appendChild(showDirsDiv); var applyFilter = function() { var searchTerm = search.value.toLowerCase(); @@ -78,6 +98,15 @@ function setupExtraNetworksForTab(tabname) { }); extraNetworksApplyFilter[tabname] = applyFilter; + + var showDirsUpdate = function() { + var css = '#' + tabname + '_extra_tabs .extra-network-subdirs { display: none; }'; + toggleCss(tabname + '_extra_show_dirs_style', css, !showDirs.checked); + localSet('extra-networks-show-dirs', showDirs.checked ? 1 : 0); + }; + showDirs.checked = localGet('extra-networks-show-dirs', 1) == 1; + showDirs.addEventListener("change", showDirsUpdate); + showDirsUpdate(); } function applyExtraNetworkFilter(tabname) { diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 3a73c89e..e0b932b9 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -375,15 +375,16 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) + checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) ui.button_save_preview = gr.Button('Save preview', elem_id=tabname+"_save_preview", visible=False) ui.preview_target_filename = gr.Textbox('Preview save filename', elem_id=tabname+"_preview_filename", visible=False) for tab in unrelated_tabs: - tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False) + tab.select(fn=lambda: [gr.update(visible=False) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) for tab in related_tabs: - tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, edit_search, dropdown_sort, button_sortorder, button_refresh], show_progress=False) + tab.select(fn=lambda: [gr.update(visible=True) for _ in range(5)], inputs=[], outputs=[edit_search, dropdown_sort, button_sortorder, button_refresh, checkbox_show_dirs], show_progress=False) def pages_html(): if not ui.pages_contents: diff --git a/style.css b/style.css index 52919f71..dc4d37b9 100644 --- a/style.css +++ b/style.css @@ -801,9 +801,12 @@ footer { margin: 0 0.15em; } .extra-networks .tab-nav .search, -.extra-networks .tab-nav .sort{ +.extra-networks .tab-nav .sort, +.extra-networks .tab-nav .show-dirs +{ margin: 0.3em; align-self: center; + width: auto; } .extra-networks .tab-nav .search { -- cgit v1.2.3 From 95821f0132f5437ef30b0dbcac7c51e55818c18f Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Wed, 9 Aug 2023 18:11:13 +0300 Subject: split webui.py's initialization and utility functions into separate files --- modules/gradio_extensons.py | 4 +- modules/initialize.py | 168 ++++++++++++++++++++ modules/initialize_util.py | 195 +++++++++++++++++++++++ modules/shared_init.py | 3 - modules/ui_extra_networks.py | 3 +- modules/ui_tempdir.py | 5 +- webui.py | 368 ++++--------------------------------------- 7 files changed, 405 insertions(+), 341 deletions(-) create mode 100644 modules/initialize.py create mode 100644 modules/initialize_util.py (limited to 'modules/ui_extra_networks.py') diff --git a/modules/gradio_extensons.py b/modules/gradio_extensons.py index 5af7fd8e..77c34c8b 100644 --- a/modules/gradio_extensons.py +++ b/modules/gradio_extensons.py @@ -1,6 +1,6 @@ import gradio as gr -from modules import scripts +from modules import scripts, ui_tempdir def add_classes_to_gradio_component(comp): """ @@ -58,3 +58,5 @@ original_BlockContext_init = gr.blocks.BlockContext.__init__ gr.components.IOComponent.__init__ = IOComponent_init gr.blocks.Block.get_config = Block_get_config gr.blocks.BlockContext.__init__ = BlockContext_init + +ui_tempdir.install_ui_tempdir_override() diff --git a/modules/initialize.py b/modules/initialize.py new file mode 100644 index 00000000..f24f7637 --- /dev/null +++ b/modules/initialize.py @@ -0,0 +1,168 @@ +import importlib +import logging +import sys +import warnings +from threading import Thread + +from modules.timer import startup_timer + + +def imports(): + logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... + logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) + + import torch # noqa: F401 + startup_timer.record("import torch") + import pytorch_lightning # noqa: F401 + startup_timer.record("import torch") + warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + + import gradio # noqa: F401 + startup_timer.record("import gradio") + + from modules import paths, timer, import_hook, errors # noqa: F401 + startup_timer.record("setup paths") + + import ldm.modules.encoders.modules # noqa: F401 + startup_timer.record("import ldm") + + import sgm.modules.encoders.modules # noqa: F401 + startup_timer.record("import sgm") + + from modules import shared_init + shared_init.initialize() + startup_timer.record("initialize shared") + + from modules import processing, gradio_extensons, ui # noqa: F401 + startup_timer.record("other imports") + + +def check_versions(): + from modules.shared_cmd_options import cmd_opts + + if not cmd_opts.skip_version_check: + from modules import errors + errors.check_versions() + + +def initialize(): + from modules import initialize_util + initialize_util.fix_torch_version() + initialize_util.fix_asyncio_event_loop_policy() + initialize_util.validate_tls_options() + initialize_util.configure_sigint_handler() + initialize_util.configure_opts_onchange() + + from modules import modelloader + modelloader.cleanup_models() + + from modules import sd_models + sd_models.setup_model() + startup_timer.record("setup SD model") + + from modules.shared_cmd_options import cmd_opts + + from modules import codeformer_model + warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision.transforms.functional_tensor") + codeformer_model.setup_model(cmd_opts.codeformer_models_path) + startup_timer.record("setup codeformer") + + from modules import gfpgan_model + gfpgan_model.setup_model(cmd_opts.gfpgan_models_path) + startup_timer.record("setup gfpgan") + + initialize_rest(reload_script_modules=False) + + +def initialize_rest(*, reload_script_modules=False): + """ + Called both from initialize() and when reloading the webui. + """ + from modules.shared_cmd_options import cmd_opts + + from modules import sd_samplers + sd_samplers.set_samplers() + startup_timer.record("set samplers") + + from modules import extensions + extensions.list_extensions() + startup_timer.record("list extensions") + + from modules import initialize_util + initialize_util.restore_config_state_file() + startup_timer.record("restore config state file") + + from modules import shared, upscaler, scripts + if cmd_opts.ui_debug_mode: + shared.sd_upscalers = upscaler.UpscalerLanczos().scalers + scripts.load_scripts() + return + + from modules import sd_models + sd_models.list_models() + startup_timer.record("list SD models") + + from modules import localization + localization.list_localizations(cmd_opts.localizations_dir) + startup_timer.record("list localizations") + + with startup_timer.subcategory("load scripts"): + scripts.load_scripts() + + if reload_script_modules: + for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: + importlib.reload(module) + startup_timer.record("reload script modules") + + from modules import modelloader + modelloader.load_upscalers() + startup_timer.record("load upscalers") + + from modules import sd_vae + sd_vae.refresh_vae_list() + startup_timer.record("refresh VAE") + + from modules import textual_inversion + textual_inversion.textual_inversion.list_textual_inversion_templates() + startup_timer.record("refresh textual inversion templates") + + from modules import script_callbacks, sd_hijack_optimizations, sd_hijack + script_callbacks.on_list_optimizers(sd_hijack_optimizations.list_optimizers) + sd_hijack.list_optimizers() + startup_timer.record("scripts list_optimizers") + + from modules import sd_unet + sd_unet.list_unets() + startup_timer.record("scripts list_unets") + + def load_model(): + """ + Accesses shared.sd_model property to load model. + After it's available, if it has been loaded before this access by some extension, + its optimization may be None because the list of optimizaers has neet been filled + by that time, so we apply optimization again. + """ + + shared.sd_model # noqa: B018 + + if sd_hijack.current_optimizer is None: + sd_hijack.apply_optimizations() + + from modules import devices + devices.first_time_calculation() + + Thread(target=load_model).start() + + from modules import shared_items + shared_items.reload_hypernetworks() + startup_timer.record("reload hypernetworks") + + from modules import ui_extra_networks + ui_extra_networks.initialize() + ui_extra_networks.register_default_pages() + + from modules import extra_networks + extra_networks.initialize() + extra_networks.register_default_extra_networks() + startup_timer.record("initialize extra networks") diff --git a/modules/initialize_util.py b/modules/initialize_util.py new file mode 100644 index 00000000..e59bd3c4 --- /dev/null +++ b/modules/initialize_util.py @@ -0,0 +1,195 @@ +import json +import logging +import os +import signal +import sys +import re + +from modules.timer import startup_timer + +def setup_logging(): + # We can't use cmd_opts for this because it will not have been initialized at this point. + log_level = os.environ.get("SD_WEBUI_LOG_LEVEL") + if log_level: + log_level = getattr(logging, log_level.upper(), None) or logging.INFO + logging.basicConfig( + level=log_level, + format='%(asctime)s %(levelname)s [%(name)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + ) + + +def gradio_server_name(): + from modules.shared_cmd_options import cmd_opts + + if cmd_opts.server_name: + return cmd_opts.server_name + else: + return "0.0.0.0" if cmd_opts.listen else None + + +def fix_torch_version(): + import torch + + # Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors + if ".dev" in torch.__version__ or "+git" in torch.__version__: + torch.__long_version__ = torch.__version__ + torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) + + +def fix_asyncio_event_loop_policy(): + """ + The default `asyncio` event loop policy only automatically creates + event loops in the main threads. Other threads must create event + loops explicitly or `asyncio.get_event_loop` (and therefore + `.IOLoop.current`) will fail. Installing this policy allows event + loops to be created automatically on any thread, matching the + behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). + """ + + import asyncio + + if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + # "Any thread" and "selector" should be orthogonal, but there's not a clean + # interface for composing policies so pick the right base. + _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore + else: + _BasePolicy = asyncio.DefaultEventLoopPolicy + + class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore + """Event loop policy that allows loop creation on any thread. + Usage:: + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + """ + + def get_event_loop(self) -> asyncio.AbstractEventLoop: + try: + return super().get_event_loop() + except (RuntimeError, AssertionError): + # This was an AssertionError in python 3.4.2 (which ships with debian jessie) + # and changed to a RuntimeError in 3.4.3. + # "There is no current event loop in thread %r" + loop = self.new_event_loop() + self.set_event_loop(loop) + return loop + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + + +def restore_config_state_file(): + from modules import shared, config_states + + config_state_file = shared.opts.restore_config_state_file + if config_state_file == "": + return + + shared.opts.restore_config_state_file = "" + shared.opts.save(shared.config_filename) + + if os.path.isfile(config_state_file): + print(f"*** About to restore extension state from file: {config_state_file}") + with open(config_state_file, "r", encoding="utf-8") as f: + config_state = json.load(f) + config_states.restore_extension_config(config_state) + startup_timer.record("restore extension config") + elif config_state_file: + print(f"!!! Config state backup not found: {config_state_file}") + + +def validate_tls_options(): + from modules.shared_cmd_options import cmd_opts + + if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile): + return + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + startup_timer.record("TLS") + + +def get_gradio_auth_creds(): + """ + Convert the gradio_auth and gradio_auth_path commandline arguments into + an iterable of (username, password) tuples. + """ + from modules.shared_cmd_options import cmd_opts + + def process_credential_line(s): + s = s.strip() + if not s: + return None + return tuple(s.split(':', 1)) + + if cmd_opts.gradio_auth: + for cred in cmd_opts.gradio_auth.split(','): + cred = process_credential_line(cred) + if cred: + yield cred + + if cmd_opts.gradio_auth_path: + with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: + for line in file.readlines(): + for cred in line.strip().split(','): + cred = process_credential_line(cred) + if cred: + yield cred + + +def configure_sigint_handler(): + # make the program just exit at ctrl+c without waiting for anything + def sigint_handler(sig, frame): + print(f'Interrupted with signal {sig} in {frame}') + os._exit(0) + + if not os.environ.get("COVERAGE_RUN"): + # Don't install the immediate-quit handler when running under coverage, + # as then the coverage report won't be generated. + signal.signal(signal.SIGINT, sigint_handler) + + +def configure_opts_onchange(): + from modules import shared, sd_models, sd_vae, ui_tempdir, sd_hijack + from modules.call_queue import wrap_queued_call + + shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: sd_models.reload_model_weights()), call=False) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: sd_vae.reload_vae_weights()), call=False) + shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) + shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) + startup_timer.record("opts onchange") + + +def setup_middleware(app): + from starlette.middleware.gzip import GZipMiddleware + + app.middleware_stack = None # reset current middleware to allow modifying user provided list + app.add_middleware(GZipMiddleware, minimum_size=1000) + configure_cors_middleware(app) + app.build_middleware_stack() # rebuild middleware stack on-the-fly + + +def configure_cors_middleware(app): + from starlette.middleware.cors import CORSMiddleware + from modules.shared_cmd_options import cmd_opts + + cors_options = { + "allow_methods": ["*"], + "allow_headers": ["*"], + "allow_credentials": True, + } + if cmd_opts.cors_allow_origins: + cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',') + if cmd_opts.cors_allow_origins_regex: + cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex + app.add_middleware(CORSMiddleware, **cors_options) + diff --git a/modules/shared_init.py b/modules/shared_init.py index b88d1d8e..d3fb687e 100644 --- a/modules/shared_init.py +++ b/modules/shared_init.py @@ -5,9 +5,6 @@ import torch from modules import shared from modules.shared import cmd_opts -import sys -sys.setrecursionlimit(1000) - def initialize(): """Initializes fields inside the shared module in a controlled manner. diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index e0b932b9..16d76a45 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -4,7 +4,6 @@ from pathlib import Path from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks from modules.images import read_info_from_image, save_image_with_geninfo -from modules.ui import up_down_symbol import gradio as gr import json import html @@ -348,6 +347,8 @@ def pages_in_preferred_order(pages): def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): + from modules.ui import up_down_symbol + ui = ExtraNetworksUi() ui.pages = [] ui.pages_contents = [] diff --git a/modules/ui_tempdir.py b/modules/ui_tempdir.py index fb75137e..506017e5 100644 --- a/modules/ui_tempdir.py +++ b/modules/ui_tempdir.py @@ -57,8 +57,9 @@ def save_pil_to_file(self, pil_image, dir=None, format="png"): return file_obj.name -# override save to file function so that it also writes PNG info -gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file +def install_ui_tempdir_override(): + """override save to file function so that it also writes PNG info""" + gradio.components.IOComponent.pil_to_temp_file = save_pil_to_file def on_tmpdir_changed(): diff --git a/webui.py b/webui.py index 0f1ace97..738b3bef 100644 --- a/webui.py +++ b/webui.py @@ -1,349 +1,43 @@ from __future__ import annotations import os -import sys import time -import importlib -import signal -import re -import warnings -import json -from threading import Thread -from typing import Iterable - -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.middleware.gzip import GZipMiddleware - -import logging - -# We can't use cmd_opts for this because it will not have been initialized at this point. -log_level = os.environ.get("SD_WEBUI_LOG_LEVEL") -if log_level: - log_level = getattr(logging, log_level.upper(), None) or logging.INFO - logging.basicConfig( - level=log_level, - format='%(asctime)s %(levelname)s [%(name)s] %(message)s', - datefmt='%Y-%m-%d %H:%M:%S', - ) - -logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... -logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) from modules import timer +from modules import initialize_util +from modules import initialize + startup_timer = timer.startup_timer startup_timer.record("launcher") -import torch -import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them -warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") -warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") -startup_timer.record("import torch") - -import gradio # noqa: F401 -startup_timer.record("import gradio") - -from modules import paths, timer, import_hook, errors # noqa: F401 -startup_timer.record("setup paths") - -import ldm.modules.encoders.modules # noqa: F401 -startup_timer.record("import ldm") - -from modules import shared_init, shared, shared_items -shared_init.initialize() -startup_timer.record("initialize shared") - -from modules import extra_networks -from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 - -# Truncate version number of nightly/local build of PyTorch to not cause exceptions with CodeFormer or Safetensors -if ".dev" in torch.__version__ or "+git" in torch.__version__: - torch.__long_version__ = torch.__version__ - torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) - -if not shared.cmd_opts.skip_version_check: - errors.check_versions() - -import modules.codeformer_model as codeformer -import modules.gfpgan_model as gfpgan -from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states -import modules.face_restoration -import modules.img2img - -import modules.lowvram -import modules.scripts -import modules.sd_hijack -import modules.sd_hijack_optimizations -import modules.sd_models -import modules.sd_vae -import modules.sd_unet -import modules.txt2img -import modules.script_callbacks -import modules.textual_inversion.textual_inversion -import modules.progress - -import modules.ui -from modules import modelloader, devices -from modules.shared import cmd_opts -import modules.hypernetworks.hypernetwork - -startup_timer.record("other imports") - - -if cmd_opts.server_name: - server_name = cmd_opts.server_name -else: - server_name = "0.0.0.0" if cmd_opts.listen else None - - -def fix_asyncio_event_loop_policy(): - """ - The default `asyncio` event loop policy only automatically creates - event loops in the main threads. Other threads must create event - loops explicitly or `asyncio.get_event_loop` (and therefore - `.IOLoop.current`) will fail. Installing this policy allows event - loops to be created automatically on any thread, matching the - behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). - """ - - import asyncio - - if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): - # "Any thread" and "selector" should be orthogonal, but there's not a clean - # interface for composing policies so pick the right base. - _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore - else: - _BasePolicy = asyncio.DefaultEventLoopPolicy - - class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore - """Event loop policy that allows loop creation on any thread. - Usage:: - - asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) - """ - - def get_event_loop(self) -> asyncio.AbstractEventLoop: - try: - return super().get_event_loop() - except (RuntimeError, AssertionError): - # This was an AssertionError in python 3.4.2 (which ships with debian jessie) - # and changed to a RuntimeError in 3.4.3. - # "There is no current event loop in thread %r" - loop = self.new_event_loop() - self.set_event_loop(loop) - return loop - - asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) - - -def restore_config_state_file(): - config_state_file = shared.opts.restore_config_state_file - if config_state_file == "": - return - - shared.opts.restore_config_state_file = "" - shared.opts.save(shared.config_filename) - - if os.path.isfile(config_state_file): - print(f"*** About to restore extension state from file: {config_state_file}") - with open(config_state_file, "r", encoding="utf-8") as f: - config_state = json.load(f) - config_states.restore_extension_config(config_state) - startup_timer.record("restore extension config") - elif config_state_file: - print(f"!!! Config state backup not found: {config_state_file}") - - -def validate_tls_options(): - if not (cmd_opts.tls_keyfile and cmd_opts.tls_certfile): - return - - try: - if not os.path.exists(cmd_opts.tls_keyfile): - print("Invalid path to TLS keyfile given") - if not os.path.exists(cmd_opts.tls_certfile): - print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") - except TypeError: - cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None - print("TLS setup invalid, running webui without TLS") - else: - print("Running with TLS") - startup_timer.record("TLS") - - -def get_gradio_auth_creds() -> Iterable[tuple[str, ...]]: - """ - Convert the gradio_auth and gradio_auth_path commandline arguments into - an iterable of (username, password) tuples. - """ - def process_credential_line(s) -> tuple[str, ...] | None: - s = s.strip() - if not s: - return None - return tuple(s.split(':', 1)) - - if cmd_opts.gradio_auth: - for cred in cmd_opts.gradio_auth.split(','): - cred = process_credential_line(cred) - if cred: - yield cred - - if cmd_opts.gradio_auth_path: - with open(cmd_opts.gradio_auth_path, 'r', encoding="utf8") as file: - for line in file.readlines(): - for cred in line.strip().split(','): - cred = process_credential_line(cred) - if cred: - yield cred - - -def configure_sigint_handler(): - # make the program just exit at ctrl+c without waiting for anything - def sigint_handler(sig, frame): - print(f'Interrupted with signal {sig} in {frame}') - os._exit(0) - - if not os.environ.get("COVERAGE_RUN"): - # Don't install the immediate-quit handler when running under coverage, - # as then the coverage report won't be generated. - signal.signal(signal.SIGINT, sigint_handler) - - -def configure_opts_onchange(): - shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False) - shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) - shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) - shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) - shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False) - startup_timer.record("opts onchange") - - -def initialize(): - fix_asyncio_event_loop_policy() - validate_tls_options() - configure_sigint_handler() - modelloader.cleanup_models() - configure_opts_onchange() - - modules.sd_models.setup_model() - startup_timer.record("setup SD model") - - codeformer.setup_model(cmd_opts.codeformer_models_path) - startup_timer.record("setup codeformer") - - gfpgan.setup_model(cmd_opts.gfpgan_models_path) - startup_timer.record("setup gfpgan") - - initialize_rest(reload_script_modules=False) - - -def initialize_rest(*, reload_script_modules=False): - """ - Called both from initialize() and when reloading the webui. - """ - sd_samplers.set_samplers() - extensions.list_extensions() - startup_timer.record("list extensions") - - restore_config_state_file() - - if cmd_opts.ui_debug_mode: - shared.sd_upscalers = upscaler.UpscalerLanczos().scalers - modules.scripts.load_scripts() - return - - modules.sd_models.list_models() - startup_timer.record("list SD models") +initialize_util.setup_logging() - localization.list_localizations(cmd_opts.localizations_dir) +initialize.imports() - with startup_timer.subcategory("load scripts"): - modules.scripts.load_scripts() - - if reload_script_modules: - for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: - importlib.reload(module) - startup_timer.record("reload script modules") - - modelloader.load_upscalers() - startup_timer.record("load upscalers") - - modules.sd_vae.refresh_vae_list() - startup_timer.record("refresh VAE") - modules.textual_inversion.textual_inversion.list_textual_inversion_templates() - startup_timer.record("refresh textual inversion templates") - - modules.script_callbacks.on_list_optimizers(modules.sd_hijack_optimizations.list_optimizers) - modules.sd_hijack.list_optimizers() - startup_timer.record("scripts list_optimizers") - - modules.sd_unet.list_unets() - startup_timer.record("scripts list_unets") - - def load_model(): - """ - Accesses shared.sd_model property to load model. - After it's available, if it has been loaded before this access by some extension, - its optimization may be None because the list of optimizaers has neet been filled - by that time, so we apply optimization again. - """ - - shared.sd_model # noqa: B018 - - if modules.sd_hijack.current_optimizer is None: - modules.sd_hijack.apply_optimizations() - - devices.first_time_calculation() - - Thread(target=load_model).start() - - shared_items.reload_hypernetworks() - startup_timer.record("reload hypernetworks") - - ui_extra_networks.initialize() - ui_extra_networks.register_default_pages() - - extra_networks.initialize() - extra_networks.register_default_extra_networks() - startup_timer.record("initialize extra networks") - - -def setup_middleware(app): - app.middleware_stack = None # reset current middleware to allow modifying user provided list - app.add_middleware(GZipMiddleware, minimum_size=1000) - configure_cors_middleware(app) - app.build_middleware_stack() # rebuild middleware stack on-the-fly - - -def configure_cors_middleware(app): - cors_options = { - "allow_methods": ["*"], - "allow_headers": ["*"], - "allow_credentials": True, - } - if cmd_opts.cors_allow_origins: - cors_options["allow_origins"] = cmd_opts.cors_allow_origins.split(',') - if cmd_opts.cors_allow_origins_regex: - cors_options["allow_origin_regex"] = cmd_opts.cors_allow_origins_regex - app.add_middleware(CORSMiddleware, **cors_options) +initialize.check_versions() def create_api(app): from modules.api.api import Api + from modules.call_queue import queue_lock + api = Api(app, queue_lock) return api def api_only(): - initialize() + from fastapi import FastAPI + from modules.shared_cmd_options import cmd_opts + + initialize.initialize() app = FastAPI() - setup_middleware(app) + initialize_util.setup_middleware(app) api = create_api(app) - modules.script_callbacks.before_ui_callback() - modules.script_callbacks.app_started_callback(None, app) + from modules import script_callbacks + script_callbacks.before_ui_callback() + script_callbacks.app_started_callback(None, app) print(f"Startup time: {startup_timer.summary()}.") api.launch( @@ -354,24 +48,28 @@ def api_only(): def webui(): + from modules.shared_cmd_options import cmd_opts + launch_api = cmd_opts.api - initialize() + initialize.initialize() + + from modules import shared, ui_tempdir, script_callbacks, ui, progress, ui_extra_networks while 1: if shared.opts.clean_temp_dir_at_start: ui_tempdir.cleanup_tmpdr() startup_timer.record("cleanup temp dir") - modules.script_callbacks.before_ui_callback() + script_callbacks.before_ui_callback() startup_timer.record("scripts before_ui_callback") - shared.demo = modules.ui.create_ui() + shared.demo = ui.create_ui() startup_timer.record("create ui") if not cmd_opts.no_gradio_queue: shared.demo.queue(64) - gradio_auth_creds = list(get_gradio_auth_creds()) or None + gradio_auth_creds = list(initialize_util.get_gradio_auth_creds()) or None auto_launch_browser = False if os.getenv('SD_WEBUI_RESTARTING') != '1': @@ -382,7 +80,7 @@ def webui(): app, local_url, share_url = shared.demo.launch( share=cmd_opts.share, - server_name=server_name, + server_name=initialize_util.gradio_server_name(), server_port=cmd_opts.port, ssl_keyfile=cmd_opts.tls_keyfile, ssl_certfile=cmd_opts.tls_certfile, @@ -407,10 +105,10 @@ def webui(): # running its code. We disable this here. Suggested by RyotaK. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] - setup_middleware(app) + initialize_util.setup_middleware(app) - modules.progress.setup_progress_api(app) - modules.ui.setup_ui_api(app) + progress.setup_progress_api(app) + ui.setup_ui_api(app) if launch_api: create_api(app) @@ -420,7 +118,7 @@ def webui(): startup_timer.record("add APIs") with startup_timer.subcategory("app_started_callback"): - modules.script_callbacks.app_started_callback(shared.demo, app) + script_callbacks.app_started_callback(shared.demo, app) timer.startup_record = startup_timer.dump() print(f"Startup time: {startup_timer.summary()}.") @@ -450,14 +148,16 @@ def webui(): shared.demo.close() time.sleep(0.5) startup_timer.reset() - modules.script_callbacks.app_reload_callback() + script_callbacks.app_reload_callback() startup_timer.record("app reload callback") - modules.script_callbacks.script_unloaded_callback() + script_callbacks.script_unloaded_callback() startup_timer.record("scripts unloaded callback") - initialize_rest(reload_script_modules=True) + initialize.initialize_rest(reload_script_modules=True) if __name__ == "__main__": + from modules.shared_cmd_options import cmd_opts + if cmd_opts.nowebui: api_only() else: -- cgit v1.2.3 From 2ceb4f81e2f291ba651dd24c7eb158ea3b446b42 Mon Sep 17 00:00:00 2001 From: catboxanon <122327233+catboxanon@users.noreply.github.com> Date: Wed, 9 Aug 2023 14:40:18 -0400 Subject: Use better symbol for extra networks sort --- modules/ui_extra_networks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'modules/ui_extra_networks.py') diff --git a/modules/ui_extra_networks.py b/modules/ui_extra_networks.py index 16d76a45..063bd7b8 100644 --- a/modules/ui_extra_networks.py +++ b/modules/ui_extra_networks.py @@ -347,7 +347,7 @@ def pages_in_preferred_order(pages): def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): - from modules.ui import up_down_symbol + from modules.ui import switch_values_symbol ui = ExtraNetworksUi() ui.pages = [] @@ -374,7 +374,7 @@ def create_ui(interface: gr.Blocks, unrelated_tabs, tabname): edit_search = gr.Textbox('', show_label=False, elem_id=tabname+"_extra_search", elem_classes="search", placeholder="Search...", visible=False, interactive=True) dropdown_sort = gr.Dropdown(choices=['Default Sort', 'Date Created', 'Date Modified', 'Name'], value='Default Sort', elem_id=tabname+"_extra_sort", elem_classes="sort", multiselect=False, visible=False, show_label=False, interactive=True, label=tabname+"_extra_sort_order") - button_sortorder = ToolButton(up_down_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) + button_sortorder = ToolButton(switch_values_symbol, elem_id=tabname+"_extra_sortorder", elem_classes="sortorder", visible=False) button_refresh = gr.Button('Refresh', elem_id=tabname+"_extra_refresh", visible=False) checkbox_show_dirs = gr.Checkbox(True, label='Show dirs', elem_id=tabname+"_extra_show_dirs", elem_classes="show-dirs", visible=False) -- cgit v1.2.3