From b270ded268c92950a35a7a326da54496ef4151c8 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Tue, 18 Jul 2023 18:10:04 +0300 Subject: fix the issue with /sdapi/v1/options failing (this time for sure!) fix automated tests downloading CLIP model --- modules/cmd_args.py | 1 + 1 file changed, 1 insertion(+) (limited to 'modules/cmd_args.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index ae78f469..e401f641 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -15,6 +15,7 @@ parser.add_argument("--update-check", action='store_true', help="launch.py argum parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") +parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint") parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored") parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",) parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",) -- cgit v1.2.3 From a8d4213317c6970aa3ca8cbeeaacb07b936b591c Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Sat, 22 Jul 2023 17:08:45 +0300 Subject: add --log-startup option to print detailed startup progress --- modules/cmd_args.py | 1 + modules/launch_utils.py | 7 +++++-- modules/timer.py | 23 +++++++++++++++++++---- 3 files changed, 25 insertions(+), 6 deletions(-) (limited to 'modules/cmd_args.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index e401f641..dd5fadc4 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -13,6 +13,7 @@ parser.add_argument("--reinstall-xformers", action='store_true', help="launch.py parser.add_argument("--reinstall-torch", action='store_true', help="launch.py argument: install the appropriate version of torch even if you have some version already installed") parser.add_argument("--update-check", action='store_true', help="launch.py argument: check for updates at startup") parser.add_argument("--test-server", action='store_true', help="launch.py argument: configure server for testing") +parser.add_argument("--log-startup", action='store_true', help="launch.py argument: print a detailed log of what's happening at startup") parser.add_argument("--skip-prepare-environment", action='store_true', help="launch.py argument: skip all environment preparation") parser.add_argument("--skip-install", action='store_true', help="launch.py argument: skip installation of packages") parser.add_argument("--do-not-download-clip", action='store_true', help="do not download CLIP model even if it's not included in the checkpoint") diff --git a/modules/launch_utils.py b/modules/launch_utils.py index c9e4344b..f77b577a 100644 --- a/modules/launch_utils.py +++ b/modules/launch_utils.py @@ -226,8 +226,11 @@ def run_extensions_installers(settings_file): with startup_timer.subcategory("run extensions installers"): for dirname_extension in list_extensions(settings_file): - run_extension_installer(os.path.join(extensions_dir, dirname_extension)) - startup_timer.record(dirname_extension) + path = os.path.join(extensions_dir, dirname_extension) + + if os.path.isdir(path): + run_extension_installer(path) + startup_timer.record(dirname_extension) re_requirement = re.compile(r"\s*([-_a-zA-Z0-9]+)\s*(?:==\s*([-+_.a-zA-Z0-9]+))?\s*") diff --git a/modules/timer.py b/modules/timer.py index da99e49f..1d38595c 100644 --- a/modules/timer.py +++ b/modules/timer.py @@ -1,4 +1,5 @@ import time +import argparse class TimerSubcategory: @@ -11,20 +12,27 @@ class TimerSubcategory: def __enter__(self): self.start = time.time() self.timer.base_category = self.original_base_category + self.category + "/" + self.timer.subcategory_level += 1 + + if self.timer.print_log: + print(f"{' ' * self.timer.subcategory_level}{self.category}:") def __exit__(self, exc_type, exc_val, exc_tb): elapsed_for_subcategroy = time.time() - self.start self.timer.base_category = self.original_base_category self.timer.add_time_to_record(self.original_base_category + self.category, elapsed_for_subcategroy) - self.timer.record(self.category) + self.timer.subcategory_level -= 1 + self.timer.record(self.category, disable_log=True) class Timer: - def __init__(self): + def __init__(self, print_log=False): self.start = time.time() self.records = {} self.total = 0 self.base_category = '' + self.print_log = print_log + self.subcategory_level = 0 def elapsed(self): end = time.time() @@ -38,13 +46,16 @@ class Timer: self.records[category] += amount - def record(self, category, extra_time=0): + def record(self, category, extra_time=0, disable_log=False): e = self.elapsed() self.add_time_to_record(self.base_category + category, e + extra_time) self.total += e + extra_time + if self.print_log and not disable_log: + print(f"{' ' * self.subcategory_level}{category}: done in {e + extra_time:.3f}s") + def subcategory(self, name): self.elapsed() @@ -71,6 +82,10 @@ class Timer: self.__init__() -startup_timer = Timer() +parser = argparse.ArgumentParser(add_help=False) +parser.add_argument("--log-startup", action='store_true', help="print a detailed log of what's happening at startup") +args = parser.parse_known_args()[0] + +startup_timer = Timer(print_log=args.log_startup) startup_record = None -- cgit v1.2.3 From 0a89cd1a584b1584a0609c0ba27fb35c434b0b68 Mon Sep 17 00:00:00 2001 From: AUTOMATIC1111 <16777216c@gmail.com> Date: Mon, 24 Jul 2023 22:08:08 +0300 Subject: Use less RAM when creating models --- modules/cmd_args.py | 1 + modules/sd_disable_initialization.py | 106 +++++++++++++++++++++++++++++++++-- modules/sd_models.py | 16 ++++-- webui.py | 4 +- 4 files changed, 114 insertions(+), 13 deletions(-) (limited to 'modules/cmd_args.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index dd5fadc4..cb4ec5f7 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -67,6 +67,7 @@ parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="pre parser.add_argument("--disable-opt-split-attention", action='store_true', help="prefer no cross-attention layer optimization for automatic choice of optimization") parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI") parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower) +parser.add_argument("--disable-model-loading-ram-optimization", action='store_true', help="disable an optimization that reduces RAM use when loading a model") parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests") parser.add_argument("--port", type=int, help="launch gradio with given server port, you need root/admin rights for ports < 1024, defaults to 7860 if available", default=None) parser.add_argument("--show-negative-prompt", action='store_true', help="does not do anything", default=False) diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 9fc89dc6..695c5736 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -3,8 +3,31 @@ import open_clip import torch import transformers.utils.hub +from modules import shared -class DisableInitialization: + +class ReplaceHelper: + def __init__(self): + self.replaced = [] + + def replace(self, obj, field, func): + original = getattr(obj, field, None) + if original is None: + return None + + self.replaced.append((obj, field, original)) + setattr(obj, field, func) + + return original + + def restore(self): + for obj, field, original in self.replaced: + setattr(obj, field, original) + + self.replaced.clear() + + +class DisableInitialization(ReplaceHelper): """ When an object of this class enters a `with` block, it starts: - preventing torch's layer initialization functions from working @@ -21,7 +44,7 @@ class DisableInitialization: """ def __init__(self, disable_clip=True): - self.replaced = [] + super().__init__() self.disable_clip = disable_clip def replace(self, obj, field, func): @@ -86,8 +109,81 @@ class DisableInitialization: self.transformers_utils_hub_get_from_cache = self.replace(transformers.utils.hub, 'get_from_cache', transformers_utils_hub_get_from_cache) def __exit__(self, exc_type, exc_val, exc_tb): - for obj, field, original in self.replaced: - setattr(obj, field, original) + self.restore() - self.replaced.clear() +class InitializeOnMeta(ReplaceHelper): + """ + Context manager that causes all parameters for linear/conv2d/mha layers to be allocated on meta device, + which results in those parameters having no values and taking no memory. model.to() will be broken and + will need to be repaired by using LoadStateDictOnMeta below when loading params from state dict. + + Usage: + ``` + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + ``` + """ + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + def set_device(x): + x["device"] = "meta" + return x + + linear_init = self.replace(torch.nn.Linear, '__init__', lambda *args, **kwargs: linear_init(*args, **set_device(kwargs))) + conv2d_init = self.replace(torch.nn.Conv2d, '__init__', lambda *args, **kwargs: conv2d_init(*args, **set_device(kwargs))) + mha_init = self.replace(torch.nn.MultiheadAttention, '__init__', lambda *args, **kwargs: mha_init(*args, **set_device(kwargs))) + self.replace(torch.nn.Module, 'to', lambda *args, **kwargs: None) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() + + +class LoadStateDictOnMeta(ReplaceHelper): + """ + Context manager that allows to read parameters from state_dict into a model that has some of its parameters in the meta device. + As those parameters are read from state_dict, they will be deleted from it, so by the end state_dict will be mostly empty, to save memory. + Meant to be used together with InitializeOnMeta above. + + Usage: + ``` + with sd_disable_initialization.LoadStateDictOnMeta(state_dict): + model.load_state_dict(state_dict, strict=False) + ``` + """ + + def __init__(self, state_dict, device): + super().__init__() + self.state_dict = state_dict + self.device = device + + def __enter__(self): + if shared.cmd_opts.disable_model_loading_ram_optimization: + return + + sd = self.state_dict + device = self.device + + def load_from_state_dict(original, self, state_dict, prefix, *args, **kwargs): + params = [(name, param) for name, param in self._parameters.items() if param is not None and param.is_meta] + + for name, param in params: + if param.is_meta: + self._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device), requires_grad=param.requires_grad) + + original(self, state_dict, prefix, *args, **kwargs) + + for name, _ in params: + key = prefix + name + if key in sd: + del sd[key] + + linear_load_from_state_dict = self.replace(torch.nn.Linear, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(linear_load_from_state_dict, *args, **kwargs)) + conv2d_load_from_state_dict = self.replace(torch.nn.Conv2d, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(conv2d_load_from_state_dict, *args, **kwargs)) + mha_load_from_state_dict = self.replace(torch.nn.MultiheadAttention, '_load_from_state_dict', lambda *args, **kwargs: load_from_state_dict(mha_load_from_state_dict, *args, **kwargs)) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.restore() diff --git a/modules/sd_models.py b/modules/sd_models.py index fb31a793..acb1e817 100644 --- a/modules/sd_models.py +++ b/modules/sd_models.py @@ -460,7 +460,6 @@ def get_empty_cond(sd_model): return sd_model.cond_stage_model([""]) - def load_model(checkpoint_info=None, already_loaded_state_dict=None): from modules import lowvram, sd_hijack checkpoint_info = checkpoint_info or select_checkpoint() @@ -495,19 +494,24 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None): sd_model = None try: with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd or shared.cmd_opts.do_not_download_clip): - sd_model = instantiate_from_config(sd_config.model) - except Exception: - pass + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) + + except Exception as e: + errors.display(e, "creating model quickly", full_traceback=True) if sd_model is None: print('Failed to create model quickly; will retry using slow method.', file=sys.stderr) - sd_model = instantiate_from_config(sd_config.model) + + with sd_disable_initialization.InitializeOnMeta(): + sd_model = instantiate_from_config(sd_config.model) sd_model.used_config = checkpoint_config timer.record("create model") - load_model_weights(sd_model, checkpoint_info, state_dict, timer) + with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu): + load_model_weights(sd_model, checkpoint_info, state_dict, timer) if shared.cmd_opts.lowvram or shared.cmd_opts.medvram: lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram) diff --git a/webui.py b/webui.py index 2314735f..51248c39 100644 --- a/webui.py +++ b/webui.py @@ -320,9 +320,9 @@ def initialize_rest(*, reload_script_modules=False): if modules.sd_hijack.current_optimizer is None: modules.sd_hijack.apply_optimizations() - Thread(target=load_model).start() + devices.first_time_calculation() - Thread(target=devices.first_time_calculation).start() + Thread(target=load_model).start() shared.reload_hypernetworks() startup_timer.record("reload hypernetworks") -- cgit v1.2.3 From bbfff771d7337707bf501b27f98da2f7a7c06f73 Mon Sep 17 00:00:00 2001 From: w-e-w <40751091+w-e-w@users.noreply.github.com> Date: Sat, 29 Jul 2023 01:07:35 +0900 Subject: --disable-all-extensions --disable-extra-extensions --- modules/cmd_args.py | 2 ++ modules/extensions.py | 10 +++++++--- modules/ui_extensions.py | 18 +++++++++++------- 3 files changed, 20 insertions(+), 10 deletions(-) (limited to 'modules/cmd_args.py') diff --git a/modules/cmd_args.py b/modules/cmd_args.py index dd5fadc4..1262f1a4 100644 --- a/modules/cmd_args.py +++ b/modules/cmd_args.py @@ -111,3 +111,5 @@ parser.add_argument('--subpath', type=str, help='customize the subpath for gradi parser.add_argument('--add-stop-route', action='store_true', help='add /_stop route to stop server') parser.add_argument('--api-server-stop', action='store_true', help='enable server stop/restart/kill via api') parser.add_argument('--timeout-keep-alive', type=int, default=30, help='set timeout_keep_alive for uvicorn') +parser.add_argument("--disable-all-extensions", action='store_true', help="prevent all extensions from running regardless of any other settings", default=False) +parser.add_argument("--disable-extra-extensions", action='store_true', help=" prevent all extensions except built-in from running regardless of any other settings", default=False) diff --git a/modules/extensions.py b/modules/extensions.py index 3ad5ed53..e4633af4 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -11,9 +11,9 @@ os.makedirs(extensions_dir, exist_ok=True) def active(): - if shared.opts.disable_all_extensions == "all": + if shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all": return [] - elif shared.opts.disable_all_extensions == "extra": + elif shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions == "extra": return [x for x in extensions if x.enabled and x.is_builtin] else: return [x for x in extensions if x.enabled] @@ -141,8 +141,12 @@ def list_extensions(): if not os.path.isdir(extensions_dir): return - if shared.opts.disable_all_extensions == "all": + if shared.cmd_opts.disable_all_extensions: + print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") + elif shared.opts.disable_all_extensions == "all": print("*** \"Disable all extensions\" option was set, will not load any extensions ***") + elif shared.cmd_opts.disable_extra_extensions: + print("*** \"--disable-extra-extensions\" arg was used, will only load built-in extensions ***") elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") diff --git a/modules/ui_extensions.py b/modules/ui_extensions.py index f3e4fba7..bd28bfcf 100644 --- a/modules/ui_extensions.py +++ b/modules/ui_extensions.py @@ -164,7 +164,7 @@ def extension_table(): ext_status = ext.status style = "" - if shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.opts.disable_all_extensions == "all": + if shared.cmd_opts.disable_extra_extensions and not ext.is_builtin or shared.opts.disable_all_extensions == "extra" and not ext.is_builtin or shared.cmd_opts.disable_all_extensions or shared.opts.disable_all_extensions == "all": style = STYLE_PRIMARY version_link = ext.version @@ -537,12 +537,16 @@ def create_ui(): extensions_update_list = gr.Text(elem_id="extensions_update_list", visible=False).style(container=False) html = "" - if shared.opts.disable_all_extensions != "none": - html = """ - - "Disable all extensions" was set, change it to "none" to load all extensions again - - """ + + if shared.cmd_opts.disable_all_extensions or shared.cmd_opts.disable_extra_extensions or shared.opts.disable_all_extensions != "none": + if shared.cmd_opts.disable_all_extensions: + msg = '"--disable-all-extensions" was used, remove it to load all extensions again' + elif shared.opts.disable_all_extensions != "none": + msg = '"Disable all extensions" was set, change it to "none" to load all extensions again' + elif shared.cmd_opts.disable_extra_extensions: + msg = '"--disable-extra-extensions" was used, remove it to load all extensions again' + html = f'{msg}' + info = gr.HTML(html) extensions_table = gr.HTML('Loading...') ui.load(fn=extension_table, inputs=[], outputs=[extensions_table]) -- cgit v1.2.3