From 6398dc9b1049f242576ca309f95a3fb1e654951c Mon Sep 17 00:00:00 2001 From: AUTOMATIC <16777216c@gmail.com> Date: Sat, 22 Oct 2022 13:34:49 +0300 Subject: further support for extensions --- modules/scripts.py | 44 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 9 deletions(-) (limited to 'modules/scripts.py') diff --git a/modules/scripts.py b/modules/scripts.py index 65f25f49..9323af3e 100644 --- a/modules/scripts.py +++ b/modules/scripts.py @@ -102,17 +102,39 @@ def list_scripts(scriptdirname, extension): if os.path.exists(extdir): for dirname in sorted(os.listdir(extdir)): dirpath = os.path.join(extdir, dirname) - if not os.path.isdir(dirpath): + scriptdirpath = os.path.join(dirpath, scriptdirname) + + if not os.path.isdir(scriptdirpath): continue - for filename in sorted(os.listdir(os.path.join(dirpath, scriptdirname))): - scripts_list.append(ScriptFile(dirpath, filename, os.path.join(dirpath, scriptdirname, filename))) + for filename in sorted(os.listdir(scriptdirpath)): + scripts_list.append(ScriptFile(dirpath, filename, os.path.join(scriptdirpath, filename))) scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)] return scripts_list +def list_files_with_name(filename): + res = [] + + dirs = [paths.script_path] + + extdir = os.path.join(paths.script_path, "extensions") + if os.path.exists(extdir): + dirs += [os.path.join(extdir, d) for d in sorted(os.listdir(extdir))] + + for dirpath in dirs: + if not os.path.isdir(dirpath): + continue + + path = os.path.join(dirpath, filename) + if os.path.isfile(filename): + res.append(path) + + return res + + def load_scripts(): global current_basedir scripts_data.clear() @@ -276,7 +298,7 @@ class ScriptRunner: print(f"Error running alwayson script: {script.filename}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) - def reload_sources(self): + def reload_sources(self, cache): for si, script in list(enumerate(self.scripts)): with open(script.filename, "r", encoding="utf8") as file: args_from = script.args_from @@ -286,9 +308,12 @@ class ScriptRunner: from types import ModuleType - compiled = compile(text, filename, 'exec') - module = ModuleType(script.filename) - exec(compiled, module.__dict__) + module = cache.get(filename, None) + if module is None: + compiled = compile(text, filename, 'exec') + module = ModuleType(script.filename) + exec(compiled, module.__dict__) + cache[filename] = module for key, script_class in module.__dict__.items(): if type(script_class) == type and issubclass(script_class, Script): @@ -303,8 +328,9 @@ scripts_img2img = ScriptRunner() def reload_script_body_only(): - scripts_txt2img.reload_sources() - scripts_img2img.reload_sources() + cache = {} + scripts_txt2img.reload_sources(cache) + scripts_img2img.reload_sources(cache) def reload_scripts(): -- cgit v1.2.3