From a40098c418df2fc28822f50c913aa3f2a16e88e8 Mon Sep 17 00:00:00 2001 From: amogh-jahagirdar Date: Wed, 27 Dec 2023 20:06:32 -0800 Subject: [PATCH] Partition Evolution Support --- pyiceberg/partitioning.py | 129 +++++- pyiceberg/table/__init__.py | 296 +++++++++++++- tests/test_integration_partition_evolution.py | 366 ++++++++++++++++++ 3 files changed, 781 insertions(+), 10 deletions(-) create mode 100644 tests/test_integration_partition_evolution.py diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index f6307f0f8c..8309b81434 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -16,14 +16,9 @@ # under the License. from __future__ import annotations -from functools import cached_property -from typing import ( - Any, - Dict, - List, - Optional, - Tuple, -) +from abc import ABC, abstractmethod +from functools import cached_property, singledispatch +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar from pydantic import ( BeforeValidator, @@ -34,7 +29,18 @@ from typing_extensions import Annotated from pyiceberg.schema import Schema -from pyiceberg.transforms import Transform, parse_transform +from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + Transform, + TruncateTransform, + UnknownTransform, + VoidTransform, + YearTransform, + parse_transform, +) from pyiceberg.typedef import IcebergBaseModel from pyiceberg.types import NestedField, StructType @@ -215,3 +221,108 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre ) ) return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID) + + +T = TypeVar("T") + + +class PartitionSpecVisitor(Generic[T], ABC): + @abstractmethod + def identity(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit identity partition field.""" + + @abstractmethod + def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> T: + """Visit bucket partition field.""" + + @abstractmethod + def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> T: + """Visit truncate partition field.""" + + @abstractmethod + def year(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit year partition field.""" + + @abstractmethod + def month(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit month partition field.""" + + @abstractmethod + def day(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit day partition field.""" + + @abstractmethod + def hour(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit hour partition field.""" + + @abstractmethod + def always_null(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit void partition field.""" + + @abstractmethod + def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> T: + """Visit unknown partition field.""" + raise ValueError(f"Unknown transform is not supported: {transform}") + + +class _PartitionNameGenerator(PartitionSpecVisitor[str]): + def identity(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + + def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> str: + return f"{source_name}_bucket_{num_buckets}" + + def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> str: + return source_name + "_trunc_" + str(width) + + def year(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_year" + + def month(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_month" + + def day(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_day" + + def hour(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_hour" + + def always_null(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_null" + + def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> str: + return super().unknown(field_id, source_name, source_id, transform) + + +R = TypeVar("R") + + +@singledispatch +def _visit(spec: PartitionSpec, schema: Schema, visitor: PartitionSpecVisitor[R]) -> List[R]: + return [_visit_partition_field(schema, field, visitor) for field in spec.fields] + + +def _visit_partition_field(schema: Schema, field: PartitionField, visitor: PartitionSpecVisitor[R]) -> R: + source_name = schema.find_column_name(field.source_id) + if not source_name: + raise ValueError(f"Could not find field with id {field.source_id}") + + transform = field.transform + if isinstance(transform, IdentityTransform): + return visitor.identity(field.field_id, source_name, field.source_id) + elif isinstance(transform, BucketTransform): + return visitor.bucket(field.field_id, source_name, field.source_id, transform.num_buckets) + elif isinstance(transform, TruncateTransform): + return visitor.truncate(field.field_id, source_name, field.source_id, transform.width) + elif isinstance(transform, DayTransform): + return visitor.day(field.field_id, source_name, field.source_id) + elif isinstance(transform, HourTransform): + return visitor.hour(field.field_id, source_name, field.source_id) + elif isinstance(transform, YearTransform): + return visitor.year(field.field_id, source_name, field.source_id) + elif isinstance(transform, VoidTransform): + return visitor.always_null(field.field_id, source_name, field.source_id) + elif isinstance(transform, UnknownTransform): + return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform)) + else: + raise ValueError(f"Unknown transform {transform}") diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 057dd8427c..c6b59f9cd0 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -51,6 +51,7 @@ And, BooleanExpression, EqualTo, + Reference, parser, visitors, ) @@ -67,7 +68,15 @@ write_manifest, write_manifest_list, ) -from pyiceberg.partitioning import PartitionSpec +from pyiceberg.partitioning import ( + INITIAL_PARTITION_SPEC_ID, + PARTITION_FIELD_ID_START, + IdentityTransform, + PartitionField, + PartitionSpec, + _PartitionNameGenerator, + _visit_partition_field, +) from pyiceberg.schema import ( Schema, SchemaVisitor, @@ -97,6 +106,7 @@ update_snapshot_summaries, ) from pyiceberg.table.sorting import SortOrder +from pyiceberg.transforms import TimeTransform, Transform, VoidTransform from pyiceberg.typedef import ( EMPTY_DICT, IcebergBaseModel, @@ -277,6 +287,14 @@ def update_schema(self) -> UpdateSchema: """ return UpdateSchema(self._table, self) + def update_spec(self) -> UpdateSpec: + """Create a new UpdateSpec to update the partitioning of the table. + + Returns: + A new UpdateSpec. + """ + return UpdateSpec(self._table, self) + def remove_properties(self, *removals: str) -> Transaction: """Remove properties. @@ -533,6 +551,39 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta return base_metadata.model_copy(update={"current_schema_id": new_schema_id}) +@_apply_table_update.register(AddPartitionSpecUpdate) +def _(update: AddPartitionSpecUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + for spec in base_metadata.partition_specs: + if spec.spec_id == update.spec_id: + raise ValueError(f"Partition spec with id {spec.spec_id} already exists: {spec}") + + context.add_update(update) + return base_metadata.model_copy( + update={ + "partition_specs": base_metadata.partition_specs + [update.spec], + } + ) + + +@_apply_table_update.register(SetDefaultSpecUpdate) +def _(update: SetDefaultSpecUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + new_spec_id = update.spec_id + if new_spec_id == base_metadata.default_spec_id: + return base_metadata + + found_spec_id = False + for spec in base_metadata.partition_specs: + found_spec_id = spec.spec_id == new_spec_id + if found_spec_id: + break + + if not found_spec_id: + raise ValueError(f"Failed to find spec with id {new_spec_id}") + + context.add_update(update) + return base_metadata.model_copy(update={"default_spec_id": new_spec_id}) + + @_apply_table_update.register(AddSnapshotUpdate) def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: if len(base_metadata.schemas) == 0: @@ -971,6 +1022,9 @@ def overwrite(self, df: pa.Table, overwrite_filter: BooleanExpression = ALWAYS_T merge.commit() + def update_spec(self, case_sensitive: bool = True) -> UpdateSpec: + return UpdateSpec(self, case_sensitive=case_sensitive) + def refs(self) -> Dict[str, SnapshotRef]: """Return the snapshot references in the table.""" return self.metadata.refs @@ -2271,3 +2325,243 @@ def commit(self) -> Snapshot: ) return snapshot + + +class UpdateSpec: + _table: Table + _name_to_field: Dict[str, PartitionField] = {} + _name_to_added_field: Dict[str, PartitionField] = {} + _transform_to_field: Dict[Tuple[int, str], PartitionField] = {} + _transform_to_added_field: Dict[Tuple[int, str], PartitionField] = {} + _renames: Dict[str, str] = {} + _added_time_fields: Dict[int, PartitionField] = {} + _case_sensitive: bool + _adds: List[PartitionField] + _deletes: Set[int] + _last_assigned_partition_id: int + _transaction: Optional[Transaction] + _unassigned_field_name = 'unassigned_field_name' + + def __init__(self, table: Table, transaction: Optional[Transaction] = None, case_sensitive: bool = True) -> None: + self._table = table + self._name_to_field = {field.name: field for field in table.spec().fields} + self._name_to_added_field = {} + self._transform_to_field = {(field.source_id, repr(field.transform)): field for field in table.spec().fields} + self._transform_to_added_field = {} + self._adds = [] + self._deletes = set() + if len(table.specs()) == 1: + self._last_assigned_partition_id = PARTITION_FIELD_ID_START - 1 + else: + self._last_assigned_partition_id = table.spec().last_assigned_field_id + self._renames = {} + self._transaction = transaction + self._case_sensitive = case_sensitive + self._added_time_fields = {} + + def add_field( + self, + source_column_name: str, + transform: Transform[Any, Any], + partition_field_name: Optional[str] = _unassigned_field_name, + ) -> UpdateSpec: + ref = Reference(source_column_name) + bound_ref = ref.bind(self._table.schema(), self._case_sensitive) + # verify transform can actually bind it + output_type = bound_ref.field.field_type + if not transform.can_transform(output_type): + raise ValueError(f"{transform} cannot transform {output_type} values from {bound_ref.field.name}") + + transform_key = (bound_ref.field.field_id, repr(transform)) + existing_partition_field = self._transform_to_field.get(transform_key) + if existing_partition_field and self._is_duplicate_partition(transform, existing_partition_field): + raise ValueError(f"Duplicate partition field for ${ref.name}=${ref}, ${existing_partition_field} already exists") + + added = self._transform_to_added_field.get(transform_key) + if added: + raise ValueError(f"Already added partition: {added.name}") + + new_field = self._partition_field((bound_ref.field.field_id, transform), partition_field_name) + if new_field.name == self._unassigned_field_name: + name = _visit_partition_field(self._table.schema(), new_field, _PartitionNameGenerator()) + new_field = PartitionField(new_field.source_id, new_field.field_id, new_field.transform, name) + + if new_field.name in self._name_to_added_field: + raise ValueError(f"Already added partition field with name: {new_field.name}") + + self._redundant_time_partition(new_field) + self._transform_to_added_field[transform_key] = new_field + + existing_partition_field = self._name_to_field.get(new_field.name) + if existing_partition_field and new_field.field_id not in self._deletes: + if isinstance(existing_partition_field.transform, VoidTransform): + self.rename_field( + existing_partition_field.name, existing_partition_field.name + "_" + str(existing_partition_field.field_id) + ) + else: + raise ValueError(f"Cannot add duplicate partition field name: {existing_partition_field.name}") + + self._name_to_added_field[new_field.name] = new_field + self._adds.append(new_field) + return self + + def add_identity(self, source_column_name: str) -> UpdateSpec: + return self.add_field(source_column_name, IdentityTransform(), self._unassigned_field_name) + + def remove_field(self, name: str) -> UpdateSpec: + added = self._name_to_added_field.get(name) + if added: + raise ValueError(f"Cannot delete newly added field {name}") + renamed = self._renames.get(name) + if renamed: + raise ValueError(f"Cannot rename and delete field {name}") + field = self._name_to_field.get(name) + if not field: + raise ValueError(f"No such partition field: {name}") + + self._deletes.add(field.field_id) + return self + + def rename_field(self, name: str, new_name: str) -> UpdateSpec: + existing_field = self._name_to_field.get(new_name) + if existing_field and isinstance(existing_field.transform, VoidTransform): + return self.rename_field(name, name + "_" + str(existing_field.field_id)) + added = self._name_to_added_field.get(name) + if added: + raise ValueError("Cannot rename recently added partitions") + field = self._name_to_field.get(name) + if not field: + raise ValueError(f"Cannot find partition field {name}") + if field.field_id in self._deletes: + raise ValueError(f"Cannot delete and rename partition field {name}") + self._renames[name] = new_name + return self + + def commit(self) -> None: + new_spec = self._apply() + updates = [] + requirements = [] + if self._table.metadata.default_spec_id != new_spec.spec_id: + if new_spec.spec_id not in self._table.specs(): + spec_update = AddPartitionSpecUpdate(spec=new_spec) + updates.append(spec_update) + if len(self._table.specs()) == 1: + required_last_assigned_partitioned_id = PARTITION_FIELD_ID_START - 1 + else: + required_last_assigned_partitioned_id = self._table.spec().last_assigned_field_id + requirements.append( + AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id) + ) + updates.append(SetDefaultSpecUpdate(spec_id=new_spec.spec_id)) + + if self._transaction is not None: + self._transaction._append_updates(*updates) # pylint: disable=W0212 + self._transaction._append_requirements(*requirements) # pylint: disable=W0212 + else: + requirements.append(AssertDefaultSpecId(default_spec_id=self._table.spec().spec_id)) + self._table._do_commit(updates=tuple(updates), requirements=tuple(requirements)) # pylint: disable=W0212 + + def __exit__(self, _: Any, value: Any, traceback: Any) -> None: + """Close and commit the change.""" + return self.commit() + + def __enter__(self) -> UpdateSpec: + """Update the table.""" + return self + + def _apply(self) -> PartitionSpec: + def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None: + try: + field = schema.find_field(name) + except ValueError: + field = None + + if source_id is not None and field is not None and field.field_id != source_id: + raise ValueError(f"Cannot create identity partition from a different field in the schema {name}") + elif field is not None and source_id != field.field_id: + raise ValueError(f"Cannot create partition from name that exists in schema {name}") + if not name: + raise ValueError("Undefined name") + if name in partition_names: + raise ValueError(f"Partition name has to be unique: {name}") + partition_names.add(name) + + def _add_new_field( + schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str] + ) -> PartitionField: + _check_and_add_partition_name(schema, name, source_id, partition_names) + return PartitionField(source_id, field_id, transform, name) + + partition_fields = [] + partition_names: Set[str] = set() + for field in self._table.spec().fields: + if field.field_id not in self._deletes: + renamed = self._renames.get(field.name) + if renamed: + new_field = _add_new_field( + self._table.schema(), field.source_id, field.field_id, renamed, field.transform, partition_names + ) + else: + new_field = _add_new_field( + self._table.schema(), field.source_id, field.field_id, field.name, field.transform, partition_names + ) + partition_fields.append(new_field) + elif self._table.format_version == 1: + renamed = self._renames.get(field.name) + if renamed: + new_field = _add_new_field( + self._table.schema(), field.source_id, field.field_id, renamed, VoidTransform(), partition_names + ) + else: + new_field = _add_new_field( + self._table.schema(), field.source_id, field.field_id, field.name, VoidTransform(), partition_names + ) + + partition_fields.append(new_field) + + for added_field in self._adds: + new_field = PartitionField( + source_id=added_field.source_id, + field_id=added_field.field_id, + transform=added_field.transform, + name=added_field.name, + ) + partition_fields.append(new_field) + # Reuse spec id or create a new one. + new_spec = PartitionSpec(*partition_fields) + new_spec_id = INITIAL_PARTITION_SPEC_ID + for spec in self._table.specs().values(): + if new_spec.compatible_with(spec): + new_spec_id = spec.spec_id + break + elif new_spec_id <= spec.spec_id: + new_spec_id = spec.spec_id + 1 + return PartitionSpec(*partition_fields, spec_id=new_spec_id) + + def _redundant_time_partition(self, field: PartitionField) -> None: + if isinstance(field.transform, TimeTransform): + existing_time_field = self._added_time_fields.get(field.source_id) + if existing_time_field: + raise ValueError(f"Cannot add time partition field: {field.name} conflicts with {existing_time_field.name}") + self._added_time_fields[field.source_id] = field + + def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]], name: Optional[str]) -> PartitionField: + if self._table.metadata.format_version == 2: + source_id, transform = transform_key + historical_fields = set() + for spec in self._table.specs().values(): + for field in spec.fields: + historical_fields.add(field) + + for field in historical_fields: + if field.source_id == source_id and field.transform == transform: + if name is not None or field.name == name: + return field + return PartitionField(transform_key[0], self._new_field_id(), transform_key[1], name) + + def _new_field_id(self) -> int: + self._last_assigned_partition_id += 1 + return self._last_assigned_partition_id + + def _is_duplicate_partition(self, transform: Transform[Any, Any], partition_field: PartitionField) -> bool: + return partition_field.field_id not in self._deletes and partition_field.transform == transform diff --git a/tests/test_integration_partition_evolution.py b/tests/test_integration_partition_evolution.py new file mode 100644 index 0000000000..1dff31ddc6 --- /dev/null +++ b/tests/test_integration_partition_evolution.py @@ -0,0 +1,366 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name + +import pytest + +from pyiceberg.catalog import Catalog, load_catalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + MonthTransform, + TruncateTransform, + VoidTransform, + YearTransform, +) +from pyiceberg.types import ( + LongType, + NestedField, + StringType, + TimestampType, +) + + +@pytest.fixture() +def catalog() -> Catalog: + return load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + +@pytest.fixture() +def simple_table(catalog: Catalog, table_schema_simple: Schema) -> Table: + return _create_table_with_schema(catalog, table_schema_simple, "1") + + +@pytest.fixture() +def table(catalog: Catalog) -> Table: + schema_with_timestamp = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + NestedField(3, "str", StringType(), required=False), + ) + return _create_table_with_schema(catalog, schema_with_timestamp, "1") + + +@pytest.fixture() +def table_v2(catalog: Catalog) -> Table: + schema_with_timestamp = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + NestedField(3, "str", StringType(), required=False), + ) + return _create_table_with_schema(catalog, schema_with_timestamp, "2") + + +def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table: + tbl_name = "default.test_schema_evolution" + try: + catalog.drop_table(tbl_name) + except NoSuchTableError: + pass + return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version}) + + +@pytest.mark.integration +def test_add_identity_partition(simple_table: Table) -> None: + simple_table.update_spec().add_identity("foo").commit() + specs = simple_table.specs() + assert len(specs) == 2 + spec = simple_table.spec() + assert spec.spec_id == 1 + assert spec.last_assigned_field_id == 1000 + + +@pytest.mark.integration +def test_add_year(table: Table) -> None: + table.update_spec().add_field("event_ts", YearTransform(), "year_transform").commit() + _validate_new_partition_fields(table, 1000, 1, PartitionField(2, 1000, YearTransform(), "year_transform")) + + +@pytest.mark.integration +def test_add_month(table: Table) -> None: + table.update_spec().add_field("event_ts", MonthTransform(), "month_transform").commit() + _validate_new_partition_fields(table, 1000, 1, PartitionField(2, 1000, MonthTransform(), "month_transform")) + + +@pytest.mark.integration +def test_add_day(table: Table) -> None: + table.update_spec().add_field("event_ts", DayTransform(), "day_transform").commit() + _validate_new_partition_fields(table, 1000, 1, PartitionField(2, 1000, DayTransform(), "day_transform")) + + +@pytest.mark.integration +def test_add_hour(table: Table) -> None: + table.update_spec().add_field("event_ts", HourTransform(), "hour_transform").commit() + _validate_new_partition_fields(table, 1000, 1, PartitionField(2, 1000, HourTransform(), "hour_transform")) + + +@pytest.mark.integration +def test_add_bucket(simple_table: Table) -> None: + simple_table.update_spec().add_field("foo", BucketTransform(12), "bucket_transform").commit() + _validate_new_partition_fields(simple_table, 1000, 1, PartitionField(1, 1000, BucketTransform(12), "bucket_transform")) + + +@pytest.mark.integration +def test_add_truncate(simple_table: Table) -> None: + simple_table.update_spec().add_field("foo", TruncateTransform(1), "truncate_transform").commit() + _validate_new_partition_fields(simple_table, 1000, 1, PartitionField(1, 1000, TruncateTransform(1), "truncate_transform")) + + +@pytest.mark.integration +def test_multiple_adds(table: Table) -> None: + table.update_spec().add_identity("id").add_field("event_ts", HourTransform(), "hourly_partitioned").add_field( + "str", TruncateTransform(2), "truncate_str" + ).commit() + _validate_new_partition_fields( + table, + 1002, + 1, + PartitionField(1, 1000, IdentityTransform(), "id"), + PartitionField(2, 1001, HourTransform(), "hourly_partitioned"), + PartitionField(3, 1002, TruncateTransform(2), "truncate_str"), + ) + + +@pytest.mark.integration +def test_add_hour_to_day(table: Table) -> None: + table.update_spec().add_field("event_ts", DayTransform(), "daily_partitioned").commit() + table.update_spec().add_field("event_ts", HourTransform(), "hourly_partitioned").commit() + _validate_new_partition_fields( + table, + 1001, + 2, + PartitionField(2, 1000, DayTransform(), "daily_partitioned"), + PartitionField(2, 1001, HourTransform(), "hourly_partitioned"), + ) + + +@pytest.mark.integration +def test_add_multiple_buckets(table: Table) -> None: + table.update_spec().add_field("id", BucketTransform(16)).add_field("id", BucketTransform(4)).commit() + _validate_new_partition_fields( + table, + 1001, + 1, + PartitionField(1, 1000, BucketTransform(16), "id_bucket_16"), + PartitionField(1, 1001, BucketTransform(4), "id_bucket_4"), + ) + + +@pytest.mark.integration +def test_remove_identity(table: Table) -> None: + table.update_spec().add_identity("id").commit() + table.update_spec().remove_field("id").commit() + assert len(table.specs()) == 3 + assert table.spec().spec_id == 2 + assert table.spec() == PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=VoidTransform(), name='id'), spec_id=2 + ) + + +@pytest.mark.integration +def test_remove_identity_v2(table_v2: Table) -> None: + table_v2.update_spec().add_identity("id").commit() + table_v2.update_spec().remove_field("id").commit() + assert len(table_v2.specs()) == 2 + assert table_v2.spec().spec_id == 0 + assert table_v2.spec() == PartitionSpec(spec_id=0) + + +@pytest.mark.integration +def test_remove_bucket(table: Table) -> None: + with table.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table.update_spec() as remove: + remove.remove_field("bucketed_id") + + assert len(table.specs()) == 3 + _validate_new_partition_fields( + table, + 1001, + 2, + PartitionField(source_id=1, field_id=1000, transform=VoidTransform(), name='bucketed_id'), + PartitionField(source_id=2, field_id=1001, transform=DayTransform(), name='day_ts'), + ) + + +@pytest.mark.integration +def test_remove_bucket_v2(table_v2: Table) -> None: + with table_v2.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table_v2.update_spec() as remove: + remove.remove_field("bucketed_id") + assert len(table_v2.specs()) == 3 + _validate_new_partition_fields( + table_v2, 1001, 2, PartitionField(source_id=2, field_id=1001, transform=DayTransform(), name='day_ts') + ) + + +@pytest.mark.integration +def test_remove_day(table: Table) -> None: + with table.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table.update_spec() as remove: + remove.remove_field("day_ts") + + assert len(table.specs()) == 3 + _validate_new_partition_fields( + table, + 1001, + 2, + PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name='bucketed_id'), + PartitionField(source_id=2, field_id=1001, transform=VoidTransform(), name='day_ts'), + ) + + +@pytest.mark.integration +def test_remove_day_v2(table_v2: Table) -> None: + with table_v2.update_spec() as update: + update.add_field("id", BucketTransform(16), "bucketed_id") + update.add_field("event_ts", DayTransform(), "day_ts") + with table_v2.update_spec() as remove: + remove.remove_field("day_ts") + assert len(table_v2.specs()) == 3 + _validate_new_partition_fields( + table_v2, 1000, 2, PartitionField(source_id=1, field_id=1000, transform=BucketTransform(16), name='bucketed_id') + ) + + +@pytest.mark.integration +def test_rename(table: Table) -> None: + table.update_spec().add_identity("id").commit() + table.update_spec().rename_field("id", "sharded_id").commit() + assert len(table.specs()) == 3 + assert table.spec().spec_id == 2 + _validate_new_partition_fields(table, 1000, 2, PartitionField(1, 1000, IdentityTransform(), "sharded_id")) + + +@pytest.mark.integration +def test_cannot_add_and_remove(table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_identity("id").remove_field("id").commit() + assert "Cannot delete newly added field id" in str(exc_info.value) + + +@pytest.mark.integration +def test_cannot_add_redundant_time_partition(table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_field("event_ts", YearTransform(), "year_transform").add_field( + "event_ts", HourTransform(), "hour_transform" + ).commit() + assert "Cannot add time partition field: hour_transform conflicts with year_transform" in str(exc_info.value) + + +@pytest.mark.integration +def test_cannot_delete_and_rename(table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_identity("id").commit() + table.update_spec().remove_field("id").rename_field("id", "sharded_id").commit() + assert "Cannot delete and rename partition field id" in str(exc_info.value) + + +@pytest.mark.integration +def test_cannot_rename_and_delete(table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_identity("id").commit() + table.update_spec().rename_field("id", "sharded_id").remove_field("id").commit() + assert "Cannot rename and delete field id" in str(exc_info.value) + + +@pytest.mark.integration +def test_cannot_add_same_tranform_for_same_field(table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_field("str", TruncateTransform(4), "truncated_str").add_field( + "str", TruncateTransform(4) + ).commit() + assert "Already added partition" in str(exc_info.value) + + +@pytest.mark.integration +def test_cannot_add_same_field_multiple_times(table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_field("id", IdentityTransform(), "duplicate").add_field( + "id", IdentityTransform(), "duplicate" + ).commit() + assert "Already added partition" in str(exc_info.value) + + +@pytest.mark.integration +def test_cannot_add_multiple_specs_same_name(table: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table.update_spec().add_field("id", IdentityTransform(), "duplicate").add_field( + "event_ts", IdentityTransform(), "duplicate" + ).commit() + assert "Already added partition" in str(exc_info.value) + + +@pytest.mark.integration +def test_change_specs_and_schema_transaction(table: Table) -> None: + with table.transaction() as transaction: + with transaction.update_spec() as update_spec: + update_spec.add_identity("id").add_field("event_ts", HourTransform(), "hourly_partitioned").add_field( + "str", TruncateTransform(2), "truncate_str" + ) + + with transaction.update_schema() as update_schema: + update_schema.add_column("col_string", StringType()) + + _validate_new_partition_fields( + table, + 1002, + 1, + PartitionField(1, 1000, IdentityTransform(), "id"), + PartitionField(2, 1001, HourTransform(), "hourly_partitioned"), + PartitionField(3, 1002, TruncateTransform(2), "truncate_str"), + ) + + assert table.schema() == Schema( + NestedField(field_id=1, name='id', field_type=LongType(), required=False), + NestedField(field_id=2, name='event_ts', field_type=TimestampType(), required=False), + NestedField(field_id=3, name='str', field_type=StringType(), required=False), + NestedField(field_id=4, name='col_string', field_type=StringType(), required=False), + schema_id=1, + identifier_field_ids=[], + ) + + +def _validate_new_partition_fields( + table: Table, expected_last_assigned_field_id: int, expected_spec_id: int, *expected_partition_fields: PartitionField +) -> None: + spec = table.spec() + assert spec.spec_id == expected_spec_id + assert spec.last_assigned_field_id == expected_last_assigned_field_id + assert len(spec.fields) == len(expected_partition_fields) + for i in range(len(spec.fields)): + assert spec.fields[i] == expected_partition_fields[i]