From 3662a274e2b6482c4ad831cc2d7976d919b40212 Mon Sep 17 00:00:00 2001 From: rucadi Date: Fri, 16 Dec 2022 17:10:13 +0100 Subject: Add polling callback --- webui.py | 1 + 1 file changed, 1 insertion(+) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 5b5c2139..6c2b511c 100644 --- a/webui.py +++ b/webui.py @@ -171,6 +171,7 @@ def create_api(app): def wait_on_server(demo=None): while 1: time.sleep(0.5) + modules.script_callbacks.app_polling_callback(None, demo) if shared.state.need_restart: shared.state.need_restart = False time.sleep(0.5) -- cgit v1.2.3 From eb5eb8aa117c3d9ff9c55c59bc589ad8e983919e Mon Sep 17 00:00:00 2001 From: rucadi Date: Fri, 16 Dec 2022 18:31:20 +0100 Subject: Add a callback called before reloading the server --- modules/script_callbacks.py | 13 ++++++++++++- webui.py | 1 + 2 files changed, 13 insertions(+), 1 deletion(-) (limited to 'webui.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index 7763936f..91fd21f4 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -75,6 +75,7 @@ callback_map = dict( callbacks_script_unloaded=[], callbacks_before_ui=[], callbacks_on_polling=[], + callbacks_on_reload=[], ) @@ -82,7 +83,6 @@ def clear_callbacks(): for callback_list in callback_map.values(): callback_list.clear() - def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callback_map['callbacks_app_started']: try: @@ -97,6 +97,14 @@ def app_polling_callback(demo: Optional[Blocks], app: FastAPI): except Exception: report_exception(c, 'callbacks_on_polling') +def app_reload_callback(demo: Optional[Blocks], app: FastAPI): + for c in callback_map['callbacks_on_reload']: + try: + c.callback() + except Exception: + report_exception(c, 'callbacks_on_reload') + + def model_loaded_callback(sd_model): for c in callback_map['callbacks_model_loaded']: try: @@ -238,6 +246,9 @@ def on_polling(callback): """register a function to be called on each polling of the server.""" add_callback(callback_map['callbacks_on_polling'], callback) +def on_before_reload(callback): + """register a function to be called just before the server reloads.""" + add_callback(callback_map['callbacks_on_reload'], callback) def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is diff --git a/webui.py b/webui.py index 6c2b511c..ddccf870 100644 --- a/webui.py +++ b/webui.py @@ -173,6 +173,7 @@ def wait_on_server(demo=None): time.sleep(0.5) modules.script_callbacks.app_polling_callback(None, demo) if shared.state.need_restart: + modules.script_callbacks.app_reload_callback(None, demo) shared.state.need_restart = False time.sleep(0.5) demo.close() -- cgit v1.2.3 From c9647c8d23efa8c939c6af39878784e246082122 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Sat, 25 Mar 2023 16:11:41 -0400 Subject: Support Gradio's theme API --- modules/shared.py | 35 +++++++++++++++++++++++++++++++++++ modules/ui.py | 2 +- webui.py | 1 + 3 files changed, 37 insertions(+), 1 deletion(-) (limited to 'webui.py') diff --git a/modules/shared.py b/modules/shared.py index 11be3985..2f7892cd 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -4,6 +4,7 @@ import json import os import sys import time +import requests from PIL import Image import gradio as gr @@ -54,6 +55,21 @@ ui_reorder_categories = [ "scripts", ] +# https://huggingface.co/datasets/freddyaboulton/gradio-theme-subdomains/resolve/main/subdomains.json +gradio_hf_hub_themes = [ + "gradio/glass", + "gradio/monochrome", + "gradio/seafoam", + "gradio/soft", + "freddyaboulton/dracula_revamped", + "gradio/dracula_test", + "abidlabs/dracula_test", + "abidlabs/pakistan", + "dawood/microsoft_windows", + "ysharma/steampunk" +] + + cmd_opts.disable_extension_access = (cmd_opts.share or cmd_opts.listen or cmd_opts.server_name) and not cmd_opts.enable_insecure_extension_access devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_esrgan, devices.device_codeformer = \ @@ -387,6 +403,7 @@ options_templates.update(options_section(('ui', "User interface"), { "ui_reorder": OptionInfo(", ".join(ui_reorder_categories), "txt2img/img2img UI item order"), "ui_extra_networks_tab_reorder": OptionInfo("", "Extra networks tab order"), "localization": OptionInfo("None", "Localization (requires restart)", gr.Dropdown, lambda: {"choices": ["None"] + list(localization.localizations.keys())}, refresh=lambda: localization.list_localizations(cmd_opts.localizations_dir)), + "gradio_theme": OptionInfo("Default", "Gradio theme (requires restart)", gr.Dropdown, lambda: {"choices": ["Default"] + gradio_hf_hub_themes}) })) options_templates.update(options_section(('ui', "Live previews"), { @@ -599,6 +616,24 @@ clip_model = None progress_print_out = sys.stdout +gradio_theme = gr.themes.Base() + + +def reload_gradio_theme(theme_name=None): + global gradio_theme + if not theme_name: + theme_name = opts.gradio_theme + + if theme_name == "Default": + gradio_theme = gr.themes.Default() + else: + try: + gradio_theme = gr.themes.ThemeClass.from_hub(theme_name) + except requests.exceptions.ConnectionError: + print("Can't access HuggingFace Hub, falling back to default Gradio theme") + gradio_theme = gr.themes.Default() + + class TotalTQDM: def __init__(self): diff --git a/modules/ui.py b/modules/ui.py index af8546c2..6e049881 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1592,7 +1592,7 @@ def create_ui(): for _interface, label, _ifid in interfaces: shared.tab_names.append(label) - with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: + with gr.Blocks(css=css, theme=shared.gradio_theme, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings", variant="compact"): for i, k, item in sorted(quicksettings_list, key=lambda x: quicksettings_names.get(x[1], x[0])): component = create_setting_component(k, is_quicksettings=True) diff --git a/webui.py b/webui.py index 30f3e4a1..6986e576 100644 --- a/webui.py +++ b/webui.py @@ -150,6 +150,7 @@ def initialize(): shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed) + shared.opts.onchange("gradio_theme", shared.reload_gradio_theme) startup_timer.record("opts onchange") shared.reload_hypernetworks() -- cgit v1.2.3 From ad5afcaae0b47e9e68b49aacf04cc3ad59d41a8e Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Wed, 29 Mar 2023 16:46:03 -0500 Subject: Save/restore working webui/extension configs --- javascript/extensions.js | 16 ++++ modules/extensions.py | 25 +++++-- modules/paths_internal.py | 1 + modules/shared.py | 1 + modules/ui_extensions.py | 186 +++++++++++++++++++++++++++++++++++++++++++++- webui.py | 25 ++++++- 6 files changed, 244 insertions(+), 10 deletions(-) (limited to 'webui.py') diff --git a/javascript/extensions.js b/javascript/extensions.js index 72924a28..c2786499 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -47,3 +47,19 @@ function install_extension_from_index(button, url){ gradioApp().querySelector('#install_extension_button').click() } + +function config_state_confirm_restore(_, config_state_name, config_restore_type) { + if (config_state_name == "Current") { + return [false, config_state_name]; + } + let restored = ""; + if (config_restore_type == "extensions") { + restored = "all saved extension versions"; + } else if (config_restore_type == "webui") { + restored = "the webui version"; + } else { + restored = "the webui version and all saved extension versions"; + } + let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".\n(A backup of the current state will be made.)"); + return [confirmed, config_state_name, config_restore_type]; +} diff --git a/modules/extensions.py b/modules/extensions.py index 3a7a0372..a87beaa3 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -3,10 +3,11 @@ import sys import traceback import time +from datetime import datetime import git from modules import shared -from modules.paths_internal import extensions_dir, extensions_builtin_dir +from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path extensions = [] @@ -31,12 +32,15 @@ class Extension: self.status = '' self.can_update = False self.is_builtin = is_builtin + self.commit_hash = '' + self.commit_date = None self.version = '' + self.branch = None self.remote = None self.have_info_from_repo = False def read_info_from_repo(self): - if self.have_info_from_repo: + if self.is_builtin or self.have_info_from_repo: return self.have_info_from_repo = True @@ -56,10 +60,15 @@ class Extension: self.status = 'unknown' self.remote = next(repo.remote().urls, None) head = repo.head.commit - ts = time.asctime(time.gmtime(repo.head.commit.committed_date)) - self.version = f'{head.hexsha[:8]} ({ts})' - - except Exception: + self.commit_date = repo.head.commit.committed_date + ts = time.asctime(time.gmtime(self.commit_date)) + if repo.active_branch: + self.branch = repo.active_branch.name + self.commit_hash = head.hexsha + self.version = f'{self.commit_hash[:8]} ({ts})' + + except Exception as ex: + print(f"Failed reading extension data from Git repository ({self.name}): {ex}", file=sys.stderr) self.remote = None def list_files(self, subdir, extension): @@ -88,12 +97,12 @@ class Extension: self.can_update = False self.status = "latest" - def fetch_and_reset_hard(self): + def fetch_and_reset_hard(self, commit='origin'): repo = git.Repo(self.path) # Fix: `error: Your local changes to the following files would be overwritten by merge`, # because WSL2 Docker set 755 file permissions instead of 644, this results to the error. repo.git.fetch(all=True) - repo.git.reset('origin', hard=True) + repo.git.reset(commit, hard=True) def list_extensions(): diff --git a/modules/paths_internal.py b/modules/paths_internal.py index 926ec3bb..6765bafe 100644 --- a/modules/paths_internal.py +++ b/modules/paths_internal.py @@ -20,3 +20,4 @@ data_path = cmd_opts_pre.data_dir models_path = os.path.join(data_path, "models") extensions_dir = os.path.join(data_path, "extensions") extensions_builtin_dir = os.path.join(script_path, "extensions-builtin") +config_states_dir = os.path.join(script_path, "config_states") diff --git a/modules/shared.py b/modules/shared.py index 5fd0eecb..ffc5e4fe 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -424,6 +424,7 @@ options_templates.update(options_section(('postprocessing', "Postprocessing"), { options_templates.update(options_section((None, "Hidden options"), { "disabled_extensions": OptionInfo([], "Disable these extensions"), "disable_all_extensions": OptionInfo("none", "Disable all extensions (preserves the list of disabled extensions)", gr.Radio, {"choices": ["none", "extra", "all"]}), + "restore_config_state_file": OptionInfo("", "Config state file to restore from, under 'config-states/' folder"), "sd_checkpoint_hash": OptionInfo("", "SHA256 hash of the current checkpoint"), })) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index efd6cda2..ed677b3e 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -2,6 +2,7 @@ import json import os.path import sys import time +from datetime import datetime import traceback import git @@ -11,7 +12,8 @@ import html import shutil import errno -from modules import extensions, shared, paths +from modules import extensions, shared, paths, config_states +from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call available_extensions = {"extensions": []} @@ -30,6 +32,9 @@ def apply_and_restart(disable_list, update_list, disable_all): update = json.loads(update_list) assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}" + if update: + save_config_state("Backup (pre-update)") + update = set(update) for ext in extensions.extensions: @@ -50,6 +55,48 @@ def apply_and_restart(disable_list, update_list, disable_all): shared.state.need_restart = True +def save_config_state(name): + current_config_state = config_states.get_config() + if not name: + name = "Config" + current_config_state["name"] = name + filename = os.path.join(config_states_dir, datetime.now().strftime("%Y_%m_%d-%H_%M_%S") + ".json") + print(f"Saving backup of webui/extension state to {filename}.") + with open(filename, "w", encoding="utf-8") as f: + json.dump(current_config_state, f) + config_states.list_config_states() + new_value = next(iter(config_states.all_config_states.keys()), "Current") + new_choices = ["Current"] + list(config_states.all_config_states.keys()) + return gr.Dropdown.update(value=new_value, choices=new_choices), f"Saved current webui/extension state to '{filename}'" + + +def restore_config_state(confirmed, config_state_name, restore_type): + if config_state_name == "Current": + return "Select a config to restore from." + if not confirmed: + return "Cancelled." + + check_access() + + save_config_state("Backup (pre-restore)") + + config_state = config_states.all_config_states[config_state_name] + + print(f"Restoring webui state from backup: {restore_type}") + + if restore_type == "extensions" or restore_type == "both": + shared.opts.restore_config_state_file = config_state["filename"] + shared.opts.save(shared.config_filename) + + if restore_type == "webui" or restore_type == "both": + config_states.restore_webui_config(config_state) + + shared.state.interrupt() + shared.state.need_restart = True + + return "" + + def check_updates(id_task, disable_list): check_access() @@ -121,6 +168,117 @@ def extension_table(): return code +def update_config_states_table(state_name): + if state_name == "Current": + config_state = config_states.get_config() + else: + config_state = config_states.all_config_states[state_name] + + config_name = config_state.get("name", "Config") + created_date = time.asctime(time.gmtime(config_state["created_at"])) + + code = f"""""" + + webui_remote = config_state["webui"]["remote"] or "" + webui_branch = config_state["webui"]["branch"] + webui_commit_hash = config_state["webui"]["commit_hash"] + if webui_commit_hash: + webui_commit_hash = webui_commit_hash[:8] + else: + webui_commit_hash = "" + webui_commit_date = config_state["webui"]["commit_date"] + if webui_commit_date: + webui_commit_date = time.asctime(time.gmtime(webui_commit_date)) + else: + webui_commit_date = "" + + code += f"""

Config Backup: {config_name}

+ Created at: {created_date}""" + + code += f"""

WebUI State

+ + + + + + + + + + + + + + + + + +
URLBranchCommitDate
{webui_remote}{webui_branch}{webui_commit_hash}{webui_commit_date}
+ """ + + code += """

Extension State

+ + + + + + + + + + + + """ + + ext_map = {ext.name: ext for ext in extensions.extensions} + + for ext_name, ext_conf in config_state["extensions"].items(): + ext_remote = ext_conf["remote"] or "" + ext_branch = ext_conf["branch"] or "" + ext_enabled = ext_conf["enabled"] + ext_commit_hash = ext_conf["commit_hash"] or "" + ext_commit_date = ext_conf["commit_date"] + if ext_commit_date: + ext_commit_date = time.asctime(time.gmtime(ext_commit_date)) + else: + ext_commit_date = "" + + remote = f"""{html.escape(ext_remote or '')}""" + + style_enabled = "" + style_remote = "" + style_branch = "" + style_commit = "" + if ext_name in ext_map: + current_ext = ext_map[ext_name] + current_ext.read_info_from_repo() + if current_ext.enabled != ext_enabled: + style_enabled = ' style="color: var(--primary-400)"' + if current_ext.remote != ext_remote: + style_remote = ' style="color: var(--primary-400)"' + if current_ext.branch != ext_branch: + style_branch = ' style="color: var(--primary-400)"' + if current_ext.commit_hash != ext_commit_hash: + style_commit = ' style="color: var(--primary-400)"' + + code += f""" + + + + + + + + """ + + code += """ + +
ExtensionURLBranchCommitDate
{html.escape(ext_name)}{remote}{ext_branch}{ext_commit_hash[:8]}{ext_commit_date}
+ """ + + return code + + def normalize_git_url(url): if url is None: return "" @@ -292,6 +450,8 @@ def refresh_available_extensions_from_data(hide_tags, sort_column, filter_text=" def create_ui(): import modules.ui + config_states.list_config_states() + with gr.Blocks(analytics_enabled=False) as ui: with gr.Tabs(elem_id="tabs_extensions") as tabs: with gr.TabItem("Installed"): @@ -386,4 +546,28 @@ def create_ui(): outputs=[extensions_table, install_result], ) + with gr.TabItem("Backup/Restore"): + with gr.Row(elem_id="extensions_backup_top_row"): + config_states_list = gr.Dropdown(label="Saved Configs", elem_id="extension_backup_saved_configs", value="Current", choices=["Current"] + list(config_states.all_config_states.keys())) + modules.ui.create_refresh_button(config_states_list, config_states.list_config_states, lambda: {"choices": ["Current"] + list(config_states.all_config_states.keys())}, "refresh_config_states") + config_restore_type = gr.Radio(label="State to restore", choices=["extensions", "webui", "both"], value="extensions", elem_id="extension_backup_restore_type") + config_restore_button = gr.Button(value="Restore Selected Config", variant="primary", elem_id="extension_backup_restore") + with gr.Row(elem_id="extensions_backup_top_row2"): + config_save_name = gr.Textbox("", placeholder="Config Name", show_label=False) + config_save_button = gr.Button(value="Save Current Config") + + config_states_info = gr.HTML("") + config_states_table = gr.HTML(lambda: update_config_states_table("Current")) + + config_save_button.click(fn=save_config_state, inputs=[config_save_name], outputs=[config_states_list, config_states_info]) + + dummy_component = gr.Label(visible=False) + config_restore_button.click(fn=restore_config_state, _js="config_state_confirm_restore", inputs=[dummy_component, config_states_list, config_restore_type], outputs=[config_states_info]) + + config_states_list.change( + fn=update_config_states_table, + inputs=[config_states_list], + outputs=[config_states_table], + ) + return ui diff --git a/webui.py b/webui.py index b570895f..b8f9a2c1 100644 --- a/webui.py +++ b/webui.py @@ -5,6 +5,7 @@ import importlib import signal import re import warnings +import json from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -37,7 +38,7 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) -from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks +from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states import modules.codeformer_model as codeformer import modules.face_restoration import modules.gfpgan_model as gfpgan @@ -105,6 +106,17 @@ def initialize(): localization.list_localizations(cmd_opts.localizations_dir) startup_timer.record("list extensions") + config_state_file = shared.opts.restore_config_state_file + shared.opts.restore_config_state_file = "" + shared.opts.save(shared.config_filename) + + if os.path.isfile(config_state_file): + print(f"*** About to restore extension state from file: {config_state_file}") + with open(config_state_file, "r", encoding="utf-8") as f: + config_state = json.load(f) + config_states.restore_extension_state(config_state) + startup_timer.record("restore extension config") + if cmd_opts.ui_debug_mode: shared.sd_upscalers = upscaler.UpscalerLanczos().scalers modules.scripts.load_scripts() @@ -301,6 +313,17 @@ def webui(): extensions.list_extensions() startup_timer.record("list extensions") + config_state_file = shared.opts.restore_config_state_file + shared.opts.restore_config_state_file = "" + shared.opts.save(shared.config_filename) + + if os.path.isfile(config_state_file): + print(f"*** About to restore extension state from file: {config_state_file}") + with open(config_state_file, "r", encoding="utf-8") as f: + config_state = json.load(f) + config_states.restore_extension_state(config_state) + startup_timer.record("restore extension config") + localization.list_localizations(cmd_opts.localizations_dir) modelloader.forbid_loaded_nonbuiltin_upscalers() -- cgit v1.2.3 From f3320b802c12f29e5a3201fcc0abfe72be294293 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Wed, 29 Mar 2023 18:32:54 -0500 Subject: Various UI fixes in config state tab --- .gitignore | 1 + javascript/extensions.js | 10 ++- modules/config_states.py | 200 +++++++++++++++++++++++++++++++++++++++++++++++ modules/ui_extensions.py | 45 ++++++----- webui.py | 8 +- 5 files changed, 241 insertions(+), 23 deletions(-) create mode 100644 modules/config_states.py (limited to 'webui.py') diff --git a/.gitignore b/.gitignore index 0b1d17ca..7c89b673 100644 --- a/.gitignore +++ b/.gitignore @@ -33,3 +33,4 @@ notification.mp3 /test/stdout.txt /test/stderr.txt /cache.json +/config_states/ diff --git a/javascript/extensions.js b/javascript/extensions.js index c2786499..3c2f995a 100644 --- a/javascript/extensions.js +++ b/javascript/extensions.js @@ -50,7 +50,7 @@ function install_extension_from_index(button, url){ function config_state_confirm_restore(_, config_state_name, config_restore_type) { if (config_state_name == "Current") { - return [false, config_state_name]; + return [false, config_state_name, config_restore_type]; } let restored = ""; if (config_restore_type == "extensions") { @@ -60,6 +60,12 @@ function config_state_confirm_restore(_, config_state_name, config_restore_type) } else { restored = "the webui version and all saved extension versions"; } - let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + ".\n(A backup of the current state will be made.)"); + let confirmed = confirm("Are you sure you want to restore from this state?\nThis will reset " + restored + "."); + if (confirmed) { + restart_reload(); + gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ + x.innerHTML = "Loading..." + }) + } return [confirmed, config_state_name, config_restore_type]; } diff --git a/modules/config_states.py b/modules/config_states.py new file mode 100644 index 00000000..2ea00929 --- /dev/null +++ b/modules/config_states.py @@ -0,0 +1,200 @@ +""" +Supports saving and restoring webui and extensions from a known working set of commits +""" + +import os +import sys +import traceback +import json +import time +import tqdm + +from datetime import datetime +from collections import OrderedDict +import git + +from modules import shared, extensions +from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path, config_states_dir + + +all_config_states = OrderedDict() + + +def list_config_states(): + global all_config_states + + all_config_states.clear() + os.makedirs(config_states_dir, exist_ok=True) + + config_states = [] + for filename in os.listdir(config_states_dir): + if filename.endswith(".json"): + path = os.path.join(config_states_dir, filename) + with open(path, "r", encoding="utf-8") as f: + j = json.load(f) + j["filepath"] = path + config_states.append(j) + + config_states = list(sorted(config_states, key=lambda cs: cs["created_at"], reverse=True)) + + for cs in config_states: + timestamp = time.asctime(time.gmtime(cs["created_at"])) + name = cs.get("name", "Config") + full_name = f"{name}: {timestamp}" + all_config_states[full_name] = cs + + return all_config_states + + +def get_webui_config(): + webui_repo = None + + try: + if os.path.exists(os.path.join(script_path, ".git")): + webui_repo = git.Repo(script_path) + except Exception: + print(f"Error reading webui git info from {script_path}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + webui_remote = None + webui_commit_hash = None + webui_commit_date = None + webui_branch = None + if webui_repo and not webui_repo.bare: + try: + webui_remote = next(webui_repo.remote().urls, None) + head = webui_repo.head.commit + webui_commit_date = webui_repo.head.commit.committed_date + webui_commit_hash = head.hexsha + webui_branch = webui_repo.active_branch.name + + except Exception: + webui_remote = None + + return { + "remote": webui_remote, + "commit_hash": webui_commit_hash, + "commit_date": webui_commit_date, + "branch": webui_branch, + } + + +def get_extension_config(): + ext_config = {} + + for ext in extensions.extensions: + entry = { + "name": ext.name, + "path": ext.path, + "enabled": ext.enabled, + "is_builtin": ext.is_builtin, + "remote": ext.remote, + "commit_hash": ext.commit_hash, + "commit_date": ext.commit_date, + "branch": ext.branch, + "have_info_from_repo": ext.have_info_from_repo + } + + ext_config[ext.name] = entry + + return ext_config + + +def get_config(): + creation_time = datetime.now().timestamp() + webui_config = get_webui_config() + ext_config = get_extension_config() + + return { + "created_at": creation_time, + "webui": webui_config, + "extensions": ext_config + } + + +def restore_webui_config(config): + print("* Restoring webui state...") + + if "webui" not in config: + print("Error: No webui data saved to config") + return + + webui_config = config["webui"] + + if "commit_hash" not in webui_config: + print("Error: No commit saved to webui config") + return + + webui_commit_hash = webui_config.get("commit_hash", None) + webui_repo = None + + try: + if os.path.exists(os.path.join(script_path, ".git")): + webui_repo = git.Repo(script_path) + except Exception: + print(f"Error reading webui git info from {script_path}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + return + + try: + webui_repo.git.fetch(all=True) + webui_repo.git.reset(webui_commit_hash, hard=True) + print(f"* Restored webui to commit {webui_commit_hash}.") + except Exception: + print(f"Error restoring webui to commit {webui_commit_hash}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + +def restore_extension_config(config): + print("* Restoring extension state...") + + if "extensions" not in config: + print("Error: No extension data saved to config") + return + + ext_config = config["extensions"] + + results = [] + disabled = [] + + for ext in tqdm.tqdm(extensions.extensions): + if ext.is_builtin: + continue + + ext.read_info_from_repo() + current_commit = ext.commit_hash + + if ext.name not in ext_config: + ext.disabled = True + disabled.append(ext.name) + results.append((ext, current_commit[:8], False, "Saved extension state not found in config, marking as disabled")) + continue + + entry = ext_config[ext.name] + + if "commit_hash" in entry and entry["commit_hash"]: + try: + ext.fetch_and_reset_hard(entry["commit_hash"]) + ext.read_info_from_repo() + if current_commit != entry["commit_hash"]: + results.append((ext, current_commit[:8], True, entry["commit_hash"][:8])) + except Exception as ex: + results.append((ext, current_commit[:8], False, ex)) + else: + results.append((ext, current_commit[:8], False, "No commit hash found in config")) + + if not entry.get("enabled", False): + ext.disabled = True + disabled.append(ext.name) + else: + ext.disabled = False + + shared.opts.disabled_extensions = disabled + shared.opts.save(shared.config_filename) + + print("* Finished restoring extensions. Results:") + for ext, prev_commit, success, result in results: + if success: + print(f" + {ext.name}: {prev_commit} -> {result}") + else: + print(f" ! {ext.name}: FAILURE ({result})") diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index ed677b3e..b94b3a3a 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -17,6 +17,7 @@ from modules.paths_internal import config_states_dir from modules.call_queue import wrap_gradio_gpu_call available_extensions = {"extensions": []} +STYLE_PRIMARY = ' style="color: var(--primary-400)"' def check_access(): @@ -67,7 +68,7 @@ def save_config_state(name): config_states.list_config_states() new_value = next(iter(config_states.all_config_states.keys()), "Current") new_choices = ["Current"] + list(config_states.all_config_states.keys()) - return gr.Dropdown.update(value=new_value, choices=new_choices), f"Saved current webui/extension state to '{filename}'" + return gr.Dropdown.update(value=new_value, choices=new_choices), f"Saved current webui/extension state to \"{filename}\"" def restore_config_state(confirmed, config_state_name, restore_type): @@ -78,14 +79,12 @@ def restore_config_state(confirmed, config_state_name, restore_type): check_access() - save_config_state("Backup (pre-restore)") - config_state = config_states.all_config_states[config_state_name] - print(f"Restoring webui state from backup: {restore_type}") + print(f"*** Restoring webui state from backup: {restore_type} ***") if restore_type == "extensions" or restore_type == "both": - shared.opts.restore_config_state_file = config_state["filename"] + shared.opts.restore_config_state_file = config_state["filepath"] shared.opts.save(shared.config_filename) if restore_type == "webui" or restore_type == "both": @@ -149,7 +148,7 @@ def extension_table(): style = "" if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all": - style = ' style="color: var(--primary-400)"' + style = STYLE_PRIMARY code += f""" @@ -181,17 +180,25 @@ def update_config_states_table(state_name): webui_remote = config_state["webui"]["remote"] or "" webui_branch = config_state["webui"]["branch"] - webui_commit_hash = config_state["webui"]["commit_hash"] - if webui_commit_hash: - webui_commit_hash = webui_commit_hash[:8] - else: - webui_commit_hash = "" + webui_commit_hash = config_state["webui"]["commit_hash"] or "" webui_commit_date = config_state["webui"]["commit_date"] if webui_commit_date: webui_commit_date = time.asctime(time.gmtime(webui_commit_date)) else: webui_commit_date = "" + current_webui = config_states.get_webui_config() + + style_remote = "" + style_branch = "" + style_commit = "" + if current_webui["remote"] != webui_remote: + style_remote = STYLE_PRIMARY + if current_webui["branch"] != webui_branch: + style_branch = STYLE_PRIMARY + if current_webui["commit_hash"] != webui_commit_hash: + style_commit = STYLE_PRIMARY + code += f"""

Config Backup: {config_name}

Created at: {created_date}""" @@ -207,10 +214,10 @@ def update_config_states_table(state_name): - {webui_remote} - {webui_branch} - {webui_commit_hash} - {webui_commit_date} + {webui_remote} + {webui_branch} + {webui_commit_hash[:8]} + {webui_commit_date} @@ -253,13 +260,13 @@ def update_config_states_table(state_name): current_ext = ext_map[ext_name] current_ext.read_info_from_repo() if current_ext.enabled != ext_enabled: - style_enabled = ' style="color: var(--primary-400)"' + style_enabled = STYLE_PRIMARY if current_ext.remote != ext_remote: - style_remote = ' style="color: var(--primary-400)"' + style_remote = STYLE_PRIMARY if current_ext.branch != ext_branch: - style_branch = ' style="color: var(--primary-400)"' + style_branch = STYLE_PRIMARY if current_ext.commit_hash != ext_commit_hash: - style_commit = ' style="color: var(--primary-400)"' + style_commit = STYLE_PRIMARY code += f""" diff --git a/webui.py b/webui.py index b8f9a2c1..5ce45056 100644 --- a/webui.py +++ b/webui.py @@ -114,8 +114,10 @@ def initialize(): print(f"*** About to restore extension state from file: {config_state_file}") with open(config_state_file, "r", encoding="utf-8") as f: config_state = json.load(f) - config_states.restore_extension_state(config_state) + config_states.restore_extension_config(config_state) startup_timer.record("restore extension config") + else: + print(f"!!! Config state backup not found: {config_state_file}") if cmd_opts.ui_debug_mode: shared.sd_upscalers = upscaler.UpscalerLanczos().scalers @@ -321,8 +323,10 @@ def webui(): print(f"*** About to restore extension state from file: {config_state_file}") with open(config_state_file, "r", encoding="utf-8") as f: config_state = json.load(f) - config_states.restore_extension_state(config_state) + config_states.restore_extension_config(config_state) startup_timer.record("restore extension config") + else: + print(f"!!! Config state backup not found: {config_state_file}") localization.list_localizations(cmd_opts.localizations_dir) -- cgit v1.2.3 From 563d048780790fb9e7e9a5fc54a86bdf1ee65d58 Mon Sep 17 00:00:00 2001 From: space-nuko <24979496+space-nuko@users.noreply.github.com> Date: Wed, 29 Mar 2023 19:22:45 -0500 Subject: Squelch warning if no config restore --- webui.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 5ce45056..ee5e4f11 100644 --- a/webui.py +++ b/webui.py @@ -116,7 +116,7 @@ def initialize(): config_state = json.load(f) config_states.restore_extension_config(config_state) startup_timer.record("restore extension config") - else: + elif config_state_file: print(f"!!! Config state backup not found: {config_state_file}") if cmd_opts.ui_debug_mode: @@ -325,7 +325,7 @@ def webui(): config_state = json.load(f) config_states.restore_extension_config(config_state) startup_timer.record("restore extension config") - else: + elif config_state_file: print(f"!!! Config state backup not found: {config_state_file}") localization.list_localizations(cmd_opts.localizations_dir) -- cgit v1.2.3 From d5063e07e8b4737621978feffd37b18077b9ea64 Mon Sep 17 00:00:00 2001 From: Vladimir Mandic Date: Thu, 30 Mar 2023 10:57:54 -0400 Subject: update torch --- .gitignore | 2 +- environment-wsl2.yaml | 11 ++++++----- launch.py | 6 +++--- requirements.txt | 1 + requirements_versions.txt | 4 ++-- webui-macos-env.sh | 2 +- webui.py | 2 ++ 7 files changed, 16 insertions(+), 12 deletions(-) (limited to 'webui.py') diff --git a/.gitignore b/.gitignore index 0b1d17ca..3b48ba9a 100644 --- a/.gitignore +++ b/.gitignore @@ -32,4 +32,4 @@ notification.mp3 /extensions /test/stdout.txt /test/stderr.txt -/cache.json +/cache.json* diff --git a/environment-wsl2.yaml b/environment-wsl2.yaml index f8872750..06134565 100644 --- a/environment-wsl2.yaml +++ b/environment-wsl2.yaml @@ -4,8 +4,9 @@ channels: - defaults dependencies: - python=3.10 - - pip=22.2.2 - - cudatoolkit=11.3 - - pytorch=1.12.1 - - torchvision=0.13.1 - - numpy=1.23.1 \ No newline at end of file + - pip=23.0 + - cudatoolkit=11.8 + - pytorch=2.0 + - torchvision=0.15 + - numpy=1.23 + \ No newline at end of file diff --git a/launch.py b/launch.py index 68e08114..37c8b516 100644 --- a/launch.py +++ b/launch.py @@ -225,10 +225,10 @@ def run_extensions_installers(settings_file): def prepare_environment(): global skip_install - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117") + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") - xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425') + xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17') gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379") clip_package = os.environ.get('CLIP_PACKAGE', "git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1") openclip_package = os.environ.get('OPENCLIP_PACKAGE', "git+https://github.com/mlfoundations/open_clip.git@bb6e834e9c70d9c27d0dc3ecedeebeaeb1ffad6b") @@ -296,7 +296,7 @@ def prepare_environment(): if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) - run_pip(f"install -r \"{requirements_file}\"", "requirements for Web UI") + run_pip(f"install -r \"{requirements_file}\"", "requirements") run_extensions_installers(settings_file=args.ui_settings_file) diff --git a/requirements.txt b/requirements.txt index c72b2927..36cdae6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +astunparse blendmodes accelerate basicsr diff --git a/requirements_versions.txt b/requirements_versions.txt index df65431a..6487f1c3 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -1,10 +1,10 @@ blendmodes==2022 transformers==4.25.1 -accelerate==0.12.0 +accelerate==0.18.0 basicsr==1.4.2 gfpgan==1.3.8 gradio==3.23 -numpy==1.23.3 +numpy==1.23.5 Pillow==9.4.0 realesrgan==0.3.0 torch diff --git a/webui-macos-env.sh b/webui-macos-env.sh index 37cac4fb..65d80413 100644 --- a/webui-macos-env.sh +++ b/webui-macos-env.sh @@ -11,7 +11,7 @@ fi export install_dir="$HOME" export COMMANDLINE_ARGS="--skip-torch-cuda-test --upcast-sampling --no-half-vae --use-cpu interrogate" -export TORCH_COMMAND="pip install torch==1.12.1 torchvision==0.13.1" +export TORCH_COMMAND="pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu118" export K_DIFFUSION_REPO="https://github.com/brkirch/k-diffusion.git" export K_DIFFUSION_COMMIT_HASH="51c9778f269cedb55a4d88c79c0246d35bdadb71" export PYTORCH_ENABLE_MPS_FALLBACK=1 diff --git a/webui.py b/webui.py index b570895f..54552cdd 100644 --- a/webui.py +++ b/webui.py @@ -20,6 +20,8 @@ startup_timer = timer.Timer() import torch import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") +warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") + startup_timer.record("import torch") import gradio -- cgit v1.2.3 From aef42bfec09a9ca93d1222b7b47256f37e192a32 Mon Sep 17 00:00:00 2001 From: keith <1868690+wk5ovc@users.noreply.github.com> Date: Mon, 3 Apr 2023 17:05:49 +0800 Subject: Fix #9046 /sdapi/v1/txt2img endpoint not working **Describe what this pull request is trying to achieve.** Fix https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/9046 **Environment this was tested in** * OS: Linux * Browser: chrome * Graphics card: RTX 3090 --- webui.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) (limited to 'webui.py') diff --git a/webui.py b/webui.py index b570895f..8927aa33 100644 --- a/webui.py +++ b/webui.py @@ -5,6 +5,7 @@ import importlib import signal import re import warnings +import asyncio from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -66,6 +67,46 @@ if cmd_opts.server_name: else: server_name = "0.0.0.0" if cmd_opts.listen else None +if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + # "Any thread" and "selector" should be orthogonal, but there's not a clean + # interface for composing policies so pick the right base. + _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore +else: + _BasePolicy = asyncio.DefaultEventLoopPolicy + + +class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore + """Event loop policy that allows loop creation on any thread. + + The default `asyncio` event loop policy only automatically creates + event loops in the main threads. Other threads must create event + loops explicitly or `asyncio.get_event_loop` (and therefore + `.IOLoop.current`) will fail. Installing this policy allows event + loops to be created automatically on any thread, matching the + behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). + + Usage:: + + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + + .. versionadded:: 5.0 + + """ + + def get_event_loop(self) -> asyncio.AbstractEventLoop: + try: + return super().get_event_loop() + except (RuntimeError, AssertionError): + # This was an AssertionError in python 3.4.2 (which ships with debian jessie) + # and changed to a RuntimeError in 3.4.3. + # "There is no current event loop in thread %r" + loop = self.new_event_loop() + self.set_event_loop(loop) + return loop + + +asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + def check_versions(): if shared.cmd_opts.skip_version_check: -- cgit v1.2.3 From 7ea5be3e29dd79e2a7669e0b02f8ce66354c5416 Mon Sep 17 00:00:00 2001 From: aniaan Date: Wed, 26 Apr 2023 20:59:55 +0800 Subject: perf(webui): Remove duplicate code --- webui.py | 3 --- 1 file changed, 3 deletions(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index b570895f..bde4025b 100644 --- a/webui.py +++ b/webui.py @@ -126,9 +126,6 @@ def initialize(): modules.scripts.load_scripts() startup_timer.record("load scripts") - modelloader.load_upscalers() - startup_timer.record("load upscalers") - modules.sd_vae.refresh_vae_list() startup_timer.record("refresh VAE") -- cgit v1.2.3 From 43186ad08407c08835f5678d10c7600e23b7fe8f Mon Sep 17 00:00:00 2001 From: Garrett Sutula Date: Thu, 27 Apr 2023 20:29:21 -0400 Subject: Add tls_verify arg for use with self-signed certs --- modules/cmd_args.py | 1 + webui.py | 1 + 2 files changed, 2 insertions(+) (limited to 'webui.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 81c0b82a..3aeecdcc 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -95,6 +95,7 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin( parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) +parser.add_argument("--tls-verify", type=bool, help="Enables the use of self-signed certificates when set to False", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True) parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions") diff --git a/webui.py b/webui.py index b570895f..c86d3d49 100644 --- a/webui.py +++ b/webui.py @@ -260,6 +260,7 @@ def webui(): server_port=cmd_opts.port, ssl_keyfile=cmd_opts.tls_keyfile, ssl_certfile=cmd_opts.tls_certfile, + ssl_verify=cmd_opts.tls_verify, debug=cmd_opts.gradio_debug, auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None, inbrowser=cmd_opts.autolaunch, -- cgit v1.2.3 From d1e62b296146bd158630b8444e4720a1ae445151 Mon Sep 17 00:00:00 2001 From: Garrett Sutula Date: Thu, 27 Apr 2023 21:30:19 -0400 Subject: Improve param semantics, --- modules/cmd_args.py | 2 +- webui.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'webui.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index 3aeecdcc..f47c21bb 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -95,7 +95,7 @@ parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origin( parser.add_argument("--cors-allow-origins-regex", type=str, help="Allowed CORS origin(s) in the form of a single regular expression", default=None) parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) -parser.add_argument("--tls-verify", type=bool, help="Enables the use of self-signed certificates when set to False", default=None) +parser.add_argument("--disable-tls-verify", action="store_false", help="When passed, enables the use of self-signed certificates.", default=None) parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) parser.add_argument("--gradio-queue", action='store_true', help="does not do anything", default=True) parser.add_argument("--no-gradio-queue", action='store_true', help="Disables gradio queue; causes the webpage to use http requests instead of websockets; was the defaul in earlier versions") diff --git a/webui.py b/webui.py index c86d3d49..a4709032 100644 --- a/webui.py +++ b/webui.py @@ -260,7 +260,7 @@ def webui(): server_port=cmd_opts.port, ssl_keyfile=cmd_opts.tls_keyfile, ssl_certfile=cmd_opts.tls_certfile, - ssl_verify=cmd_opts.tls_verify, + ssl_verify=cmd_opts.disable_tls_verify, debug=cmd_opts.gradio_debug, auth=[tuple(cred.split(':')) for cred in gradio_auth_creds] if gradio_auth_creds else None, inbrowser=cmd_opts.autolaunch, -- cgit v1.2.3 From 86bafb625ae2671722abe1f95a4c0fc5e38381e3 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Apr 2023 10:21:01 +0300 Subject: put asyncio fix into a function to make it more obvious where it starts and ends --- webui.py | 67 ++++++++++++++++++++++++++++++++-------------------------------- 1 file changed, 34 insertions(+), 33 deletions(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index fa8e21ae..0e2a3df0 100644 --- a/webui.py +++ b/webui.py @@ -5,7 +5,6 @@ import importlib import signal import re import warnings -import asyncio from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware @@ -67,45 +66,45 @@ if cmd_opts.server_name: else: server_name = "0.0.0.0" if cmd_opts.listen else None -if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): - # "Any thread" and "selector" should be orthogonal, but there's not a clean - # interface for composing policies so pick the right base. - _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore -else: - _BasePolicy = asyncio.DefaultEventLoopPolicy - - -class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore - """Event loop policy that allows loop creation on any thread. - The default `asyncio` event loop policy only automatically creates - event loops in the main threads. Other threads must create event - loops explicitly or `asyncio.get_event_loop` (and therefore - `.IOLoop.current`) will fail. Installing this policy allows event - loops to be created automatically on any thread, matching the - behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). - - Usage:: +def fix_asyncio_event_loop_policy(): + """ + The default `asyncio` event loop policy only automatically creates + event loops in the main threads. Other threads must create event + loops explicitly or `asyncio.get_event_loop` (and therefore + `.IOLoop.current`) will fail. Installing this policy allows event + loops to be created automatically on any thread, matching the + behavior of Tornado versions prior to 5.0 (or 5.0 on Python 2). + """ - asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + import asyncio - .. versionadded:: 5.0 + if sys.platform == "win32" and hasattr(asyncio, "WindowsSelectorEventLoopPolicy"): + # "Any thread" and "selector" should be orthogonal, but there's not a clean + # interface for composing policies so pick the right base. + _BasePolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore + else: + _BasePolicy = asyncio.DefaultEventLoopPolicy - """ + class AnyThreadEventLoopPolicy(_BasePolicy): # type: ignore + """Event loop policy that allows loop creation on any thread. + Usage:: - def get_event_loop(self) -> asyncio.AbstractEventLoop: - try: - return super().get_event_loop() - except (RuntimeError, AssertionError): - # This was an AssertionError in python 3.4.2 (which ships with debian jessie) - # and changed to a RuntimeError in 3.4.3. - # "There is no current event loop in thread %r" - loop = self.new_event_loop() - self.set_event_loop(loop) - return loop + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + """ + def get_event_loop(self) -> asyncio.AbstractEventLoop: + try: + return super().get_event_loop() + except (RuntimeError, AssertionError): + # This was an AssertionError in python 3.4.2 (which ships with debian jessie) + # and changed to a RuntimeError in 3.4.3. + # "There is no current event loop in thread %r" + loop = self.new_event_loop() + self.set_event_loop(loop) + return loop -asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) + asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) def check_versions(): @@ -140,6 +139,8 @@ Use --skip-version-check commandline argument to disable this check. def initialize(): + fix_asyncio_event_loop_policy() + check_versions() extensions.list_extensions() -- cgit v1.2.3 From ee71eee1818f6f6eba9895c93ba25e0cad27e069 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Apr 2023 12:36:50 +0300 Subject: stuff related to torch version change --- launch.py | 8 ++++---- modules/safe.py | 5 +---- requirements_versions.txt | 2 +- webui.py | 6 ++++-- 4 files changed, 10 insertions(+), 11 deletions(-) (limited to 'webui.py') diff --git a/launch.py b/launch.py index 4256ef25..af1c8309 100644 --- a/launch.py +++ b/launch.py @@ -121,12 +121,12 @@ def run_python(code, desc=None, errdesc=None): return run(f'"{python}" -c "{code}"', desc, errdesc) -def run_pip(args, desc=None): +def run_pip(args, desc=None, live=False): if skip_install: return index_url_line = f' --index-url {index_url}' if index_url != '' else '' - return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}") + return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}", live=live) def check_run_python(code): @@ -225,7 +225,7 @@ def run_extensions_installers(settings_file): def prepare_environment(): global skip_install - torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118") + torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==2.0.0 torchvision==0.15.1 --index-url https://download.pytorch.org/whl/cu118") requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.17') @@ -271,7 +271,7 @@ def prepare_environment(): if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers: if platform.system() == "Windows": if platform.python_version().startswith("3.10"): - run_pip(f"install -U -I --no-deps {xformers_package}", "xformers") + run_pip(f"install -U -I --no-deps {xformers_package}", "xformers", live=True) else: print("Installation of xformers is not supported in this version of Python.") print("You can also check this and build manually: https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers#building-xformers-on-windows-by-duckness") diff --git a/modules/safe.py b/modules/safe.py index 82d44be3..dadf319c 100644 --- a/modules/safe.py +++ b/modules/safe.py @@ -1,6 +1,5 @@ # this code is adapted from the script contributed by anon from /h/ -import io import pickle import collections import sys @@ -12,11 +11,9 @@ import _codecs import zipfile import re - # PyTorch 1.13 and later have _TypedStorage renamed to TypedStorage TypedStorage = torch.storage.TypedStorage if hasattr(torch.storage, 'TypedStorage') else torch.storage._TypedStorage - def encode(*args): out = _codecs.encode(*args) return out @@ -27,7 +24,7 @@ class RestrictedUnpickler(pickle.Unpickler): def persistent_load(self, saved_id): assert saved_id[0] == 'storage' - return TypedStorage() + return TypedStorage(_internal=True) def find_class(self, module, name): if self.extra_handler is not None: diff --git a/requirements_versions.txt b/requirements_versions.txt index 52bd9b4d..94d32d3d 100644 --- a/requirements_versions.txt +++ b/requirements_versions.txt @@ -25,6 +25,6 @@ lark==1.1.2 inflection==0.5.1 GitPython==3.1.30 torchsde==0.2.5 -safetensors==0.3.0 +safetensors==0.3.1 httpcore<=0.15 fastapi==0.94.0 diff --git a/webui.py b/webui.py index 1f97b475..3fd6e1e9 100644 --- a/webui.py +++ b/webui.py @@ -21,6 +21,8 @@ import torch import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") +warnings.filterwarnings(action='ignore', category=UserWarning, message='TypedStorage is deprecated') + startup_timer.record("import torch") @@ -113,7 +115,7 @@ def check_versions(): if shared.cmd_opts.skip_version_check: return - expected_torch_version = "1.13.1" + expected_torch_version = "2.0.0" if version.parse(torch.__version__) < version.parse(expected_torch_version): errors.print_error_explanation(f""" @@ -126,7 +128,7 @@ there are reports of issues with training tab on the latest version. Use --skip-version-check commandline argument to disable this check. """.strip()) - expected_xformers_version = "0.0.16rc425" + expected_xformers_version = "0.0.17" if shared.xformers_available: import xformers -- cgit v1.2.3 From aee6d9bb74d9868d0782acc54b6cbefa32422a8f Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Apr 2023 12:39:05 +0300 Subject: remove unneeded warning filter --- webui.py | 1 - 1 file changed, 1 deletion(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 3fd6e1e9..759125e0 100644 --- a/webui.py +++ b/webui.py @@ -21,7 +21,6 @@ import torch import pytorch_lightning # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") -warnings.filterwarnings(action='ignore', category=UserWarning, message='TypedStorage is deprecated') startup_timer.record("import torch") -- cgit v1.2.3 From a95dc02535ce647efbb3bc66964ab61b6f1446f4 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 29 Apr 2023 19:05:43 +0300 Subject: remove unwanted changes from #8789 --- modules/script_callbacks.py | 15 +++------------ webui.py | 4 ++-- 2 files changed, 5 insertions(+), 14 deletions(-) (limited to 'webui.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index cede1cb0..17109732 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -94,7 +94,6 @@ callback_map = dict( callbacks_script_unloaded=[], callbacks_before_ui=[], callbacks_on_reload=[], - callbacks_on_polling=[], ) @@ -102,6 +101,7 @@ def clear_callbacks(): for callback_list in callback_map.values(): callback_list.clear() + def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callback_map['callbacks_app_started']: try: @@ -109,14 +109,8 @@ def app_started_callback(demo: Optional[Blocks], app: FastAPI): except Exception: report_exception(c, 'app_started_callback') -def app_polling_callback(demo: Optional[Blocks], app: FastAPI): - for c in callback_map['callbacks_on_polling']: - try: - c.callback() - except Exception: - report_exception(c, 'callbacks_on_polling') -def app_reload_callback(demo: Optional[Blocks], app: FastAPI): +def app_reload_callback(): for c in callback_map['callbacks_on_reload']: try: c.callback() @@ -269,14 +263,11 @@ def on_app_started(callback): add_callback(callback_map['callbacks_app_started'], callback) -def on_polling(callback): - """register a function to be called on each polling of the server.""" - add_callback(callback_map['callbacks_on_polling'], callback) - def on_before_reload(callback): """register a function to be called just before the server reloads.""" add_callback(callback_map['callbacks_on_reload'], callback) + def on_model_loaded(callback): """register a function to be called when the stable diffusion model is created; the model is passed as an argument; this function is also called when the script is reloaded. """ diff --git a/webui.py b/webui.py index 98b6d394..a0d5e2cd 100644 --- a/webui.py +++ b/webui.py @@ -264,13 +264,13 @@ def create_api(app): def wait_on_server(demo=None): while 1: time.sleep(0.5) - modules.script_callbacks.app_polling_callback(None, demo) if shared.state.need_restart: - modules.script_callbacks.app_reload_callback(None, demo) shared.state.need_restart = False time.sleep(0.5) demo.close() time.sleep(0.5) + + modules.script_callbacks.app_reload_callback() break -- cgit v1.2.3