diff options
Diffstat (limited to 'webui.py')
-rw-r--r-- | webui.py | 229 |
1 files changed, 162 insertions, 67 deletions
@@ -1,108 +1,203 @@ import os
+import sys
import threading
-
-from modules import devices
-from modules.paths import script_path
+import time
+import importlib
import signal
import threading
-import modules.paths
+from fastapi import FastAPI
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.middleware.gzip import GZipMiddleware
+
+from modules import import_hook, errors
+from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
+from modules.paths import script_path
+
+from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
import modules.codeformer_model as codeformer
-import modules.esrgan_model as esrgan
-import modules.bsrgan_model as bsrgan
import modules.extras
import modules.face_restoration
import modules.gfpgan_model as gfpgan
import modules.img2img
-import modules.ldsr_model as ldsr
+
import modules.lowvram
-import modules.realesrgan_model as realesrgan
+import modules.paths
import modules.scripts
import modules.sd_hijack
import modules.sd_models
-import modules.shared as shared
-import modules.swinir_model as swinir
+import modules.sd_vae
import modules.txt2img
+import modules.script_callbacks
+
import modules.ui
from modules import modelloader
-from modules.paths import script_path
from modules.shared import cmd_opts
+import modules.hypernetworks.hypernetwork
+
+
+if cmd_opts.server_name:
+ server_name = cmd_opts.server_name
+else:
+ server_name = "0.0.0.0" if cmd_opts.listen else None
+
+
+def initialize():
+ extensions.list_extensions()
+ localization.list_localizations(cmd_opts.localizations_dir)
+
+ if cmd_opts.ui_debug_mode:
+ shared.sd_upscalers = upscaler.UpscalerLanczos().scalers
+ modules.scripts.load_scripts()
+ return
+
+ modelloader.cleanup_models()
+ modules.sd_models.setup_model()
+ codeformer.setup_model(cmd_opts.codeformer_models_path)
+ gfpgan.setup_model(cmd_opts.gfpgan_models_path)
+ shared.face_restorers.append(modules.face_restoration.FaceRestoration())
+
+ modelloader.list_builtin_upscalers()
+ modules.scripts.load_scripts()
+ modelloader.load_upscalers()
+
+ modules.sd_vae.refresh_vae_list()
+
+ try:
+ modules.sd_models.load_model()
+ except Exception as e:
+ errors.display(e, "loading stable diffusion model")
+ print("", file=sys.stderr)
+ print("Stable diffusion model failed to load, exiting", file=sys.stderr)
+ exit(1)
+
+ shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
+ shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
+ shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
+ shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
+
+ if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:
+
+ try:
+ if not os.path.exists(cmd_opts.tls_keyfile):
+ print("Invalid path to TLS keyfile given")
+ if not os.path.exists(cmd_opts.tls_certfile):
+ print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'")
+ except TypeError:
+ cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None
+ print("TLS setup invalid, running webui without TLS")
+ else:
+ print("Running with TLS")
-modelloader.cleanup_models()
-modules.sd_models.setup_model(cmd_opts.ckpt_dir)
-codeformer.setup_model(cmd_opts.codeformer_models_path)
-gfpgan.setup_model(cmd_opts.gfpgan_models_path)
-shared.face_restorers.append(modules.face_restoration.FaceRestoration())
-modelloader.load_upscalers()
-queue_lock = threading.Lock()
+ # make the program just exit at ctrl+c without waiting for anything
+ def sigint_handler(sig, frame):
+ print(f'Interrupted with signal {sig} in {frame}')
+ os._exit(0)
+ signal.signal(signal.SIGINT, sigint_handler)
-def wrap_queued_call(func):
- def f(*args, **kwargs):
- with queue_lock:
- res = func(*args, **kwargs)
- return res
+def setup_cors(app):
+ if cmd_opts.cors_allow_origins and cmd_opts.cors_allow_origins_regex:
+ app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
+ elif cmd_opts.cors_allow_origins:
+ app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*'])
+ elif cmd_opts.cors_allow_origins_regex:
+ app.add_middleware(CORSMiddleware, allow_origin_regex=cmd_opts.cors_allow_origins_regex, allow_methods=['*'])
- return f
+def create_api(app):
+ from modules.api.api import Api
+ api = Api(app, queue_lock)
+ return api
-def wrap_gradio_gpu_call(func):
- def f(*args, **kwargs):
- devices.torch_gc()
- shared.state.sampling_step = 0
- shared.state.job_count = -1
- shared.state.job_no = 0
- shared.state.job_timestamp = shared.state.get_job_timestamp()
- shared.state.current_latent = None
- shared.state.current_image = None
- shared.state.current_image_sampling_step = 0
- shared.state.interrupted = False
+def wait_on_server(demo=None):
+ while 1:
+ time.sleep(0.5)
+ if shared.state.need_restart:
+ shared.state.need_restart = False
+ time.sleep(0.5)
+ demo.close()
+ time.sleep(0.5)
+ break
- with queue_lock:
- res = func(*args, **kwargs)
- shared.state.job = ""
- shared.state.job_count = 0
+def api_only():
+ initialize()
- devices.torch_gc()
+ app = FastAPI()
+ setup_cors(app)
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
+ api = create_api(app)
- return res
+ modules.script_callbacks.app_started_callback(None, app)
- return modules.ui.wrap_gradio_call(f)
+ api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861)
-modules.scripts.load_scripts(os.path.join(script_path, "scripts"))
+def webui():
+ launch_api = cmd_opts.api
+ initialize()
-shared.sd_model = modules.sd_models.load_model()
-shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
+ while 1:
+ if shared.opts.clean_temp_dir_at_start:
+ ui_tempdir.cleanup_tmpdr()
+ shared.demo = modules.ui.create_ui()
-def webui():
- # make the program just exit at ctrl+c without waiting for anything
- def sigint_handler(sig, frame):
- print(f'Interrupted with signal {sig} in {frame}')
- os._exit(0)
+ app, local_url, share_url = shared.demo.queue(default_enabled=False).launch(
+ share=cmd_opts.share,
+ server_name=server_name,
+ server_port=cmd_opts.port,
+ ssl_keyfile=cmd_opts.tls_keyfile,
+ ssl_certfile=cmd_opts.tls_certfile,
+ debug=cmd_opts.gradio_debug,
+ auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
+ inbrowser=cmd_opts.autolaunch,
+ prevent_thread_lock=True
+ )
+ # after initial launch, disable --autolaunch for subsequent restarts
+ cmd_opts.autolaunch = False
- signal.signal(signal.SIGINT, sigint_handler)
+ # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
+ # an attacker to trick the user into opening a malicious HTML page, which makes a request to the
+ # running web ui and do whatever the attacker wants, including installing an extension and
+ # running its code. We disable this here. Suggested by RyotaK.
+ app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware']
+
+ setup_cors(app)
+
+ app.add_middleware(GZipMiddleware, minimum_size=1000)
+
+ if launch_api:
+ create_api(app)
+
+ modules.script_callbacks.app_started_callback(shared.demo, app)
+ modules.script_callbacks.app_started_callback(shared.demo, app)
+
+ wait_on_server(shared.demo)
+ print('Restarting UI...')
+
+ sd_samplers.set_samplers()
+
+ extensions.list_extensions()
+
+ localization.list_localizations(cmd_opts.localizations_dir)
+
+ modelloader.forbid_loaded_nonbuiltin_upscalers()
+ modules.scripts.reload_scripts()
+ modelloader.load_upscalers()
- demo = modules.ui.create_ui(
- txt2img=wrap_gradio_gpu_call(modules.txt2img.txt2img),
- img2img=wrap_gradio_gpu_call(modules.img2img.img2img),
- run_extras=wrap_gradio_gpu_call(modules.extras.run_extras),
- run_pnginfo=modules.extras.run_pnginfo,
- run_modelmerger=modules.extras.run_modelmerger
- )
+ for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
+ importlib.reload(module)
- demo.launch(
- share=cmd_opts.share,
- server_name="0.0.0.0" if cmd_opts.listen else None,
- server_port=cmd_opts.port,
- debug=cmd_opts.gradio_debug,
- auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None,
- inbrowser=cmd_opts.autolaunch,
- )
+ modules.sd_models.list_models()
if __name__ == "__main__":
- webui()
+ if cmd_opts.nowebui:
+ api_only()
+ else:
+ webui()
|