From 2b5b0f6502cc245b2365e4a8b5fdaaefb9d67b5d Mon Sep 17 00:00:00 2001 From: sigoden Date: Sat, 8 Jun 2024 10:30:34 +0800 Subject: refactor: improve Argcfile.sh and scripts/run-tool* (#34) * refactor: improve Argcfile and scripts/run-tool* * fix run-tool.js on windows --- scripts/run-tool.py | 99 ++++++++++++++++++++++++++++++++++------------------- 1 file changed, 64 insertions(+), 35 deletions(-) (limited to 'scripts/run-tool.py') diff --git a/scripts/run-tool.py b/scripts/run-tool.py index 220b099..f5aef4f 100755 --- a/scripts/run-tool.py +++ b/scripts/run-tool.py @@ -6,16 +6,39 @@ import sys import importlib.util -def parse_argv(): - tool_name = sys.argv[0] +def main(): + (tool_name, raw_data) = parse_argv("run-tool.py") + tool_data = parse_raw_data(raw_data) + + root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) + setup_env(root_dir, tool_name) + + tool_path = os.path.join(root_dir, f"tools/{tool_name}.py") + run(tool_path, "run", tool_data) + + +def parse_raw_data(data): + if not data: + raise ValueError("No JSON data") + + try: + return json.loads(data) + except Exception: + raise ValueError("Invalid JSON data") + + +def parse_argv(this_file_name): + argv = sys.argv[:] + [None] * max(0, 3 - len(sys.argv)) + + tool_name = argv[0] tool_data = None - if tool_name.endswith("run-tool.py"): - tool_name = sys.argv[1] if len(sys.argv) > 1 else None - tool_data = sys.argv[2] if len(sys.argv) > 2 else None + if tool_name.endswith(this_file_name): + tool_name = argv[1] + tool_data = argv[2] else: tool_name = os.path.basename(tool_name) - tool_data = sys.argv[1] if len(sys.argv) > 1 else None + tool_data = sys.argv[1] if tool_name.endswith(".py"): tool_name = tool_name[:-3] @@ -23,17 +46,11 @@ def parse_argv(): return tool_name, tool_data -def load_module(tool_name): - tool_file_name = f"{tool_name}.py" - tool_path = os.path.join(os.environ["LLM_ROOT_DIR"], f"tools/{tool_file_name}") - if os.path.exists(tool_path): - spec = importlib.util.spec_from_file_location(f"{tool_file_name}", tool_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module - else: - print(f"Invalid function: {tool_file_name}") - sys.exit(1) +def setup_env(root_dir, tool_name): + os.environ["LLM_ROOT_DIR"] = root_dir + load_env(os.path.join(root_dir, ".env")) + os.environ["LLM_TOOL_NAME"] = tool_name + os.environ["LLM_TOOL_CACHE_DIR"] = os.path.join(root_dir, "cache", tool_name) def load_env(file_path): @@ -50,27 +67,39 @@ def load_env(file_path): pass -LLM_ROOT_DIR = os.environ["LLM_ROOT_DIR"] = os.path.abspath( - os.path.join(os.path.dirname(__file__), "..") -) +def run(tool_path, tool_func, tool_data): + try: + spec = importlib.util.spec_from_file_location( + os.path.basename(tool_path), tool_path + ) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + except: + raise Exception(f"Unable to load tool at '{tool_path}'") -load_env(os.path.join(LLM_ROOT_DIR, ".env")) + if not hasattr(mod, tool_func): + raise Exception(f"Not module function '{tool_func}' at '{tool_path}'") -tool_name, tool_data = parse_argv() + value = getattr(mod, tool_func)(**tool_data) + dump_value(value) -os.environ["LLM_TOOL_NAME"] = tool_name -os.environ["LLM_TOOL_CACHE_DIR"] = os.path.join(LLM_ROOT_DIR, "cache", tool_name) -if not tool_data: - print("No json data") - sys.exit(1) +def dump_value(value): + if value is None: + return -data = None -try: - data = json.loads(tool_data) -except (json.JSONDecodeError, TypeError): - print("Invalid json data") - sys.exit(1) + value_type = type(value).__name__ + if value_type in ("str", "int", "float", "bool"): + print(value) + elif value_type == "dict" or value_type == "list": + value_str = json.dumps(value, indent=2) + assert value == json.loads(value_str) + print(value_str) -module = load_module(tool_name) -module.run(**data) + +if __name__ == "__main__": + try: + main() + except Exception as e: + print(e, file=sys.stderr) + sys.exit(1) -- cgit v1.2.3