From 739a832d87c00e3b5977a24bba5654fa5ea7a702 Mon Sep 17 00:00:00 2001 From: sigoden Date: Fri, 7 Jun 2024 15:16:31 +0800 Subject: feat: js/py generate declarations from comments (#30) --- scripts/build-declarations.py | 177 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 177 insertions(+) create mode 100644 scripts/build-declarations.py (limited to 'scripts/build-declarations.py') diff --git a/scripts/build-declarations.py b/scripts/build-declarations.py new file mode 100644 index 0000000..17a27b6 --- /dev/null +++ b/scripts/build-declarations.py @@ -0,0 +1,177 @@ +#!/usr/bin/env python + +import ast +import os +import json +import re +import sys +from collections import OrderedDict + + +def main(): + scriptfile = sys.argv[1] + with open(scriptfile, "r", encoding="utf-8") as f: + contents = f.read() + + functions = extract_functions(contents) + declarations = [] + for function in functions: + func_name, docstring, func_args = function + description, params = parse_docstring(docstring) + declarations.append( + build_declaration(func_name, description, params, func_args) + ) + + name = os.path.splitext(os.path.basename(scriptfile))[0] + if declarations: + declarations = declarations[0:1] + declarations[0]["name"] = name + + print(json.dumps(declarations, indent=2)) + + +def extract_functions(contents: str): + tree = ast.parse(contents) + output = [] + for node in ast.walk(tree): + if not isinstance(node, ast.FunctionDef): + continue + func_name = node.name + if func_name.startswith("_"): + continue + docstring = ast.get_docstring(node) or "" + func_args = OrderedDict() + for arg in node.args.args: + arg_name = arg.arg + arg_type = get_arg_type(arg.annotation) + func_args[arg_name] = arg_type + output.append((func_name, docstring, func_args)) + return output + + +def get_arg_type(annotation) -> str: + if annotation is None: + return "" + elif isinstance(annotation, ast.Name): + return annotation.id + elif isinstance(annotation, ast.Subscript): + if isinstance(annotation.value, ast.Name): + type_name = annotation.value.id + if type_name == "List": + child = get_arg_type(annotation.slice) + return f"list[{child}]" + if type_name == "Literal": + literals = [ast.unparse(el) for el in annotation.slice.elts] + return f"{'|'.join(literals)}" + if type_name == "Optional": + child = get_arg_type(annotation.slice) + return f"{child}?" + return "any" + + +def parse_docstring(docstring: str): + lines = docstring.splitlines() + description = "" + rawParams = [] + is_in_args = False + for line in lines: + if not is_in_args: + if line.startswith("Args:"): + is_in_args = True + else: + description += f"\n{line}" + continue + else: + if re.search(r"^\s+", line): + rawParams.append(line.strip()) + else: + break + params = {} + for rawParam in rawParams: + name, type_, description = parse_param(rawParam) + params[name] = (type_, description) + return (description.strip(), params) + + +def parse_param(raw_param: str): + name = "" + description = "" + type_from_comment = "" + if ":" in raw_param: + name, description = raw_param.split(":", 1) + name = name.strip() + description = description.strip() + else: + name = raw_param + if " " in name: + name, type_from_comment = name.split(" ", 1) + type_from_comment = type_from_comment.strip() + + if type_from_comment.startswith("(") and type_from_comment.endswith(")"): + type_from_comment = type_from_comment[1:-1] + type_parts = [value.strip() for value in type_from_comment.split(",")] + type_ = type_parts[0] + if "optional" in type_parts[1:]: + type_ = f"{type_}?" + + return (name, type_, description) + + +def build_declaration( + name: str, description: str, params: dict, args: OrderedDict[str, str] +) -> dict[str, dict]: + schema = { + "name": name, + "description": description, + "properties": {}, + } + required_params = [] + for arg_name, arg_type in args.items(): + type_ = arg_type + description = "" + required = True + if params.get(arg_name): + param_type, description = params[arg_name] + if not type_: + type_ = param_type + if type_.endswith("?"): + type_ = type_[:-1] + required = False + try: + property = build_property(type_, description) + except: + raise ValueError(f"Unable to parse arg '{arg_name}' of function '{name}'") + schema["properties"][arg_name] = property + if required: + required_params.append(arg_name) + if required_params: + schema["required"] = required_params + return schema + + +def build_property(type_: str, description: str): + property = {} + if "|" in type_: + property["type"] = "string" + property["enum"] = type_.replace("'", "").split("|") + elif type_ == "bool": + property["type"] = "boolean" + elif type_ == "str": + property["type"] = "string" + elif type_ == "int": + property["type"] = "integer" + elif type_ == "float": + property["type"] = "number" + elif type_ == "list[str]": + property["type"] = "array" + property["items"] = {"type": "string"} + elif type_ == "": + property["type"] = "string" + else: + raise ValueError(f"Unsupported type `{type_}`") + property["description"] = description + return property + + +if __name__ == "__main__": + main() -- cgit v1.2.3