diff --git a/src/pydantic_to_pyarrow/schema.py b/src/pydantic_to_pyarrow/schema.py index c52f7a2..ba80ec3 100644 --- a/src/pydantic_to_pyarrow/schema.py +++ b/src/pydantic_to_pyarrow/schema.py @@ -111,10 +111,21 @@ def _get_annotated_type( return _get_pyarrow_type(field_type, metadata, allow_losing_tz) +def _get_dict_type( + field_type: Type[Any], metadata: List[Any], allow_losing_tz: bool +) -> pa.DataType: + key_type, value_type = get_args(field_type) + return pa.map_( + _get_pyarrow_type(key_type, metadata, allow_losing_tz=allow_losing_tz), + _get_pyarrow_type(value_type, metadata, allow_losing_tz=allow_losing_tz), + ) + + FIELD_TYPES = { Literal: _get_literal_type, list: _get_list_type, Annotated: _get_annotated_type, + dict: _get_dict_type, } diff --git a/tests/test_schema.py b/tests/test_schema.py index 231a41c..d287589 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -530,3 +530,28 @@ class EnumModel(BaseModel): with pytest.raises(SchemaCreationError): get_pyarrow_schema(EnumModel) + + +def test_dict() -> None: + class DictModel(BaseModel): + foo: Dict[str, int] + + expected = pa.schema( + [ + pa.field("foo", pa.map_(pa.string(), pa.int64()), nullable=False), + ] + ) + + objs = [ + {"foo": {"a": 1, "b": 2}}, + {"foo": {"c": 3, "d": 4, "e": 5}}, + ] + + actual = get_pyarrow_schema(DictModel) + assert actual == expected + + new_schema, new_objs = _write_pq_and_read(objs, expected) + assert new_schema == expected + + # pyarrow converts to tuples, need to convert back to dicts + assert objs == [{"foo": dict(t["foo"])} for t in new_objs]