From 0fc7dc1c04a046d95588651ffc4e71a7d40378d3 Mon Sep 17 00:00:00 2001 From: wfjsw Date: Sat, 11 Nov 2023 04:01:13 -0600 Subject: implementing script metadata and DAG sorting mechanism --- modules/extensions.py | 80 +++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 71 insertions(+), 9 deletions(-) (limited to 'modules/extensions.py') diff --git a/modules/extensions.py b/modules/extensions.py index bf9a1878..e317a404 100644 --- a/modules/extensions.py +++ b/modules/extensions.py @@ -1,3 +1,5 @@ +import configparser +import functools import os import threading @@ -23,8 +25,9 @@ class Extension: lock = threading.Lock() cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version'] - def __init__(self, name, path, enabled=True, is_builtin=False): + def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None): self.name = name + self.canonical_name = canonical_name or name.lower() self.path = path self.enabled = enabled self.status = '' @@ -37,6 +40,17 @@ class Extension: self.remote = None self.have_info_from_repo = False + @functools.cached_property + def metadata(self): + if os.path.isfile(os.path.join(self.path, "sd_webui_metadata.ini")): + try: + config = configparser.ConfigParser() + config.read(os.path.join(self.path, "sd_webui_metadata.ini")) + return config + except Exception: + errors.report(f"Error reading sd_webui_metadata.ini for extension {self.canonical_name}.", exc_info=True) + return None + def to_dict(self): return {x: getattr(self, x) for x in self.cached_fields} @@ -136,9 +150,6 @@ class Extension: def list_extensions(): extensions.clear() - if not os.path.isdir(extensions_dir): - return - if shared.cmd_opts.disable_all_extensions: print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***") elif shared.opts.disable_all_extensions == "all": @@ -148,18 +159,69 @@ def list_extensions(): elif shared.opts.disable_all_extensions == "extra": print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***") - extension_paths = [] + extension_dependency_map = {} + + # scan through extensions directory and load metadata for dirname in [extensions_dir, extensions_builtin_dir]: if not os.path.isdir(dirname): - return + continue for extension_dirname in sorted(os.listdir(dirname)): path = os.path.join(dirname, extension_dirname) if not os.path.isdir(path): continue - extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir)) + canonical_name = extension_dirname + requires = None + + if os.path.isfile(os.path.join(path, "sd_webui_metadata.ini")): + try: + config = configparser.ConfigParser() + config.read(os.path.join(path, "sd_webui_metadata.ini")) + canonical_name = config.get("Extension", "Name", fallback=canonical_name) + requires = config.get("Extension", "Requires", fallback=None) + continue + except Exception: + errors.report(f"Error reading sd_webui_metadata.ini for extension {extension_dirname}. " + f"Will load regardless.", exc_info=True) + + canonical_name = canonical_name.lower().strip() + + # check for duplicated canonical names + if canonical_name in extension_dependency_map: + errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions " + f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". " + f"The current loading extension will be discarded.", exc_info=False) + continue - for dirname, path, is_builtin in extension_paths: - extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin) + # we want to wash the data to lowercase and remove whitespaces just in case + requires = [x.strip() for x in requires.lower().split(',')] if requires else [] + + extension_dependency_map[canonical_name] = { + "dirname": extension_dirname, + "path": path, + "requires": requires, + } + + # check for requirements + for (_, extension_data) in extension_dependency_map.items(): + dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires'] + requirement_met = True + for req in requires: + if req not in extension_dependency_map: + errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. " + f"The current loading extension will be discarded.", exc_info=False) + requirement_met = False + break + dep_dirname = extension_dependency_map[req]['dirname'] + if dep_dirname in shared.opts.disabled_extensions: + errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. " + f"The current loading extension will be discarded.", exc_info=False) + requirement_met = False + break + + is_builtin = dirname == extensions_builtin_dir + extension = Extension(name=dirname, path=path, + enabled=dirname not in shared.opts.disabled_extensions and requirement_met, + is_builtin=is_builtin) extensions.append(extension) -- cgit v1.2.3