Skip to content

Commit

Permalink
Support Appends with TimeTransform Partitions (apache#784)
Browse files Browse the repository at this point in the history
* checkpoint

* checkpoint2

* todo: sort with pyarrow_transform vals

* checkpoint

* checkpoint

* fix

* tests

* more tests

* adopt review feedback

* comment

* checkpoint

* checkpoint2

* todo: sort with pyarrow_transform vals

* checkpoint

* checkpoint

* fix

* tests

* more tests

* adopt review feedback

* comment

* rebase
  • Loading branch information
sungwy authored May 31, 2024
1 parent 20f6afd commit 65a03d2
Show file tree
Hide file tree
Showing 6 changed files with 392 additions and 54 deletions.
2 changes: 1 addition & 1 deletion pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def partition(self) -> Record: # partition key transformed with iceberg interna
for raw_partition_field_value in self.raw_partition_field_values:
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
if len(partition_fields) != 1:
raise ValueError("partition_fields must contain exactly one field.")
raise ValueError(f"Cannot have redundant partitions: {partition_fields}")
partition_field = partition_fields[0]
iceberg_typed_key_values[partition_field.name] = partition_record_value(
partition_field=partition_field,
Expand Down
67 changes: 29 additions & 38 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,10 +392,11 @@ def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT)
if not isinstance(df, pa.Table):
raise ValueError(f"Expected PyArrow table, got: {df}")

supported_transforms = {IdentityTransform}
if not all(type(field.transform) in supported_transforms for field in self.table_metadata.spec().fields):
if unsupported_partitions := [
field for field in self.table_metadata.spec().fields if not field.transform.supports_pyarrow_transform
]:
raise ValueError(
f"All transforms are not supported, expected: {supported_transforms}, but get: {[str(field) for field in self.table_metadata.spec().fields if field.transform not in supported_transforms]}."
f"Not all partition types are supported for writes. Following partitions cannot be written using pyarrow: {unsupported_partitions}."
)

_check_schema_compatible(self._table.schema(), other_schema=df.schema)
Expand Down Expand Up @@ -3643,33 +3644,6 @@ class TablePartition:
arrow_table_partition: pa.Table


def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]:
order = "ascending" if not reverse else "descending"
null_placement = "at_start" if reverse else "at_end"
return {"sort_keys": [(column_name, order) for column_name in partition_columns], "null_placement": null_placement}


def group_by_partition_scheme(arrow_table: pa.Table, partition_columns: list[str]) -> pa.Table:
"""Given a table, sort it by current partition scheme."""
# only works for identity for now
sort_options = _get_partition_sort_order(partition_columns, reverse=False)
sorted_arrow_table = arrow_table.sort_by(sorting=sort_options["sort_keys"], null_placement=sort_options["null_placement"])
return sorted_arrow_table


def get_partition_columns(
spec: PartitionSpec,
schema: Schema,
) -> list[str]:
partition_cols = []
for partition_field in spec.fields:
column_name = schema.find_column_name(partition_field.source_id)
if not column_name:
raise ValueError(f"{partition_field=} could not be found in {schema}.")
partition_cols.append(column_name)
return partition_cols


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
Expand Down Expand Up @@ -3724,13 +3698,30 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
"""
import pyarrow as pa

partition_columns = get_partition_columns(spec=spec, schema=schema)
arrow_table = group_by_partition_scheme(arrow_table, partition_columns)

reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True)
reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist()

slice_instructions: list[dict[str, Any]] = []
partition_columns: List[Tuple[PartitionField, NestedField]] = [
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
]
partition_values_table = pa.table({
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
for partition, field in partition_columns
})

# Sort by partitions
sort_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
null_placement="at_end",
).to_pylist()
arrow_table = arrow_table.take(sort_indices)

# Get slice_instructions to group by partitions
partition_values_table = partition_values_table.take(sort_indices)
reversed_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
null_placement="at_start",
).to_pylist()
slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
Expand All @@ -3741,6 +3732,6 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: list[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
table_partitions: List[TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)

return table_partitions
99 changes: 98 additions & 1 deletion pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from abc import ABC, abstractmethod
from enum import IntEnum
from functools import singledispatch
from typing import Any, Callable, Generic, Optional, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Generic, Optional, TypeVar
from typing import Literal as LiteralType
from uuid import UUID

Expand Down Expand Up @@ -82,6 +82,9 @@
from pyiceberg.utils.parsing import ParseNumberFromBrackets
from pyiceberg.utils.singleton import Singleton

if TYPE_CHECKING:
import pyarrow as pa

S = TypeVar("S")
T = TypeVar("T")

Expand Down Expand Up @@ -175,6 +178,13 @@ def __eq__(self, other: Any) -> bool:
return self.root == other.root
return False

@property
def supports_pyarrow_transform(self) -> bool:
return False

@abstractmethod
def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]": ...


class BucketTransform(Transform[S, int]):
"""Base Transform class to transform a value into a bucket partition value.
Expand Down Expand Up @@ -290,6 +300,9 @@ def __repr__(self) -> str:
"""Return the string representation of the BucketTransform class."""
return f"BucketTransform(num_buckets={self._num_buckets})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()


class TimeResolution(IntEnum):
YEAR = 6
Expand Down Expand Up @@ -349,6 +362,10 @@ def dedup_name(self) -> str:
def preserves_order(self) -> bool:
return True

@property
def supports_pyarrow_transform(self) -> bool:
return True


class YearTransform(TimeTransform[S]):
"""Transforms a datetime value into a year value.
Expand Down Expand Up @@ -391,6 +408,21 @@ def __repr__(self) -> str:
"""Return the string representation of the YearTransform class."""
return "YearTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = datetime.EPOCH_DATE
elif isinstance(source, TimestampType):
epoch = datetime.EPOCH_TIMESTAMP
elif isinstance(source, TimestamptzType):
epoch = datetime.EPOCH_TIMESTAMPTZ
else:
raise ValueError(f"Cannot apply year transform for type: {source}")

return lambda v: pc.years_between(pa.scalar(epoch), v) if v is not None else None


class MonthTransform(TimeTransform[S]):
"""Transforms a datetime value into a month value.
Expand Down Expand Up @@ -433,6 +465,27 @@ def __repr__(self) -> str:
"""Return the string representation of the MonthTransform class."""
return "MonthTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = datetime.EPOCH_DATE
elif isinstance(source, TimestampType):
epoch = datetime.EPOCH_TIMESTAMP
elif isinstance(source, TimestamptzType):
epoch = datetime.EPOCH_TIMESTAMPTZ
else:
raise ValueError(f"Cannot apply month transform for type: {source}")

def month_func(v: pa.Array) -> pa.Array:
return pc.add(
pc.multiply(pc.years_between(pa.scalar(epoch), v), pa.scalar(12)),
pc.add(pc.month(v), pa.scalar(-1)),
)

return lambda v: month_func(v) if v is not None else None


class DayTransform(TimeTransform[S]):
"""Transforms a datetime value into a day value.
Expand Down Expand Up @@ -478,6 +531,21 @@ def __repr__(self) -> str:
"""Return the string representation of the DayTransform class."""
return "DayTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, DateType):
epoch = datetime.EPOCH_DATE
elif isinstance(source, TimestampType):
epoch = datetime.EPOCH_TIMESTAMP
elif isinstance(source, TimestamptzType):
epoch = datetime.EPOCH_TIMESTAMPTZ
else:
raise ValueError(f"Cannot apply day transform for type: {source}")

return lambda v: pc.days_between(pa.scalar(epoch), v) if v is not None else None


class HourTransform(TimeTransform[S]):
"""Transforms a datetime value into a hour value.
Expand Down Expand Up @@ -515,6 +583,19 @@ def __repr__(self) -> str:
"""Return the string representation of the HourTransform class."""
return "HourTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
import pyarrow as pa
import pyarrow.compute as pc

if isinstance(source, TimestampType):
epoch = datetime.EPOCH_TIMESTAMP
elif isinstance(source, TimestamptzType):
epoch = datetime.EPOCH_TIMESTAMPTZ
else:
raise ValueError(f"Cannot apply hour transform for type: {source}")

return lambda v: pc.hours_between(pa.scalar(epoch), v) if v is not None else None


def _base64encode(buffer: bytes) -> str:
"""Convert bytes to base64 string."""
Expand Down Expand Up @@ -585,6 +666,13 @@ def __repr__(self) -> str:
"""Return the string representation of the IdentityTransform class."""
return "IdentityTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
return lambda v: v

@property
def supports_pyarrow_transform(self) -> bool:
return True


class TruncateTransform(Transform[S, S]):
"""A transform for truncating a value to a specified width.
Expand Down Expand Up @@ -725,6 +813,9 @@ def __repr__(self) -> str:
"""Return the string representation of the TruncateTransform class."""
return f"TruncateTransform(width={self._width})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()


@singledispatch
def _human_string(value: Any, _type: IcebergType) -> str:
Expand Down Expand Up @@ -807,6 +898,9 @@ def __repr__(self) -> str:
"""Return the string representation of the UnknownTransform class."""
return f"UnknownTransform(transform={repr(self._transform)})"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()


class VoidTransform(Transform[S, None], Singleton):
"""A transform that always returns None."""
Expand Down Expand Up @@ -835,6 +929,9 @@ def __repr__(self) -> str:
"""Return the string representation of the VoidTransform class."""
return "VoidTransform()"

def pyarrow_transform(self, source: IcebergType) -> "Callable[[pa.Array], pa.Array]":
raise NotImplementedError()


def _truncate_number(
name: str, pred: BoundLiteralPredicate[L], transform: Callable[[Optional[L]], Optional[L]]
Expand Down
43 changes: 43 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2158,3 +2158,46 @@ def arrow_table_with_only_nulls(pa_schema: "pa.Schema") -> "pa.Table":
import pyarrow as pa

return pa.Table.from_pylist([{}, {}], schema=pa_schema)


@pytest.fixture(scope="session")
def arrow_table_date_timestamps() -> "pa.Table":
"""Pyarrow table with only date, timestamp and timestamptz values."""
import pyarrow as pa

return pa.Table.from_pydict(
{
"date": [date(2023, 12, 31), date(2024, 1, 1), date(2024, 1, 31), date(2024, 2, 1), date(2024, 2, 1), None],
"timestamp": [
datetime(2023, 12, 31, 0, 0, 0),
datetime(2024, 1, 1, 0, 0, 0),
datetime(2024, 1, 31, 0, 0, 0),
datetime(2024, 2, 1, 0, 0, 0),
datetime(2024, 2, 1, 6, 0, 0),
None,
],
"timestamptz": [
datetime(2023, 12, 31, 0, 0, 0, tzinfo=timezone.utc),
datetime(2024, 1, 1, 0, 0, 0, tzinfo=timezone.utc),
datetime(2024, 1, 31, 0, 0, 0, tzinfo=timezone.utc),
datetime(2024, 2, 1, 0, 0, 0, tzinfo=timezone.utc),
datetime(2024, 2, 1, 6, 0, 0, tzinfo=timezone.utc),
None,
],
},
schema=pa.schema([
("date", pa.date32()),
("timestamp", pa.timestamp(unit="us")),
("timestamptz", pa.timestamp(unit="us", tz="UTC")),
]),
)


@pytest.fixture(scope="session")
def arrow_table_date_timestamps_schema() -> Schema:
"""Pyarrow table Schema with only date, timestamp and timestamptz values."""
return Schema(
NestedField(field_id=1, name="date", field_type=DateType(), required=False),
NestedField(field_id=2, name="timestamp", field_type=TimestampType(), required=False),
NestedField(field_id=3, name="timestamptz", field_type=TimestamptzType(), required=False),
)
Loading

0 comments on commit 65a03d2

Please sign in to comment.