From cb31abcf58ea1f64266e6d821937eed058c35f4d Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Sun, 30 Oct 2022 21:54:31 +0700 Subject: Settings to select VAE --- webui.py | 5 +++++ 1 file changed, 5 insertions(+) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 29530872..27949f3d 100644 --- a/webui.py +++ b/webui.py @@ -21,6 +21,7 @@ import modules.paths import modules.scripts import modules.sd_hijack import modules.sd_models +import modules.sd_vae import modules.shared as shared import modules.txt2img @@ -74,8 +75,12 @@ def initialize(): modules.scripts.load_scripts() + modules.sd_vae.refresh_vae_list() modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model))) + # I don't know what needs to be done to only reload VAE, with all those hijacks callbacks, and lowvram, + # so for now this reloads the whole model too, and no cache + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model, force=True)), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) -- cgit v1.2.3 From 056f06d3738c267b1014e6e8e1ef5bd97af1fb45 Mon Sep 17 00:00:00 2001 From: Muhammad Rizqi Nur Date: Wed, 2 Nov 2022 12:51:46 +0700 Subject: Reload VAE without reloading sd checkpoint --- modules/sd_models.py | 15 ++++---- modules/sd_vae.py | 97 ++++++++++++++++++++++++++++++++++++++++++++++++---- webui.py | 4 +-- 3 files changed, 98 insertions(+), 18 deletions(-) (limited to 'webui.py') diff --git a/modules/sd_models.py b/modules/sd_models.py index 6ab85b65..883639d1 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -159,15 +159,13 @@ def get_state_dict_from_checkpoint(pl_sd): return pl_sd -vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} - def load_model_weights(model, checkpoint_info, vae_file="auto"): checkpoint_file = checkpoint_info.filename sd_model_hash = checkpoint_info.hash vae_file = sd_vae.resolve_vae(checkpoint_file, vae_file=vae_file) - checkpoint_key = (checkpoint_info, vae_file) + checkpoint_key = checkpoint_info if checkpoint_key not in checkpoints_loaded: print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}") @@ -190,13 +188,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): devices.dtype = torch.float32 if shared.cmd_opts.no_half else torch.float16 devices.dtype_vae = torch.float32 if shared.cmd_opts.no_half or shared.cmd_opts.no_half_vae else torch.float16 - sd_vae.load_vae(model, vae_file) - model.first_stage_model.to(devices.dtype_vae) - if shared.opts.sd_checkpoint_cache > 0: + # if PR #4035 were to get merged, restore base VAE first before caching checkpoints_loaded[checkpoint_key] = model.state_dict().copy() while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache: checkpoints_loaded.popitem(last=False) # LRU + else: vae_name = sd_vae.get_filename(vae_file) print(f"Loading weights [{sd_model_hash}] with {vae_name} VAE from cache") @@ -207,6 +204,8 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"): model.sd_model_checkpoint = checkpoint_file model.sd_checkpoint_info = checkpoint_info + sd_vae.load_vae(model, vae_file) + def load_model(checkpoint_info=None): from modules import lowvram, sd_hijack @@ -254,14 +253,14 @@ def load_model(checkpoint_info=None): return sd_model -def reload_model_weights(sd_model=None, info=None, force=False): +def reload_model_weights(sd_model=None, info=None): from modules import lowvram, devices, sd_hijack checkpoint_info = info or select_checkpoint() if not sd_model: sd_model = shared.sd_model - if sd_model.sd_model_checkpoint == checkpoint_info.filename and not force: + if sd_model.sd_model_checkpoint == checkpoint_info.filename: return if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info): diff --git a/modules/sd_vae.py b/modules/sd_vae.py index e9239326..78e14e8a 100644 --- a/modules/sd_vae.py +++ b/modules/sd_vae.py @@ -1,26 +1,65 @@ import torch import os from collections import namedtuple -from modules import shared, devices +from modules import shared, devices, script_callbacks from modules.paths import models_path import glob + model_dir = "Stable-diffusion" model_path = os.path.abspath(os.path.join(models_path, model_dir)) vae_dir = "VAE" vae_path = os.path.abspath(os.path.join(models_path, vae_dir)) + vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"} + + default_vae_dict = {"auto": "auto", "None": "None"} default_vae_list = ["auto", "None"] + + default_vae_values = [default_vae_dict[x] for x in default_vae_list] vae_dict = dict(default_vae_dict) vae_list = list(default_vae_list) first_load = True + +base_vae = None +loaded_vae_file = None +checkpoint_info = None + + +def get_base_vae(model): + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model: + return base_vae + return None + + +def store_base_vae(model): + global base_vae, checkpoint_info + if checkpoint_info != model.sd_checkpoint_info: + base_vae = model.first_stage_model.state_dict().copy() + checkpoint_info = model.sd_checkpoint_info + + +def delete_base_vae(): + global base_vae, checkpoint_info + base_vae = None + checkpoint_info = None + + +def restore_base_vae(model): + global base_vae, checkpoint_info + if base_vae is not None and checkpoint_info == model.sd_checkpoint_info: + load_vae_dict(model, base_vae) + delete_base_vae() + + def get_filename(filepath): return os.path.splitext(os.path.basename(filepath))[0] + def refresh_vae_list(vae_path=vae_path, model_path=model_path): global vae_dict, vae_list res = {} @@ -43,6 +82,7 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path): vae_dict.update(res) return vae_list + def resolve_vae(checkpoint_file, vae_file="auto"): global first_load, vae_dict, vae_list # save_settings = False @@ -96,24 +136,26 @@ def resolve_vae(checkpoint_file, vae_file="auto"): return vae_file -def load_vae(model, vae_file): - global first_load, vae_dict, vae_list + +def load_vae(model, vae_file=None): + global first_load, vae_dict, vae_list, loaded_vae_file # save_settings = False if vae_file: print(f"Loading VAE weights from: {vae_file}") vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location) vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys} - model.first_stage_model.load_state_dict(vae_dict_1) + load_vae_dict(model, vae_dict_1) - # If vae used is not in dict, update it - # It will be removed on refresh though - if vae_file is not None: + # If vae used is not in dict, update it + # It will be removed on refresh though vae_opt = get_filename(vae_file) if vae_opt not in vae_dict: vae_dict[vae_opt] = vae_file vae_list.append(vae_opt) + loaded_vae_file = vae_file + """ # Save current VAE to VAE settings, maybe? will it work? if save_settings: @@ -124,4 +166,45 @@ def load_vae(model, vae_file): """ first_load = False + + +# don't call this from outside +def load_vae_dict(model, vae_dict_1=None): + if vae_dict_1: + store_base_vae(model) + model.first_stage_model.load_state_dict(vae_dict_1) + else: + restore_base_vae() model.first_stage_model.to(devices.dtype_vae) + + +def reload_vae_weights(sd_model=None, vae_file="auto"): + from modules import lowvram, devices, sd_hijack + + if not sd_model: + sd_model = shared.sd_model + + checkpoint_info = sd_model.sd_checkpoint_info + checkpoint_file = checkpoint_info.filename + vae_file = resolve_vae(checkpoint_file, vae_file=vae_file) + + if loaded_vae_file == vae_file: + return + + if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: + lowvram.send_everything_to_cpu() + else: + sd_model.to(devices.cpu) + + sd_hijack.model_hijack.undo_hijack(sd_model) + + load_vae(sd_model, vae_file) + + sd_hijack.model_hijack.hijack(sd_model) + script_callbacks.model_loaded_callback(sd_model) + + if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram: + sd_model.to(devices.device) + + print(f"VAE Weights loaded.") + return sd_model diff --git a/webui.py b/webui.py index 7cb4691b..034777a2 100644 --- a/webui.py +++ b/webui.py @@ -81,9 +81,7 @@ def initialize(): modules.sd_vae.refresh_vae_list() modules.sd_models.load_model() shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights())) - # I don't know what needs to be done to only reload VAE, with all those hijacks callbacks, and lowvram, - # so for now this reloads the whole model too - shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(force=True)), call=False) + shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False) shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) -- cgit v1.2.3 From dd2108fdac2ebf943d4ac3563a49202222b88acf Mon Sep 17 00:00:00 2001 From: Maiko Tan Date: Wed, 2 Nov 2022 15:04:35 +0800 Subject: fix: should invoke callback as well in api only mode --- modules/script_callbacks.py | 3 ++- webui.py | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) (limited to 'webui.py') diff --git a/modules/script_callbacks.py b/modules/script_callbacks.py index da88635b..c28e220e 100644 --- a/modules/script_callbacks.py +++ b/modules/script_callbacks.py @@ -2,6 +2,7 @@ import sys import traceback from collections import namedtuple import inspect +from typing import Optional from fastapi import FastAPI from gradio import Blocks @@ -62,7 +63,7 @@ def clear_callbacks(): callbacks_image_saved.clear() callbacks_cfg_denoiser.clear() -def app_started_callback(demo: Blocks, app: FastAPI): +def app_started_callback(demo: Optional[Blocks], app: FastAPI): for c in callbacks_app_started: try: c.callback(demo, app) diff --git a/webui.py b/webui.py index 84e5c1fd..dc4223dc 100644 --- a/webui.py +++ b/webui.py @@ -114,6 +114,8 @@ def api_only(): app.add_middleware(GZipMiddleware, minimum_size=1000) api = create_api(app) + modules.script_callbacks.app_started_callback(None, app) + api.launch(server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861) -- cgit v1.2.3 From 5f0117154382eb0e2547c72630256681673e353b Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Fri, 4 Nov 2022 10:07:29 +0300 Subject: shut down gradio's "everything allowed" CORS policy; I checked the main functionality to work with this, but if this breaks some exotic workflow, I'm sorry. --- README.md | 7 ++++--- webui.py | 6 ++++++ 2 files changed, 10 insertions(+), 3 deletions(-) (limited to 'webui.py') diff --git a/README.md b/README.md index 55c050d5..33508f31 100644 --- a/README.md +++ b/README.md @@ -155,14 +155,15 @@ The documentation was moved from this README over to the project's [wiki](https: - Swin2SR - https://github.com/mv-lab/swin2sr - LDSR - https://github.com/Hafiidz/latent-diffusion - Ideas for optimizations - https://github.com/basujindal/stable-diffusion -- Doggettx - Cross Attention layer optimization - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. -- InvokeAI, lstein - Cross Attention layer optimization - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) -- Rinon Gal - Textual Inversion - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). +- Cross Attention layer optimization - Doggettx - https://github.com/Doggettx/stable-diffusion, original idea for prompt editing. +- Cross Attention layer optimization - InvokeAI, lstein - https://github.com/invoke-ai/InvokeAI (originally http://github.com/lstein/stable-diffusion) +- Textual Inversion - Rinon Gal - https://github.com/rinongal/textual_inversion (we're not using his code, but we are using his ideas). - Idea for SD upscale - https://github.com/jquesnelle/txt2imghd - Noise generation for outpainting mk2 - https://github.com/parlance-zz/g-diffuser-bot - CLIP interrogator idea and borrowing some code - https://github.com/pharmapsychotic/clip-interrogator - Idea for Composable Diffusion - https://github.com/energy-based-model/Compositional-Visual-Generation-with-Composable-Diffusion-Models-PyTorch - xformers - https://github.com/facebookresearch/xformers - DeepDanbooru - interrogator for anime diffusers https://github.com/KichangKim/DeepDanbooru +- Security advice - RyotaK - Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user. - (You) diff --git a/webui.py b/webui.py index 3b21c071..81df09dd 100644 --- a/webui.py +++ b/webui.py @@ -141,6 +141,12 @@ def webui(): # after initial launch, disable --autolaunch for subsequent restarts cmd_opts.autolaunch = False + # gradio uses a very open CORS policy via app.user_middleware, which makes it possible for + # an attacker to trick the user into opening a malicious HTML page, which makes a request to the + # running web ui and do whatever the attcker wants, including installing an extension and + # runnnig its code. We disable this here. Suggested by RyotaK. + app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] + app.add_middleware(GZipMiddleware, minimum_size=1000) if launch_api: -- cgit v1.2.3 From b8435e632f7ba0da12a2c8e9c788dda519279d24 Mon Sep 17 00:00:00 2001 From: evshiron Date: Sat, 5 Nov 2022 02:36:47 +0800 Subject: add --cors-allow-origins cmd opt --- modules/shared.py | 7 ++++--- webui.py | 9 +++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) (limited to 'webui.py') diff --git a/modules/shared.py b/modules/shared.py index a9e28b9c..e83cbcdf 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,7 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) +parser.add_argument("--cors-allow-origins", type=str, help="Allowed CORS origins", default=None) cmd_opts = parser.parse_args() restricted_opts = { @@ -147,9 +148,9 @@ class State: self.interrupted = True def nextjob(self): - if opts.show_progress_every_n_steps == -1: + if opts.show_progress_every_n_steps == -1: self.do_set_current_image() - + self.job_no += 1 self.sampling_step = 0 self.current_image_sampling_step = 0 @@ -198,7 +199,7 @@ class State: return if self.current_latent is None: return - + if opts.show_progress_grid: self.current_image = sd_samplers.samples_to_image_grid(self.current_latent) else: diff --git a/webui.py b/webui.py index 81df09dd..3788af0b 100644 --- a/webui.py +++ b/webui.py @@ -5,6 +5,7 @@ import importlib import signal import threading from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path @@ -93,6 +94,11 @@ def initialize(): signal.signal(signal.SIGINT, sigint_handler) +def setup_cors(app): + if cmd_opts.cors_allow_origins: + app.add_middleware(CORSMiddleware, allow_origins=cmd_opts.cors_allow_origins.split(','), allow_methods=['*']) + + def create_api(app): from modules.api.api import Api api = Api(app, queue_lock) @@ -114,6 +120,7 @@ def api_only(): initialize() app = FastAPI() + setup_cors(app) app.add_middleware(GZipMiddleware, minimum_size=1000) api = create_api(app) @@ -147,6 +154,8 @@ def webui(): # runnnig its code. We disable this here. Suggested by RyotaK. app.user_middleware = [x for x in app.user_middleware if x.cls.__name__ != 'CORSMiddleware'] + setup_cors(app) + app.add_middleware(GZipMiddleware, minimum_size=1000) if launch_api: -- cgit v1.2.3 From e9a5562b9b27a1a4f9c282637b111cefd9727a41 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 5 Nov 2022 04:06:51 -0500 Subject: add support for tls (gradio tls options) --- modules/shared.py | 3 +++ webui.py | 22 ++++++++++++++++++++-- 2 files changed, 23 insertions(+), 2 deletions(-) (limited to 'webui.py') diff --git a/modules/shared.py b/modules/shared.py index 962115f6..7a20c3af 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -86,6 +86,9 @@ parser.add_argument("--nowebui", action='store_true', help="use api=True to laun parser.add_argument("--ui-debug-mode", action='store_true', help="Don't load model to quickly launch UI") parser.add_argument("--device-id", type=str, help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)", default=None) parser.add_argument("--administrator", action='store_true', help="Administrator rights", default=False) +parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requires --tls-certfile to fully function", default=None) +parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None) +parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None) cmd_opts = parser.parse_args() restricted_opts = { diff --git a/webui.py b/webui.py index 81df09dd..d366f4ca 100644 --- a/webui.py +++ b/webui.py @@ -34,7 +34,7 @@ from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork queue_lock = threading.Lock() - +server_name = "0.0.0.0" if cmd_opts.listen else cmd_opts.server_name def wrap_queued_call(func): def f(*args, **kwargs): @@ -85,6 +85,22 @@ def initialize(): shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork))) shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength) + if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None: + + try: + if not os.path.exists(cmd_opts.tls_keyfile): + print("Invalid path to TLS keyfile given") + if not os.path.exists(cmd_opts.tls_certfile): + print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") + except TypeError: + cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None + print(f"path: '{cmd_opts.tls_keyfile}' {type(cmd_opts.tls_keyfile)}") + print(f"path: '{cmd_opts.tls_certfile}' {type(cmd_opts.tls_certfile)}") + print("TLS setup invalid, running webui without TLS") + else: + print("Running with TLS") + + # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') @@ -131,8 +147,10 @@ def webui(): app, local_url, share_url = demo.launch( share=cmd_opts.share, - server_name="0.0.0.0" if cmd_opts.listen else None, + server_name=server_name, server_port=cmd_opts.port, + ssl_keyfile=cmd_opts.tls_keyfile, + ssl_certfile=cmd_opts.tls_certfile, debug=cmd_opts.gradio_debug, auth=[tuple(cred.split(':')) for cred in cmd_opts.gradio_auth.strip('"').split(',')] if cmd_opts.gradio_auth else None, inbrowser=cmd_opts.autolaunch, -- cgit v1.2.3 From a02bad570ef7718436369bb4e4aa5b8e0f1f5689 Mon Sep 17 00:00:00 2001 From: papuSpartan Date: Sat, 5 Nov 2022 04:14:21 -0500 Subject: rm dbg --- webui.py | 2 -- 1 file changed, 2 deletions(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index d366f4ca..222dbeee 100644 --- a/webui.py +++ b/webui.py @@ -94,8 +94,6 @@ def initialize(): print(f"Invalid path to TLS certfile: '{cmd_opts.tls_certfile}'") except TypeError: cmd_opts.tls_keyfile = cmd_opts.tls_certfile = None - print(f"path: '{cmd_opts.tls_keyfile}' {type(cmd_opts.tls_keyfile)}") - print(f"path: '{cmd_opts.tls_certfile}' {type(cmd_opts.tls_certfile)}") print("TLS setup invalid, running webui without TLS") else: print("Running with TLS") -- cgit v1.2.3 From a2a1a2f7270a865175f64475229838a8d64509ea Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 09:02:25 +0300 Subject: add ability to create extensions that add localizations --- javascript/ui.js | 2 ++ modules/localization.py | 6 ++++++ modules/scripts.py | 1 - modules/shared.py | 2 -- modules/ui.py | 3 +-- webui.py | 9 +++++---- 6 files changed, 14 insertions(+), 9 deletions(-) (limited to 'webui.py') diff --git a/javascript/ui.js b/javascript/ui.js index 7e116465..95cfd106 100644 --- a/javascript/ui.js +++ b/javascript/ui.js @@ -208,4 +208,6 @@ function update_token_counter(button_id) { function restart_reload(){ document.body.innerHTML='

Reloading...

'; setTimeout(function(){location.reload()},2000) + + return [] } diff --git a/modules/localization.py b/modules/localization.py index b1810cda..f6a6f2fb 100644 --- a/modules/localization.py +++ b/modules/localization.py @@ -3,6 +3,7 @@ import os import sys import traceback + localizations = {} @@ -16,6 +17,11 @@ def list_localizations(dirname): localizations[fn] = os.path.join(dirname, file) + from modules import scripts + for file in scripts.list_scripts("localizations", ".json"): + fn, ext = os.path.splitext(file.filename) + localizations[fn] = file.path + def localization_js(current_localization_name): fn = localizations.get(current_localization_name, None) diff --git a/modules/scripts.py b/modules/scripts.py index 366c90d7..637b2329 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -3,7 +3,6 @@ import sys import traceback from collections import namedtuple -import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing diff --git a/modules/shared.py b/modules/shared.py index 70b998ff..e8bacd3c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -221,8 +221,6 @@ interrogator = modules.interrogate.InterrogateModels("interrogate") face_restorers = [] -localization.list_localizations(cmd_opts.localizations_dir) - def realesrgan_models_names(): import modules.realesrgan_model diff --git a/modules/ui.py b/modules/ui.py index 76ca9b07..23643c22 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -1563,11 +1563,10 @@ def create_ui(wrap_gradio_gpu_call): shared.state.need_restart = True restart_gradio.click( - fn=request_restart, + _js='restart_reload', inputs=[], outputs=[], - _js='restart_reload' ) if column is not None: diff --git a/webui.py b/webui.py index a5a520f0..4342a962 100644 --- a/webui.py +++ b/webui.py @@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import devices, sd_samplers, upscaler, extensions +from modules import devices, sd_samplers, upscaler, extensions, localization import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -28,9 +28,7 @@ import modules.txt2img import modules.script_callbacks import modules.ui -from modules import devices from modules import modelloader -from modules.paths import script_path from modules.shared import cmd_opts import modules.hypernetworks.hypernetwork @@ -64,6 +62,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def initialize(): extensions.list_extensions() + localization.list_localizations(cmd_opts.localizations_dir) if cmd_opts.ui_debug_mode: shared.sd_upscalers = upscaler.UpscalerLanczos().scalers @@ -99,7 +98,6 @@ def initialize(): else: print("Running with TLS") - # make the program just exit at ctrl+c without waiting for anything def sigint_handler(sig, frame): print(f'Interrupted with signal {sig} in {frame}') @@ -185,6 +183,9 @@ def webui(): print('Reloading extensions') extensions.list_extensions() + + localization.list_localizations(cmd_opts.localizations_dir) + print('Reloading custom scripts') modules.scripts.reload_scripts() print('Reloading modules: modules.ui') -- cgit v1.2.3 From e5b4e3f820cd09e751f1d168ab05d606d078a0d9 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 6 Nov 2022 10:12:53 +0300 Subject: add tags to extensions, and ability to filter out tags list changed Settings keys in UI do not print VRAM/etc stats everywhere but in calls that use GPU --- modules/ui.py | 25 ++++++++++++---------- modules/ui_extensions.py | 55 ++++++++++++++++++++++++++++++++++++++---------- style.css | 5 +++++ webui.py | 2 +- 4 files changed, 64 insertions(+), 23 deletions(-) (limited to 'webui.py') diff --git a/modules/ui.py b/modules/ui.py index 23643c22..c946ad59 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -174,9 +174,9 @@ def save_pil_to_file(pil_image, dir=None): gr.processing_utils.save_pil_to_file = save_pil_to_file -def wrap_gradio_call(func, extra_outputs=None): +def wrap_gradio_call(func, extra_outputs=None, add_stats=False): def f(*args, extra_outputs_array=extra_outputs, **kwargs): - run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled + run_memmon = opts.memmon_poll_rate > 0 and not shared.mem_mon.disabled and add_stats if run_memmon: shared.mem_mon.monitor() t = time.perf_counter() @@ -203,11 +203,18 @@ def wrap_gradio_call(func, extra_outputs=None): res = extra_outputs_array + [f"
{plaintext_to_html(type(e).__name__+': '+str(e))}
"] + shared.state.skipped = False + shared.state.interrupted = False + shared.state.job_count = 0 + + if not add_stats: + return tuple(res) + elapsed = time.perf_counter() - t elapsed_m = int(elapsed // 60) elapsed_s = elapsed % 60 elapsed_text = f"{elapsed_s:.2f}s" - if (elapsed_m > 0): + if elapsed_m > 0: elapsed_text = f"{elapsed_m}m "+elapsed_text if run_memmon: @@ -225,10 +232,6 @@ def wrap_gradio_call(func, extra_outputs=None): # last item is always HTML res[-1] += f"

Time taken: {elapsed_text}

{vram_html}
" - shared.state.skipped = False - shared.state.interrupted = False - shared.state.job_count = 0 - return tuple(res) return f @@ -1436,7 +1439,7 @@ def create_ui(wrap_gradio_gpu_call): opts.reorder() def run_settings(*args): - changed = 0 + changed = [] for key, value, comp in zip(opts.data_labels.keys(), args, components): assert comp == dummy_component or opts.same_type(value, opts.data_labels[key].default), f"Bad value for setting {key}: {value}; expecting {type(opts.data_labels[key].default).__name__}" @@ -1454,12 +1457,12 @@ def create_ui(wrap_gradio_gpu_call): if opts.data_labels[key].onchange is not None: opts.data_labels[key].onchange() - changed += 1 + changed.append(key) try: opts.save(shared.config_filename) except RuntimeError: - return opts.dumpjson(), f'{changed} settings changed without save.' - return opts.dumpjson(), f'{changed} settings changed.' + return opts.dumpjson(), f'{len(changed)} settings changed without save: {", ".join(changed)}.' + return opts.dumpjson(), f'{len(changed)} settings changed: {", ".join(changed)}.' def run_settings_single(value, key): if not opts.same_type(value, opts.data_labels[key].default): diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index 8e0d41d5..02ab9643 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -140,13 +140,15 @@ def install_extension_from_url(dirname, url): shutil.rmtree(tmpdir, True) -def install_extension_from_index(url): +def install_extension_from_index(url, hide_tags): ext_table, message = install_extension_from_url(None, url) - return refresh_available_extensions_from_data(), ext_table, message + code, _ = refresh_available_extensions_from_data(hide_tags) + return code, ext_table, message -def refresh_available_extensions(url): + +def refresh_available_extensions(url, hide_tags): global available_extensions import urllib.request @@ -155,13 +157,25 @@ def refresh_available_extensions(url): available_extensions = json.loads(text) - return url, refresh_available_extensions_from_data(), '' + code, tags = refresh_available_extensions_from_data(hide_tags) + + return url, code, gr.CheckboxGroup.update(choices=tags), '' + + +def refresh_available_extensions_for_tags(hide_tags): + code, _ = refresh_available_extensions_from_data(hide_tags) + return code, '' -def refresh_available_extensions_from_data(): + +def refresh_available_extensions_from_data(hide_tags): extlist = available_extensions["extensions"] installed_extension_urls = {normalize_git_url(extension.remote): extension.name for extension in extensions.extensions} + tags = available_extensions.get("tags", {}) + tags_to_hide = set(hide_tags) + hidden = 0 + code = f""" @@ -178,17 +192,24 @@ def refresh_available_extensions_from_data(): name = ext.get("name", "noname") url = ext.get("url", None) description = ext.get("description", "") + extension_tags = ext.get("tags", []) if url is None: continue + if len([x for x in extension_tags if x in tags_to_hide]) > 0: + hidden += 1 + continue + existing = installed_extension_urls.get(normalize_git_url(url), None) install_code = f"""""" + tags_text = ", ".join([f"{x}" for x in extension_tags]) + code += f""" - + @@ -199,7 +220,10 @@ def refresh_available_extensions_from_data():
{html.escape(name)}{html.escape(name)}
{tags_text}
{html.escape(description)} {install_code}
""" - return code + if hidden > 0: + code += f"

Extension hidden: {hidden}

" + + return code, list(tags) def create_ui(): @@ -238,21 +262,30 @@ def create_ui(): extension_to_install = gr.Text(elem_id="extension_to_install", visible=False) install_extension_button = gr.Button(elem_id="install_extension_button", visible=False) + with gr.Row(): + hide_tags = gr.CheckboxGroup(value=["ads", "localization"], label="Hide extensions with tags", choices=["script", "ads", "localization"]) + install_result = gr.HTML() available_extensions_table = gr.HTML() refresh_available_extensions_button.click( - fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update()]), - inputs=[available_extensions_index], - outputs=[available_extensions_index, available_extensions_table, install_result], + fn=modules.ui.wrap_gradio_call(refresh_available_extensions, extra_outputs=[gr.update(), gr.update(), gr.update()]), + inputs=[available_extensions_index, hide_tags], + outputs=[available_extensions_index, available_extensions_table, hide_tags, install_result], ) install_extension_button.click( fn=modules.ui.wrap_gradio_call(install_extension_from_index, extra_outputs=[gr.update(), gr.update()]), - inputs=[extension_to_install], + inputs=[extension_to_install, hide_tags], outputs=[available_extensions_table, extensions_table, install_result], ) + hide_tags.change( + fn=modules.ui.wrap_gradio_call(refresh_available_extensions_for_tags, extra_outputs=[gr.update()]), + inputs=[hide_tags], + outputs=[available_extensions_table, install_result] + ) + with gr.TabItem("Install from URL"): install_url = gr.Text(label="URL for extension's git repository") install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") diff --git a/style.css b/style.css index a0382a8c..e2b71f25 100644 --- a/style.css +++ b/style.css @@ -563,6 +563,11 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h opacity: 0.5; } +.extension-tag{ + font-weight: bold; + font-size: 95%; +} + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running diff --git a/webui.py b/webui.py index 4342a962..f4f1d74d 100644 --- a/webui.py +++ b/webui.py @@ -57,7 +57,7 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): return res - return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs) + return modules.ui.wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True) def initialize(): -- cgit v1.2.3