Skip to content

Commit

Permalink
Initial implementation for simple types
Browse files Browse the repository at this point in the history
  • Loading branch information
simw committed Nov 1, 2023
1 parent 959c7eb commit 969b4d8
Show file tree
Hide file tree
Showing 6 changed files with 260 additions and 3 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
# pydantic-to-pyarrow

[![CI](https://github.com/simw/pydantic-to-pyarrow/actions/workflows/test.yml/badge.svg?event=push)](https://github.com/simw/pydantic-to-pyarrow/actions/workflows/test.yml)

A library to convert a pydantic model to a pyarrow schema
23 changes: 20 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,35 @@ target-version = ["py38"]

[tool.ruff]
select = [
"F", # pyflakes
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"C90", # Mccabe complexity
"I", # isort
"C", # flake8-comprehensions
"B", # flake8-bugbear
"N", # pep8-naming
"UP", # pyupgrade
"YTT", # flake8-2020
"ANN", # flake8-annotations
"ASYNC", # flake8-ASYNC
"S", # flake8-bandit
"BLE", # flake8-blind-except
"FBT", # flake8-boolean-trap
"B", # flake8-bugbear
"A", # flake8-builtins
"C4", # flake8-comprehensions
"PT", # flake8-pyteststyle
"PD", # pandas-vet
"PL", # pylint
"PERF", # perflint
]
ignore = []
line-length = 88
indent-width = 4
target-version = "py38"

[tool.ruff.per-file-ignores]
"tests/**/**" = ["S"] # Don't run bandit on tests (eg flagging on assert)

[tool.ruff.format]
quote-style = "double"
indent-style = "space"
Expand All @@ -78,6 +94,7 @@ branch = true

[tool.coverage.report]
precision = 2
show_missing = true
exclude_lines = [
'pragma: no cover',
'raise NotImplementedError',
Expand Down
4 changes: 4 additions & 0 deletions src/pydantic_to_pyarrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from .schema import SchemaCreationError, get_pyarrow_schema

__version__ = "0.0.1"

__all__ = [
"__version__",
"get_pyarrow_schema",
"SchemaCreationError",
]
73 changes: 73 additions & 0 deletions src/pydantic_to_pyarrow/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import datetime
import types
from typing import Any, Type, TypeVar, Union, cast

import pyarrow as pa # type: ignore
from pydantic import BaseModel, NaiveDatetime
from typing_extensions import get_args, get_origin

BaseModelType = TypeVar("BaseModelType", bound=BaseModel)


FIELD_MAP = {
str: pa.string(),
bool: pa.bool_(),
int: pa.int64(),
float: pa.float64(),
datetime.date: pa.date32(),
NaiveDatetime: pa.timestamp("ms", tz=None),
}


class SchemaCreationError(Exception):
"""Error when creating pyarrow schema."""


def _is_optional(field_type: type[Any]) -> bool:
origin = get_origin(field_type)
is_python_39_union = origin is Union
is_python_310_union = hasattr(types, "UnionType") and origin is types.UnionType

if not is_python_39_union and not is_python_310_union:
return False

return type(None) in get_args(field_type)


def _get_pyarrow_type(field_type: type[Any]) -> pa.DataType:
if field_type in FIELD_MAP:
return FIELD_MAP[field_type]

raise SchemaCreationError(f"Unknown type: {field_type}")


def get_pyarrow_schema(
pydantic_class: Type[BaseModelType],
) -> pa.Schema:
fields = []
for name, field_info in pydantic_class.model_fields.items():
field_type = field_info.annotation

if field_type is None:
# Not sure how to get here through pydantic, hence nocover
raise SchemaCreationError(
f"Missing type for field {name}"
) # pragma: no cover

try:
nullable = False
if _is_optional(field_type):
nullable = True
types_under_union = list(set(get_args(field_type)) - {type(None)})
# mypy infers field_type as type[Any] | None here, hence casting
field_type = cast(type[Any], types_under_union[0])

pa_field = _get_pyarrow_type(field_type)
except Exception as err: # noqa: BLE001 - ignore blind exception
raise SchemaCreationError(
f"Error processing field {name}: {field_type}, {err}"
) from err

fields.append(pa.field(name, pa_field, nullable=nullable))

return pa.schema(fields)
Empty file.
160 changes: 160 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import datetime
import tempfile
from pathlib import Path
from typing import Any, Deque, Dict, List, Optional, Tuple

import pyarrow as pa # type: ignore
import pyarrow.parquet as pq # type: ignore
import pytest
from pydantic import BaseModel, NaiveDatetime

from pydantic_to_pyarrow import SchemaCreationError, get_pyarrow_schema


def _write_pq_and_read(
objs: List[Dict[str, Any]], schema: pa.Schema
) -> Tuple[pa.Schema, List[Dict[str, Any]]]:
"""
This helper function takes a list of dictionaries, and transfers
them through -> pyarrow -> parquet -> pyarrow -> list of dictionaries,
returning the schema and the list of dictionaries.
In this way, it can be checked whether the data conversion has
affected either the schema or the data.
"""
tbl = pa.Table.from_pylist(objs, schema=schema)
with tempfile.TemporaryDirectory() as temp_dir:
path = Path(temp_dir) / "values.parquet"
pq.write_table(tbl, path)
new_tbl = pq.read_table(path)

new_objs = new_tbl.to_pylist()
return new_tbl.schema, new_objs


def test_some_types_dont_read_as_written_part1() -> None:
"""
The pyarrow timestampe with precision of seconds is not
supported by parquet files, only ms level precision.
"""
schema = pa.schema(
[
pa.field("a", pa.timestamp("s"), nullable=True),
]
)
objs = [{"a": datetime.datetime(2020, 1, 1)}]
new_schema, new_objs = _write_pq_and_read(objs, schema)
assert new_objs == objs
assert len(new_schema) == 1
assert new_schema[0] == pa.field("a", pa.timestamp("ms"), nullable=True)
with pytest.raises(AssertionError):
assert new_schema == schema


def test_some_types_dont_read_as_written_part2() -> None:
"""
While parquet files should correctly convert from the python timezone
aware datetime to an 'instant', it doesn't record the timezone in the
parquet file. Hence, when it's read back, we don't get exactly the
same datetime object.
"""
schema = pa.schema(
[
pa.field("a", pa.timestamp("ms"), nullable=True),
]
)
tz = datetime.timezone(datetime.timedelta(hours=5))
objs = [{"a": datetime.datetime(2020, 1, 1, 1, 0, 0, tzinfo=tz)}]
new_schema, new_objs = _write_pq_and_read(objs, schema)
assert new_schema == schema
# The tzinfo is lost, and the datetime is converted to UTC
# which pushes it back into the previous year.
assert new_objs[0]["a"] == datetime.datetime(2019, 12, 31, 20, 0, 0)
with pytest.raises(AssertionError):
assert new_objs == objs


def test_simple_types() -> None:
class SimpleModel(BaseModel):
a: str
b: bool
c: int
d: float

expected = pa.schema(
[
pa.field("a", pa.string(), nullable=False),
pa.field("b", pa.bool_(), nullable=False),
pa.field("c", pa.int64(), nullable=False),
pa.field("d", pa.float64(), nullable=False),
]
)

actual = get_pyarrow_schema(SimpleModel)
assert actual == expected

objs = [{"a": "a", "b": True, "c": 1, "d": 1.01}]
new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected
assert new_objs == objs


def test_unknown_type() -> None:
class SimpleModel(BaseModel):
a: Deque[int]

with pytest.raises(SchemaCreationError):
get_pyarrow_schema(SimpleModel)


def test_nullable_types() -> None:
class NullableModel(BaseModel):
a: str
b: Optional[str]
c: int
d: Optional[int]

expected = pa.schema(
[
pa.field("a", pa.string(), nullable=False),
pa.field("b", pa.string(), nullable=True),
pa.field("c", pa.int64(), nullable=False),
pa.field("d", pa.int64(), nullable=True),
]
)

actual = get_pyarrow_schema(NullableModel)
assert actual == expected

objs = [{"a": "a", "b": "b", "c": 1, "d": 1}]
new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected
assert new_objs == objs


def test_date_types_with_no_tz() -> None:
""" """

class DateModel(BaseModel):
a: datetime.date
b: NaiveDatetime

expected = pa.schema(
[
pa.field("a", pa.date32(), nullable=False),
pa.field("b", pa.timestamp("ms"), nullable=False),
]
)

actual = get_pyarrow_schema(DateModel)
assert actual == expected

objs = [
{
"a": datetime.date(2020, 1, 1),
"b": datetime.datetime(2020, 1, 1, 0, 0, 0),
}
]
new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected
assert new_objs == objs

0 comments on commit 969b4d8

Please sign in to comment.