aboutsummaryrefslogtreecommitdiffstats
path: root/scripts/run-tool.py
diff options
context:
space:
mode:
authorsigoden <sigoden@gmail.com>2024-06-08 10:30:34 +0800
committerGitHub <noreply@github.com>2024-06-08 10:30:34 +0800
commit2b5b0f6502cc245b2365e4a8b5fdaaefb9d67b5d (patch)
treed54633e31f57b55bf3f1a456a54cbbb6f93f1ebe /scripts/run-tool.py
parent0c6b609c261cb6f586668626d860dc6754725794 (diff)
downloadllm-functions-docker-2b5b0f6502cc245b2365e4a8b5fdaaefb9d67b5d.tar.gz
refactor: improve Argcfile.sh and scripts/run-tool* (#34)
* refactor: improve Argcfile and scripts/run-tool* * fix run-tool.js on windows
Diffstat (limited to 'scripts/run-tool.py')
-rwxr-xr-xscripts/run-tool.py99
1 files changed, 64 insertions, 35 deletions
diff --git a/scripts/run-tool.py b/scripts/run-tool.py
index 220b099..f5aef4f 100755
--- a/scripts/run-tool.py
+++ b/scripts/run-tool.py
@@ -6,16 +6,39 @@ import sys
import importlib.util
-def parse_argv():
- tool_name = sys.argv[0]
+def main():
+ (tool_name, raw_data) = parse_argv("run-tool.py")
+ tool_data = parse_raw_data(raw_data)
+
+ root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
+ setup_env(root_dir, tool_name)
+
+ tool_path = os.path.join(root_dir, f"tools/{tool_name}.py")
+ run(tool_path, "run", tool_data)
+
+
+def parse_raw_data(data):
+ if not data:
+ raise ValueError("No JSON data")
+
+ try:
+ return json.loads(data)
+ except Exception:
+ raise ValueError("Invalid JSON data")
+
+
+def parse_argv(this_file_name):
+ argv = sys.argv[:] + [None] * max(0, 3 - len(sys.argv))
+
+ tool_name = argv[0]
tool_data = None
- if tool_name.endswith("run-tool.py"):
- tool_name = sys.argv[1] if len(sys.argv) > 1 else None
- tool_data = sys.argv[2] if len(sys.argv) > 2 else None
+ if tool_name.endswith(this_file_name):
+ tool_name = argv[1]
+ tool_data = argv[2]
else:
tool_name = os.path.basename(tool_name)
- tool_data = sys.argv[1] if len(sys.argv) > 1 else None
+ tool_data = sys.argv[1]
if tool_name.endswith(".py"):
tool_name = tool_name[:-3]
@@ -23,17 +46,11 @@ def parse_argv():
return tool_name, tool_data
-def load_module(tool_name):
- tool_file_name = f"{tool_name}.py"
- tool_path = os.path.join(os.environ["LLM_ROOT_DIR"], f"tools/{tool_file_name}")
- if os.path.exists(tool_path):
- spec = importlib.util.spec_from_file_location(f"{tool_file_name}", tool_path)
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- return module
- else:
- print(f"Invalid function: {tool_file_name}")
- sys.exit(1)
+def setup_env(root_dir, tool_name):
+ os.environ["LLM_ROOT_DIR"] = root_dir
+ load_env(os.path.join(root_dir, ".env"))
+ os.environ["LLM_TOOL_NAME"] = tool_name
+ os.environ["LLM_TOOL_CACHE_DIR"] = os.path.join(root_dir, "cache", tool_name)
def load_env(file_path):
@@ -50,27 +67,39 @@ def load_env(file_path):
pass
-LLM_ROOT_DIR = os.environ["LLM_ROOT_DIR"] = os.path.abspath(
- os.path.join(os.path.dirname(__file__), "..")
-)
+def run(tool_path, tool_func, tool_data):
+ try:
+ spec = importlib.util.spec_from_file_location(
+ os.path.basename(tool_path), tool_path
+ )
+ mod = importlib.util.module_from_spec(spec)
+ spec.loader.exec_module(mod)
+ except:
+ raise Exception(f"Unable to load tool at '{tool_path}'")
-load_env(os.path.join(LLM_ROOT_DIR, ".env"))
+ if not hasattr(mod, tool_func):
+ raise Exception(f"Not module function '{tool_func}' at '{tool_path}'")
-tool_name, tool_data = parse_argv()
+ value = getattr(mod, tool_func)(**tool_data)
+ dump_value(value)
-os.environ["LLM_TOOL_NAME"] = tool_name
-os.environ["LLM_TOOL_CACHE_DIR"] = os.path.join(LLM_ROOT_DIR, "cache", tool_name)
-if not tool_data:
- print("No json data")
- sys.exit(1)
+def dump_value(value):
+ if value is None:
+ return
-data = None
-try:
- data = json.loads(tool_data)
-except (json.JSONDecodeError, TypeError):
- print("Invalid json data")
- sys.exit(1)
+ value_type = type(value).__name__
+ if value_type in ("str", "int", "float", "bool"):
+ print(value)
+ elif value_type == "dict" or value_type == "list":
+ value_str = json.dumps(value, indent=2)
+ assert value == json.loads(value_str)
+ print(value_str)
-module = load_module(tool_name)
-module.run(**data)
+
+if __name__ == "__main__":
+ try:
+ main()
+ except Exception as e:
+ print(e, file=sys.stderr)
+ sys.exit(1)