Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
jqin61 committed Feb 4, 2024
1 parent 7a583ab commit fcd94fd
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 24 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ test-integration:
docker-compose -f dev/docker-compose-integration.yml kill
docker-compose -f dev/docker-compose-integration.yml rm -f
docker-compose -f dev/docker-compose-integration.yml up -d
sleep 10
sleep 5
docker-compose -f dev/docker-compose-integration.yml exec -T spark-iceberg ipython ./provision.py
poetry run pytest tests/ -v -m integration ${PYTEST_ARGS} -s
poetry run pytest tests/ -v -m newyork ${PYTEST_ARGS} -s

test-integration-rebuild:
docker-compose -f dev/docker-compose-integration.yml kill
Expand Down
18 changes: 9 additions & 9 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2469,7 +2469,7 @@ class TablePartition:
arrow_table_partition: pa.Table


def get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]:
def _get_partition_sort_order(partition_columns: list[str], reverse: bool = False) -> dict[str, Any]:
order = 'ascending' if not reverse else 'descending'
null_placement = 'at_start' if reverse else 'at_end'
return {'sort_keys': [(column_name, order) for column_name in partition_columns], 'null_placement': null_placement}
Expand All @@ -2487,12 +2487,12 @@ def group_by_partition_scheme(iceberg_table: Table, arrow_table: pa.Table, parti
)

# only works for identity
sort_options = get_partition_sort_order(partition_columns, reverse=False)
sort_options = _get_partition_sort_order(partition_columns, reverse=False)
sorted_arrow_table = arrow_table.sort_by(sorting=sort_options['sort_keys'], null_placement=sort_options['null_placement'])
return sorted_arrow_table


def get_partition_columns(iceberg_table: Table, arrow_table: pa.Table) -> list[str]:
def _get_partition_columns(iceberg_table: Table, arrow_table: pa.Table) -> list[str]:
arrow_table_cols = set(arrow_table.column_names)
partition_cols = []
for transform_field in iceberg_table.spec().fields:
Expand All @@ -2505,13 +2505,13 @@ def get_partition_columns(iceberg_table: Table, arrow_table: pa.Table) -> list[s
return partition_cols


def get_partition_key(arrow_table: pa.Table, partition_columns: list[str], offset: int) -> Record:
def _get_partition_key(arrow_table: pa.Table, partition_columns: list[str], offset: int) -> Record:
# todo: Instead of fetching partition keys one at a time, try filtering by a mask made of offsets, and convert to py together,
# possibly slightly more efficient.
return Record(**{col: arrow_table.column(col)[offset].as_py() for col in partition_columns})


def partition(iceberg_table: Table, arrow_table: pa.Table) -> Iterable[TablePartition]:
def _partition(iceberg_table: Table, arrow_table: pa.Table) -> Iterable[TablePartition]:
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
Example:
Expand Down Expand Up @@ -2539,11 +2539,11 @@ def partition(iceberg_table: Table, arrow_table: pa.Table) -> Iterable[TablePart
"""
import pyarrow as pa

partition_columns = get_partition_columns(iceberg_table, arrow_table)
partition_columns = _get_partition_columns(iceberg_table, arrow_table)

arrow_table = group_by_partition_scheme(iceberg_table, arrow_table, partition_columns)

reversing_sort_order_options = get_partition_sort_order(partition_columns, reverse=True)
reversing_sort_order_options = _get_partition_sort_order(partition_columns, reverse=True)
reversed_indices = pa.compute.sort_indices(arrow_table, **reversing_sort_order_options).to_pylist()

slice_instructions = []
Expand All @@ -2559,7 +2559,7 @@ def partition(iceberg_table: Table, arrow_table: pa.Table) -> Iterable[TablePart

table_partitions: list[TablePartition] = [
TablePartition(
partition_key=get_partition_key(arrow_table, partition_columns, inst["offset"]),
partition_key=_get_partition_key(arrow_table, partition_columns, inst["offset"]),
arrow_table_partition=arrow_table.slice(**inst),
)
for inst in slice_instructions
Expand All @@ -2574,7 +2574,7 @@ def _dataframe_to_data_files(table: Table, df: pa.Table) -> Iterable[DataFile]:
counter = itertools.count(0)

if len(table.spec().fields) > 0:
partitions = partition(table, df)
partitions = _partition(table, df)
yield from write_file(
table,
iter([
Expand Down
64 changes: 51 additions & 13 deletions tests/integration/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def catalog() -> Catalog:
return catalog


TEST_DATA_without_null = {
TEST_DATA_WITH_NULL = {
'bool': [False, None, True],
'string': ['a', None, 'z'],
# Go over the 16 bytes to kick in truncation
Expand Down Expand Up @@ -168,7 +168,30 @@ def arrow_table_without_null() -> pa.Table:
return pa.Table.from_pydict(TEST_DATA_WITHOUT_NULL, schema=pa_schema)


# working on
@pytest.fixture(scope="session")
def arrow_table_with_null() -> pa.Table:
"""PyArrow table with all kinds of columns"""
pa_schema = pa.schema([
("bool", pa.bool_()),
("string", pa.string()),
("string_long", pa.string()),
("int", pa.int32()),
("long", pa.int64()),
("float", pa.float32()),
("double", pa.float64()),
("timestamp", pa.timestamp(unit="us")),
("timestamptz", pa.timestamp(unit="us", tz="UTC")),
("date", pa.date32()),
# Not supported by Spark
# ("time", pa.time64("us")),
# Not natively supported by Arrow
# ("uuid", pa.fixed(16)),
("binary", pa.binary()),
("fixed", pa.binary(16)),
])
return pa.Table.from_pydict(TEST_DATA_WITH_NULL, schema=pa_schema)

# stub
@pytest.fixture(scope="session", autouse=True)
def table_v1_without_null_partitioned(session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
identifier = "default.arrow_table_v1_without_null_partitioned"
Expand All @@ -188,6 +211,25 @@ def table_v1_without_null_partitioned(session_catalog: Catalog, arrow_table_with

assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}"

# # for above
# @pytest.fixture(scope="session", autouse=True)
# def table_v1_with_null_partitioned(session_catalog: Catalog, arrow_table_with_null: pa.Table) -> None:
# identifier = "default.arrow_table_v1_without_null_partitioned"

# try:
# session_catalog.drop_table(identifier=identifier)
# except NoSuchTableError:
# pass

# tbl = session_catalog.create_table(
# identifier=identifier,
# schema=TABLE_SCHEMA,
# partition_spec=PartitionSpec(PartitionField(source_id=4, field_id=1001, transform=IdentityTransform(), name="int")),
# properties={'format-version': '1'},
# )
# tbl.append(arrow_table_with_null)

# assert tbl.format_version == 1, f"Expected v1, got: v{tbl.format_version}"

@pytest.fixture(scope="session", autouse=True)
def table_v1_appended_without_null_partitioned(session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
Expand Down Expand Up @@ -243,11 +285,7 @@ def table_v2_appended_without_null_partitioned(session_catalog: Catalog, arrow_t
assert tbl.format_version == 2, f"Expected v1, got: v{tbl.format_version}"


# 4 table creation finished


# working on
@pytest.mark.integration
@pytest.mark.newyork
@pytest.mark.parametrize("col", TEST_DATA_WITHOUT_NULL.keys())
@pytest.mark.parametrize("format_version", [1, 2])
def test_query_filter_null(spark: SparkSession, col: str, format_version: int) -> None:
Expand All @@ -256,7 +294,7 @@ def test_query_filter_null(spark: SparkSession, col: str, format_version: int) -
assert df.where(f"{col} is not null").count() == 3, f"Expected 3 rows for {col}"


@pytest.mark.integration
@pytest.mark.adrian
@pytest.mark.parametrize("col", TEST_DATA_WITHOUT_NULL.keys())
@pytest.mark.parametrize("format_version", [1, 2])
def test_query_filter_appended_null_partitioned(spark: SparkSession, col: str, format_version: int) -> None:
Expand Down Expand Up @@ -326,15 +364,15 @@ def spark() -> SparkSession:
return spark


@pytest.mark.integration
@pytest.mark.parametrize("col", TEST_DATA_without_null.keys())
@pytest.mark.adrian
@pytest.mark.parametrize("col", TEST_DATA_WITHOUT_NULL.keys())
def test_query_filter_v1_v2_append_null(spark: SparkSession, col: str) -> None:
identifier = "default.arrow_table_v1_v2_appended_without_null"
df = spark.table(identifier)
assert df.where(f"{col} is not null").count() == 6, f"Expected 3 row for {col}"


@pytest.mark.integration
@pytest.mark.adrian
def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
identifier = "default.arrow_table_summaries"

Expand Down Expand Up @@ -390,7 +428,7 @@ def test_summaries(spark: SparkSession, session_catalog: Catalog, arrow_table_wi
}


@pytest.mark.integration
@pytest.mark.adrian
def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
identifier = "default.arrow_data_files"

Expand Down Expand Up @@ -425,7 +463,7 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w
assert [row.deleted_data_files_count for row in rows] == [0, 0, 0]


@pytest.mark.integration
@pytest.mark.adrian
def test_invalid_arguments(spark: SparkSession, session_catalog: Catalog, arrow_table_without_null: pa.Table) -> None:
identifier = "default.arrow_data_files"

Expand Down

0 comments on commit fcd94fd

Please sign in to comment.