Skip to content

Commit

Permalink
Construction of filenames for partitioned writes (apache#453)
Browse files Browse the repository at this point in the history
* PartitionKey Class And Tests

* fix linting; add decimal input transform test

* fix bool to path lower case; fix timestamptz tests; other pr comments

* clean up

* add uuid partition type

* clean up; rename ambiguous function name
  • Loading branch information
jqin61 authored and hpal committed Mar 1, 2024
1 parent a5eeac4 commit 681fed3
Show file tree
Hide file tree
Showing 5 changed files with 925 additions and 52 deletions.
101 changes: 98 additions & 3 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,21 @@
# under the License.
from __future__ import annotations

import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import date, datetime
from functools import cached_property, singledispatch
from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar
from typing import (
Any,
Dict,
Generic,
List,
Optional,
Tuple,
TypeVar,
)
from urllib.parse import quote

from pydantic import (
BeforeValidator,
Expand All @@ -41,8 +53,18 @@
YearTransform,
parse_transform,
)
from pyiceberg.typedef import IcebergBaseModel
from pyiceberg.types import NestedField, StructType
from pyiceberg.typedef import IcebergBaseModel, Record
from pyiceberg.types import (
DateType,
IcebergType,
NestedField,
PrimitiveType,
StructType,
TimestampType,
TimestamptzType,
UUIDType,
)
from pyiceberg.utils.datetime import date_to_days, datetime_to_micros

INITIAL_PARTITION_SPEC_ID = 0
PARTITION_FIELD_ID_START: int = 1000
Expand Down Expand Up @@ -199,6 +221,23 @@ def partition_type(self, schema: Schema) -> StructType:
nested_fields.append(NestedField(field.field_id, field.name, result_type, required=False))
return StructType(*nested_fields)

def partition_to_path(self, data: Record, schema: Schema) -> str:
partition_type = self.partition_type(schema)
field_types = partition_type.fields

field_strs = []
value_strs = []
for pos, value in enumerate(data.record_fields()):
partition_field = self.fields[pos]
value_str = partition_field.transform.to_human_string(field_types[pos].field_type, value=value)

value_str = quote(value_str, safe='')
value_strs.append(value_str)
field_strs.append(partition_field.name)

path = "/".join([field_str + "=" + value_str for field_str, value_str in zip(field_strs, value_strs)])
return path


UNPARTITIONED_PARTITION_SPEC = PartitionSpec(spec_id=0)

Expand Down Expand Up @@ -326,3 +365,59 @@ def _visit_partition_field(schema: Schema, field: PartitionField, visitor: Parti
return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform))
else:
raise ValueError(f"Unknown transform {transform}")


@dataclass(frozen=True)
class PartitionFieldValue:
field: PartitionField
value: Any


@dataclass(frozen=True)
class PartitionKey:
raw_partition_field_values: List[PartitionFieldValue]
partition_spec: PartitionSpec
schema: Schema

@cached_property
def partition(self) -> Record: # partition key transformed with iceberg internal representation as input
iceberg_typed_key_values = {}
for raw_partition_field_value in self.raw_partition_field_values:
partition_fields = self.partition_spec.source_id_to_fields_map[raw_partition_field_value.field.source_id]
if len(partition_fields) != 1:
raise ValueError("partition_fields must contain exactly one field.")
partition_field = partition_fields[0]
iceberg_type = self.schema.find_field(name_or_id=raw_partition_field_value.field.source_id).field_type
iceberg_typed_value = _to_partition_representation(iceberg_type, raw_partition_field_value.value)
transformed_value = partition_field.transform.transform(iceberg_type)(iceberg_typed_value)
iceberg_typed_key_values[partition_field.name] = transformed_value
return Record(**iceberg_typed_key_values)

def to_path(self) -> str:
return self.partition_spec.partition_to_path(self.partition, self.schema)


@singledispatch
def _to_partition_representation(type: IcebergType, value: Any) -> Any:
return TypeError(f"Unsupported partition field type: {type}")


@_to_partition_representation.register(TimestampType)
@_to_partition_representation.register(TimestamptzType)
def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]:
return datetime_to_micros(value) if value is not None else None


@_to_partition_representation.register(DateType)
def _(type: IcebergType, value: Optional[date]) -> Optional[int]:
return date_to_days(value) if value is not None else None


@_to_partition_representation.register(UUIDType)
def _(type: IcebergType, value: Optional[uuid.UUID]) -> Optional[str]:
return str(value) if value is not None else None


@_to_partition_representation.register(PrimitiveType)
def _(type: IcebergType, value: Optional[Any]) -> Optional[Any]:
return value
5 changes: 5 additions & 0 deletions pyiceberg/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,11 @@ def _(value: int, _type: IcebergType) -> str:
return _int_to_human_string(_type, value)


@_human_string.register(bool)
def _(value: bool, _type: IcebergType) -> str:
return str(value).lower()


@singledispatch
def _int_to_human_string(_type: IcebergType, value: int) -> str:
return str(value)
Expand Down
51 changes: 50 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@
import boto3
import pytest
from moto import mock_aws
from pyspark.sql import SparkSession

from pyiceberg import schema
from pyiceberg.catalog import Catalog
from pyiceberg.catalog import Catalog, load_catalog
from pyiceberg.catalog.noop import NoopCatalog
from pyiceberg.expressions import BoundReference
from pyiceberg.io import (
Expand Down Expand Up @@ -1925,3 +1926,51 @@ def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table:
@pytest.fixture
def bound_reference_str() -> BoundReference[str]:
return BoundReference(field=NestedField(1, "field", StringType(), required=False), accessor=Accessor(position=0, inner=None))


@pytest.fixture(scope="session")
def session_catalog() -> Catalog:
return load_catalog(
"local",
**{
"type": "rest",
"uri": "http://localhost:8181",
"s3.endpoint": "http://localhost:9000",
"s3.access-key-id": "admin",
"s3.secret-access-key": "password",
},
)


@pytest.fixture(scope="session")
def spark() -> SparkSession:
import importlib.metadata
import os

spark_version = ".".join(importlib.metadata.version("pyspark").split(".")[:2])
scala_version = "2.12"
iceberg_version = "1.4.3"

os.environ["PYSPARK_SUBMIT_ARGS"] = (
f"--packages org.apache.iceberg:iceberg-spark-runtime-{spark_version}_{scala_version}:{iceberg_version},"
f"org.apache.iceberg:iceberg-aws-bundle:{iceberg_version} pyspark-shell"
)
os.environ["AWS_REGION"] = "us-east-1"
os.environ["AWS_ACCESS_KEY_ID"] = "admin"
os.environ["AWS_SECRET_ACCESS_KEY"] = "password"

spark = (
SparkSession.builder.appName("PyIceberg integration test")
.config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
.config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog")
.config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog")
.config("spark.sql.catalog.integration.uri", "http://localhost:8181")
.config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO")
.config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/")
.config("spark.sql.catalog.integration.s3.endpoint", "http://localhost:9000")
.config("spark.sql.catalog.integration.s3.path-style-access", "true")
.config("spark.sql.defaultCatalog", "integration")
.getOrCreate()
)

return spark
Loading

0 comments on commit 681fed3

Please sign in to comment.