-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial implementation for simple types
- Loading branch information
Showing
6 changed files
with
260 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,5 @@ | ||
# pydantic-to-pyarrow | ||
|
||
[![CI](https://github.com/simw/pydantic-to-pyarrow/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/simw/pydantic-to-pyarrow/actions/workflows/test.yml) | ||
|
||
A library to convert a pydantic model to a pyarrow schema |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,9 @@ | ||
from .schema import SchemaCreationError, get_pyarrow_schema | ||
|
||
__version__ = "0.0.1" | ||
|
||
__all__ = [ | ||
"__version__", | ||
"get_pyarrow_schema", | ||
"SchemaCreationError", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import datetime | ||
import types | ||
from typing import Any, Type, TypeVar, Union, cast | ||
|
||
import pyarrow as pa # type: ignore | ||
from pydantic import BaseModel, NaiveDatetime | ||
from typing_extensions import get_args, get_origin | ||
|
||
BaseModelType = TypeVar("BaseModelType", bound=BaseModel) | ||
|
||
|
||
FIELD_MAP = { | ||
str: pa.string(), | ||
bool: pa.bool_(), | ||
int: pa.int64(), | ||
float: pa.float64(), | ||
datetime.date: pa.date32(), | ||
NaiveDatetime: pa.timestamp("ms", tz=None), | ||
} | ||
|
||
|
||
class SchemaCreationError(Exception): | ||
"""Error when creating pyarrow schema.""" | ||
|
||
|
||
def _is_optional(field_type: type[Any]) -> bool: | ||
origin = get_origin(field_type) | ||
is_python_39_union = origin is Union | ||
is_python_310_union = hasattr(types, "UnionType") and origin is types.UnionType | ||
|
||
if not is_python_39_union and not is_python_310_union: | ||
return False | ||
|
||
return type(None) in get_args(field_type) | ||
|
||
|
||
def _get_pyarrow_type(field_type: type[Any]) -> pa.DataType: | ||
if field_type in FIELD_MAP: | ||
return FIELD_MAP[field_type] | ||
|
||
raise SchemaCreationError(f"Unknown type: {field_type}") | ||
|
||
|
||
def get_pyarrow_schema( | ||
pydantic_class: Type[BaseModelType], | ||
) -> pa.Schema: | ||
fields = [] | ||
for name, field_info in pydantic_class.model_fields.items(): | ||
field_type = field_info.annotation | ||
|
||
if field_type is None: | ||
# Not sure how to get here through pydantic, hence nocover | ||
raise SchemaCreationError( | ||
f"Missing type for field {name}" | ||
) # pragma: no cover | ||
|
||
try: | ||
nullable = False | ||
if _is_optional(field_type): | ||
nullable = True | ||
types_under_union = list(set(get_args(field_type)) - {type(None)}) | ||
# mypy infers field_type as type[Any] | None here, hence casting | ||
field_type = cast(type[Any], types_under_union[0]) | ||
|
||
pa_field = _get_pyarrow_type(field_type) | ||
except Exception as err: # noqa: BLE001 - ignore blind exception | ||
raise SchemaCreationError( | ||
f"Error processing field {name}: {field_type}, {err}" | ||
) from err | ||
|
||
fields.append(pa.field(name, pa_field, nullable=nullable)) | ||
|
||
return pa.schema(fields) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import datetime | ||
import tempfile | ||
from pathlib import Path | ||
from typing import Any, Deque, Dict, List, Optional, Tuple | ||
|
||
import pyarrow as pa # type: ignore | ||
import pyarrow.parquet as pq # type: ignore | ||
import pytest | ||
from pydantic import BaseModel, NaiveDatetime | ||
|
||
from pydantic_to_pyarrow import SchemaCreationError, get_pyarrow_schema | ||
|
||
|
||
def _write_pq_and_read( | ||
objs: List[Dict[str, Any]], schema: pa.Schema | ||
) -> Tuple[pa.Schema, List[Dict[str, Any]]]: | ||
""" | ||
This helper function takes a list of dictionaries, and transfers | ||
them through -> pyarrow -> parquet -> pyarrow -> list of dictionaries, | ||
returning the schema and the list of dictionaries. | ||
In this way, it can be checked whether the data conversion has | ||
affected either the schema or the data. | ||
""" | ||
tbl = pa.Table.from_pylist(objs, schema=schema) | ||
with tempfile.TemporaryDirectory() as temp_dir: | ||
path = Path(temp_dir) / "values.parquet" | ||
pq.write_table(tbl, path) | ||
new_tbl = pq.read_table(path) | ||
|
||
new_objs = new_tbl.to_pylist() | ||
return new_tbl.schema, new_objs | ||
|
||
|
||
def test_some_types_dont_read_as_written_part1() -> None: | ||
""" | ||
The pyarrow timestampe with precision of seconds is not | ||
supported by parquet files, only ms level precision. | ||
""" | ||
schema = pa.schema( | ||
[ | ||
pa.field("a", pa.timestamp("s"), nullable=True), | ||
] | ||
) | ||
objs = [{"a": datetime.datetime(2020, 1, 1)}] | ||
new_schema, new_objs = _write_pq_and_read(objs, schema) | ||
assert new_objs == objs | ||
assert len(new_schema) == 1 | ||
assert new_schema[0] == pa.field("a", pa.timestamp("ms"), nullable=True) | ||
with pytest.raises(AssertionError): | ||
assert new_schema == schema | ||
|
||
|
||
def test_some_types_dont_read_as_written_part2() -> None: | ||
""" | ||
While parquet files should correctly convert from the python timezone | ||
aware datetime to an 'instant', it doesn't record the timezone in the | ||
parquet file. Hence, when it's read back, we don't get exactly the | ||
same datetime object. | ||
""" | ||
schema = pa.schema( | ||
[ | ||
pa.field("a", pa.timestamp("ms"), nullable=True), | ||
] | ||
) | ||
tz = datetime.timezone(datetime.timedelta(hours=5)) | ||
objs = [{"a": datetime.datetime(2020, 1, 1, 1, 0, 0, tzinfo=tz)}] | ||
new_schema, new_objs = _write_pq_and_read(objs, schema) | ||
assert new_schema == schema | ||
# The tzinfo is lost, and the datetime is converted to UTC | ||
# which pushes it back into the previous year. | ||
assert new_objs[0]["a"] == datetime.datetime(2019, 12, 31, 20, 0, 0) | ||
with pytest.raises(AssertionError): | ||
assert new_objs == objs | ||
|
||
|
||
def test_simple_types() -> None: | ||
class SimpleModel(BaseModel): | ||
a: str | ||
b: bool | ||
c: int | ||
d: float | ||
|
||
expected = pa.schema( | ||
[ | ||
pa.field("a", pa.string(), nullable=False), | ||
pa.field("b", pa.bool_(), nullable=False), | ||
pa.field("c", pa.int64(), nullable=False), | ||
pa.field("d", pa.float64(), nullable=False), | ||
] | ||
) | ||
|
||
actual = get_pyarrow_schema(SimpleModel) | ||
assert actual == expected | ||
|
||
objs = [{"a": "a", "b": True, "c": 1, "d": 1.01}] | ||
new_schema, new_objs = _write_pq_and_read(objs, expected) | ||
assert new_schema == expected | ||
assert new_objs == objs | ||
|
||
|
||
def test_unknown_type() -> None: | ||
class SimpleModel(BaseModel): | ||
a: Deque[int] | ||
|
||
with pytest.raises(SchemaCreationError): | ||
get_pyarrow_schema(SimpleModel) | ||
|
||
|
||
def test_nullable_types() -> None: | ||
class NullableModel(BaseModel): | ||
a: str | ||
b: Optional[str] | ||
c: int | ||
d: Optional[int] | ||
|
||
expected = pa.schema( | ||
[ | ||
pa.field("a", pa.string(), nullable=False), | ||
pa.field("b", pa.string(), nullable=True), | ||
pa.field("c", pa.int64(), nullable=False), | ||
pa.field("d", pa.int64(), nullable=True), | ||
] | ||
) | ||
|
||
actual = get_pyarrow_schema(NullableModel) | ||
assert actual == expected | ||
|
||
objs = [{"a": "a", "b": "b", "c": 1, "d": 1}] | ||
new_schema, new_objs = _write_pq_and_read(objs, expected) | ||
assert new_schema == expected | ||
assert new_objs == objs | ||
|
||
|
||
def test_date_types_with_no_tz() -> None: | ||
""" """ | ||
|
||
class DateModel(BaseModel): | ||
a: datetime.date | ||
b: NaiveDatetime | ||
|
||
expected = pa.schema( | ||
[ | ||
pa.field("a", pa.date32(), nullable=False), | ||
pa.field("b", pa.timestamp("ms"), nullable=False), | ||
] | ||
) | ||
|
||
actual = get_pyarrow_schema(DateModel) | ||
assert actual == expected | ||
|
||
objs = [ | ||
{ | ||
"a": datetime.date(2020, 1, 1), | ||
"b": datetime.datetime(2020, 1, 1, 0, 0, 0), | ||
} | ||
] | ||
new_schema, new_objs = _write_pq_and_read(objs, expected) | ||
assert new_schema == expected | ||
assert new_objs == objs |