diff --git a/README.md b/README.md index fadb640..f0d1200 100644 --- a/README.md +++ b/README.md @@ -48,10 +48,10 @@ should manually specify the full schema. Python / Pydantic | Pyarrow | Overflow --- | --- | --- str | pa.string() | -Literal[strings] | pa.dictionary(pa.int32(), pa.string()) +Literal[strings] | pa.dictionary(pa.int32(), pa.string()) | . | . | . int | pa.int64() if no minimum constraint, pa.uint64() if minimum is zero | Yes, at 2^63 (for signed) or 2^64 (for unsigned) -Literal[ints] | pa.int64() | Yes, at 2^63 +Literal[ints] | pa.int64() | float | pa.float64() | Yes decimal.Decimal | pa.decimal128 ONLY if supplying max_digits and decimal_places for pydantic field | Yes . | . | . @@ -64,6 +64,8 @@ pydantic.types.AwareDatetime | pa.timestamp("ms", tz=None) ONLY if param allow_l Optional[...] | The pyarrow field is nullable | Pydantic Model | pa.struct() | List[...] | pa.list_(...) | +Enum of str | pa.dictionary(pa.int32(), pa.string()) | +Enum of int | pa.int64() | ## An Example diff --git a/pyproject.toml b/pyproject.toml index 76db441..86996ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "pydantic_to_pyarrow" -version = "0.1.1" +version = "0.1.2" description = "Conversion from pydantic models to pyarrow schemas" authors = ["Simon Wicks "] readme = "README.md" diff --git a/src/pydantic_to_pyarrow/__init__.py b/src/pydantic_to_pyarrow/__init__.py index 4fcf37e..f44a69e 100644 --- a/src/pydantic_to_pyarrow/__init__.py +++ b/src/pydantic_to_pyarrow/__init__.py @@ -1,6 +1,6 @@ from .schema import SchemaCreationError, get_pyarrow_schema -__version__ = "0.1.1" +__version__ = "0.1.2" __all__ = [ "__version__", diff --git a/src/pydantic_to_pyarrow/schema.py b/src/pydantic_to_pyarrow/schema.py index d5bb101..c52f7a2 100644 --- a/src/pydantic_to_pyarrow/schema.py +++ b/src/pydantic_to_pyarrow/schema.py @@ -1,6 +1,7 @@ import datetime import types from decimal import Decimal +from enum import EnumMeta from typing import Any, List, Literal, Optional, Type, TypeVar, Union, cast import pyarrow as pa # type: ignore @@ -9,6 +10,7 @@ from typing_extensions import Annotated, get_args, get_origin BaseModelType = TypeVar("BaseModelType", bound=BaseModel) +EnumType = TypeVar("EnumType", bound=EnumMeta) class SchemaCreationError(Exception): @@ -116,6 +118,19 @@ def _get_annotated_type( } +def _get_enum_type(field_type: Type[Any]) -> pa.DataType: + is_str = [isinstance(enum_value.value, str) for enum_value in field_type] + if all(is_str): + return pa.dictionary(pa.int32(), pa.string()) + + is_int = [isinstance(enum_value.value, int) for enum_value in field_type] + if all(is_int): + return pa.int64() + + msg = "Enums only allowed if all str or all int" + raise SchemaCreationError(msg) + + def _is_optional(field_type: Type[Any]) -> bool: origin = get_origin(field_type) is_python_39_union = origin is Union @@ -141,6 +156,9 @@ def _get_pyarrow_type( f"{field_type} only allowed if ok losing timezone information" ) + if isinstance(field_type, EnumMeta): + return _get_enum_type(field_type) + if field_type in TYPES_WITH_METADATA: return TYPES_WITH_METADATA[field_type](metadata) diff --git a/tests/test_schema.py b/tests/test_schema.py index e37caae..231a41c 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -1,6 +1,7 @@ import datetime import tempfile from decimal import Decimal +from enum import Enum, auto from pathlib import Path from typing import Any, Deque, Dict, List, Literal, Optional, Tuple @@ -449,3 +450,83 @@ class ListModel(BaseModel): new_schema, new_objs = _write_pq_and_read(objs, expected) assert new_schema == expected assert new_objs == objs + + +def test_enum_str() -> None: + class MyEnum(Enum): + val1 = "val1" + val2 = "val2" + val3 = "val3" + + class EnumModel(BaseModel): + a: MyEnum + b: List[MyEnum] + c: Optional[MyEnum] + + expected = pa.schema( + [ + pa.field("a", pa.dictionary(pa.int32(), pa.string()), nullable=False), + pa.field( + "b", pa.list_(pa.dictionary(pa.int32(), pa.string())), nullable=False + ), + pa.field("c", pa.dictionary(pa.int32(), pa.string()), nullable=True), + ] + ) + + actual = get_pyarrow_schema(EnumModel) + assert actual == expected + + objs = [{"a": "val1", "b": ["val2", "val3"], "c": None}] + model = EnumModel.model_validate(objs[0]) + assert model.a == MyEnum.val1 + assert model.b == [MyEnum.val2, MyEnum.val3] + assert model.c is None + + new_schema, new_objs = _write_pq_and_read(objs, expected) + assert new_schema == expected + assert new_objs == objs + + +def test_enum_int() -> None: + class MyEnum(Enum): + val1 = 1 + val2 = 2 + val3 = auto() + + class EnumModel(BaseModel): + a: MyEnum + b: List[MyEnum] + c: Optional[MyEnum] + + expected = pa.schema( + [ + pa.field("a", pa.int64(), nullable=False), + pa.field("b", pa.list_(pa.int64()), nullable=False), + pa.field("c", pa.int64(), nullable=True), + ] + ) + + actual = get_pyarrow_schema(EnumModel) + assert actual == expected + + objs = [{"a": 1, "b": [2, 3], "c": None}] + model = EnumModel.model_validate(objs[0]) + assert model.a == MyEnum.val1 + assert model.b == [MyEnum.val2, MyEnum.val3] + assert model.c is None + + new_schema, new_objs = _write_pq_and_read(objs, expected) + assert new_schema == expected + assert new_objs == objs + + +def test_enum_mixed() -> None: + class MyEnum(Enum): + val1 = 1 + val2 = "val2" + + class EnumModel(BaseModel): + a: MyEnum + + with pytest.raises(SchemaCreationError): + get_pyarrow_schema(EnumModel)