Skip to content

Commit

Permalink
Partition Evolution Support
Browse files Browse the repository at this point in the history
  • Loading branch information
amogh-jahagirdar committed Jan 26, 2024
1 parent 0f08806 commit a40098c
Show file tree
Hide file tree
Showing 3 changed files with 781 additions and 10 deletions.
129 changes: 120 additions & 9 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,9 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
)
from abc import ABC, abstractmethod
from functools import cached_property, singledispatch
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar

from pydantic import (
BeforeValidator,
Expand All @@ -34,7 +29,18 @@
from typing_extensions import Annotated

from pyiceberg.schema import Schema
from pyiceberg.transforms import Transform, parse_transform
from pyiceberg.transforms import (
BucketTransform,
DayTransform,
HourTransform,
IdentityTransform,
Transform,
TruncateTransform,
UnknownTransform,
VoidTransform,
YearTransform,
parse_transform,
)
from pyiceberg.typedef import IcebergBaseModel
from pyiceberg.types import NestedField, StructType

Expand Down Expand Up @@ -215,3 +221,108 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre
)
)
return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID)


T = TypeVar("T")


class PartitionSpecVisitor(Generic[T], ABC):
@abstractmethod
def identity(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit identity partition field."""

@abstractmethod
def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> T:
"""Visit bucket partition field."""

@abstractmethod
def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> T:
"""Visit truncate partition field."""

@abstractmethod
def year(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit year partition field."""

@abstractmethod
def month(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit month partition field."""

@abstractmethod
def day(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit day partition field."""

@abstractmethod
def hour(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit hour partition field."""

@abstractmethod
def always_null(self, field_id: int, source_name: str, source_id: int) -> T:
"""Visit void partition field."""

@abstractmethod
def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> T:
"""Visit unknown partition field."""
raise ValueError(f"Unknown transform is not supported: {transform}")


class _PartitionNameGenerator(PartitionSpecVisitor[str]):
def identity(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name

def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> str:
return f"{source_name}_bucket_{num_buckets}"

def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> str:
return source_name + "_trunc_" + str(width)

def year(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name + "_year"

def month(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name + "_month"

def day(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name + "_day"

def hour(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name + "_hour"

def always_null(self, field_id: int, source_name: str, source_id: int) -> str:
return source_name + "_null"

def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> str:
return super().unknown(field_id, source_name, source_id, transform)


R = TypeVar("R")


@singledispatch
def _visit(spec: PartitionSpec, schema: Schema, visitor: PartitionSpecVisitor[R]) -> List[R]:
return [_visit_partition_field(schema, field, visitor) for field in spec.fields]


def _visit_partition_field(schema: Schema, field: PartitionField, visitor: PartitionSpecVisitor[R]) -> R:
source_name = schema.find_column_name(field.source_id)
if not source_name:
raise ValueError(f"Could not find field with id {field.source_id}")

transform = field.transform
if isinstance(transform, IdentityTransform):
return visitor.identity(field.field_id, source_name, field.source_id)
elif isinstance(transform, BucketTransform):
return visitor.bucket(field.field_id, source_name, field.source_id, transform.num_buckets)
elif isinstance(transform, TruncateTransform):
return visitor.truncate(field.field_id, source_name, field.source_id, transform.width)
elif isinstance(transform, DayTransform):
return visitor.day(field.field_id, source_name, field.source_id)
elif isinstance(transform, HourTransform):
return visitor.hour(field.field_id, source_name, field.source_id)
elif isinstance(transform, YearTransform):
return visitor.year(field.field_id, source_name, field.source_id)
elif isinstance(transform, VoidTransform):
return visitor.always_null(field.field_id, source_name, field.source_id)
elif isinstance(transform, UnknownTransform):
return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform))
else:
raise ValueError(f"Unknown transform {transform}")
Loading

0 comments on commit a40098c

Please sign in to comment.