Skip to content

Commit

Permalink
add support for dictionary field type as pyarrow map array (#13)
Browse files Browse the repository at this point in the history
* add support for dictionary field type as pyarrow map array

https://arrow.apache.org/docs/python/data.html#map-arrays

* add metadata and allow_losing_tz
  • Loading branch information
mae5357 authored May 23, 2024
1 parent adbfdb2 commit fe2cc1a
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
11 changes: 11 additions & 0 deletions src/pydantic_to_pyarrow/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,21 @@ def _get_annotated_type(
return _get_pyarrow_type(field_type, metadata, allow_losing_tz)


def _get_dict_type(
field_type: Type[Any], metadata: List[Any], allow_losing_tz: bool
) -> pa.DataType:
key_type, value_type = get_args(field_type)
return pa.map_(
_get_pyarrow_type(key_type, metadata, allow_losing_tz=allow_losing_tz),
_get_pyarrow_type(value_type, metadata, allow_losing_tz=allow_losing_tz),
)


FIELD_TYPES = {
Literal: _get_literal_type,
list: _get_list_type,
Annotated: _get_annotated_type,
dict: _get_dict_type,
}


Expand Down
25 changes: 25 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,3 +530,28 @@ class EnumModel(BaseModel):

with pytest.raises(SchemaCreationError):
get_pyarrow_schema(EnumModel)


def test_dict() -> None:
class DictModel(BaseModel):
foo: Dict[str, int]

expected = pa.schema(
[
pa.field("foo", pa.map_(pa.string(), pa.int64()), nullable=False),
]
)

objs = [
{"foo": {"a": 1, "b": 2}},
{"foo": {"c": 3, "d": 4, "e": 5}},
]

actual = get_pyarrow_schema(DictModel)
assert actual == expected

new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected

# pyarrow converts to tuples, need to convert back to dicts
assert objs == [{"foo": dict(t["foo"])} for t in new_objs]

0 comments on commit fe2cc1a

Please sign in to comment.