Skip to content

Commit

Permalink
Adding conversion of Enum type (#12)
Browse files Browse the repository at this point in the history
* Add str and int enums to possible pydnatic field types

* Bump version to 0.1.2

* Update README for implementing enum type
  • Loading branch information
simw authored Mar 5, 2024
1 parent 8eb2a7e commit adbfdb2
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 4 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ should manually specify the full schema.
Python / Pydantic | Pyarrow | Overflow
--- | --- | ---
str | pa.string() |
Literal[strings] | pa.dictionary(pa.int32(), pa.string())
Literal[strings] | pa.dictionary(pa.int32(), pa.string()) |
. | . | .
int | pa.int64() if no minimum constraint, pa.uint64() if minimum is zero | Yes, at 2^63 (for signed) or 2^64 (for unsigned)
Literal[ints] | pa.int64() | Yes, at 2^63
Literal[ints] | pa.int64() |
float | pa.float64() | Yes
decimal.Decimal | pa.decimal128 ONLY if supplying max_digits and decimal_places for pydantic field | Yes
. | . | .
Expand All @@ -64,6 +64,8 @@ pydantic.types.AwareDatetime | pa.timestamp("ms", tz=None) ONLY if param allow_l
Optional[...] | The pyarrow field is nullable |
Pydantic Model | pa.struct() |
List[...] | pa.list_(...) |
Enum of str | pa.dictionary(pa.int32(), pa.string()) |
Enum of int | pa.int64() |

## An Example

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "pydantic_to_pyarrow"
version = "0.1.1"
version = "0.1.2"
description = "Conversion from pydantic models to pyarrow schemas"
authors = ["Simon Wicks <[email protected]>"]
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion src/pydantic_to_pyarrow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .schema import SchemaCreationError, get_pyarrow_schema

__version__ = "0.1.1"
__version__ = "0.1.2"

__all__ = [
"__version__",
Expand Down
18 changes: 18 additions & 0 deletions src/pydantic_to_pyarrow/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import types
from decimal import Decimal
from enum import EnumMeta
from typing import Any, List, Literal, Optional, Type, TypeVar, Union, cast

import pyarrow as pa # type: ignore
Expand All @@ -9,6 +10,7 @@
from typing_extensions import Annotated, get_args, get_origin

BaseModelType = TypeVar("BaseModelType", bound=BaseModel)
EnumType = TypeVar("EnumType", bound=EnumMeta)


class SchemaCreationError(Exception):
Expand Down Expand Up @@ -116,6 +118,19 @@ def _get_annotated_type(
}


def _get_enum_type(field_type: Type[Any]) -> pa.DataType:
is_str = [isinstance(enum_value.value, str) for enum_value in field_type]
if all(is_str):
return pa.dictionary(pa.int32(), pa.string())

is_int = [isinstance(enum_value.value, int) for enum_value in field_type]
if all(is_int):
return pa.int64()

msg = "Enums only allowed if all str or all int"
raise SchemaCreationError(msg)


def _is_optional(field_type: Type[Any]) -> bool:
origin = get_origin(field_type)
is_python_39_union = origin is Union
Expand All @@ -141,6 +156,9 @@ def _get_pyarrow_type(
f"{field_type} only allowed if ok losing timezone information"
)

if isinstance(field_type, EnumMeta):
return _get_enum_type(field_type)

if field_type in TYPES_WITH_METADATA:
return TYPES_WITH_METADATA[field_type](metadata)

Expand Down
81 changes: 81 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import datetime
import tempfile
from decimal import Decimal
from enum import Enum, auto
from pathlib import Path
from typing import Any, Deque, Dict, List, Literal, Optional, Tuple

Expand Down Expand Up @@ -449,3 +450,83 @@ class ListModel(BaseModel):
new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected
assert new_objs == objs


def test_enum_str() -> None:
class MyEnum(Enum):
val1 = "val1"
val2 = "val2"
val3 = "val3"

class EnumModel(BaseModel):
a: MyEnum
b: List[MyEnum]
c: Optional[MyEnum]

expected = pa.schema(
[
pa.field("a", pa.dictionary(pa.int32(), pa.string()), nullable=False),
pa.field(
"b", pa.list_(pa.dictionary(pa.int32(), pa.string())), nullable=False
),
pa.field("c", pa.dictionary(pa.int32(), pa.string()), nullable=True),
]
)

actual = get_pyarrow_schema(EnumModel)
assert actual == expected

objs = [{"a": "val1", "b": ["val2", "val3"], "c": None}]
model = EnumModel.model_validate(objs[0])
assert model.a == MyEnum.val1
assert model.b == [MyEnum.val2, MyEnum.val3]
assert model.c is None

new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected
assert new_objs == objs


def test_enum_int() -> None:
class MyEnum(Enum):
val1 = 1
val2 = 2
val3 = auto()

class EnumModel(BaseModel):
a: MyEnum
b: List[MyEnum]
c: Optional[MyEnum]

expected = pa.schema(
[
pa.field("a", pa.int64(), nullable=False),
pa.field("b", pa.list_(pa.int64()), nullable=False),
pa.field("c", pa.int64(), nullable=True),
]
)

actual = get_pyarrow_schema(EnumModel)
assert actual == expected

objs = [{"a": 1, "b": [2, 3], "c": None}]
model = EnumModel.model_validate(objs[0])
assert model.a == MyEnum.val1
assert model.b == [MyEnum.val2, MyEnum.val3]
assert model.c is None

new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected
assert new_objs == objs


def test_enum_mixed() -> None:
class MyEnum(Enum):
val1 = 1
val2 = "val2"

class EnumModel(BaseModel):
a: MyEnum

with pytest.raises(SchemaCreationError):
get_pyarrow_schema(EnumModel)

0 comments on commit adbfdb2

Please sign in to comment.