Skip to content

Commit

Permalink
feat/finishing typeddict inputs (#95)
Browse files Browse the repository at this point in the history
Why
===

We got pretty close to having TypedDicts for river-python inputs before,
but had to roll back due to a protocol mismatch.

Trying again, and also adding some tests to confirm that at the very
least the Pydantic models can decode was was encoded by the TypedDict
encoders. It's not a perfect science, but it should be good enough to
start building more confidence as we make additional progress.

### The reason for "janky" tests

There's a bit of a chicken-and-egg situation when trying to test code
generation at runtime.

We have three options:
- write pytest handlers where each invocation runs the codegen with a
temp target (like the shell script does here), writes a static file for
each text into that directory, then executes a new python into that
directory. The challenge with this is that it would suck to write or
maintain.
- write pytest handlers which runs the codegen with unique module name
targets (like `gen1`, `gen2`, `gen3`, one for each codegen run
necessary) and carefully juggle the imports to make sure we don't try to
import something that's not there yet. This _might_ be the best option,
but I'm not convinced about the ergonomics at the moment. It might be OK
though, with highly targeted `.gitignore`'s.
- maintain a bespoke test runner, optimize for writing and maintaining
these tests, and just acknowledge that we are doing something obscure
and difficult.

I definitely wrote the tests here in a way that would give some coverage
and also provide confidence, while intentionally deferring the above
decision so we can keep making progress. in the meantime.

What changed
============

- Added some janky tests for comparing the encoding of both models
- Fixed many bugs in the TypedDict codegen and encoders

Test plan
=========

```
$ bash scripts/parity.sh
Using /tmp/river-codegen-parity.bAZ
Starting...
Verified
```
  • Loading branch information
blast-hardcheese authored Sep 30, 2024
1 parent 62b236e commit ca9d552
Show file tree
Hide file tree
Showing 5 changed files with 326 additions and 32 deletions.
11 changes: 10 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,13 @@ disallow_untyped_defs = True
warn_return_any = True

[mypy-grpc.*]
ignore_missing_imports = True
ignore_missing_imports = True

[mypy-parity.gen.*]
ignore_missing_imports = True

[mypy-pyd.*]
ignore_missing_imports = True

[mypy-tyd.*]
ignore_missing_imports = True
101 changes: 70 additions & 31 deletions replit_river/codegen/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Set,
Tuple,
Union,
cast,
)

import black
Expand Down Expand Up @@ -80,8 +81,17 @@ def reindent(prefix: str, code: str) -> str:
return indent(dedent(code), prefix)


def is_literal(tpe: RiverType) -> bool:
if isinstance(tpe, RiverUnionType):
return all(is_literal(t) for t in tpe.anyOf)
elif isinstance(tpe, RiverConcreteType):
return tpe.type in set(["string", "number", "boolean"])
else:
return False


def encode_type(
type: RiverType, prefix: str, base_model: str = "BaseModel"
type: RiverType, prefix: str, base_model: str
) -> Tuple[str, Sequence[str]]:
chunks: List[str] = []
if isinstance(type, RiverNotType):
Expand Down Expand Up @@ -219,14 +229,6 @@ def flatten_union(tpe: RiverType) -> list[RiverType]:
type = original_type
any_of: List[str] = []

def is_literal(tpe: RiverType) -> bool:
if isinstance(tpe, RiverUnionType):
return all(is_literal(t) for t in tpe.anyOf)
elif isinstance(tpe, RiverConcreteType):
return tpe.type in set(["string", "number", "boolean"])
else:
return False

typeddict_encoder = []
for i, t in enumerate(type.anyOf):
type_name, type_chunks = encode_type(t, f"{prefix}AnyOf_{i}", base_model)
Expand Down Expand Up @@ -273,44 +275,44 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
# Handle the case where type is not specified
typeddict_encoder.append("x")
return ("Any", ())
if type.type == "string":
elif type.type == "string":
if type.const:
typeddict_encoder.append(f"'{type.const}'")
return (f"Literal['{type.const}']", ())
else:
typeddict_encoder.append("x")
return ("str", ())
if type.type == "Uint8Array":
elif type.type == "Uint8Array":
typeddict_encoder.append("x.decode()")
return ("bytes", ())
if type.type == "number":
elif type.type == "number":
if type.const is not None:
# enums are represented as const number in the schema
typeddict_encoder.append(f"{type.const}")
return (f"Literal[{type.const}]", ())
typeddict_encoder.append("x")
return ("float", ())
if type.type == "integer":
elif type.type == "integer":
if type.const is not None:
# enums are represented as const number in the schema
typeddict_encoder.append(f"{type.const}")
return (f"Literal[{type.const}]", ())
typeddict_encoder.append("x")
return ("int", ())
if type.type == "boolean":
elif type.type == "boolean":
typeddict_encoder.append("x")
return ("bool", ())
if type.type == "null":
elif type.type == "null":
typeddict_encoder.append("None")
return ("None", ())
if type.type == "Date":
elif type.type == "Date":
typeddict_encoder.append("TODO: dstewart")
return ("datetime.datetime", ())
if type.type == "array" and type.items:
elif type.type == "array" and type.items:
type_name, type_chunks = encode_type(type.items, prefix, base_model)
typeddict_encoder.append("TODO: dstewart")
return (f"List[{type_name}]", type_chunks)
if (
elif (
type.type == "object"
and type.patternProperties
and "^(.*)$" in type.patternProperties
Expand All @@ -323,7 +325,11 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
assert type.type == "object", type.type

current_chunks: List[str] = [f"class {prefix}({base_model}):"]
# For the encoder path, do we need "x" to be bound?
# lambda x: ... vs lambda _: {}
needs_binding = False
if type.properties:
needs_binding = True
typeddict_encoder.append("{")
for name, prop in type.properties.items():
typeddict_encoder.append(f"'{name}':")
Expand Down Expand Up @@ -353,18 +359,35 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
)
if name not in prop.required:
typeddict_encoder.append(
f"if x['{safe_name}'] else None"
dedent(
f"""
if '{safe_name}' in x
and x['{safe_name}'] is not None
else None
"""
)
)
elif prop.type == "array":
assert type_name.startswith(
"List["
) # in case we change to list[...]
_inner_type_name = type_name[len("List[") : -len("]")]
typeddict_encoder.append(
f"[encode_{_inner_type_name}(y) for y in x['{name}']]"
)
items = cast(RiverConcreteType, prop).items
assert items, "Somehow items was none"
if is_literal(cast(RiverType, items)):
typeddict_encoder.append(f"x['{name}']")
else:
assert type_name.startswith(
"List["
) # in case we change to list[...]
_inner_type_name = type_name[len("List[") : -len("]")]
typeddict_encoder.append(
f"""[
encode_{_inner_type_name}(y)
for y in x['{name}']
]"""
)
else:
typeddict_encoder.append(f"x['{safe_name}']")
if name in prop.required:
typeddict_encoder.append(f"x['{safe_name}']")
else:
typeddict_encoder.append(f"x.get('{safe_name}')")

if name == "$kind":
# If the field is a literal, the Python type-checker will complain
Expand Down Expand Up @@ -403,8 +426,9 @@ def extract_props(tpe: RiverType) -> list[dict[str, RiverType]]:
current_chunks.append("")

if base_model == "TypedDict":
binding = "x" if needs_binding else "_"
current_chunks = (
[f"encode_{prefix}: Callable[['{prefix}'], Any] = (lambda x: "]
[f"encode_{prefix}: Callable[['{prefix}'], Any] = (lambda {binding}: "]
+ typeddict_encoder
+ [")"]
+ current_chunks
Expand Down Expand Up @@ -449,7 +473,7 @@ def generate_river_client_module(

if schema_root.handshakeSchema is not None:
(handshake_type, handshake_chunks) = encode_type(
schema_root.handshakeSchema, "HandshakeSchema"
schema_root.handshakeSchema, "HandshakeSchema", "BaseModel"
)
chunks.extend(handshake_chunks)
else:
Expand Down Expand Up @@ -482,7 +506,9 @@ def __init__(self, client: river.Client[{handshake_type}]):
)
chunks.extend(input_chunks)
output_type, output_chunks = encode_type(
procedure.output, f"{schema_name.title()}{name.title()}Output"
procedure.output,
f"{schema_name.title()}{name.title()}Output",
"BaseModel",
)
chunks.extend(output_chunks)
if procedure.errors:
Expand Down Expand Up @@ -517,7 +543,20 @@ def __init__(self, client: river.Client[{handshake_type}]):
""".rstrip()

if typed_dict_inputs:
render_input_method = f"encode_{input_type}"
if is_literal(procedure.input):
render_input_method = "lambda x: x"
elif isinstance(
procedure.input, RiverConcreteType
) and procedure.input.type in ["array"]:
assert input_type.startswith(
"List["
) # in case we change to list[...]
_input_type_name = input_type[len("List[") : -len("]")]
render_input_method = (
f"lambda xs: [encode_{_input_type_name}(x) for x in xs]"
)
else:
render_input_method = f"encode_{input_type}"
else:
render_input_method = f"""\
lambda x: TypeAdapter({input_type})
Expand Down
40 changes: 40 additions & 0 deletions scripts/parity.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#!/usr/bin/env bash
#
# parity.sh: Generate Pydantic and TypedDict models and check for deep equality.
# This script expects that ai-infra is cloned alongside river-python.

set -e

scripts="$(dirname "$0")"
cd "${scripts}/.."

root="$(mktemp -d --tmpdir 'river-codegen-parity.XXX')"
mkdir "$root/src"

echo "Using $root" >&2

function cleanup {
if [ -z "${DEBUG}" ]; then
echo "Cleaning up..." >&2
rm -rfv "${root}" >&2
fi
}
trap "cleanup" 0 2 3 15

gen() {
fname="$1"; shift
name="$1"; shift
poetry run python -m replit_river.codegen \
client \
--output "${root}/src/${fname}" \
--client-name "${name}" \
../ai-infra/pkgs/pid2_client/src/schema/schema.json \
"$@"
}

gen tyd.py Pid2TypedDict --typed-dict-inputs
gen pyd.py Pid2Pydantic

PYTHONPATH="${root}/src:${scripts}"
poetry run bash -c "MYPYPATH='$PYTHONPATH' mypy -m parity.check_parity"
poetry run bash -c "PYTHONPATH='$PYTHONPATH' python -m parity.check_parity"
Loading

0 comments on commit ca9d552

Please sign in to comment.