From fddb4883f4a408b3464076465e1b0949ebe0fc30 Mon Sep 17 00:00:00 2001 From: evshiron Date: Wed, 26 Oct 2022 22:33:45 +0800 Subject: prototype progress api --- modules/shared.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index 1a9b8289..00f61898 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -146,6 +146,19 @@ class State: def get_job_timestamp(self): return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? + def js(self): + obj = { + "skipped": self.skipped, + "interrupted": self.skipped, + "job": self.job, + "job_count": self.job_count, + "job_no": self.job_no, + "sampling_step": self.sampling_step, + "sampling_steps": self.sampling_steps, + } + + return json.dumps(obj) + state = State() -- cgit v1.2.3 From 88f46a5bec610cf03641f18becbe3deda541e982 Mon Sep 17 00:00:00 2001 From: evshiron Date: Sun, 30 Oct 2022 05:04:29 +0800 Subject: update progress response model --- modules/api/api.py | 6 +++--- modules/api/models.py | 4 ++-- modules/shared.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/api/api.py b/modules/api/api.py index 7e8522a2..5912d289 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -61,7 +61,7 @@ class Api: self.app.add_api_route("/sdapi/v1/extra-single-image", self.extras_single_image_api, methods=["POST"], response_model=ExtrasSingleImageResponse) self.app.add_api_route("/sdapi/v1/extra-batch-images", self.extras_batch_images_api, methods=["POST"], response_model=ExtrasBatchImagesResponse) self.app.add_api_route("/sdapi/v1/png-info", self.pnginfoapi, methods=["POST"], response_model=PNGInfoResponse) - self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"]) + self.app.add_api_route("/sdapi/v1/progress", self.progressapi, methods=["GET"], response_model=ProgressResponse) def text2imgapi(self, txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): sampler_index = sampler_to_index(txt2imgreq.sampler_index) @@ -171,7 +171,7 @@ class Api: # copy from check_progress_call of ui.py if shared.state.job_count == 0: - return ProgressResponse(progress=0, eta_relative=0, state=shared.state.js()) + return ProgressResponse(progress=0, eta_relative=0, state=shared.state.dict()) # avoid dividing zero progress = 0.01 @@ -187,7 +187,7 @@ class Api: progress = min(progress, 1) - return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.js()) + return ProgressResponse(progress=progress, eta_relative=eta_relative, state=shared.state.dict()) def launch(self, server_name, port): self.app.include_router(self.router) diff --git a/modules/api/models.py b/modules/api/models.py index 709ab5a6..0ab85ec5 100644 --- a/modules/api/models.py +++ b/modules/api/models.py @@ -1,6 +1,6 @@ import inspect from click import prompt -from pydantic import BaseModel, Field, Json, create_model +from pydantic import BaseModel, Field, create_model from typing import Any, Optional from typing_extensions import Literal from inflection import underscore @@ -160,4 +160,4 @@ class PNGInfoResponse(BaseModel): class ProgressResponse(BaseModel): progress: float = Field(title="Progress", description="The progress with a range of 0 to 1") eta_relative: float = Field(title="ETA in secs") - state: Json = Field(title="State", description="The current state snapshot") + state: dict = Field(title="State", description="The current state snapshot") diff --git a/modules/shared.py b/modules/shared.py index 0f4c035d..f7b0990c 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -147,7 +147,7 @@ class State: def get_job_timestamp(self): return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? - def js(self): + def dict(self): obj = { "skipped": self.skipped, "interrupted": self.skipped, @@ -158,7 +158,7 @@ class State: "sampling_steps": self.sampling_steps, } - return json.dumps(obj) + return obj state = State() -- cgit v1.2.3 From 149784202cca8612b43629c601ee27cfda64e623 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sun, 30 Oct 2022 09:10:22 +0300 Subject: rework #3722 to not introduce duplicate code --- modules/api/api.py | 43 +++++++++++++------------------------------ modules/shared.py | 22 +++++++++++++++++++--- webui.py | 19 +++---------------- 3 files changed, 35 insertions(+), 49 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/api/api.py b/modules/api/api.py index 5c5b210f..6c06d449 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -9,31 +9,6 @@ from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusion from modules.sd_samplers import all_samplers from modules.extras import run_extras, run_pnginfo -# copy from wrap_gradio_gpu_call of webui.py -# because queue lock will be acquired in api handlers -# and time start needs to be set -# the function has been modified into two parts - -def before_gpu_call(): - devices.torch_gc() - - shared.state.sampling_step = 0 - shared.state.job_count = -1 - shared.state.job_no = 0 - shared.state.job_timestamp = shared.state.get_job_timestamp() - shared.state.current_latent = None - shared.state.current_image = None - shared.state.current_image_sampling_step = 0 - shared.state.skipped = False - shared.state.interrupted = False - shared.state.textinfo = None - shared.state.time_start = time.time() - -def after_gpu_call(): - shared.state.job = "" - shared.state.job_count = 0 - - devices.torch_gc() def upscaler_to_index(name: str): try: @@ -41,8 +16,10 @@ def upscaler_to_index(name: str): except: raise HTTPException(status_code=400, detail=f"Invalid upscaler, needs to be on of these: {' , '.join([x.name for x in sd_upscalers])}") + sampler_to_index = lambda name: next(filter(lambda row: name.lower() == row[1].name.lower(), enumerate(all_samplers)), None) + def setUpscalers(req: dict): reqDict = vars(req) reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1) @@ -51,6 +28,7 @@ def setUpscalers(req: dict): reqDict.pop('upscaler_2') return reqDict + class Api: def __init__(self, app, queue_lock): self.router = APIRouter() @@ -78,10 +56,13 @@ class Api: ) p = StableDiffusionProcessingTxt2Img(**vars(populate)) # Override object param - before_gpu_call() + + shared.state.begin() + with self.queue_lock: processed = process_images(p) - after_gpu_call() + + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) @@ -119,11 +100,13 @@ class Api: imgs = [img] * p.batch_size p.init_images = imgs - # Override object param - before_gpu_call() + + shared.state.begin() + with self.queue_lock: processed = process_images(p) - after_gpu_call() + + shared.state.end() b64images = list(map(encode_pil_to_base64, processed.images)) diff --git a/modules/shared.py b/modules/shared.py index f7b0990c..e4f163c1 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -144,9 +144,6 @@ class State: self.sampling_step = 0 self.current_image_sampling_step = 0 - def get_job_timestamp(self): - return datetime.datetime.now().strftime("%Y%m%d%H%M%S") # shouldn't this return job_timestamp? - def dict(self): obj = { "skipped": self.skipped, @@ -160,6 +157,25 @@ class State: return obj + def begin(self): + self.sampling_step = 0 + self.job_count = -1 + self.job_no = 0 + self.job_timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S") + self.current_latent = None + self.current_image = None + self.current_image_sampling_step = 0 + self.skipped = False + self.interrupted = False + self.textinfo = None + + devices.torch_gc() + + def end(self): + self.job = "" + self.job_count = 0 + + devices.torch_gc() state = State() diff --git a/webui.py b/webui.py index ade7334b..29530872 100644 --- a/webui.py +++ b/webui.py @@ -46,26 +46,13 @@ def wrap_queued_call(func): def wrap_gradio_gpu_call(func, extra_outputs=None): def f(*args, **kwargs): - devices.torch_gc() - - shared.state.sampling_step = 0 - shared.state.job_count = -1 - shared.state.job_no = 0 - shared.state.job_timestamp = shared.state.get_job_timestamp() - shared.state.current_latent = None - shared.state.current_image = None - shared.state.current_image_sampling_step = 0 - shared.state.skipped = False - shared.state.interrupted = False - shared.state.textinfo = None + + shared.state.begin() with queue_lock: res = func(*args, **kwargs) - shared.state.job = "" - shared.state.job_count = 0 - - devices.torch_gc() + shared.state.end() return res -- cgit v1.2.3 From 910a097ae2ed78a62101951f1b87137f9e1baaea Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 31 Oct 2022 17:36:45 +0300 Subject: add initial version of the extensions tab fix broken Restart Gradio button --- javascript/extensions.js | 24 +++++ modules/extensions.py | 83 +++++++++++++++ modules/generation_parameters_copypaste.py | 5 + modules/scripts.py | 21 +--- modules/shared.py | 10 +- modules/ui.py | 16 ++- modules/ui_extensions.py | 162 +++++++++++++++++++++++++++++ style.css | 22 +++- webui.py | 20 ++-- 9 files changed, 333 insertions(+), 30 deletions(-) create mode 100644 javascript/extensions.js create mode 100644 modules/extensions.py create mode 100644 modules/ui_extensions.py (limited to 'modules/shared.py') diff --git a/javascript/extensions.js b/javascript/extensions.js new file mode 100644 index 00000000..86f5336d --- /dev/null +++ b/javascript/extensions.js @@ -0,0 +1,24 @@ + +function extensions_apply(_, _){ + disable = [] + update = [] + gradioApp().querySelectorAll('#extensions input[type="checkbox"]').forEach(function(x){ + if(x.name.startsWith("enable_") && ! x.checked) + disable.push(x.name.substr(7)) + + if(x.name.startsWith("update_") && x.checked) + update.push(x.name.substr(7)) + }) + + restart_reload() + + return [JSON.stringify(disable), JSON.stringify(update)] +} + +function extensions_check(){ + gradioApp().querySelectorAll('#extensions .extension_status').forEach(function(x){ + x.innerHTML = "Loading..." + }) + + return [] +} \ No newline at end of file diff --git a/modules/extensions.py b/modules/extensions.py new file mode 100644 index 00000000..8d6ae848 --- /dev/null +++ b/modules/extensions.py @@ -0,0 +1,83 @@ +import os +import sys +import traceback + +import git + +from modules import paths, shared + + +extensions = [] +extensions_dir = os.path.join(paths.script_path, "extensions") + + +def active(): + return [x for x in extensions if x.enabled] + + +class Extension: + def __init__(self, name, path, enabled=True): + self.name = name + self.path = path + self.enabled = enabled + self.status = '' + self.can_update = False + + repo = None + try: + if os.path.exists(os.path.join(path, ".git")): + repo = git.Repo(path) + except Exception: + print(f"Error reading github repository info from {path}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + if repo is None or repo.bare: + self.remote = None + else: + self.remote = next(repo.remote().urls, None) + self.status = 'unknown' + + def list_files(self, subdir, extension): + from modules import scripts + + dirpath = os.path.join(self.path, subdir) + if not os.path.isdir(dirpath): + return [] + + res = [] + for filename in sorted(os.listdir(dirpath)): + res.append(scripts.ScriptFile(dirpath, filename, os.path.join(dirpath, filename))) + + res = [x for x in res if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] + + return res + + def check_updates(self): + repo = git.Repo(self.path) + for fetch in repo.remote().fetch("--dry-run"): + if fetch.flags != fetch.HEAD_UPTODATE: + self.can_update = True + self.status = "behind" + return + + self.can_update = False + self.status = "latest" + + def pull(self): + repo = git.Repo(self.path) + repo.remotes.origin.pull() + + +def list_extensions(): + extensions.clear() + + if not os.path.isdir(extensions_dir): + return + + for dirname in sorted(os.listdir(extensions_dir)): + path = os.path.join(extensions_dir, dirname) + if not os.path.isdir(path): + continue + + extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions) + extensions.append(extension) diff --git a/modules/generation_parameters_copypaste.py b/modules/generation_parameters_copypaste.py index df70c728..985ec95e 100644 --- a/modules/generation_parameters_copypaste.py +++ b/modules/generation_parameters_copypaste.py @@ -17,6 +17,11 @@ paste_fields = {} bind_list = [] +def reset(): + paste_fields.clear() + bind_list.clear() + + def quote(text): if ',' not in str(text): return text diff --git a/modules/scripts.py b/modules/scripts.py index 96e44bfd..533db45c 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -7,7 +7,7 @@ import modules.ui as ui import gradio as gr from modules.processing import StableDiffusionProcessing -from modules import shared, paths, script_callbacks +from modules import shared, paths, script_callbacks, extensions AlwaysVisible = object() @@ -107,17 +107,8 @@ def list_scripts(scriptdirname, extension): for filename in sorted(os.listdir(basedir)): scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename))) - extdir = os.path.join(paths.script_path, "extensions") - if os.path.exists(extdir): - for dirname in sorted(os.listdir(extdir)): - dirpath = os.path.join(extdir, dirname) - scriptdirpath = os.path.join(dirpath, scriptdirname) - - if not os.path.isdir(scriptdirpath): - continue - - for filename in sorted(os.listdir(scriptdirpath)): - scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename))) + for ext in extensions.active(): + scripts_list += ext.list_files(scriptdirname, extension) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] @@ -127,11 +118,7 @@ def list_scripts(scriptdirname, extension): def list_files_with_name(filename): res = [] - dirs = [paths.script_path] - - extdir = os.path.join(paths.script_path, "extensions") - if os.path.exists(extdir): - dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))] + dirs = [paths.script_path] + [ext.path for ext in extensions.active()] for dirpath in dirs: if not os.path.isdir(dirpath): diff --git a/modules/shared.py b/modules/shared.py index e4f163c1..cce87081 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -132,6 +132,7 @@ class State: current_image = None current_image_sampling_step = 0 textinfo = None + need_restart = False def skip(self): self.skipped = True @@ -354,6 +355,12 @@ options_templates.update(options_section(('sampler-params', "Sampler parameters" 'eta_noise_seed_delta': OptionInfo(0, "Eta noise seed delta", gr.Number, {"precision": 0}), })) +options_templates.update(options_section((None, "Hidden options"), { + "disabled_extensions": OptionInfo([], "Disable those extensions"), +})) + +options_templates.update() + class Options: data = None @@ -365,8 +372,9 @@ class Options: def __setattr__(self, key, value): if self.data is not None: - if key in self.data: + if key in self.data or key in self.data_labels: self.data[key] = value + return return super(Options, self).__setattr__(key, value) diff --git a/modules/ui.py b/modules/ui.py index 5055ca64..2c15abb7 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -19,7 +19,7 @@ import numpy as np from PIL import Image, PngImagePlugin -from modules import sd_hijack, sd_models, localization, script_callbacks +from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions from modules.paths import script_path from modules.shared import opts, cmd_opts, restricted_opts @@ -671,6 +671,7 @@ def create_ui(wrap_gradio_gpu_call): import modules.img2img import modules.txt2img + parameters_copypaste.reset() with gr.Blocks(analytics_enabled=False) as txt2img_interface: txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False) @@ -1511,8 +1512,9 @@ def create_ui(wrap_gradio_gpu_call): column = None with gr.Row(elem_id="settings").style(equal_height=False): for i, (k, item) in enumerate(opts.data_labels.items()): + section_must_be_skipped = item.section[0] is None - if previous_section != item.section: + if previous_section != item.section and not section_must_be_skipped: if cols_displayed < settings_cols and (items_displayed >= items_per_col or previous_section is None): if column is not None: column.__exit__() @@ -1531,6 +1533,8 @@ def create_ui(wrap_gradio_gpu_call): if k in quicksettings_names and not shared.cmd_opts.freeze_settings: quicksettings_list.append((i, k, item)) components.append(dummy_component) + elif section_must_be_skipped: + components.append(dummy_component) else: component = create_setting_component(k) component_dict[k] = component @@ -1572,9 +1576,10 @@ def create_ui(wrap_gradio_gpu_call): def request_restart(): shared.state.interrupt() - settings_interface.gradio_ref.do_restart = True + shared.state.need_restart = True restart_gradio.click( + fn=request_restart, inputs=[], outputs=[], @@ -1612,14 +1617,15 @@ def create_ui(wrap_gradio_gpu_call): interfaces += script_callbacks.ui_tabs_callback() interfaces += [(settings_interface, "Settings", "settings")] + extensions_interface = ui_extensions.create_ui() + interfaces += [(extensions_interface, "Extensions", "extensions")] + with gr.Blocks(css=css, analytics_enabled=False, title="Stable Diffusion") as demo: with gr.Row(elem_id="quicksettings"): for i, k, item in quicksettings_list: component = create_setting_component(k, is_quicksettings=True) component_dict[k] = component - settings_interface.gradio_ref = demo - parameters_copypaste.integrate_settings_paste_fields(component_dict) parameters_copypaste.run_bind() diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py new file mode 100644 index 00000000..b7d747dc --- /dev/null +++ b/modules/ui_extensions.py @@ -0,0 +1,162 @@ +import json +import os.path +import shutil +import sys +import time +import traceback + +import git + +import gradio as gr +import html + +from modules import extensions, shared, paths + + +def apply_and_restart(disable_list, update_list): + disabled = json.loads(disable_list) + assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" + + update = json.loads(update_list) + assert type(update) == list, f"wrong update_list data for apply_and_restart: {update_list}" + + update = set(update) + + for ext in extensions.extensions: + if ext.name not in update: + continue + + try: + ext.pull() + except Exception: + print(f"Error pulling updates for {ext.name}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + shared.opts.disabled_extensions = disabled + shared.opts.save(shared.config_filename) + + shared.state.interrupt() + shared.state.need_restart = True + + +def check_updates(): + for ext in extensions.extensions: + if ext.remote is None: + continue + + try: + ext.check_updates() + except Exception: + print(f"Error checking updates for {ext.name}:", file=sys.stderr) + print(traceback.format_exc(), file=sys.stderr) + + return extension_table() + + +def extension_table(): + code = f""" + + + + + + + + + + """ + + for ext in extensions.extensions: + if ext.can_update: + ext_status = f"""""" + else: + ext_status = ext.status + + code += f""" + + + + {ext_status} + + """ + + code += """ + +
ExtensionURLUpdate
{html.escape(ext.remote or '')}
+ """ + + return code + + +def install_extension_from_url(dirname, url): + assert url, 'No URL specified' + + if dirname is None or dirname == "": + *parts, last_part = url.split('/') + last_part = last_part.replace(".git", "") + + dirname = last_part + + target_dir = os.path.join(extensions.extensions_dir, dirname) + assert not os.path.exists(target_dir), f'Extension directory already exists: {target_dir}' + + assert len([x for x in extensions.extensions if x.remote == url]) == 0, 'Extension with this URL is already installed' + + tmpdir = os.path.join(paths.script_path, "tmp", dirname) + + try: + shutil.rmtree(tmpdir, True) + + repo = git.Repo.clone_from(url, tmpdir) + repo.remote().fetch() + + os.rename(tmpdir, target_dir) + + extensions.list_extensions() + return [extension_table(), html.escape(f"Installed into {target_dir}. Use Installed tab to restart.")] + finally: + shutil.rmtree(tmpdir, True) + + +def create_ui(): + import modules.ui + + with gr.Blocks(analytics_enabled=False) as ui: + with gr.Tabs(elem_id="tabs_extensions") as tabs: + with gr.TabItem("Installed"): + extensions_disabled_list = gr.Text(elem_id="extensions_disabled_list", visible=False) + extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False) + + with gr.Row(): + apply = gr.Button(value="Apply and restart UI", variant="primary") + check = gr.Button(value="Check for updates") + + extensions_table = gr.HTML(lambda: extension_table()) + + apply.click( + fn=apply_and_restart, + _js="extensions_apply", + inputs=[extensions_disabled_list, extensions_update_list], + outputs=[], + ) + + check.click( + fn=check_updates, + _js="extensions_check", + inputs=[], + outputs=[extensions_table], + ) + + with gr.TabItem("Install from URL"): + install_url = gr.Text(label="URL for extension's git repository") + install_dirname = gr.Text(label="Local directory name", placeholder="Leave empty for auto") + intall_button = gr.Button(value="Install", variant="primary") + intall_result = gr.HTML(elem_id="extension_install_result") + + intall_button.click( + fn=modules.ui.wrap_gradio_call(install_extension_from_url, extra_outputs=[gr.update()]), + inputs=[install_dirname, install_url], + outputs=[extensions_table, intall_result], + ) + + return ui diff --git a/style.css b/style.css index 8b2211b1..859c3933 100644 --- a/style.css +++ b/style.css @@ -530,6 +530,26 @@ img2maskimg, #img2maskimg > .h-60, #img2maskimg > .h-60 > div, #img2maskimg > .h min-height: 480px !important; } +/* Extensions */ + +#extensions{ + border-collapse: collapse; +} + +#extensions td, #extensions th{ + border: 1px solid #ccc; + padding: 0.25em 0.5em; +} + +#extensions input[type="checkbox"]{ + margin-right: 0.5em; +} + +#tab_extensions button{ + max-width: 16em; +} + + /* The following handles localization for right-to-left (RTL) languages like Arabic. The rtl media type will only be activated by the logic in javascript/localization.js. If you change anything above, you need to make sure it is RTL compliant by just running @@ -607,4 +627,4 @@ Then, you will need to add the RTL counterpart only if needed in the rtl section right: unset; left: 0.5em; } -} +} \ No newline at end of file diff --git a/webui.py b/webui.py index 29530872..ad2eb236 100644 --- a/webui.py +++ b/webui.py @@ -9,7 +9,7 @@ from fastapi.middleware.gzip import GZipMiddleware from modules.paths import script_path -from modules import devices, sd_samplers, upscaler +from modules import devices, sd_samplers, upscaler, extensions import modules.codeformer_model as codeformer import modules.extras import modules.face_restoration @@ -60,6 +60,11 @@ def wrap_gradio_gpu_call(func, extra_outputs=None): def initialize(): + extensions.list_extensions() + #for ext in extensions.extensions: + # print(ext.name, ext.path, ext.enabled, ext.remote) + #exit() + if cmd_opts.ui_debug_mode: shared.sd_upscalers = upscaler.UpscalerLanczos().scalers modules.scripts.load_scripts() @@ -92,15 +97,18 @@ def create_api(app): api = Api(app, queue_lock) return api + def wait_on_server(demo=None): while 1: time.sleep(0.5) - if demo and getattr(demo, 'do_restart', False): + if shared.state.need_restart: + shared.state.need_restart = False time.sleep(0.5) demo.close() time.sleep(0.5) break + def api_only(): initialize() @@ -132,14 +140,16 @@ def webui(): app.add_middleware(GZipMiddleware, minimum_size=1000) - if (launch_api): + if launch_api: create_api(app) wait_on_server(demo) sd_samplers.set_samplers() - print('Reloading Custom Scripts') + print('Reloading extensions') + extensions.list_extensions() + print('Reloading custom scripts') modules.scripts.reload_scripts() print('Reloading modules: modules.ui') importlib.reload(modules.ui) @@ -148,8 +158,6 @@ def webui(): print('Restarting Gradio') - -task = [] if __name__ == "__main__": if cmd_opts.nowebui: api_only() -- cgit v1.2.3 From dc7425a56e7a014cbfa3b3d44ad2321e519fe378 Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 31 Oct 2022 18:33:44 +0300 Subject: disable access to extension stuff for non-local servers --- modules/shared.py | 5 ++++- modules/ui_extensions.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index cce87081..a27c654e 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -40,7 +40,7 @@ parser.add_argument("--lowram", action='store_true', help="load stable diffusion parser.add_argument("--always-batch-cond-uncond", action='store_true', help="disables cond/uncond batching that is enabled to save memory with --medvram or --lowvram") parser.add_argument("--unload-gfpgan", action='store_true', help="does not do anything.") parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["full", "autocast"], default="autocast") -parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site (doesn't work for me but you might have better luck)") +parser.add_argument("--share", action='store_true', help="use share=True for gradio and make the UI accessible through their site") parser.add_argument("--ngrok", type=str, help="ngrok authtoken, alternative to gradio --share", default=None) parser.add_argument("--ngrok-region", type=str, help="The region in which ngrok should start.", default="us") parser.add_argument("--codeformer-models-path", type=str, help="Path to directory with codeformer model file(s).", default=os.path.join(models_path, 'Codeformer')) @@ -97,6 +97,9 @@ restricted_opts = { "outdir_save", } +if cmd_opts.share or cmd_opts.listen: + cmd_opts.disable_extension_access = True + devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index b7d747dc..e74b7d68 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -13,7 +13,13 @@ import html from modules import extensions, shared, paths +def check_access(): + assert not shared.cmd_opts.disable_extension_access, "extension access disabed because of commandline flags" + + def apply_and_restart(disable_list, update_list): + check_access() + disabled = json.loads(disable_list) assert type(disabled) == list, f"wrong disable_list data for apply_and_restart: {disable_list}" @@ -40,6 +46,8 @@ def apply_and_restart(disable_list, update_list): def check_updates(): + check_access() + for ext in extensions.extensions: if ext.remote is None: continue @@ -89,6 +97,8 @@ def extension_table(): def install_extension_from_url(dirname, url): + check_access() + assert url, 'No URL specified' if dirname is None or dirname == "": -- cgit v1.2.3 From 9e22a357545c0395c81dd800c72fa18f350545ec Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Mon, 31 Oct 2022 18:45:50 +0300 Subject: fix the error with extension tab not working because of the previous commit --- modules/shared.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'modules/shared.py') diff --git a/modules/shared.py b/modules/shared.py index a27c654e..c83fb9f5 100644 --- a/modules/shared.py +++ b/modules/shared.py @@ -97,8 +97,7 @@ restricted_opts = { "outdir_save", } -if cmd_opts.share or cmd_opts.listen: - cmd_opts.disable_extension_access = True +cmd_opts.disable_extension_access = cmd_opts.share or cmd_opts.listen devices.device, devices.device_interrogate, devices.device_gfpgan, devices.device_swinir, devices.device_esrgan, devices.device_scunet, devices.device_codeformer = \ (devices.cpu if any(y in cmd_opts.use_cpu for y in [x, 'all']) else devices.get_optimal_device() for x in ['sd', 'interrogate', 'gfpgan', 'swinir', 'esrgan', 'scunet', 'codeformer']) -- cgit v1.2.3