diff options
Diffstat (limited to 'webui.py')
-rw-r--r-- | webui.py | 74 |
1 files changed, 28 insertions, 46 deletions
@@ -14,7 +14,6 @@ from typing import Iterable from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
-from packaging import version
import logging
@@ -44,12 +43,16 @@ startup_timer.record("import torch") import gradio # noqa: F401
startup_timer.record("import gradio")
-from modules import paths, timer, import_hook, errors, devices # noqa: F401
+from modules import paths, timer, import_hook, errors # noqa: F401
startup_timer.record("setup paths")
import ldm.modules.encoders.modules # noqa: F401
startup_timer.record("import ldm")
+from modules import shared_init, shared, shared_items
+shared_init.initialize()
+startup_timer.record("initialize shared")
+
from modules import extra_networks
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401
@@ -58,10 +61,13 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__long_version__ = torch.__version__
torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0)
-from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+if not shared.cmd_opts.skip_version_check:
+ errors.check_versions()
+
import modules.codeformer_model as codeformer
-import modules.face_restoration
import modules.gfpgan_model as gfpgan
+from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states
+import modules.face_restoration
import modules.img2img
import modules.lowvram
@@ -77,7 +83,7 @@ import modules.textual_inversion.textual_inversion import modules.progress
import modules.ui
-from modules import modelloader
+from modules import modelloader, devices
from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork
@@ -130,37 +136,6 @@ def fix_asyncio_event_loop_policy(): asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
-def check_versions():
- if shared.cmd_opts.skip_version_check:
- return
-
- expected_torch_version = "2.0.0"
-
- if version.parse(torch.__version__) < version.parse(expected_torch_version):
- errors.print_error_explanation(f"""
-You are running torch {torch.__version__}.
-The program is tested to work with torch {expected_torch_version}.
-To reinstall the desired version, run with commandline flag --reinstall-torch.
-Beware that this will cause a lot of large files to be downloaded, as well as
-there are reports of issues with training tab on the latest version.
-
-Use --skip-version-check commandline argument to disable this check.
- """.strip())
-
- expected_xformers_version = "0.0.20"
- if shared.xformers_available:
- import xformers
-
- if version.parse(xformers.__version__) < version.parse(expected_xformers_version):
- errors.print_error_explanation(f"""
-You are running xformers {xformers.__version__}.
-The program is tested to work with xformers {expected_xformers_version}.
-To reinstall the desired version, run with commandline flag --reinstall-xformers.
-
-Use --skip-version-check commandline argument to disable this check.
- """.strip())
-
-
def restore_config_state_file():
config_state_file = shared.opts.restore_config_state_file
if config_state_file == "":
@@ -237,7 +212,7 @@ def configure_sigint_handler(): def configure_opts_onchange():
shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
- shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
+ shared.opts.onchange("sd_vae_overrides_per_model_preferences", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)
shared.opts.onchange("gradio_theme", shared.reload_gradio_theme)
shared.opts.onchange("cross_attention_optimization", wrap_queued_call(lambda: modules.sd_hijack.model_hijack.redo_hijack(shared.sd_model)), call=False)
@@ -248,7 +223,6 @@ def initialize(): fix_asyncio_event_loop_policy()
validate_tls_options()
configure_sigint_handler()
- check_versions()
modelloader.cleanup_models()
configure_opts_onchange()
@@ -320,11 +294,11 @@ def initialize_rest(*, reload_script_modules=False): if modules.sd_hijack.current_optimizer is None:
modules.sd_hijack.apply_optimizations()
- Thread(target=load_model).start()
+ devices.first_time_calculation()
- Thread(target=devices.first_time_calculation).start()
+ Thread(target=load_model).start()
- shared.reload_hypernetworks()
+ shared_items.reload_hypernetworks()
startup_timer.record("reload hypernetworks")
ui_extra_networks.initialize()
@@ -368,13 +342,14 @@ def api_only(): setup_middleware(app)
api = create_api(app)
+ modules.script_callbacks.before_ui_callback()
modules.script_callbacks.app_started_callback(None, app)
print(f"Startup time: {startup_timer.summary()}.")
api.launch(
server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1",
port=cmd_opts.port if cmd_opts.port else 7861,
- root_path = f"/{cmd_opts.subpath}"
+ root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else ""
)
@@ -398,6 +373,13 @@ def webui(): gradio_auth_creds = list(get_gradio_auth_creds()) or None
+ auto_launch_browser = False
+ if os.getenv('SD_WEBUI_RESTARTING') != '1':
+ if shared.opts.auto_launch_browser == "Remote" or cmd_opts.autolaunch:
+ auto_launch_browser = True
+ elif shared.opts.auto_launch_browser == "Local":
+ auto_launch_browser = not any([cmd_opts.listen, cmd_opts.share, cmd_opts.ngrok])
+
app, local_url, share_url = shared.demo.launch(
share=cmd_opts.share,
server_name=server_name,
@@ -407,7 +389,7 @@ def webui(): ssl_verify=cmd_opts.disable_tls_verify,
debug=cmd_opts.gradio_debug,
auth=gradio_auth_creds,
- inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1',
+ inbrowser=auto_launch_browser,
prevent_thread_lock=True,
allowed_paths=cmd_opts.gradio_allowed_path,
app_kwargs={
@@ -417,9 +399,6 @@ def webui(): root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "",
)
- # after initial launch, disable --autolaunch for subsequent restarts
- cmd_opts.autolaunch = False
-
startup_timer.record("gradio launch")
# gradio uses a very open CORS policy via app.user_middleware, which makes it possible for
@@ -464,6 +443,9 @@ def webui(): shared.demo.close()
break
+ # disable auto launch webui in browser for subsequent UI Reload
+ os.environ.setdefault('SD_WEBUI_RESTARTING', '1')
+
print('Restarting UI...')
shared.demo.close()
time.sleep(0.5)
|