aboutsummaryrefslogtreecommitdiffstats
path: root/modules/scripts.py
diff options
context:
space:
mode:
authorwfjsw <wfjsw@users.noreply.github.com>2023-11-11 10:01:13 +0000
committerwfjsw <wfjsw@users.noreply.github.com>2023-11-11 10:01:13 +0000
commit0fc7dc1c04a046d95588651ffc4e71a7d40378d3 (patch)
tree816f304a9b606700c350b999c07392400a815297 /modules/scripts.py
parent5e80d9ee99c5899e5e2b130408ffb65a0585a62a (diff)
downloadstable-diffusion-webui-gfx803-0fc7dc1c04a046d95588651ffc4e71a7d40378d3.tar.gz
stable-diffusion-webui-gfx803-0fc7dc1c04a046d95588651ffc4e71a7d40378d3.tar.bz2
stable-diffusion-webui-gfx803-0fc7dc1c04a046d95588651ffc4e71a7d40378d3.zip
implementing script metadata and DAG sorting mechanism
Diffstat (limited to 'modules/scripts.py')
-rw-r--r--modules/scripts.py141
1 files changed, 126 insertions, 15 deletions
diff --git a/modules/scripts.py b/modules/scripts.py
index 5c6e0226..e92a34a0 100644
--- a/modules/scripts.py
+++ b/modules/scripts.py
@@ -2,6 +2,7 @@ import os
import re
import sys
import inspect
+from graphlib import TopologicalSorter, CycleError
from collections import namedtuple
from dataclasses import dataclass
@@ -314,15 +315,131 @@ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedi
def list_scripts(scriptdirname, extension, *, include_extensions=True):
scripts_list = []
-
- basedir = os.path.join(paths.script_path, scriptdirname)
- if os.path.exists(basedir):
- for filename in sorted(os.listdir(basedir)):
- scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
+ script_dependency_map = {}
+
+ # build script dependency map
+
+ root_script_basedir = os.path.join(paths.script_path, scriptdirname)
+ if os.path.exists(root_script_basedir):
+ for filename in sorted(os.listdir(root_script_basedir)):
+ script_dependency_map[filename] = {
+ "extension": None,
+ "extension_dirname": None,
+ "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
+ "requires": [],
+ "load_before": [],
+ "load_after": [],
+ }
if include_extensions:
for ext in extensions.active():
- scripts_list += ext.list_files(scriptdirname, extension)
+ extension_scripts_list = ext.list_files(scriptdirname, extension)
+ for extension_script in extension_scripts_list:
+ # this is built on the assumption that script name is unique.
+ # I think bad thing is gonna happen if name collide in the current implementation anyway, but we
+ # will need to refactor here if this assumption is broken later on.
+ if extension_script.filename in script_dependency_map:
+ errors.report(f"Duplicate script name \"{extension_script.filename}\" found in extensions "
+ f"\"{ext.name}\" and \"{script_dependency_map[extension_script.filename]['extension_dirname'] or 'builtin'}\". "
+ f"The current loading file will be discarded.", exc_info=False)
+ continue
+
+ relative_path = scriptdirname + "/" + extension_script.filename
+
+ requires = None
+ load_before = None
+ load_after = None
+
+ if ext.metadata is not None:
+ requires = ext.metadata.get(relative_path, "Requires", fallback=None)
+ load_before = ext.metadata.get(relative_path, "Before", fallback=None)
+ load_after = ext.metadata.get(relative_path, "After", fallback=None)
+
+ requires = [x.strip() for x in requires.split(',')] if requires else []
+ load_after = [x.strip() for x in load_after.split(',')] if load_after else []
+ load_before = [x.strip() for x in load_before.split(',')] if load_before else []
+
+ script_dependency_map[extension_script.filename] = {
+ "extension": ext.canonical_name,
+ "extension_dirname": ext.name,
+ "script_file": extension_script,
+ "requires": requires,
+ "load_before": load_before,
+ "load_after": load_after,
+ }
+
+ # resolve dependencies
+
+ loaded_extensions = set()
+ for _, script_data in script_dependency_map.items():
+ if script_data['extension'] is not None:
+ loaded_extensions.add(script_data['extension'])
+
+ for script_filename, script_data in script_dependency_map.items():
+ # load before requires inverse dependency
+ # in this case, append the script name into the load_after list of the specified script
+ for load_before_script in script_data['load_before']:
+ if load_before_script.startswith('ext:'):
+ # if this requires an extension to be loaded before
+ required_extension = load_before_script[4:]
+ for _, script_data2 in script_dependency_map.items():
+ if script_data2['extension'] == required_extension:
+ script_data2['load_after'].append(script_filename)
+ break
+ else:
+ # if this requires an individual script to be loaded before
+ if load_before_script in script_dependency_map:
+ script_dependency_map[load_before_script]['load_after'].append(script_filename)
+
+ # resolve extension name in load_after lists
+ for load_after_script in script_data['load_after']:
+ if load_after_script.startswith('ext:'):
+ # if this requires an extension to be loaded after
+ required_extension = load_after_script[4:]
+ for script_file_name2, script_data2 in script_dependency_map.items():
+ if script_data2['extension'] == required_extension:
+ script_data['load_after'].append(script_file_name2)
+
+ # remove all extension names in load_after lists
+ script_data['load_after'] = [x for x in script_data['load_after'] if not x.startswith('ext:')]
+
+ # build the DAG
+ sorter = TopologicalSorter()
+ for script_filename, script_data in script_dependency_map.items():
+ requirement_met = True
+ for required_script in script_data['requires']:
+ if required_script.startswith('ext:'):
+ # if this requires an extension to be installed
+ required_extension = required_script[4:]
+ if required_extension not in loaded_extensions:
+ errors.report(f"Script \"{script_filename}\" requires extension \"{required_extension}\" to "
+ f"be loaded, but it is not. Skipping.",
+ exc_info=False)
+ requirement_met = False
+ break
+ else:
+ # if this requires an individual script to be loaded
+ if required_script not in script_dependency_map:
+ errors.report(f"Script \"{script_filename}\" requires script \"{required_script}\" to "
+ f"be loaded, but it is not. Skipping.",
+ exc_info=False)
+ requirement_met = False
+ break
+ if not requirement_met:
+ continue
+
+ sorter.add(script_filename, *script_data['load_after'])
+
+ # sort the scripts
+ try:
+ ordered_script = sorter.static_order()
+ except CycleError:
+ errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
+ ordered_script = script_dependency_map.keys()
+
+ for script_filename in ordered_script:
+ script_data = script_dependency_map[script_filename]
+ scripts_list.append(script_data['script_file'])
scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
@@ -365,15 +482,9 @@ def load_scripts():
elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
- def orderby(basedir):
- # 1st webui, 2nd extensions-builtin, 3rd extensions
- priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
- for key in priority:
- if basedir.startswith(key):
- return priority[key]
- return 9999
-
- for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
+ # here the scripts_list is already ordered
+ # processing_script is not considered though
+ for scriptfile in scripts_list:
try:
if scriptfile.basedir != paths.script_path:
sys.path = [scriptfile.basedir] + sys.path