aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAarni Koskela <akx@iki.fi>2023-05-11 20:46:45 +0000
committerAarni Koskela <akx@iki.fi>2023-05-17 07:15:03 +0000
commit85b4f89926f7c3aaa7846dcbb47df3fd3b483b6b (patch)
tree716f033e712fe3eb7fa14bdfb6f8fe51836ae58c
parent4b07f2f584596604c4499efb0b0295e96985080f (diff)
downloadstable-diffusion-webui-gfx803-85b4f89926f7c3aaa7846dcbb47df3fd3b483b6b.tar.gz
stable-diffusion-webui-gfx803-85b4f89926f7c3aaa7846dcbb47df3fd3b483b6b.tar.bz2
stable-diffusion-webui-gfx803-85b4f89926f7c3aaa7846dcbb47df3fd3b483b6b.zip
Replace state.need_restart with state.server_command + replace poll loop with signal
-rw-r--r--modules/shared.py42
-rw-r--r--modules/ui.py6
-rw-r--r--modules/ui_extensions.py7
-rw-r--r--webui.py39
4 files changed, 68 insertions, 26 deletions
diff --git a/modules/shared.py b/modules/shared.py
index 3abf71c0..648a2a19 100644
--- a/modules/shared.py
+++ b/modules/shared.py
@@ -2,6 +2,7 @@ import datetime
import json
import os
import sys
+import threading
import time
import gradio as gr
@@ -110,8 +111,47 @@ class State:
id_live_preview = 0
textinfo = None
time_start = None
- need_restart = False
server_start = None
+ _server_command_signal = threading.Event()
+ _server_command: str | None = None
+
+ @property
+ def need_restart(self) -> bool:
+ # Compatibility getter for need_restart.
+ return self.server_command == "restart"
+
+ @need_restart.setter
+ def need_restart(self, value: bool) -> None:
+ # Compatibility setter for need_restart.
+ if value:
+ self.server_command = "restart"
+
+ @property
+ def server_command(self):
+ return self._server_command
+
+ @server_command.setter
+ def server_command(self, value: str | None) -> None:
+ """
+ Set the server command to `value` and signal that it's been set.
+ """
+ self._server_command = value
+ self._server_command_signal.set()
+
+ def wait_for_server_command(self, timeout: float | None = None) -> str | None:
+ """
+ Wait for server command to get set; return and clear the value and signal.
+ """
+ if self._server_command_signal.wait(timeout):
+ self._server_command_signal.clear()
+ req = self._server_command
+ self._server_command = None
+ return req
+ return None
+
+ def request_restart(self) -> None:
+ self.interrupt()
+ self.server_command = True
def skip(self):
self.skipped = True
diff --git a/modules/ui.py b/modules/ui.py
index 8e51e782..bed8464e 100644
--- a/modules/ui.py
+++ b/modules/ui.py
@@ -1609,12 +1609,8 @@ def create_ui():
outputs=[]
)
- def request_restart():
- shared.state.interrupt()
- shared.state.need_restart = True
-
restart_gradio.click(
- fn=request_restart,
+ fn=shared.state.request_restart,
_js='restart_reload',
inputs=[],
outputs=[],
diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py
index d7a0f685..4ba3bdd7 100644
--- a/modules/ui_extensions.py
+++ b/modules/ui_extensions.py
@@ -52,9 +52,7 @@ def apply_and_restart(disable_list, update_list, disable_all):
shared.opts.disabled_extensions = disabled
shared.opts.disable_all_extensions = disable_all
shared.opts.save(shared.config_filename)
-
- shared.state.interrupt()
- shared.state.need_restart = True
+ shared.state.request_restart()
def save_config_state(name):
@@ -92,8 +90,7 @@ def restore_config_state(confirmed, config_state_name, restore_type):
if restore_type == "webui" or restore_type == "both":
config_states.restore_webui_config(config_state)
- shared.state.interrupt()
- shared.state.need_restart = True
+ shared.state.request_restart()
return ""
diff --git a/webui.py b/webui.py
index 293a16cc..39dec3ca 100644
--- a/webui.py
+++ b/webui.py
@@ -234,7 +234,10 @@ def initialize():
print(f'Interrupted with signal {sig} in {frame}')
os._exit(0)
- signal.signal(signal.SIGINT, sigint_handler)
+ if not os.environ.get("COVERAGE_RUN"):
+ # Don't install the immediate-quit handler when running under coverage,
+ # as then the coverage report won't be generated.
+ signal.signal(signal.SIGINT, sigint_handler)
def setup_middleware(app):
@@ -255,19 +258,6 @@ def create_api(app):
return api
-def wait_on_server(demo=None):
- while 1:
- time.sleep(0.5)
- if shared.state.need_restart:
- shared.state.need_restart = False
- time.sleep(0.5)
- demo.close()
- time.sleep(0.5)
-
- modules.script_callbacks.app_reload_callback()
- break
-
-
def api_only():
initialize()
@@ -328,6 +318,7 @@ def webui():
inbrowser=cmd_opts.autolaunch,
prevent_thread_lock=True
)
+
# after initial launch, disable --autolaunch for subsequent restarts
cmd_opts.autolaunch = False
@@ -359,8 +350,26 @@ def webui():
redirector.get("/")
gradio.mount_gradio_app(redirector, shared.demo, path=f"/{cmd_opts.subpath}")
- wait_on_server(shared.demo)
+ try:
+ while True:
+ server_command = shared.state.wait_for_server_command(timeout=5)
+ if server_command:
+ if server_command in ("stop", "restart"):
+ break
+ else:
+ print(f"Unknown server command: {server_command}")
+ except KeyboardInterrupt:
+ server_command = "stop"
+
+ if server_command == "stop":
+ # If we catch a keyboard interrupt, we want to stop the server and exit.
+ print('Caught KeyboardInterrupt, stopping...')
+ shared.demo.close()
+ break
print('Restarting UI...')
+ shared.demo.close()
+ time.sleep(0.5)
+ modules.script_callbacks.app_reload_callback()
startup_timer.reset()