From ca9d55216c03541f7907e538fcedaa6c7e568664 Mon Sep 17 00:00:00 2001 From: Devon Stewart Date: Mon, 30 Sep 2024 18:54:30 +0000 Subject: [PATCH] feat/finishing typeddict inputs (#95) Why === We got pretty close to having TypedDicts for river-python inputs before, but had to roll back due to a protocol mismatch. Trying again, and also adding some tests to confirm that at the very least the Pydantic models can decode was was encoded by the TypedDict encoders. It's not a perfect science, but it should be good enough to start building more confidence as we make additional progress. ### The reason for "janky" tests There's a bit of a chicken-and-egg situation when trying to test code generation at runtime. We have three options: - write pytest handlers where each invocation runs the codegen with a temp target (like the shell script does here), writes a static file for each text into that directory, then executes a new python into that directory. The challenge with this is that it would suck to write or maintain. - write pytest handlers which runs the codegen with unique module name targets (like `gen1`, `gen2`, `gen3`, one for each codegen run necessary) and carefully juggle the imports to make sure we don't try to import something that's not there yet. This _might_ be the best option, but I'm not convinced about the ergonomics at the moment. It might be OK though, with highly targeted `.gitignore`'s. - maintain a bespoke test runner, optimize for writing and maintaining these tests, and just acknowledge that we are doing something obscure and difficult. I definitely wrote the tests here in a way that would give some coverage and also provide confidence, while intentionally deferring the above decision so we can keep making progress. in the meantime. What changed ============ - Added some janky tests for comparing the encoding of both models - Fixed many bugs in the TypedDict codegen and encoders Test plan ========= ``` $ bash scripts/parity.sh Using /tmp/river-codegen-parity.bAZ Starting... Verified ``` --- mypy.ini | 11 ++- replit_river/codegen/client.py | 101 ++++++++++++++------- scripts/parity.sh | 40 ++++++++ scripts/parity/check_parity.py | 161 +++++++++++++++++++++++++++++++++ scripts/parity/gen.py | 45 +++++++++ 5 files changed, 326 insertions(+), 32 deletions(-) create mode 100644 scripts/parity.sh create mode 100644 scripts/parity/check_parity.py create mode 100644 scripts/parity/gen.py diff --git a/mypy.ini b/mypy.ini index 214eb5b..67a94b6 100644 --- a/mypy.ini +++ b/mypy.ini @@ -4,4 +4,13 @@ disallow_untyped_defs = True warn_return_any = True [mypy-grpc.*] -ignore_missing_imports = True \ No newline at end of file +ignore_missing_imports = True + +[mypy-parity.gen.*] +ignore_missing_imports = True + +[mypy-pyd.*] +ignore_missing_imports = True + +[mypy-tyd.*] +ignore_missing_imports = True diff --git a/replit_river/codegen/client.py b/replit_river/codegen/client.py index 4a9b431..cac352c 100644 --- a/replit_river/codegen/client.py +++ b/replit_river/codegen/client.py @@ -12,6 +12,7 @@ Set, Tuple, Union, + cast, ) import black @@ -80,8 +81,17 @@ def reindent(prefix: str, code: str) -> str: return indent(dedent(code), prefix) +def is_literal(tpe: RiverType) -> bool: + if isinstance(tpe, RiverUnionType): + return all(is_literal(t) for t in tpe.anyOf) + elif isinstance(tpe, RiverConcreteType): + return tpe.type in set(["string", "number", "boolean"]) + else: + return False + + def encode_type( - type: RiverType, prefix: str, base_model: str = "BaseModel" + type: RiverType, prefix: str, base_model: str ) -> Tuple[str, Sequence[str]]: chunks: List[str] = [] if isinstance(type, RiverNotType): @@ -219,14 +229,6 @@ def flatten_union(tpe: RiverType) -> list[RiverType]: type = original_type any_of: List[str] = [] - def is_literal(tpe: RiverType) -> bool: - if isinstance(tpe, RiverUnionType): - return all(is_literal(t) for t in tpe.anyOf) - elif isinstance(tpe, RiverConcreteType): - return tpe.type in set(["string", "number", "boolean"]) - else: - return False - typeddict_encoder = [] for i, t in enumerate(type.anyOf): type_name, type_chunks = encode_type(t, f"{prefix}AnyOf_{i}", base_model) @@ -273,44 +275,44 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: # Handle the case where type is not specified typeddict_encoder.append("x") return ("Any", ()) - if type.type == "string": + elif type.type == "string": if type.const: typeddict_encoder.append(f"'{type.const}'") return (f"Literal['{type.const}']", ()) else: typeddict_encoder.append("x") return ("str", ()) - if type.type == "Uint8Array": + elif type.type == "Uint8Array": typeddict_encoder.append("x.decode()") return ("bytes", ()) - if type.type == "number": + elif type.type == "number": if type.const is not None: # enums are represented as const number in the schema typeddict_encoder.append(f"{type.const}") return (f"Literal[{type.const}]", ()) typeddict_encoder.append("x") return ("float", ()) - if type.type == "integer": + elif type.type == "integer": if type.const is not None: # enums are represented as const number in the schema typeddict_encoder.append(f"{type.const}") return (f"Literal[{type.const}]", ()) typeddict_encoder.append("x") return ("int", ()) - if type.type == "boolean": + elif type.type == "boolean": typeddict_encoder.append("x") return ("bool", ()) - if type.type == "null": + elif type.type == "null": typeddict_encoder.append("None") return ("None", ()) - if type.type == "Date": + elif type.type == "Date": typeddict_encoder.append("TODO: dstewart") return ("datetime.datetime", ()) - if type.type == "array" and type.items: + elif type.type == "array" and type.items: type_name, type_chunks = encode_type(type.items, prefix, base_model) typeddict_encoder.append("TODO: dstewart") return (f"List[{type_name}]", type_chunks) - if ( + elif ( type.type == "object" and type.patternProperties and "^(.*)$" in type.patternProperties @@ -323,7 +325,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: assert type.type == "object", type.type current_chunks: List[str] = [f"class {prefix}({base_model}):"] + # For the encoder path, do we need "x" to be bound? + # lambda x: ... vs lambda _: {} + needs_binding = False if type.properties: + needs_binding = True typeddict_encoder.append("{") for name, prop in type.properties.items(): typeddict_encoder.append(f"'{name}':") @@ -353,18 +359,35 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: ) if name not in prop.required: typeddict_encoder.append( - f"if x['{safe_name}'] else None" + dedent( + f""" + if '{safe_name}' in x + and x['{safe_name}'] is not None + else None + """ + ) ) elif prop.type == "array": - assert type_name.startswith( - "List[" - ) # in case we change to list[...] - _inner_type_name = type_name[len("List[") : -len("]")] - typeddict_encoder.append( - f"[encode_{_inner_type_name}(y) for y in x['{name}']]" - ) + items = cast(RiverConcreteType, prop).items + assert items, "Somehow items was none" + if is_literal(cast(RiverType, items)): + typeddict_encoder.append(f"x['{name}']") + else: + assert type_name.startswith( + "List[" + ) # in case we change to list[...] + _inner_type_name = type_name[len("List[") : -len("]")] + typeddict_encoder.append( + f"""[ + encode_{_inner_type_name}(y) + for y in x['{name}'] + ]""" + ) else: - typeddict_encoder.append(f"x['{safe_name}']") + if name in prop.required: + typeddict_encoder.append(f"x['{safe_name}']") + else: + typeddict_encoder.append(f"x.get('{safe_name}')") if name == "$kind": # If the field is a literal, the Python type-checker will complain @@ -403,8 +426,9 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]: current_chunks.append("") if base_model == "TypedDict": + binding = "x" if needs_binding else "_" current_chunks = ( - [f"encode_{prefix}: Callable[['{prefix}'], Any] = (lambda x: "] + [f"encode_{prefix}: Callable[['{prefix}'], Any] = (lambda {binding}: "] + typeddict_encoder + [")"] + current_chunks @@ -449,7 +473,7 @@ def generate_river_client_module( if schema_root.handshakeSchema is not None: (handshake_type, handshake_chunks) = encode_type( - schema_root.handshakeSchema, "HandshakeSchema" + schema_root.handshakeSchema, "HandshakeSchema", "BaseModel" ) chunks.extend(handshake_chunks) else: @@ -482,7 +506,9 @@ def __init__(self, client: river.Client[{handshake_type}]): ) chunks.extend(input_chunks) output_type, output_chunks = encode_type( - procedure.output, f"{schema_name.title()}{name.title()}Output" + procedure.output, + f"{schema_name.title()}{name.title()}Output", + "BaseModel", ) chunks.extend(output_chunks) if procedure.errors: @@ -517,7 +543,20 @@ def __init__(self, client: river.Client[{handshake_type}]): """.rstrip() if typed_dict_inputs: - render_input_method = f"encode_{input_type}" + if is_literal(procedure.input): + render_input_method = "lambda x: x" + elif isinstance( + procedure.input, RiverConcreteType + ) and procedure.input.type in ["array"]: + assert input_type.startswith( + "List[" + ) # in case we change to list[...] + _input_type_name = input_type[len("List[") : -len("]")] + render_input_method = ( + f"lambda xs: [encode_{_input_type_name}(x) for x in xs]" + ) + else: + render_input_method = f"encode_{input_type}" else: render_input_method = f"""\ lambda x: TypeAdapter({input_type}) diff --git a/scripts/parity.sh b/scripts/parity.sh new file mode 100644 index 0000000..c9beda2 --- /dev/null +++ b/scripts/parity.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bash +# +# parity.sh: Generate Pydantic and TypedDict models and check for deep equality. +# This script expects that ai-infra is cloned alongside river-python. + +set -e + +scripts="$(dirname "$0")" +cd "${scripts}/.." + +root="$(mktemp -d --tmpdir 'river-codegen-parity.XXX')" +mkdir "$root/src" + +echo "Using $root" >&2 + +function cleanup { + if [ -z "${DEBUG}" ]; then + echo "Cleaning up..." >&2 + rm -rfv "${root}" >&2 + fi +} +trap "cleanup" 0 2 3 15 + +gen() { + fname="$1"; shift + name="$1"; shift + poetry run python -m replit_river.codegen \ + client \ + --output "${root}/src/${fname}" \ + --client-name "${name}" \ + ../ai-infra/pkgs/pid2_client/src/schema/schema.json \ + "$@" +} + +gen tyd.py Pid2TypedDict --typed-dict-inputs +gen pyd.py Pid2Pydantic + +PYTHONPATH="${root}/src:${scripts}" +poetry run bash -c "MYPYPATH='$PYTHONPATH' mypy -m parity.check_parity" +poetry run bash -c "PYTHONPATH='$PYTHONPATH' python -m parity.check_parity" diff --git a/scripts/parity/check_parity.py b/scripts/parity/check_parity.py new file mode 100644 index 0000000..4ecc362 --- /dev/null +++ b/scripts/parity/check_parity.py @@ -0,0 +1,161 @@ +from typing import Any, Callable, Literal, TypedDict, TypeVar, Union, cast + +import pyd +import tyd +from parity.gen import ( + gen_bool, + gen_choice, + gen_dict, + gen_float, + gen_int, + gen_list, + gen_opt, + gen_str, +) +from pydantic import TypeAdapter + +A = TypeVar("A") + + +def baseTestPattern( + x: A, encode: Callable[[A], Any], adapter: TypeAdapter[Any] +) -> None: + a = encode(x) + m = adapter.validate_python(a) + z = adapter.dump_python(m) + + assert a == z + + +def testAiexecExecInit() -> None: + x: tyd.AiexecExecInit = { + "args": gen_list(gen_str)(), + "env": gen_opt(gen_dict(gen_str))(), + "cwd": gen_opt(gen_str)(), + "omitStdout": gen_opt(gen_bool)(), + "omitStderr": gen_opt(gen_bool)(), + "useReplitRunEnv": gen_opt(gen_bool)(), + } + + baseTestPattern(x, tyd.encode_AiexecExecInit, TypeAdapter(pyd.AiexecExecInit)) + + +def testAgenttoollanguageserverOpendocumentInput() -> None: + x: tyd.AgenttoollanguageserverOpendocumentInput = { + "uri": gen_str(), + "languageId": gen_str(), + "version": gen_float(), + "text": gen_str(), + } + + baseTestPattern( + x, + tyd.encode_AgenttoollanguageserverOpendocumentInput, + TypeAdapter(pyd.AgenttoollanguageserverOpendocumentInput), + ) + + +kind_type = Union[ + Literal[1], + Literal[2], + Literal[3], + Literal[4], + Literal[5], + Literal[6], + Literal[7], + Literal[8], + Literal[9], + Literal[10], + Literal[11], + Literal[12], + Literal[13], + Literal[14], + Literal[15], + Literal[16], + Literal[17], + Literal[18], + Literal[19], + Literal[20], + Literal[21], + Literal[22], + Literal[23], + Literal[24], + Literal[25], + Literal[26], + None, +] + + +def testAgenttoollanguageserverGetcodesymbolInput() -> None: + x: tyd.AgenttoollanguageserverGetcodesymbolInput = { + "uri": gen_str(), + "position": { + "line": gen_float(), + "character": gen_float(), + }, + "kind": cast(kind_type, gen_opt(gen_choice(list(range(1, 27))))()), + } + + baseTestPattern( + x, + tyd.encode_AgenttoollanguageserverGetcodesymbolInput, + TypeAdapter(pyd.AgenttoollanguageserverGetcodesymbolInput), + ) + + +class size_type(TypedDict): + rows: int + cols: int + + +def testShellexecSpawnInput() -> None: + x: tyd.ShellexecSpawnInput = { + "cmd": gen_str(), + "args": gen_opt(gen_list(gen_str))(), + "initialCmd": gen_opt(gen_str)(), + "env": gen_opt(gen_dict(gen_str))(), + "cwd": gen_opt(gen_str)(), + "size": gen_opt( + lambda: cast( + size_type, + { + "rows": gen_int(), + "cols": gen_int(), + }, + ) + )(), + "useReplitRunEnv": gen_opt(gen_bool)(), + "useCgroupMagic": gen_opt(gen_bool)(), + "interactive": gen_opt(gen_bool)(), + } + + baseTestPattern( + x, + tyd.encode_ShellexecSpawnInput, + TypeAdapter(pyd.ShellexecSpawnInput), + ) + + +def testConmanfilesystemPersistInput() -> None: + x: tyd.ConmanfilesystemPersistInput = {} + + baseTestPattern( + x, + tyd.encode_ConmanfilesystemPersistInput, + TypeAdapter(pyd.ConmanfilesystemPersistInput), + ) + + +def main() -> None: + testAiexecExecInit() + testAgenttoollanguageserverOpendocumentInput() + testAgenttoollanguageserverGetcodesymbolInput() + testShellexecSpawnInput() + testConmanfilesystemPersistInput() + + +if __name__ == "__main__": + print("Starting...") + for _ in range(0, 100): + main() + print("Verified") diff --git a/scripts/parity/gen.py b/scripts/parity/gen.py new file mode 100644 index 0000000..ae77718 --- /dev/null +++ b/scripts/parity/gen.py @@ -0,0 +1,45 @@ +import random +from typing import Callable, Optional, TypeVar + +A = TypeVar("A") + + +def gen_char() -> str: + pos = random.randint(0, 26 * 2) + if pos < 26: + return chr(ord("A") + pos) + else: + return chr(ord("a") + pos - 26) + + +def gen_str() -> str: + return "".join(gen_char() for _ in range(0, random.randint(0, 20))) + + +def gen_list(gen_x: Callable[[], A]) -> Callable[[], list[A]]: + return lambda: [gen_x() for _ in range(0, random.randint(0, 10))] + + +def gen_bool() -> bool: + # Ten times more likely to be true than false + return bool(random.randint(0, 10)) + + +def gen_float() -> float: + return random.random() * 100 + + +def gen_int() -> int: + return random.randint(0, 2048) + + +def gen_choice(choices: list[A]) -> Callable[[], A]: + return lambda: random.choice(choices) + + +def gen_opt(gen_x: Callable[[], A]) -> Callable[[], Optional[A]]: + return lambda: gen_x() if gen_bool() else None + + +def gen_dict(gen_x: Callable[[], A]) -> Callable[[], dict[str, A]]: + return lambda: dict((gen_str(), gen_x()) for _ in range(0, random.randint(0, 5)))