From a99d5708e6d603e8f7cfd1b8c6595f8026219ba0 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 17 Jul 2023 20:10:24 +0300 Subject: skip installing packages with pip if theyare already installed record time it took to launch --- webui.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'webui.py') diff --git a/webui.py b/webui.py index 34c2fd18..2aafc09f 100644 --- a/webui.py +++ b/webui.py @@ -31,21 +31,22 @@ if log_level: logging.getLogger("torch.distributed.nn").setLevel(logging.ERROR) # sshh... logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) -from modules import paths, timer, import_hook, errors, devices # noqa: F401 - +from modules import timer startup_timer = timer.startup_timer +startup_timer.record("launcher") import torch import pytorch_lightning # noqa: F401 # pytorch_lightning should be imported after torch, but it re-enables warnings on import so import once to disable them warnings.filterwarnings(action="ignore", category=DeprecationWarning, module="pytorch_lightning") warnings.filterwarnings(action="ignore", category=UserWarning, module="torchvision") - - startup_timer.record("import torch") import gradio # noqa: F401 startup_timer.record("import gradio") +from modules import paths, timer, import_hook, errors, devices # noqa: F401 +startup_timer.record("setup paths") + import ldm.modules.encoders.modules # noqa: F401 startup_timer.record("import ldm") -- cgit v1.2.3 From 118529a6dc9481499ebcd589a9b42f98b58bca8a Mon Sep 17 00:00:00 2001 From: Jabasukuriputo Wang Date: Fri, 21 Jul 2023 21:49:33 +0800 Subject: typo fix --- modules/launch_utils.py | 6 +++--- webui.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) (limited to 'webui.py') diff --git a/modules/launch_utils.py b/modules/launch_utils.py index 03552bc2..8e11bf42 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -233,7 +233,7 @@ def run_extensions_installers(settings_file): re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*") -def requrements_met(requirements_file): +def requirements_met(requirements_file): """ Does a simple parse of a requirements.txt file to determine if all rerqirements in it are already installed. Returns True if so, False if not installed or parsing fails. @@ -293,7 +293,7 @@ def prepare_environment(): try: # the existance of this file is a signal to webui.sh/bat that webui needs to be restarted when it stops execution os.remove(os.path.join(script_path, "tmp", "restart")) - os.environ.setdefault('SD_WEBUI_RESTARTING ', '1') + os.environ.setdefault('SD_WEBUI_RESTARTING', '1') except OSError: pass @@ -354,7 +354,7 @@ def prepare_environment(): if not os.path.isfile(requirements_file): requirements_file = os.path.join(script_path, requirements_file) - if not requrements_met(requirements_file): + if not requirements_met(requirements_file): run_pip(f"install -r \"{requirements_file}\"", "requirements") run_extensions_installers(settings_file=args.ui_settings_file) diff --git a/webui.py b/webui.py index 2aafc09f..2314735f 100644 --- a/webui.py +++ b/webui.py @@ -407,7 +407,7 @@ def webui(): ssl_verify=cmd_opts.disable_tls_verify, debug=cmd_opts.gradio_debug, auth=gradio_auth_creds, - inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING ') != '1', + inbrowser=cmd_opts.autolaunch and os.getenv('SD_WEBUI_RESTARTING') != '1', prevent_thread_lock=True, allowed_paths=cmd_opts.gradio_allowed_path, app_kwargs={ -- cgit v1.2.3 From 0a89cd1a584b1584a0609c0ba27fb35c434b0b68 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 24 Jul 2023 22:08:08 +0300 Subject: Use less RAM when creating models --- modules/cmd_args.py | 1 + modules/sd_disable_initialization.py | 106 +++++++++++++++++++++++++++++++++-- modules/sd_models.py | 16 ++++-- webui.py | 4 +- 4 files changed, 114 insertions(+), 13 deletions(-) (limited to 'webui.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index dd5fadc4..cb4ec5f7 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -67,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 9fc89dc6..695c5736 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -3,8 +3,31 @@ import open_clip import torch import transformers.utils.hub +from modules import shared -class DisableInitialization: + +class ReplaceHelper: + def __init__(self): + self.replaced = [] + + def replace(self, obj, field, func): + original = getattr(obj, field, None) + if original is None: + return None + + self.replaced.append((obj, field, original)) + setattr(obj, field, func) + + return original + + def restore(self): + for obj, field, original in self.replaced: + setattr(obj, field, original) + + self.replaced.clear() + + +class DisableInitialization(ReplaceHelper): """ When an object of this class enters a `with` block, it starts: - preventing torch's layer initialization functions from working @@ -21,7 +44,7 @@ class DisableInitialization: """ def __init__(self, disable_clip=True): - self.replaced = [] + super().__init__() self.disable_clip = disable_clip def replace(self, obj, field, func): @@ -86,8 +109,81 @@ class DisableInitialization: self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): - for obj, field, original in self.replaced: - setattr(obj, field, original) + self.restore() - self.replaced.clear() +class InitializeOnMeta(ReplaceHelper): + """ + Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device, + which results in those parameters having no values and taking no memory. model.to() will be broken and + will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict. + + Usage: + ``` + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + ``` + """ + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + def set_device(x): + x["device"] = "meta" + return x + + linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs))) + conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs))) + mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs))) + self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() + + +class LoadStateDictOnMeta(ReplaceHelper): + """ + Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device. + As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory. + Meant to be used together with InitializeOnMeta above. + + Usage: + ``` + with sd_disable_initialization.LoadStateDictOnMeta(state_dict): + model.load_state_dict(state_dict, strict=False) + ``` + """ + + def __init__(self, state_dict, device): + super().__init__() + self.state_dict = state_dict + self.device = device + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + sd = self.state_dict + device = self.device + + def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): + params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta] + + for name, param in params: + if param.is_meta: + self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad) + + original(self, state_dict, prefix, *args, **kwargs) + + for name, _ in params: + key = prefix + name + if key in sd: + del sd[key] + + linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) + conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) + mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() diff --git a/modules/sd_models.py b/modules/sd_models.py index fb31a793..acb1e817 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -460,7 +460,6 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) - def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): - sd_model = instantiate_from_config(sd_config.model) - except Exception: - pass + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + + except Exception as e: + errors.display(e, "creating model quickly", full_traceback=True) if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) - sd_model = instantiate_from_config(sd_config.model) + + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config timer.record("create model") - load_model_weights(sd_model, checkpoint_info, state_dict, timer) + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) diff --git a/webui.py b/webui.py index 2314735f..51248c39 100644 --- a/webui.py +++ b/webui.py @@ -320,9 +320,9 @@ def initialize_rest(*, reload_script_modules=False): if modules.sd_hijack.current_optimizer is None: modules.sd_hijack.apply_optimizations() - Thread(target=load_model).start() + devices.first_time_calculation() - Thread(target=devices.first_time_calculation).start() + Thread(target=load_model).start() shared.reload_hypernetworks() startup_timer.record("reload hypernetworks") -- cgit v1.2.3 From b1a16a298c7f71e8e7d6254f16d4143f77b5c9a5 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Tue, 25 Jul 2023 20:36:39 +0900 Subject: api only subpath (rootpath) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-Authored-By: 陈杰 --- modules/api/api.py | 4 ++-- webui.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'webui.py') diff --git a/modules/api/api.py b/modules/api/api.py index 9d73083f..606db179 100644 --- a/modules/api/api.py +++ b/modules/api/api.py @@ -724,9 +724,9 @@ class Api: cuda = {'error': f'{err}'} return models.MemoryResponse(ram=ram, cuda=cuda) - def launch(self, server_name, port): + def launch(self, server_name, port, root_path): self.app.include_router(self.router) - uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive) + uvicorn.run(self.app, host=server_name, port=port, timeout_keep_alive=shared.cmd_opts.timeout_keep_alive, root_path=root_path) def kill_webui(self): restart.stop_program() diff --git a/webui.py b/webui.py index 2314735f..6bf06854 100644 --- a/webui.py +++ b/webui.py @@ -374,7 +374,7 @@ def api_only(): api.launch( server_name="0.0.0.0" if cmd_opts.listen else "127.0.0.1", port=cmd_opts.port if cmd_opts.port else 7861, - root_path = f"/{cmd_opts.subpath}" + root_path=f"/{cmd_opts.subpath}" if cmd_opts.subpath else "" ) -- cgit v1.2.3 From af528552d6006268e352f422d533a69a6a718a5d Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Thu, 3 Aug 2023 23:18:50 +0300 Subject: fix linter issues --- modules/ui.py | 57 +++++++++++++++++++++++++-------------------------------- webui.py | 4 ++-- 2 files changed, 27 insertions(+), 34 deletions(-) (limited to 'webui.py') diff --git a/modules/ui.py b/modules/ui.py index c059dcec..70364a8b 100644 --- a/modules/ui.py +++ b/modules/ui.py @@ -12,30 +12,23 @@ import numpy as np from PIL import Image, PngImagePlugin # noqa: F401 from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call -from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles +from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles, scripts from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML from modules.paths import script_path from modules.ui_common import create_refresh_button from modules.ui_gradio_extensions import reload_javascript - from modules.shared import opts, cmd_opts -import modules.codeformer_model import modules.generation_parameters_copypaste as parameters_copypaste -import modules.gfpgan_model -import modules.hypernetworks.ui -import modules.scripts +import modules.hypernetworks.ui as hypernetworks_ui +import modules.textual_inversion.ui as textual_inversion_ui +import modules.textual_inversion.textual_inversion as textual_inversion import modules.shared as shared -import modules.styles -import modules.textual_inversion.ui from modules import prompt_parser from modules.sd_hijack import model_hijack from modules.sd_samplers import samplers, samplers_for_img2img -from modules.textual_inversion import textual_inversion -import modules.hypernetworks.ui from modules.generation_parameters_copypaste import image_from_url_text -import modules.extras create_setting_component = ui_settings.create_setting_component @@ -391,11 +384,11 @@ def create_ui(): parameters_copypaste.reset() - modules.scripts.scripts_current = modules.scripts.scripts_txt2img - modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False) + scripts.scripts_current = scripts.scripts_txt2img + scripts.scripts_txt2img.initialize_scripts(is_img2img=False) with gr.Blocks(analytics_enabled=False) as txt2img_interface: - toprow = txt2img_toprow = Toprow(is_img2img=False) + toprow = Toprow(is_img2img=False) dummy_component = gr.Label(visible=False) txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False) @@ -406,7 +399,7 @@ def create_ui(): with gr.Row().style(equal_height=False): with gr.Column(variant='compact', elem_id="txt2img_settings"): - modules.scripts.scripts_txt2img.prepare_ui() + scripts.scripts_txt2img.prepare_ui() for category in ordered_ui_categories(): if category == "sampler": @@ -474,10 +467,10 @@ def create_ui(): elif category == "scripts": with FormGroup(elem_id="txt2img_script_container"): - custom_inputs = modules.scripts.scripts_txt2img.setup_ui() + custom_inputs = scripts.scripts_txt2img.setup_ui() else: - modules.scripts.scripts_txt2img.setup_ui_for_section(category) + scripts.scripts_txt2img.setup_ui_for_section(category) hr_resolution_preview_inputs = [enable_hr, width, height, hr_scale, hr_resize_x, hr_resize_y] @@ -611,7 +604,7 @@ def create_ui(): (hr_prompt, "Hires prompt"), (hr_negative_prompt, "Hires negative prompt"), (hr_prompts_container, lambda d: gr.update(visible=True) if d.get("Hires prompt", "") != "" or d.get("Hires negative prompt", "") != "" else gr.update()), - *modules.scripts.scripts_txt2img.infotext_fields + *scripts.scripts_txt2img.infotext_fields ] parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings) parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding( @@ -634,11 +627,11 @@ def create_ui(): ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery) - modules.scripts.scripts_current = modules.scripts.scripts_img2img - modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True) + scripts.scripts_current = scripts.scripts_img2img + scripts.scripts_img2img.initialize_scripts(is_img2img=True) with gr.Blocks(analytics_enabled=False) as img2img_interface: - toprow = img2img_toprow = Toprow(is_img2img=True) + toprow = Toprow(is_img2img=True) img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False) @@ -740,7 +733,7 @@ def create_ui(): with FormRow(): resize_mode = gr.Radio(label="Resize mode", elem_id="resize_mode", choices=["Just resize", "Crop and resize", "Resize and fill", "Just resize (latent upscale)"], type="index", value="Just resize") - modules.scripts.scripts_img2img.prepare_ui() + scripts.scripts_img2img.prepare_ui() for category in ordered_ui_categories(): if category == "sampler": @@ -821,7 +814,7 @@ def create_ui(): elif category == "scripts": with FormGroup(elem_id="img2img_script_container"): - custom_inputs = modules.scripts.scripts_img2img.setup_ui() + custom_inputs = scripts.scripts_img2img.setup_ui() elif category == "inpaint": with FormGroup(elem_id="inpaint_controls", visible=False) as inpaint_controls: @@ -852,7 +845,7 @@ def create_ui(): outputs=[inpaint_controls, mask_alpha], ) else: - modules.scripts.scripts_img2img.setup_ui_for_section(category) + scripts.scripts_img2img.setup_ui_for_section(category) img2img_gallery, generation_info, html_info, html_log = create_output_panel("img2img", opts.outdir_img2img_samples) @@ -1001,7 +994,7 @@ def create_ui(): (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()), (denoising_strength, "Denoising strength"), (mask_blur, "Mask blur"), - *modules.scripts.scripts_img2img.infotext_fields + *scripts.scripts_img2img.infotext_fields ] parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings) parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings) @@ -1009,7 +1002,7 @@ def create_ui(): paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None, )) - modules.scripts.scripts_current = None + scripts.scripts_current = None with gr.Blocks(analytics_enabled=False) as extras_interface: ui_postprocessing.create_ui() @@ -1063,7 +1056,7 @@ def create_ui(): new_hypernetwork_name = gr.Textbox(label="Name", elem_id="train_new_hypernetwork_name") new_hypernetwork_sizes = gr.CheckboxGroup(label="Modules", value=["768", "320", "640", "1280"], choices=["768", "1024", "320", "640", "1280"], elem_id="train_new_hypernetwork_sizes") new_hypernetwork_layer_structure = gr.Textbox("1, 2, 1", label="Enter hypernetwork layer structure", placeholder="1st and last digit must be 1. ex:'1, 2, 1'", elem_id="train_new_hypernetwork_layer_structure") - new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=modules.hypernetworks.ui.keys, elem_id="train_new_hypernetwork_activation_func") + new_hypernetwork_activation_func = gr.Dropdown(value="linear", label="Select activation function of hypernetwork. Recommended : Swish / Linear(none)", choices=hypernetworks_ui.keys, elem_id="train_new_hypernetwork_activation_func") new_hypernetwork_initialization_option = gr.Dropdown(value = "Normal", label="Select Layer weights initialization. Recommended: Kaiming for relu-like, Xavier for sigmoid-like, Normal otherwise", choices=["Normal", "KaimingUniform", "KaimingNormal", "XavierUniform", "XavierNormal"], elem_id="train_new_hypernetwork_initialization_option") new_hypernetwork_add_layer_norm = gr.Checkbox(label="Add layer normalization", elem_id="train_new_hypernetwork_add_layer_norm") new_hypernetwork_use_dropout = gr.Checkbox(label="Use dropout", elem_id="train_new_hypernetwork_use_dropout") @@ -1208,7 +1201,7 @@ def create_ui(): ti_outcome = gr.HTML(elem_id="ti_error", value="") create_embedding.click( - fn=modules.textual_inversion.ui.create_embedding, + fn=textual_inversion_ui.create_embedding, inputs=[ new_embedding_name, initialization_text, @@ -1223,7 +1216,7 @@ def create_ui(): ) create_hypernetwork.click( - fn=modules.hypernetworks.ui.create_hypernetwork, + fn=hypernetworks_ui.create_hypernetwork, inputs=[ new_hypernetwork_name, new_hypernetwork_sizes, @@ -1243,7 +1236,7 @@ def create_ui(): ) run_preprocess.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.preprocess, extra_outputs=[gr.update()]), + fn=wrap_gradio_gpu_call(textual_inversion_ui.preprocess, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ dummy_component, @@ -1279,7 +1272,7 @@ def create_ui(): ) train_embedding.click( - fn=wrap_gradio_gpu_call(modules.textual_inversion.ui.train_embedding, extra_outputs=[gr.update()]), + fn=wrap_gradio_gpu_call(textual_inversion_ui.train_embedding, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ dummy_component, @@ -1313,7 +1306,7 @@ def create_ui(): ) train_hypernetwork.click( - fn=wrap_gradio_gpu_call(modules.hypernetworks.ui.train_hypernetwork, extra_outputs=[gr.update()]), + fn=wrap_gradio_gpu_call(hypernetworks_ui.train_hypernetwork, extra_outputs=[gr.update()]), _js="start_training_textual_inversion", inputs=[ dummy_component, diff --git a/webui.py b/webui.py index 2dc4f1aa..8d84e5a4 100644 --- a/webui.py +++ b/webui.py @@ -58,10 +58,10 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) -from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states import modules.codeformer_model as codeformer -import modules.face_restoration import modules.gfpgan_model as gfpgan +from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states +import modules.face_restoration import modules.img2img import modules.lowvram -- cgit v1.2.3 From 073c0ebba380acbd73be8262feba41212165ff84 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Fri, 4 Aug 2023 08:04:06 +0300 Subject: add gradio version warning --- modules/errors.py | 50 ++++++++++++++++++++++++++++++++++++++++++++++++++ webui.py | 41 +++++++---------------------------------- 2 files changed, 57 insertions(+), 34 deletions(-) (limited to 'webui.py') diff --git a/modules/errors.py b/modules/errors.py index dffabe45..192cd8ff 100644 --- a/modules/errors.py +++ b/modules/errors.py @@ -84,3 +84,53 @@ def run(code, task): code() except Exception as e: display(task, e) + + +def check_versions(): + from packaging import version + from modules import shared + + import torch + import gradio + + expected_torch_version = "2.0.0" + expected_xformers_version = "0.0.20" + expected_gradio_version = "3.39.0" + + if version.parse(torch.__version__) < version.parse(expected_torch_version): + print_error_explanation(f""" +You are running torch {torch.__version__}. +The program is tested to work with torch {expected_torch_version}. +To reinstall the desired version, run with commandline flag --reinstall-torch. +Beware that this will cause a lot of large files to be downloaded, as well as +there are reports of issues with training tab on the latest version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if shared.xformers_available: + import xformers + + if version.parse(xformers.__version__) < version.parse(expected_xformers_version): + print_error_explanation(f""" +You are running xformers {xformers.__version__}. +The program is tested to work with xformers {expected_xformers_version}. +To reinstall the desired version, run with commandline flag --reinstall-xformers. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + + if gradio.__version__ != expected_gradio_version: + print_error_explanation(f""" +You are running gradio {gradio.__version__}. +The program is designed to work with gradio {expected_gradio_version}. +Using a different version of gradio is extremely likely to break the program. + +Reasons why you have the mismatched gradio version can be: + - you use --skip-install flag. + - you use webui.py to start the program instead of launch.py. + - an extension installs the incompatible gradio version. + +Use --skip-version-check commandline argument to disable this check. + """.strip()) + diff --git a/webui.py b/webui.py index 8d84e5a4..1803ea8a 100644 --- a/webui.py +++ b/webui.py @@ -14,7 +14,6 @@ from typing import Iterable from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.gzip import GZipMiddleware -from packaging import version import logging @@ -50,6 +49,7 @@ startup_timer.record("setup paths") import ldm.modules.encoders.modules # noqa: F401 startup_timer.record("import ldm") + from modules import extra_networks from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, queue_lock # noqa: F401 @@ -58,9 +58,14 @@ if ".dev" in torch.__version__ or "+git" in torch.__version__: torch.__long_version__ = torch.__version__ torch.__version__ = re.search(r'[\d.]+[\d]', torch.__version__).group(0) +from modules import shared + +if not shared.cmd_opts.skip_version_check: + errors.check_versions() + import modules.codeformer_model as codeformer import modules.gfpgan_model as gfpgan -from modules import shared, sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states +from modules import sd_samplers, upscaler, extensions, localization, ui_tempdir, ui_extra_networks, config_states import modules.face_restoration import modules.img2img @@ -130,37 +135,6 @@ def fix_asyncio_event_loop_policy(): asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) -def check_versions(): - if shared.cmd_opts.skip_version_check: - return - - expected_torch_version = "2.0.0" - - if version.parse(torch.__version__) < version.parse(expected_torch_version): - errors.print_error_explanation(f""" -You are running torch {torch.__version__}. -The program is tested to work with torch {expected_torch_version}. -To reinstall the desired version, run with commandline flag --reinstall-torch. -Beware that this will cause a lot of large files to be downloaded, as well as -there are reports of issues with training tab on the latest version. - -Use --skip-version-check commandline argument to disable this check. - """.strip()) - - expected_xformers_version = "0.0.20" - if shared.xformers_available: - import xformers - - if version.parse(xformers.__version__) < version.parse(expected_xformers_version): - errors.print_error_explanation(f""" -You are running xformers {xformers.__version__}. -The program is tested to work with xformers {expected_xformers_version}. -To reinstall the desired version, run with commandline flag --reinstall-xformers. - -Use --skip-version-check commandline argument to disable this check. - """.strip()) - - def restore_config_state_file(): config_state_file = shared.opts.restore_config_state_file if config_state_file == "": @@ -248,7 +222,6 @@ def initialize(): fix_asyncio_event_loop_policy() validate_tls_options() configure_sigint_handler() - check_versions() modelloader.cleanup_models() configure_opts_onchange() -- cgit v1.2.3