From 499e36e295fc94b2eb3484e354ee64a897fbcf50 Mon Sep 17 00:00:00 2001 From: amogh-jahagirdar Date: Wed, 27 Dec 2023 20:06:32 -0800 Subject: [PATCH] Partition Evolution Support --- mkdocs/docs/api.md | 57 ++ pyiceberg/partitioning.py | 131 ++++- pyiceberg/table/__init__.py | 295 ++++++++++- pyiceberg/table/metadata.py | 3 +- tests/catalog/test_hive.py | 4 +- tests/integration/test_partition_evolution.py | 490 ++++++++++++++++++ 6 files changed, 966 insertions(+), 14 deletions(-) create mode 100644 tests/integration/test_partition_evolution.py diff --git a/mkdocs/docs/api.md b/mkdocs/docs/api.md index 53801922fc..724a45c52f 100644 --- a/mkdocs/docs/api.md +++ b/mkdocs/docs/api.md @@ -418,6 +418,63 @@ with table.update_schema(allow_incompatible_changes=True) as update: update.delete_column("some_field") ``` +## Partition evolution + +PyIceberg supports partition evolution. See the [partition evolution](https://iceberg.apache.org/spec/#partition-evolution) +for more details. + +The API to use when evolving partitions is the `update_spec` API on the table. + +```python +with table.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") +``` + +Updating the partition spec can also be done as part of a transaction with other operations. + +```python +with table.transaction() as transaction: + with transaction.update_spec() as update_spec: + update_spec.add_field("id", BucketTransform(16), "bucketed_id") + update_spec.add_field("event_ts", DayTransform(), "day_ts") + # ... Update properties etc +``` + +### Add fields + +New partition fields can be added via the `add_field` API which takes in the field name to partition on, +the partition transform, and an optional partition name. If the partition name is not specified, +one will be created. + +```python +with table.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + # identity is a shortcut API for adding an IdentityTransform + update.identity("some_field") +``` + +### Remove fields + +Partition fields can also be removed via the `remove_field` API if it no longer makes sense to partition on those fields. + +```python +with table.update_spec() as update:some_partition_name + # Remove the partition field with the name + update.remove_field("some_partition_name") +``` + +### Rename fields + +Partition fields can also be renamed via the `rename_field` API. + +```python +with table.update_spec() as update: + # Rename the partition field with the name bucketed_id to sharded_id + update.rename_field("bucketed_id", "sharded_id") +``` + ## Table properties Set and remove properties through the `Transaction` API: diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index f6307f0f8c..cd5a957b22 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -16,14 +16,9 @@ # under the License. from __future__ import annotations -from functools import cached_property -from typing import ( - Any, - Dict, - List, - Optional, - Tuple, -) +from abc import ABC, abstractmethod +from functools import cached_property, singledispatch +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar from pydantic import ( BeforeValidator, @@ -34,7 +29,18 @@ from typing_extensions import Annotated from pyiceberg.schema import Schema -from pyiceberg.transforms import Transform, parse_transform +from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + Transform, + TruncateTransform, + UnknownTransform, + VoidTransform, + YearTransform, + parse_transform, +) from pyiceberg.typedef import IcebergBaseModel from pyiceberg.types import NestedField, StructType @@ -143,7 +149,7 @@ def is_unpartitioned(self) -> bool: def last_assigned_field_id(self) -> int: if self.fields: return max(pf.field_id for pf in self.fields) - return PARTITION_FIELD_ID_START + return PARTITION_FIELD_ID_START - 1 @cached_property def source_id_to_fields_map(self) -> Dict[int, List[PartitionField]]: @@ -215,3 +221,108 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre ) ) return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID) + + +T = TypeVar("T") + + +class PartitionSpecVisitor(Generic[T], ABC): + @abstractmethod + def identity(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit identity partition field.""" + + @abstractmethod + def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> T: + """Visit bucket partition field.""" + + @abstractmethod + def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> T: + """Visit truncate partition field.""" + + @abstractmethod + def year(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit year partition field.""" + + @abstractmethod + def month(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit month partition field.""" + + @abstractmethod + def day(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit day partition field.""" + + @abstractmethod + def hour(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit hour partition field.""" + + @abstractmethod + def always_null(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit void partition field.""" + + @abstractmethod + def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> T: + """Visit unknown partition field.""" + raise ValueError(f"Unknown transform is not supported: {transform}") + + +class _PartitionNameGenerator(PartitionSpecVisitor[str]): + def identity(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + + def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> str: + return f"{source_name}_bucket_{num_buckets}" + + def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> str: + return source_name + "_trunc_" + str(width) + + def year(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_year" + + def month(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_month" + + def day(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_day" + + def hour(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_hour" + + def always_null(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_null" + + def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> str: + return super().unknown(field_id, source_name, source_id, transform) + + +R = TypeVar("R") + + +@singledispatch +def _visit(spec: PartitionSpec, schema: Schema, visitor: PartitionSpecVisitor[R]) -> List[R]: + return [_visit_partition_field(schema, field, visitor) for field in spec.fields] + + +def _visit_partition_field(schema: Schema, field: PartitionField, visitor: PartitionSpecVisitor[R]) -> R: + source_name = schema.find_column_name(field.source_id) + if not source_name: + raise ValueError(f"Could not find field with id {field.source_id}") + + transform = field.transform + if isinstance(transform, IdentityTransform): + return visitor.identity(field.field_id, source_name, field.source_id) + elif isinstance(transform, BucketTransform): + return visitor.bucket(field.field_id, source_name, field.source_id, transform.num_buckets) + elif isinstance(transform, TruncateTransform): + return visitor.truncate(field.field_id, source_name, field.source_id, transform.width) + elif isinstance(transform, DayTransform): + return visitor.day(field.field_id, source_name, field.source_id) + elif isinstance(transform, HourTransform): + return visitor.hour(field.field_id, source_name, field.source_id) + elif isinstance(transform, YearTransform): + return visitor.year(field.field_id, source_name, field.source_id) + elif isinstance(transform, VoidTransform): + return visitor.always_null(field.field_id, source_name, field.source_id) + elif isinstance(transform, UnknownTransform): + return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform)) + else: + raise ValueError(f"Unknown transform {transform}") diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 060f13772b..e49f9400fe 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -51,6 +51,7 @@ And, BooleanExpression, EqualTo, + Reference, parser, visitors, ) @@ -67,7 +68,15 @@ write_manifest, write_manifest_list, ) -from pyiceberg.partitioning import PartitionSpec +from pyiceberg.partitioning import ( + INITIAL_PARTITION_SPEC_ID, + PARTITION_FIELD_ID_START, + IdentityTransform, + PartitionField, + PartitionSpec, + _PartitionNameGenerator, + _visit_partition_field, +) from pyiceberg.schema import ( PartnerAccessor, Schema, @@ -99,6 +108,7 @@ update_snapshot_summaries, ) from pyiceberg.table.sorting import SortOrder +from pyiceberg.transforms import TimeTransform, Transform, VoidTransform from pyiceberg.typedef import ( EMPTY_DICT, IcebergBaseModel, @@ -372,6 +382,14 @@ def update_snapshot(self) -> UpdateSnapshot: """ return UpdateSnapshot(self._table, self) + def update_spec(self) -> UpdateSpec: + """Create a new UpdateSpec to update the partitioning of the table. + + Returns: + A new UpdateSpec. + """ + return UpdateSpec(self._table, self) + def remove_properties(self, *removals: str) -> Transaction: """Remove properties. @@ -634,6 +652,43 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta return base_metadata.model_copy(update={"current_schema_id": new_schema_id}) +@_apply_table_update.register(AddPartitionSpecUpdate) +def _(update: AddPartitionSpecUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + for spec in base_metadata.partition_specs: + if spec.spec_id == update.spec.spec_id: + raise ValueError(f"Partition spec with id {spec.spec_id} already exists: {spec}") + context.add_update(update) + return base_metadata.model_copy( + update={ + "partition_specs": base_metadata.partition_specs + [update.spec], + "last_partition_id": max( + max(field.field_id for field in update.spec.fields), + base_metadata.last_partition_id or PARTITION_FIELD_ID_START - 1, + ), + } + ) + + +@_apply_table_update.register(SetDefaultSpecUpdate) +def _(update: SetDefaultSpecUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + new_spec_id = update.spec_id + if new_spec_id == -1: + new_spec_id = max(spec.spec_id for spec in base_metadata.partition_specs) + if new_spec_id == base_metadata.default_spec_id: + return base_metadata + found_spec_id = False + for spec in base_metadata.partition_specs: + found_spec_id = spec.spec_id == new_spec_id + if found_spec_id: + break + + if not found_spec_id: + raise ValueError(f"Failed to find spec with id {new_spec_id}") + + context.add_update(update) + return base_metadata.model_copy(update={"default_spec_id": new_spec_id}) + + @_apply_table_update.register(AddSnapshotUpdate) def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if len(base_metadata.schemas) == 0: @@ -969,6 +1024,12 @@ def sort_orders(self) -> Dict[int, SortOrder]: """Return a dict of the sort orders of this table.""" return {sort_order.order_id: sort_order for sort_order in self.metadata.sort_orders} + def last_partition_id(self) -> int: + """Return the highest assigned partition field ID across all specs or 999 if only the unpartitioned spec exists.""" + if self.metadata.last_partition_id: + return self.metadata.last_partition_id + return PARTITION_FIELD_ID_START - 1 + @property def properties(self) -> Dict[str, str]: """Properties of the table.""" @@ -1095,6 +1156,9 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T for data_file in data_files: update_snapshot.append_data_file(data_file) + def update_spec(self, case_sensitive: bool = True) -> UpdateSpec: + return UpdateSpec(self, case_sensitive=case_sensitive) + def refs(self) -> Dict[str, SnapshotRef]: """Return the snapshot references in the table.""" return self.metadata.refs @@ -2655,3 +2719,232 @@ def overwrite(self) -> OverwriteFiles: operation=Operation.OVERWRITE if self._table.current_snapshot() is not None else Operation.APPEND, transaction=self._transaction, ) + + +class UpdateSpec: + _table: Table + _name_to_field: Dict[str, PartitionField] = {} + _name_to_added_field: Dict[str, PartitionField] = {} + _transform_to_field: Dict[Tuple[int, str], PartitionField] = {} + _transform_to_added_field: Dict[Tuple[int, str], PartitionField] = {} + _renames: Dict[str, str] = {} + _added_time_fields: Dict[int, PartitionField] = {} + _case_sensitive: bool + _adds: List[PartitionField] + _deletes: Set[int] + _last_assigned_partition_id: int + _transaction: Optional[Transaction] + + def __init__(self, table: Table, transaction: Optional[Transaction] = None, case_sensitive: bool = True) -> None: + self._table = table + self._name_to_field = {field.name: field for field in table.spec().fields} + self._name_to_added_field = {} + self._transform_to_field = {(field.source_id, repr(field.transform)): field for field in table.spec().fields} + self._transform_to_added_field = {} + self._adds = [] + self._deletes = set() + self._last_assigned_partition_id = table.last_partition_id() + self._renames = {} + self._transaction = transaction + self._case_sensitive = case_sensitive + self._added_time_fields = {} + + def add_field( + self, + source_column_name: str, + transform: Transform[Any, Any], + partition_field_name: Optional[str] = None, + ) -> UpdateSpec: + ref = Reference(source_column_name) + bound_ref = ref.bind(self._table.schema(), self._case_sensitive) + # verify transform can actually bind it + output_type = bound_ref.field.field_type + if not transform.can_transform(output_type): + raise ValueError(f"{transform} cannot transform {output_type} values from {bound_ref.field.name}") + + transform_key = (bound_ref.field.field_id, repr(transform)) + existing_partition_field = self._transform_to_field.get(transform_key) + if existing_partition_field and self._is_duplicate_partition(transform, existing_partition_field): + raise ValueError(f"Duplicate partition field for ${ref.name}=${ref}, ${existing_partition_field} already exists") + + added = self._transform_to_added_field.get(transform_key) + if added: + raise ValueError(f"Already added partition: {added.name}") + + new_field = self._partition_field((bound_ref.field.field_id, transform), partition_field_name) + if new_field.name in self._name_to_added_field: + raise ValueError(f"Already added partition field with name: {new_field.name}") + + if isinstance(new_field.transform, TimeTransform): + existing_time_field = self._added_time_fields.get(new_field.source_id) + if existing_time_field: + raise ValueError(f"Cannot add time partition field: {new_field.name} conflicts with {existing_time_field.name}") + self._added_time_fields[new_field.source_id] = new_field + self._transform_to_added_field[transform_key] = new_field + + existing_partition_field = self._name_to_field.get(new_field.name) + if existing_partition_field and new_field.field_id not in self._deletes: + if isinstance(existing_partition_field.transform, VoidTransform): + self.rename_field( + existing_partition_field.name, existing_partition_field.name + "_" + str(existing_partition_field.field_id) + ) + else: + raise ValueError(f"Cannot add duplicate partition field name: {existing_partition_field.name}") + + self._name_to_added_field[new_field.name] = new_field + self._adds.append(new_field) + return self + + def add_identity(self, source_column_name: str) -> UpdateSpec: + return self.add_field(source_column_name, IdentityTransform(), None) + + def remove_field(self, name: str) -> UpdateSpec: + added = self._name_to_added_field.get(name) + if added: + raise ValueError(f"Cannot delete newly added field {name}") + renamed = self._renames.get(name) + if renamed: + raise ValueError(f"Cannot rename and delete field {name}") + field = self._name_to_field.get(name) + if not field: + raise ValueError(f"No such partition field: {name}") + + self._deletes.add(field.field_id) + return self + + def rename_field(self, name: str, new_name: str) -> UpdateSpec: + existing_field = self._name_to_field.get(new_name) + if existing_field and isinstance(existing_field.transform, VoidTransform): + return self.rename_field(name, name + "_" + str(existing_field.field_id)) + added = self._name_to_added_field.get(name) + if added: + raise ValueError("Cannot rename recently added partitions") + field = self._name_to_field.get(name) + if not field: + raise ValueError(f"Cannot find partition field {name}") + if field.field_id in self._deletes: + raise ValueError(f"Cannot delete and rename partition field {name}") + self._renames[name] = new_name + return self + + def commit(self) -> None: + new_spec = self._apply() + if self._table.metadata.default_spec_id != new_spec.spec_id: + if new_spec.spec_id not in self._table.specs(): + updates = [AddPartitionSpecUpdate(spec=new_spec), SetDefaultSpecUpdate(spec_id=-1)] + else: + updates = [SetDefaultSpecUpdate(spec_id=new_spec.spec_id)] + + required_last_assigned_partitioned_id = self._table.last_partition_id() + requirements = [AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id)] + + if self._transaction is not None: + self._transaction._append_updates(*updates) # pylint: disable=W0212 + self._transaction._append_requirements(*requirements) # pylint: disable=W0212 + else: + requirements.append(AssertDefaultSpecId(default_spec_id=self._table.spec().spec_id)) + self._table._do_commit(updates=tuple(updates), requirements=tuple(requirements)) # pylint: disable=W0212 + + def __exit__(self, _: Any, value: Any, traceback: Any) -> None: + """Close and commit the change.""" + return self.commit() + + def __enter__(self) -> UpdateSpec: + """Update the table.""" + return self + + def _apply(self) -> PartitionSpec: + def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None: + try: + field = schema.find_field(name) + except ValueError: + field = None + + if source_id is not None and field is not None and field.field_id != source_id: + raise ValueError(f"Cannot create identity partition from a different field in the schema {name}") + elif field is not None and source_id != field.field_id: + raise ValueError(f"Cannot create partition from name that exists in schema {name}") + if not name: + raise ValueError("Undefined name") + if name in partition_names: + raise ValueError(f"Partition name has to be unique: {name}") + partition_names.add(name) + + def _add_new_field( + schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str] + ) -> PartitionField: + _check_and_add_partition_name(schema, name, source_id, partition_names) + return PartitionField(source_id, field_id, transform, name) + + partition_fields = [] + partition_names: Set[str] = set() + for field in self._table.spec().fields: + if field.field_id not in self._deletes: + renamed = self._renames.get(field.name) + if renamed: + new_field = _add_new_field( + self._table.schema(), field.source_id, field.field_id, renamed, field.transform, partition_names + ) + else: + new_field = _add_new_field( + self._table.schema(), field.source_id, field.field_id, field.name, field.transform, partition_names + ) + partition_fields.append(new_field) + elif self._table.format_version == 1: + renamed = self._renames.get(field.name) + if renamed: + new_field = _add_new_field( + self._table.schema(), field.source_id, field.field_id, renamed, VoidTransform(), partition_names + ) + else: + new_field = _add_new_field( + self._table.schema(), field.source_id, field.field_id, field.name, VoidTransform(), partition_names + ) + + partition_fields.append(new_field) + + for added_field in self._adds: + new_field = PartitionField( + source_id=added_field.source_id, + field_id=added_field.field_id, + transform=added_field.transform, + name=added_field.name, + ) + partition_fields.append(new_field) + + # Reuse spec id or create a new one. + new_spec = PartitionSpec(*partition_fields) + new_spec_id = INITIAL_PARTITION_SPEC_ID + for spec in self._table.specs().values(): + if new_spec.compatible_with(spec): + new_spec_id = spec.spec_id + break + elif new_spec_id <= spec.spec_id: + new_spec_id = spec.spec_id + 1 + return PartitionSpec(*partition_fields, spec_id=new_spec_id) + + def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]], name: Optional[str]) -> PartitionField: + if self._table.metadata.format_version == 2: + source_id, transform = transform_key + historical_fields = [] + for spec in self._table.specs().values(): + for field in spec.fields: + historical_fields.append((field.source_id, field.field_id, repr(field.transform), field.name)) + + for field_key in historical_fields: + if field_key[0] == source_id and field_key[2] == repr(transform): + if name is None or field_key[3] == name: + return PartitionField(source_id, field_key[1], transform, name) + + new_field_id = self._new_field_id() + if name is None: + tmp_field = PartitionField(transform_key[0], new_field_id, transform_key[1], 'unassigned_field_name') + name = _visit_partition_field(self._table.schema(), tmp_field, _PartitionNameGenerator()) + return PartitionField(transform_key[0], new_field_id, transform_key[1], name) + + def _new_field_id(self) -> int: + self._last_assigned_partition_id += 1 + return self._last_assigned_partition_id + + def _is_duplicate_partition(self, transform: Transform[Any, Any], partition_field: PartitionField) -> bool: + return partition_field.field_id not in self._deletes and partition_field.transform == transform diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index a5dfb6ce4c..ea7a02f715 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -310,7 +310,8 @@ def construct_partition_specs(cls, data: Dict[str, Any]) -> Dict[str, Any]: data[PARTITION_SPECS] = [{"field-id": 0, "fields": ()}] data[LAST_PARTITION_ID] = max( - [field.get(FIELD_ID) for spec in data[PARTITION_SPECS] for field in spec[FIELDS]], default=PARTITION_FIELD_ID_START + [field.get(FIELD_ID) for spec in data[PARTITION_SPECS] for field in spec[FIELDS]], + default=PARTITION_FIELD_ID_START - 1, ) return data diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py index dc2689e0d8..e59b7599bc 100644 --- a/tests/catalog/test_hive.py +++ b/tests/catalog/test_hive.py @@ -277,7 +277,7 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase, ) ], current_schema_id=0, - last_partition_id=1000, + last_partition_id=999, properties={"owner": "javaberg", 'write.parquet.compression-codec': 'zstd'}, partition_specs=[PartitionSpec()], default_spec_id=0, @@ -330,7 +330,7 @@ def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabas schema=expected_schema, schemas=[expected_schema], current_schema_id=0, - last_partition_id=1000, + last_partition_id=999, properties={"owner": "javaberg", "write.parquet.compression-codec": "zstd"}, partition_spec=[], partition_specs=[expected_spec], diff --git a/tests/integration/test_partition_evolution.py b/tests/integration/test_partition_evolution.py new file mode 100644 index 0000000000..16feef565d --- /dev/null +++ b/tests/integration/test_partition_evolution.py @@ -0,0 +1,490 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name + +import pytest + +from pyiceberg.catalog import Catalog, load_catalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + MonthTransform, + TruncateTransform, + VoidTransform, + YearTransform, +) +from pyiceberg.types import ( + LongType, + NestedField, + StringType, + TimestampType, +) + + +@pytest.fixture() +def catalog_rest() -> Catalog: + return load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + +@pytest.fixture() +def catalog_hive() -> Catalog: + return load_catalog( + "local", + **{ + "type": "hive", + "uri": "http://localhost:9083", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + +def _simple_table(catalog: Catalog, table_schema_simple: Schema) -> Table: + return _create_table_with_schema(catalog, table_schema_simple, "1") + + +def _table(catalog: Catalog) -> Table: + schema_with_timestamp = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + NestedField(3, "str", StringType(), required=False), + ) + return _create_table_with_schema(catalog, schema_with_timestamp, "1") + + +def _table_v2(catalog: Catalog) -> Table: + schema_with_timestamp = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + NestedField(3, "str", StringType(), required=False), + ) + return _create_table_with_schema(catalog, schema_with_timestamp, "2") + + +def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table: + tbl_name = "default.test_schema_evolution" + try: + catalog.drop_table(tbl_name) + except NoSuchTableError: + pass + return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version}) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_add_identity_partition(catalog: Catalog, table_schema_simple: Schema) -> None: + simple_table = _simple_table(catalog, table_schema_simple) + simple_table.update_spec().add_identity("foo").commit() + specs = simple_table.specs() + assert len(specs) == 2 + spec = simple_table.spec() + assert spec.spec_id == 1 + assert spec.last_assigned_field_id == 1000 + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_add_year(catalog: Catalog) -> None: + table = _table(catalog) + table.update_spec().add_field("event_ts", YearTransform(), "year_transform").commit() + _validate_new_partition_fields(table, 1000, 1, 1000, PartitionField(2, 1000, YearTransform(), "year_transform")) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_add_month(catalog: Catalog) -> None: + table = _table(catalog) + table.update_spec().add_field("event_ts", MonthTransform(), "month_transform").commit() + _validate_new_partition_fields(table, 1000, 1, 1000, PartitionField(2, 1000, MonthTransform(), "month_transform")) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_add_day(catalog: Catalog) -> None: + table = _table(catalog) + table.update_spec().add_field("event_ts", DayTransform(), "day_transform").commit() + _validate_new_partition_fields(table, 1000, 1, 1000, PartitionField(2, 1000, DayTransform(), "day_transform")) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_add_hour(catalog: Catalog) -> None: + table = _table(catalog) + table.update_spec().add_field("event_ts", HourTransform(), "hour_transform").commit() + _validate_new_partition_fields(table, 1000, 1, 1000, PartitionField(2, 1000, HourTransform(), "hour_transform")) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_add_bucket(catalog: Catalog, table_schema_simple: Schema) -> None: + simple_table = _create_table_with_schema(catalog, table_schema_simple, "1") + simple_table.update_spec().add_field("foo", BucketTransform(12), "bucket_transform").commit() + _validate_new_partition_fields(simple_table, 1000, 1, 1000, PartitionField(1, 1000, BucketTransform(12), "bucket_transform")) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_add_truncate(catalog: Catalog, table_schema_simple: Schema) -> None: + simple_table = _create_table_with_schema(catalog, table_schema_simple, "1") + simple_table.update_spec().add_field("foo", TruncateTransform(1), "truncate_transform").commit() + _validate_new_partition_fields( + simple_table, 1000, 1, 1000, PartitionField(1, 1000, TruncateTransform(1), "truncate_transform") + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_multiple_adds(catalog: Catalog) -> None: + table = _table(catalog) + table.update_spec().add_identity("id").add_field("event_ts", HourTransform(), "hourly_partitioned").add_field( + "str", TruncateTransform(2), "truncate_str" + ).commit() + _validate_new_partition_fields( + table, + 1002, + 1, + 1002, + PartitionField(1, 1000, IdentityTransform(), "id"), + PartitionField(2, 1001, HourTransform(), "hourly_partitioned"), + PartitionField(3, 1002, TruncateTransform(2), "truncate_str"), + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_add_hour_to_day(catalog: Catalog) -> None: + table = _table(catalog) + table.update_spec().add_field("event_ts", DayTransform(), "daily_partitioned").commit() + table.update_spec().add_field("event_ts", HourTransform(), "hourly_partitioned").commit() + _validate_new_partition_fields( + table, + 1001, + 2, + 1001, + PartitionField(2, 1000, DayTransform(), "daily_partitioned"), + PartitionField(2, 1001, HourTransform(), "hourly_partitioned"), + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_add_multiple_buckets(catalog: Catalog) -> None: + table = _table(catalog) + table.update_spec().add_field("id", BucketTransform(16)).add_field("id", BucketTransform(4)).commit() + _validate_new_partition_fields( + table, + 1001, + 1, + 1001, + PartitionField(1, 1000, BucketTransform(16), "id_bucket_16"), + PartitionField(1, 1001, BucketTransform(4), "id_bucket_4"), + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_remove_identity(catalog: Catalog) -> None: + table = _table(catalog) + table.update_spec().add_identity("id").commit() + table.update_spec().remove_field("id").commit() + assert len(table.specs()) == 3 + assert table.spec().spec_id == 2 + assert table.spec() == PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=VoidTransform(), name='id'), spec_id=2 + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_remove_identity_v2(catalog: Catalog) -> None: + table_v2 = _table_v2(catalog) + table_v2.update_spec().add_identity("id").commit() + table_v2.update_spec().remove_field("id").commit() + assert len(table_v2.specs()) == 2 + assert table_v2.spec().spec_id == 0 + assert table_v2.spec() == PartitionSpec(spec_id=0) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_remove_bucket(catalog: Catalog) -> None: + table = _table(catalog) + with table.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table.update_spec() as remove: + remove.remove_field("bucketed_id") + + assert len(table.specs()) == 3 + _validate_new_partition_fields( + table, + 1001, + 2, + 1001, + PartitionField(source_id=1, field_id=1000, transform=VoidTransform(), name='bucketed_id'), + PartitionField(source_id=2, field_id=1001, transform=DayTransform(), name='day_ts'), + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_remove_bucket_v2(catalog: Catalog) -> None: + table_v2 = _table_v2(catalog) + with table_v2.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table_v2.update_spec() as remove: + remove.remove_field("bucketed_id") + assert len(table_v2.specs()) == 3 + _validate_new_partition_fields( + table_v2, 1001, 2, 1001, PartitionField(source_id=2, field_id=1001, transform=DayTransform(), name='day_ts') + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_remove_day(catalog: Catalog) -> None: + table = _table(catalog) + with table.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table.update_spec() as remove: + remove.remove_field("day_ts") + + assert len(table.specs()) == 3 + _validate_new_partition_fields( + table, + 1001, + 2, + 1001, + PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name='bucketed_id'), + PartitionField(source_id=2, field_id=1001, transform=VoidTransform(), name='day_ts'), + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_remove_day_v2(catalog: Catalog) -> None: + table_v2 = _table_v2(catalog) + with table_v2.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table_v2.update_spec() as remove: + remove.remove_field("day_ts") + assert len(table_v2.specs()) == 3 + _validate_new_partition_fields( + table_v2, 1000, 2, 1001, PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name='bucketed_id') + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_rename(catalog: Catalog) -> None: + table = _table(catalog) + table.update_spec().add_identity("id").commit() + table.update_spec().rename_field("id", "sharded_id").commit() + assert len(table.specs()) == 3 + assert table.spec().spec_id == 2 + _validate_new_partition_fields(table, 1000, 2, 1000, PartitionField(1, 1000, IdentityTransform(), "sharded_id")) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_cannot_add_and_remove(catalog: Catalog) -> None: + table = _table(catalog) + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_identity("id").remove_field("id").commit() + assert "Cannot delete newly added field id" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_cannot_add_redundant_time_partition(catalog: Catalog) -> None: + table = _table(catalog) + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_field("event_ts", YearTransform(), "year_transform").add_field( + "event_ts", HourTransform(), "hour_transform" + ).commit() + assert "Cannot add time partition field: hour_transform conflicts with year_transform" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_cannot_delete_and_rename(catalog: Catalog) -> None: + table = _table(catalog) + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_identity("id").commit() + table.update_spec().remove_field("id").rename_field("id", "sharded_id").commit() + assert "Cannot delete and rename partition field id" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_cannot_rename_and_delete(catalog: Catalog) -> None: + table = _table(catalog) + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_identity("id").commit() + table.update_spec().rename_field("id", "sharded_id").remove_field("id").commit() + assert "Cannot rename and delete field id" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_cannot_add_same_tranform_for_same_field(catalog: Catalog) -> None: + table = _table(catalog) + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_field("str", TruncateTransform(4), "truncated_str").add_field( + "str", TruncateTransform(4) + ).commit() + assert "Already added partition" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_cannot_add_same_field_multiple_times(catalog: Catalog) -> None: + table = _table(catalog) + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_field("id", IdentityTransform(), "duplicate").add_field( + "id", IdentityTransform(), "duplicate" + ).commit() + assert "Already added partition" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_cannot_add_multiple_specs_same_name(catalog: Catalog) -> None: + table = _table(catalog) + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_field("id", IdentityTransform(), "duplicate").add_field( + "event_ts", IdentityTransform(), "duplicate" + ).commit() + assert "Already added partition" in str(exc_info.value) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_change_specs_and_schema_transaction(catalog: Catalog) -> None: + table = _table(catalog) + with table.transaction() as transaction: + with transaction.update_spec() as update_spec: + update_spec.add_identity("id").add_field("event_ts", HourTransform(), "hourly_partitioned").add_field( + "str", TruncateTransform(2), "truncate_str" + ) + + with transaction.update_schema() as update_schema: + update_schema.add_column("col_string", StringType()) + + _validate_new_partition_fields( + table, + 1002, + 1, + 1002, + PartitionField(1, 1000, IdentityTransform(), "id"), + PartitionField(2, 1001, HourTransform(), "hourly_partitioned"), + PartitionField(3, 1002, TruncateTransform(2), "truncate_str"), + ) + + assert table.schema() == Schema( + NestedField(field_id=1, name='id', field_type=LongType(), required=False), + NestedField(field_id=2, name='event_ts', field_type=TimestampType(), required=False), + NestedField(field_id=3, name='str', field_type=StringType(), required=False), + NestedField(field_id=4, name='col_string', field_type=StringType(), required=False), + schema_id=1, + identifier_field_ids=[], + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_multiple_adds_and_remove_v1(catalog: Catalog) -> None: + table = _table(catalog) + with table.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table.update_spec() as update: + update.remove_field("day_ts").remove_field("bucketed_id") + with table.update_spec() as update: + update.add_field("str", TruncateTransform(2), "truncated_str") + _validate_new_partition_fields( + table, + 1002, + 3, + 1002, + PartitionField(1, 1000, VoidTransform(), "bucketed_id"), + PartitionField(2, 1001, VoidTransform(), "day_ts"), + PartitionField(3, 1002, TruncateTransform(2), "truncated_str"), + ) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_multiple_adds_and_remove_v2(catalog: Catalog) -> None: + table_v2 = _table_v2(catalog) + with table_v2.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table_v2.update_spec() as update: + update.remove_field("day_ts").remove_field("bucketed_id") + with table_v2.update_spec() as update: + update.add_field("str", TruncateTransform(2), "truncated_str") + _validate_new_partition_fields(table_v2, 1002, 2, 1002, PartitionField(3, 1002, TruncateTransform(2), "truncated_str")) + + +@pytest.mark.integration +@pytest.mark.parametrize('catalog', [pytest.lazy_fixture('catalog_hive'), pytest.lazy_fixture('catalog_rest')]) +def test_multiple_remove_and_add_reuses_v2(catalog: Catalog) -> None: + table_v2 = _table_v2(catalog) + with table_v2.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table_v2.update_spec() as update: + update.remove_field("day_ts").remove_field("bucketed_id") + with table_v2.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + _validate_new_partition_fields(table_v2, 1000, 2, 1001, PartitionField(1, 1000, BucketTransform(16), "bucketed_id")) + + +def _validate_new_partition_fields( + table: Table, + expected_spec_last_assigned_field_id: int, + expected_spec_id: int, + expected_metadata_last_assigned_field_id: int, + *expected_partition_fields: PartitionField, +) -> None: + spec = table.spec() + assert spec.spec_id == expected_spec_id + assert spec.last_assigned_field_id == expected_spec_last_assigned_field_id + assert table.last_partition_id() == expected_metadata_last_assigned_field_id + assert len(spec.fields) == len(expected_partition_fields) + for i in range(len(spec.fields)): + assert spec.fields[i] == expected_partition_fields[i]