Skip to content

Commit

Permalink
fix linting; add decimal input transform test
Browse files Browse the repository at this point in the history
  • Loading branch information
jqin61 committed Feb 21, 2024
1 parent 132599b commit 7fcf75a
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 31 deletions.
19 changes: 9 additions & 10 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
Optional,
Tuple,
)
from urllib.parse import quote

from pydantic import (
BeforeValidator,
Expand Down Expand Up @@ -208,7 +209,6 @@ def partition_to_path(self, data: Record, schema: Schema) -> str:

partition_field = self.fields[pos] # partition field
value_str = partition_field.transform.to_human_string(field_types[pos].field_type, value=value)
from urllib.parse import quote

value_str = quote(value_str, safe='')
value_strs.append(value_str)
Expand Down Expand Up @@ -250,10 +250,9 @@ class PartitionFieldValue:

@dataclass(frozen=True)
class PartitionKey:
raw_partition_field_values: list[PartitionFieldValue]
raw_partition_field_values: List[PartitionFieldValue]
partition_spec: PartitionSpec
schema: Schema
from functools import cached_property

@cached_property
def partition(self) -> Record: # partition key in iceberg type
Expand All @@ -263,8 +262,8 @@ def partition(self) -> Record: # partition key in iceberg type
assert len(partition_fields) == 1
partition_field = partition_fields[0]
iceberg_type = self.schema.find_field(name_or_id=raw_partition_field_value.field.source_id).field_type
_iceberg_typed_value = iceberg_typed_value(iceberg_type, raw_partition_field_value.value)
transformed_value = partition_field.transform.transform(iceberg_type)(_iceberg_typed_value)
iceberg_typed_value = _to_iceberg_type(iceberg_type, raw_partition_field_value.value)
transformed_value = partition_field.transform.transform(iceberg_type)(iceberg_typed_value)
iceberg_typed_key_values[partition_field.name] = transformed_value
return Record(**iceberg_typed_key_values)

Expand All @@ -273,21 +272,21 @@ def to_path(self) -> str:


@singledispatch
def iceberg_typed_value(type: IcebergType, value: Any) -> Any:
def _to_iceberg_type(type: IcebergType, value: Any) -> Any:
return TypeError(f"Unsupported partition field type: {type}")


@iceberg_typed_value.register(TimestampType)
@iceberg_typed_value.register(TimestamptzType)
@_to_iceberg_type.register(TimestampType)
@_to_iceberg_type.register(TimestamptzType)
def _(type: IcebergType, value: Optional[datetime]) -> Optional[int]:
return datetime_to_micros(value) if value is not None else None


@iceberg_typed_value.register(DateType)
@_to_iceberg_type.register(DateType)
def _(type: IcebergType, value: Optional[date]) -> Optional[int]:
return date_to_days(value) if value is not None else None


@iceberg_typed_value.register(PrimitiveType)
@_to_iceberg_type.register(PrimitiveType)
def _(type: IcebergType, value: Optional[Any]) -> Optional[Any]:
return value
81 changes: 60 additions & 21 deletions tests/integration/test_partitioning_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
# under the License.
# pylint:disable=redefined-outer-name
from datetime import date, datetime
from typing import Any
from decimal import Decimal
from typing import Any, List

import pytest
import pytz
from pyspark.sql import SparkSession
from pyspark.sql.utils import AnalysisException

from pyiceberg.catalog import Catalog, load_catalog
from pyiceberg.exceptions import NamespaceAlreadyExistsError
Expand All @@ -40,6 +42,7 @@
BinaryType,
BooleanType,
DateType,
DecimalType,
DoubleType,
FixedType,
FloatType,
Expand Down Expand Up @@ -136,6 +139,7 @@ def spark() -> SparkSession:
# NestedField(field_id=12, name="uuid", field_type=UuidType(), required=False),
NestedField(field_id=11, name="binary_field", field_type=BinaryType(), required=False),
NestedField(field_id=12, name="fixed_field", field_type=FixedType(16), required=False),
NestedField(field_id=13, name="decimal", field_type=DecimalType(5, 2), required=False),
)


Expand Down Expand Up @@ -332,6 +336,25 @@ def spark() -> SparkSession:
(CAST('example' AS BINARY), 'Associated string value for binary `example`')
""",
),
(
[PartitionField(source_id=13, field_id=1001, transform=IdentityTransform(), name="decimal_field")],
[Decimal('123.45')],
Record(decimal_field=Decimal('123.45')),
"decimal_field=123.45",
f"""CREATE TABLE {identifier} (
decimal_field decimal(5,2),
string_field string
)
USING iceberg
PARTITIONED BY (
identity(decimal_field)
)
""",
f"""INSERT INTO {identifier}
VALUES
(123.45, 'Associated string value for decimal 123.45')
""",
),
# Year Month Day Hour Transform
# Month Transform
(
Expand Down Expand Up @@ -533,7 +556,7 @@ def spark() -> SparkSession:
"bigint_field_trunc=4294967296",
f"""CREATE TABLE {identifier} (
bigint_field bigint,
other_data string
string_field string
)
USING iceberg
PARTITIONED BY (
Expand All @@ -552,7 +575,7 @@ def spark() -> SparkSession:
"string_field_trunc=abc",
f"""CREATE TABLE {identifier} (
string_field string,
other_data string
another_string_field string
)
USING iceberg
PARTITIONED BY (
Expand All @@ -564,15 +587,33 @@ def spark() -> SparkSession:
('abcdefg', 'Another sample for string');
""",
),
# it seems the transform.tohumanstring does take a bytes type which means i do not need to do extra conversion in iceberg_typed_value() for BinaryType
(
[PartitionField(source_id=13, field_id=1001, transform=TruncateTransform(width=5), name="decimal_field_trunc")],
[Decimal('678.93')],
Record(decimal_field_trunc=Decimal('678.90')),
"decimal_field_trunc=678.90", # Assuming truncation width of 1 leads to truncating to 670
f"""CREATE TABLE {identifier} (
decimal_field decimal(5,2),
string_field string
)
USING iceberg
PARTITIONED BY (
truncate(decimal_field, 2)
)
""",
f"""INSERT INTO {identifier}
VALUES
(678.90, 'Associated string value for decimal 678.90')
""",
),
(
[PartitionField(source_id=11, field_id=1001, transform=TruncateTransform(10), name="binary_field_trunc")],
[b'HELLOICEBERG'],
Record(binary_field_trunc=b'HELLOICEBE'),
"binary_field_trunc=SEVMTE9JQ0VCRQ%3D%3D",
f"""CREATE TABLE {identifier} (
binary_field binary,
other_data string
string_field string
)
USING iceberg
PARTITIONED BY (
Expand All @@ -592,7 +633,7 @@ def spark() -> SparkSession:
"int_field_bucket=0",
f"""CREATE TABLE {identifier} (
int_field int,
other_data string
string_field string
)
USING iceberg
PARTITIONED BY (
Expand Down Expand Up @@ -638,8 +679,8 @@ def spark() -> SparkSession:
def test_partition_key(
session_catalog: Catalog,
spark: SparkSession,
partition_fields: list[PartitionField],
partition_values: list[Any],
partition_fields: List[PartitionField],
partition_values: List[Any],
expected_partition_record: Record,
expected_hive_partition_path_slice: str,
spark_create_table_sql_for_justification: str,
Expand All @@ -653,16 +694,12 @@ def test_partition_key(
partition_spec=spec,
schema=TABLE_SCHEMA,
)
# print(f"{key.partition=}")
# print(f"{key.to_path()=}")
# this affects the metadata in DataFile and all above layers
# key.partition is used to write the metadata in DataFile, ManifestFile and all above layers
assert key.partition == expected_partition_record
# this affects the hive partitioning part in the parquet file path
# key.to_path() generates the hive partitioning part of the to-write parquet file path
assert key.to_path() == expected_hive_partition_path_slice

from pyspark.sql.utils import AnalysisException

# verify expected values are not made up but conform to spark behaviors
# Justify expected values are not made up but conform to spark behaviors
if spark_create_table_sql_for_justification is not None and spark_data_insert_sql_for_justification is not None:
try:
spark.sql(f"drop table {identifier}")
Expand All @@ -675,9 +712,11 @@ def test_partition_key(
iceberg_table = session_catalog.load_table(identifier=identifier)
snapshot = iceberg_table.current_snapshot()
assert snapshot
verify_partition = snapshot.manifests(iceberg_table.io)[0].fetch_manifest_entry(iceberg_table.io)[0].data_file.partition
verify_path = snapshot.manifests(iceberg_table.io)[0].fetch_manifest_entry(iceberg_table.io)[0].data_file.file_path
# print(f"{verify_partition=}")
# print(f"{verify_path=}")
assert verify_partition == expected_partition_record
assert expected_hive_partition_path_slice in verify_path
spark_partition_for_justification = (
snapshot.manifests(iceberg_table.io)[0].fetch_manifest_entry(iceberg_table.io)[0].data_file.partition
)
spark_path_for_justification = (
snapshot.manifests(iceberg_table.io)[0].fetch_manifest_entry(iceberg_table.io)[0].data_file.file_path
)
assert spark_partition_for_justification == expected_partition_record
assert expected_hive_partition_path_slice in spark_path_for_justification

0 comments on commit 7fcf75a

Please sign in to comment.