aboutsummaryrefslogtreecommitdiffstats
path: root/modules/extensions.py
diff options
context:
space:
mode:
authorAUTOMATIC1111 <16777216c@gmail.com>2023-11-20 11:47:09 +0000
committerAUTOMATIC1111 <16777216c@gmail.com>2023-11-20 11:47:09 +0000
commit9b471436b2226458a767077707ea102e331b5d78 (patch)
treeab80bd7865737e71e2d8552c1f53f8ed56201e1f /modules/extensions.py
parentbde439ef67776be126d6a8c569a23d54dbc3e707 (diff)
downloadstable-diffusion-webui-gfx803-9b471436b2226458a767077707ea102e331b5d78.tar.gz
stable-diffusion-webui-gfx803-9b471436b2226458a767077707ea102e331b5d78.tar.bz2
stable-diffusion-webui-gfx803-9b471436b2226458a767077707ea102e331b5d78.zip
rework extensions metadata: use custom sorter that doesn't mess the order as much and ignores cyclic errors, use classes with named fields instead of dictionaries, eliminate some duplicated code
Diffstat (limited to 'modules/extensions.py')
-rw-r--r--modules/extensions.py132
1 files changed, 70 insertions, 62 deletions
diff --git a/modules/extensions.py b/modules/extensions.py
index f3988d02..1899cd52 100644
--- a/modules/extensions.py
+++ b/modules/extensions.py
@@ -1,5 +1,6 @@
+from __future__ import annotations
+
import configparser
-import functools
import os
import threading
import re
@@ -8,7 +9,6 @@ from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path # noqa: F401
-extensions = []
os.makedirs(extensions_dir, exist_ok=True)
@@ -22,13 +22,56 @@ def active():
return [x for x in extensions if x.enabled]
+class ExtensionMetadata:
+ filename = "metadata.ini"
+ config: configparser.ConfigParser
+ canonical_name: str
+ requires: list
+
+ def __init__(self, path, canonical_name):
+ self.config = configparser.ConfigParser()
+
+ filepath = os.path.join(path, self.filename)
+ if os.path.isfile(filepath):
+ try:
+ self.config.read(filepath)
+ except Exception:
+ errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)
+
+ self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
+ self.canonical_name = canonical_name.lower().strip()
+
+ self.requires = self.get_script_requirements("Requires", "Extension")
+
+ def get_script_requirements(self, field, section, extra_section=None):
+ """reads a list of requirements from the config; field is the name of the field in the ini file,
+ like Requires or Before, and section is the name of the [section] in the ini file; additionally,
+ reads more requirements from [extra_section] if specified."""
+
+ x = self.config.get(section, field, fallback='')
+
+ if extra_section:
+ x = x + ', ' + self.config.get(extra_section, field, fallback='')
+
+ return self.parse_list(x.lower())
+
+ def parse_list(self, text):
+ """converts a line from config ("ext1 ext2, ext3 ") into a python list (["ext1", "ext2", "ext3"])"""
+
+ if not text:
+ return []
+
+ # both "," and " " are accepted as separator
+ return [x for x in re.split(r"[,\s]+", text.strip()) if x]
+
+
class Extension:
lock = threading.Lock()
cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
+ metadata: ExtensionMetadata
- def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None):
+ def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
self.name = name
- self.canonical_name = canonical_name or name.lower()
self.path = path
self.enabled = enabled
self.status = ''
@@ -40,18 +83,8 @@ class Extension:
self.branch = None
self.remote = None
self.have_info_from_repo = False
-
- @functools.cached_property
- def metadata(self):
- if os.path.isfile(os.path.join(self.path, "metadata.ini")):
- try:
- config = configparser.ConfigParser()
- config.read(os.path.join(self.path, "metadata.ini"))
- return config
- except Exception:
- errors.report(f"Error reading metadata.ini for extension {self.canonical_name}.",
- exc_info=True)
- return None
+ self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
+ self.canonical_name = metadata.canonical_name
def to_dict(self):
return {x: getattr(self, x) for x in self.cached_fields}
@@ -162,7 +195,7 @@ def list_extensions():
elif shared.opts.disable_all_extensions == "extra":
print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
- extension_dependency_map = {}
+ loaded_extensions = {}
# scan through extensions directory and load metadata
for dirname in [extensions_builtin_dir, extensions_dir]:
@@ -175,55 +208,30 @@ def list_extensions():
continue
canonical_name = extension_dirname
- requires = None
+ metadata = ExtensionMetadata(path, canonical_name)
- if os.path.isfile(os.path.join(path, "metadata.ini")):
- try:
- config = configparser.ConfigParser()
- config.read(os.path.join(path, "metadata.ini"))
- canonical_name = config.get("Extension", "Name", fallback=canonical_name)
- requires = config.get("Extension", "Requires", fallback=None)
- except Exception:
- errors.report(f"Error reading metadata.ini for extension {extension_dirname}. "
- f"Will load regardless.", exc_info=True)
+ # check for duplicated canonical names
+ already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
+ if already_loaded_extension is not None:
+ errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
+ continue
- canonical_name = canonical_name.lower().strip()
+ is_builtin = dirname == extensions_builtin_dir
+ extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
+ extensions.append(extension)
+ loaded_extensions[canonical_name] = extension
- # check for duplicated canonical names
- if canonical_name in extension_dependency_map:
- errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions "
- f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". "
- f"The current loading extension will be discarded.", exc_info=False)
+ # check for requirements
+ for extension in extensions:
+ for req in extension.metadata.requires:
+ required_extension = loaded_extensions.get(req)
+ if required_extension is None:
+ errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
continue
- # both "," and " " are accepted as separator
- requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
+ if not extension.enabled:
+ errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
+ continue
- extension_dependency_map[canonical_name] = {
- "dirname": extension_dirname,
- "path": path,
- "requires": requires,
- }
- # check for requirements
- for (_, extension_data) in extension_dependency_map.items():
- dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires']
- requirement_met = True
- for req in requires:
- if req not in extension_dependency_map:
- errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. "
- f"The current loading extension will be discarded.", exc_info=False)
- requirement_met = False
- break
- dep_dirname = extension_dependency_map[req]['dirname']
- if dep_dirname in shared.opts.disabled_extensions:
- errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. "
- f"The current loading extension will be discarded.", exc_info=False)
- requirement_met = False
- break
-
- is_builtin = dirname == extensions_builtin_dir
- extension = Extension(name=dirname, path=path,
- enabled=dirname not in shared.opts.disabled_extensions and requirement_met,
- is_builtin=is_builtin)
- extensions.append(extension)
+extensions: list[Extension] = []