Skip to content

Commit

Permalink
Write specs in Python instead of parsing comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gahjelle committed Sep 6, 2024
1 parent f5d7228 commit a91bddf
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 39 deletions.
48 changes: 15 additions & 33 deletions generate_spherely_vfunc_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
STUB_FILE_PATH = Path(__file__).parent / "src" / "spherely.pyi"
BEGIN_MARKER = "# /// Begin types"
END_MARKER = "# /// End types"
LINE_PREFIX = "# - "


def update_stub_file(path=STUB_FILE_PATH):
def update_stub_file(path=STUB_FILE_PATH, **type_specs):
stub_text = path.read_text(encoding="utf-8")
try:
start_idx = stub_text.index(BEGIN_MARKER)
Expand All @@ -19,41 +18,19 @@ def update_stub_file(path=STUB_FILE_PATH):
f"were not found in stub file '{path}'"
) from None

args_specs = [
_parse_vfunctype_args(line.removeprefix(LINE_PREFIX))
for line in stub_text[start_idx:end_idx].splitlines()
if line.startswith(LINE_PREFIX)
]

header = "\n".join(
[BEGIN_MARKER, "#"]
+ [
f"{LINE_PREFIX}{', '.join(f'{a}={t}' for a, t in args.items())}"
for args in args_specs
]
+ ["#", ""]
header = f"{BEGIN_MARKER}\n"
code = "\n\n".join(
_vfunctype_factory(name, **args) for name, args in type_specs.items()
)
code = "\n\n".join(_vfunctype_factory(**args) for args in args_specs)
updated_stub_text = stub_text[:start_idx] + header + code + stub_text[end_idx:]
path.write_text(updated_stub_text, encoding="utf-8")


def _parse_vfunctype_args(signature):
types = {}
for arg in signature.split(","):
arg_name, _, arg_type = arg.strip().partition("=")
types[arg_name.strip()] = arg_type.strip()

# The `n_in` parameter isn't a type and should be interpreted as an int
return types | {"n_in": int(types["n_in"])}


def _vfunctype_factory(n_in, **optargs):
def _vfunctype_factory(class_name, n_in, **optargs):
"""Create new VFunc types.
Based on the number of input arrays and optional arguments and their types."""
names = ["geography"] if n_in == 1 else list(string.ascii_lowercase[:n_in])
class_name = f"_VFunc_Nin{n_in}{''.join(optargs)}_Nout1"
arg_names = ["geography"] if n_in == 1 else list(string.ascii_lowercase[:n_in])
class_code = [
f"class {class_name}(",
" Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]",
Expand All @@ -67,13 +44,14 @@ def _vfunctype_factory(n_in, **optargs):
)

geog_types = ["Geography", "npt.ArrayLike"]
for types in itertools.product(geog_types, repeat=n_in):
for arg_types in itertools.product(geog_types, repeat=n_in):
arg_str = ", ".join(
f"{arg_name}: {arg_type}" for arg_name, arg_type in zip(names, types)
f"{arg_name}: {arg_type}"
for arg_name, arg_type in zip(arg_names, arg_types)
)
return_type = (
"_ScalarReturnType"
if all(t == geog_types[0] for t in types)
if all(t == geog_types[0] for t in arg_types)
else "npt.NDArray[_ArrayReturnDType]"
)
class_code.extend(
Expand All @@ -93,4 +71,8 @@ def _vfunctype_factory(n_in, **optargs):


if __name__ == "__main__":
update_stub_file()
update_stub_file(
_VFunc_Nin1_Nout1={"n_in": 1},
_VFunc_Nin2_Nout1={"n_in": 2},
# _VFunc_Nin2radius_Nout1={"n_in": 2, "radius": "float"},
)
8 changes: 2 additions & 6 deletions src/spherely.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,10 @@ _ScalarReturnType = TypeVar("_ScalarReturnType", bound=Any)
_ArrayReturnDType = TypeVar("_ArrayReturnDType", bound=Any)

# The following types are auto-generated. Please don't edit them by hand.
# Instead, create lines with n_in and optional arguments like below and run
# generate_spherely_vfunc_types.py to update them.
# Instead, update the generate_spherely_vfunc_types.py script and run it
# to update the types.
#
# /// Begin types
#
# - n_in=1
# - n_in=2
#
class _VFunc_Nin1_Nout1(Generic[_NameType, _ScalarReturnType, _ArrayReturnDType]):
@property
def __name__(self) -> _NameType: ...
Expand Down

0 comments on commit a91bddf

Please sign in to comment.