From b29f5f937c020ebbc7b431caaca198f028aba48e Mon Sep 17 00:00:00 2001 From: amogh-jahagirdar Date: Wed, 27 Dec 2023 20:06:32 -0800 Subject: [PATCH] Initial partition evolution --- pyiceberg/partitioning.py | 146 +++++++++- pyiceberg/table/__init__.py | 252 ++++++++++++++++- tests/test_integration_partition_evolution.py | 259 ++++++++++++++++++ 3 files changed, 647 insertions(+), 10 deletions(-) create mode 100644 tests/test_integration_partition_evolution.py diff --git a/pyiceberg/partitioning.py b/pyiceberg/partitioning.py index f6307f0f8c..5ff32dce7a 100644 --- a/pyiceberg/partitioning.py +++ b/pyiceberg/partitioning.py @@ -16,14 +16,9 @@ # under the License. from __future__ import annotations -from functools import cached_property -from typing import ( - Any, - Dict, - List, - Optional, - Tuple, -) +from abc import ABC, abstractmethod +from functools import cached_property, singledispatch +from typing import Any, Dict, Generic, List, Optional, Tuple, TypeVar from pydantic import ( BeforeValidator, @@ -34,7 +29,18 @@ from typing_extensions import Annotated from pyiceberg.schema import Schema -from pyiceberg.transforms import Transform, parse_transform +from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + Transform, + TruncateTransform, + UnknownTransform, + VoidTransform, + YearTransform, + parse_transform, +) from pyiceberg.typedef import IcebergBaseModel from pyiceberg.types import NestedField, StructType @@ -85,6 +91,20 @@ def __str__(self) -> str: """Return the string representation of the PartitionField class.""" return f"{self.field_id}: {self.name}: {self.transform}({self.source_id})" + def __hash__(self) -> int: + """Return the hash of the partition field.""" + return hash((self.name, self.source_id, self.field_id, repr(self.transform))) + + def __eq__(self, other: Any) -> bool: + """Return True if two partition fields are considered equal, False otherwise.""" + return ( + isinstance(other, PartitionField) + and other.field_id == self.field_id + and other.name == self.name + and other.source_id == self.source_id + and repr(other.transform) == repr(self.transform) + ) + class PartitionSpec(IcebergBaseModel): """ @@ -215,3 +235,111 @@ def assign_fresh_partition_spec_ids(spec: PartitionSpec, old_schema: Schema, fre ) ) return PartitionSpec(*partition_fields, spec_id=INITIAL_PARTITION_SPEC_ID) + + +T = TypeVar("T") + + +class PartitionSpecVisitor(Generic[T], ABC): + @abstractmethod + def identity(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit identity partition field.""" + + @abstractmethod + def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> T: + """Visit bucket partition field.""" + + @abstractmethod + def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> T: + """Visit truncate partition field.""" + + @abstractmethod + def year(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit year partition field.""" + + @abstractmethod + def month(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit month partition field.""" + + @abstractmethod + def day(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit day partition field.""" + + @abstractmethod + def hour(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit hour partition field.""" + + @abstractmethod + def always_null(self, field_id: int, source_name: str, source_id: int) -> T: + """Visit void partition field.""" + + @abstractmethod + def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> T: + """Visit unknown partition field.""" + raise ValueError(f"Unknown transform {transform} is not supported") + + +class _PartitionNameGenerator(PartitionSpecVisitor[str]): + def identity(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + + def bucket(self, field_id: int, source_name: str, source_id: int, num_buckets: int) -> str: + return source_name + "_bucket_" + str(num_buckets) + + def truncate(self, field_id: int, source_name: str, source_id: int, width: int) -> str: + return source_name + "_trunc_" + str(width) + + def year(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_year" + + def month(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_month" + + def day(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_day" + + def hour(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_hour" + + def always_null(self, field_id: int, source_name: str, source_id: int) -> str: + return source_name + "_null" + + def unknown(self, field_id: int, source_name: str, source_id: int, transform: str) -> str: + return super().unknown(field_id, source_name, source_id, transform) + + +R = TypeVar("R") + + +@singledispatch +def _visit(spec: PartitionSpec, schema: Schema, visitor: PartitionSpecVisitor[R]) -> List[R]: + results = [] + for field in spec.fields: + results.append(_visit_field(schema, field, visitor)) + return results + + +def _visit_field(schema: Schema, field: PartitionField, visitor: PartitionSpecVisitor[R]) -> R: + source_name = schema.find_column_name(field.source_id) + if not source_name: + raise ValueError(f"Could not find column with id {field.source_id}") + + transform = field.transform + if isinstance(transform, IdentityTransform): + return visitor.identity(field.field_id, source_name, field.source_id) + elif isinstance(transform, BucketTransform): + return visitor.bucket(field.field_id, source_name, field.source_id, transform.num_buckets) + elif isinstance(transform, TruncateTransform): + return visitor.truncate(field.field_id, source_name, field.source_id, transform.width) + elif isinstance(transform, DayTransform): + return visitor.day(field.field_id, source_name, field.source_id) + elif isinstance(transform, HourTransform): + return visitor.hour(field.field_id, source_name, field.source_id) + elif isinstance(transform, YearTransform): + return visitor.year(field.field_id, source_name, field.source_id) + elif isinstance(transform, VoidTransform): + return visitor.always_null(field.field_id, source_name, field.source_id) + elif isinstance(transform, UnknownTransform): + return visitor.unknown(field.field_id, source_name, field.source_id, repr(transform)) + else: + raise ValueError(f"Unknown transform {transform}") diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 40b37cc248..7a45ad7703 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -50,6 +50,7 @@ And, BooleanExpression, EqualTo, + Reference, parser, visitors, ) @@ -63,7 +64,15 @@ ManifestEntry, ManifestFile, ) -from pyiceberg.partitioning import PartitionSpec +from pyiceberg.partitioning import ( + INITIAL_PARTITION_SPEC_ID, + PARTITION_FIELD_ID_START, + IdentityTransform, + PartitionField, + PartitionSpec, + _PartitionNameGenerator, + _visit_field, +) from pyiceberg.schema import ( Schema, SchemaVisitor, @@ -80,6 +89,7 @@ from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry from pyiceberg.table.sorting import SortOrder +from pyiceberg.transforms import TimeTransform, Transform, VoidTransform from pyiceberg.typedef import ( EMPTY_DICT, IcebergBaseModel, @@ -830,6 +840,9 @@ def history(self) -> List[SnapshotLogEntry]: def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive: bool = True) -> UpdateSchema: return UpdateSchema(self, allow_incompatible_changes=allow_incompatible_changes, case_sensitive=case_sensitive) + def update_spec(self, case_sensitive: bool = True) -> UpdateSpec: + return UpdateSpec(self, case_sensitive=case_sensitive) + def refs(self) -> Dict[str, SnapshotRef]: """Return the snapshot references in the table.""" return self.metadata.refs @@ -1904,3 +1917,240 @@ def _generate_snapshot_id() -> int: snapshot_id = snapshot_id if snapshot_id >= 0 else snapshot_id * -1 return snapshot_id + + +class UpdateSpec: + _table: Table + _schema: Schema + _spec: PartitionSpec + _name_to_field: Dict[str, PartitionField] = {} + _name_to_added_field: Dict[str, PartitionField] = {} + _transform_to_field: Dict[Tuple[int, str], PartitionField] = {} + _transform_to_added_field: Dict[Tuple[int, str], PartitionField] = {} + _renames: Dict[str, str] = {} + _added_time_fields: Dict[int, PartitionField] = {} + _case_sensitive: bool + _adds: List[PartitionField] + _deletes: Set[int] + _last_assigned_partition_id: int + _transaction: Optional[Transaction] + _unassigned_field_name = 'unassigned_field_name' + + def __init__(self, table: Table, transaction: Optional[Transaction] = None, case_sensitive: bool = True) -> None: + self._table = table + self._schema = table.schema() + self._spec = table.spec() + self._name_to_field = {field.name: field for field in self._spec.fields} + self._name_to_added_field = {} + self._transform_to_field = {(field.source_id, repr(field.transform)): field for field in self._spec.fields} + self._transform_to_added_field = {} + self._adds = [] + self._deletes = set() + if len(table.specs()) == 1: + self._last_assigned_partition_id = PARTITION_FIELD_ID_START - 1 + else: + self._last_assigned_partition_id = table.spec().last_assigned_field_id + self._renames = {} + self._transaction = transaction + self._case_sensitive = case_sensitive + self._added_time_fields = {} + + def add_field( + self, partition_field_name: Optional[str], source_column_name: str, transform: Transform[Any, Any] + ) -> UpdateSpec: + ref = Reference(source_column_name) + bound_ref = ref.bind(self._schema, self._case_sensitive) + # verify transform can actually bind it + output_type = bound_ref.field.field_type + if not transform.can_transform(output_type): + raise ValueError(f"{transform} cannot transform {output_type} values from {bound_ref.field.name}") + + transform_key = (bound_ref.field.field_id, repr(transform)) + existing_partition_field = self._transform_to_field.get(transform_key) + if existing_partition_field and self._is_duplicate_partition(transform, existing_partition_field): + raise ValueError(f"Duplicate partition field for ${ref.name}=${ref}, ${existing_partition_field} already exists") + + added = self._transform_to_added_field.get(transform_key) + if added: + raise ValueError(f"Already added partition {added}") + new_field = self._partition_field((bound_ref.field.field_id, transform), partition_field_name) + if new_field.name == self._unassigned_field_name: + name = _visit_field(self._schema, new_field, _PartitionNameGenerator()) + new_field = PartitionField(new_field.source_id, new_field.field_id, new_field.transform, name) + self._check_redundant_partitions(new_field) + self._transform_to_added_field[transform_key] = new_field + + existing_partition_field = self._name_to_field.get(new_field.name) + if existing_partition_field and new_field.field_id not in self._deletes: + if isinstance(existing_partition_field.transform, VoidTransform): + self.rename_field( + existing_partition_field.name, existing_partition_field.name + "_" + str(existing_partition_field.field_id) + ) + else: + raise ValueError(f"Cannot add duplicate partition field name: {existing_partition_field.name}") + + self._name_to_added_field[new_field.name] = new_field + self._adds.append(new_field) + return self + + def add_identity(self, source_column_name: str) -> UpdateSpec: + return self.add_field(self._unassigned_field_name, source_column_name, IdentityTransform()) + + def remove_field(self, name: str) -> UpdateSpec: + added = self._name_to_added_field.get(name) + if added: + raise ValueError(f"Cannot delete newly added field {name}") + renamed = self._renames.get(name) + if renamed: + raise ValueError(f"Cannot rename and delete field {name}") + field = self._name_to_field.get(name) + if not field: + raise ValueError(f"No such partition field: {name}") + + self._deletes.add(field.field_id) + return self + + def rename_field(self, name: str, new_name: str) -> UpdateSpec: + existing_field = self._name_to_field.get(new_name) + if existing_field and isinstance(existing_field.transform, VoidTransform): + return self.rename_field(name, name + "_" + str(existing_field.field_id)) + added = self._name_to_added_field.get(name) + if added: + raise ValueError("Cannot rename recently added partitions") + field = self._name_to_field.get(name) + if not field: + raise ValueError(f"Cannot find partition field {name}") + if field.field_id in self._deletes: + raise ValueError(f"Cannot delete and rename partition field {name}") + self._renames[name] = new_name + return self + + def commit(self) -> None: + new_spec = self._apply() + updates = [] + requirements = [] + if self._table.metadata.default_spec_id != new_spec.spec_id: + if new_spec.spec_id not in self._table.specs(): + spec_update = AddPartitionSpecUpdate(spec=new_spec) + updates.append(spec_update) + if len(self._table.specs()) == 1: + required_last_assigned_partitioned_id = PARTITION_FIELD_ID_START - 1 + else: + required_last_assigned_partitioned_id = self._table.spec().last_assigned_field_id + requirements.append( + AssertLastAssignedPartitionId(last_assigned_partition_id=required_last_assigned_partitioned_id) + ) + if self._transaction: + self._transaction._append_updates(*updates) # pylint: disable=W0212 + self._transaction._append_requirements(*requirements) # pylint: disable=W0212 + else: + updates.append(SetDefaultSpecUpdate(spec_id=new_spec.spec_id)) + requirements.append(AssertDefaultSpecId(default_spec_id=self._table.metadata.default_spec_id)) + self._table._do_commit(updates=tuple(updates), requirements=tuple(requirements)) # pylint: disable=W0212 + + def __exit__(self, _: Any, value: Any, traceback: Any) -> None: + """Close and commit the change.""" + return self.commit() + + def __enter__(self) -> UpdateSpec: + """Update the table.""" + return self + + def _apply(self) -> PartitionSpec: + def _check_and_add_partition_name(schema: Schema, name: str, source_id: int, partition_names: Set[str]) -> None: + try: + field = schema.find_field(name) + except ValueError: + field = None + + if source_id and field and field.field_id != source_id: + raise ValueError(f"Cannot create identity partition from a different field in the schema {name}") + elif field and source_id != field.field_id: + raise ValueError(f"Cannot create partition from name that exists in schema {name}") + if not name: + raise ValueError("Undefined name") + if name in partition_names: + raise ValueError(f"Cannot use partition name more than once {name}") + partition_names.add(name) + + def _add_new_field( + schema: Schema, source_id: int, field_id: int, name: str, transform: Transform[Any, Any], partition_names: Set[str] + ) -> PartitionField: + _check_and_add_partition_name(schema, name, source_id, partition_names) + return PartitionField(source_id, field_id, transform, name) + + partition_fields = [] + partition_names: Set[str] = set() + for field in self._spec.fields: + if field.field_id not in self._deletes: + renamed = self._renames.get(field.name) + if renamed: + new_field = _add_new_field( + self._schema, field.source_id, field.field_id, renamed, field.transform, partition_names + ) + else: + new_field = _add_new_field( + self._schema, field.source_id, field.field_id, field.name, field.transform, partition_names + ) + partition_fields.append(new_field) + elif self._table.format_version == 1: + renamed = self._renames.get(field.name) + if renamed: + new_field = _add_new_field( + self._schema, field.source_id, field.field_id, renamed, VoidTransform(), partition_names + ) + else: + new_field = _add_new_field( + self._schema, field.source_id, field.field_id, field.name, VoidTransform(), partition_names + ) + + partition_fields.append(new_field) + + for added_field in self._adds: + new_field = PartitionField( + source_id=added_field.source_id, + field_id=added_field.field_id, + transform=added_field.transform, + name=added_field.name, + ) + partition_fields.append(new_field) + # Reuse spec id or create a new one. + new_spec = PartitionSpec(*partition_fields) + new_spec_id = INITIAL_PARTITION_SPEC_ID + for spec in self._table.specs().values(): + if new_spec.compatible_with(spec): + new_spec_id = spec.spec_id + break + elif new_spec_id <= spec.spec_id: + new_spec_id = spec.spec_id + 1 + + return PartitionSpec(*partition_fields, spec_id=new_spec_id) + + def _check_redundant_partitions(self, field: PartitionField) -> None: + if isinstance(field.transform, TimeTransform): + existing_time_field = self._added_time_fields.get(field.source_id) + if existing_time_field: + raise ValueError(f"Cannot add time partition field: {field.name} conflicts with {existing_time_field.name}") + self._added_time_fields[field.source_id] = field + + def _partition_field(self, transform_key: Tuple[int, Transform[Any, Any]], name: Optional[str]) -> PartitionField: + if self._table.metadata.format_version == 2: + source_id = transform_key[0] + transform = transform_key[1] + historical_fields = set() + for spec in self._table.specs().values(): + for field in spec.fields: + historical_fields.add(field) + + for field in historical_fields: + if field.source_id == source_id and field.transform == transform: + if not name or field.name == name: + return field + return PartitionField(transform_key[0], self._new_field_id(), transform_key[1], name) + + def _new_field_id(self) -> int: + self._last_assigned_partition_id += 1 + return self._last_assigned_partition_id + + def _is_duplicate_partition(self, transform: Transform[Any, Any], partition_field: PartitionField) -> bool: + return partition_field.field_id not in self._deletes and partition_field.transform == transform diff --git a/tests/test_integration_partition_evolution.py b/tests/test_integration_partition_evolution.py new file mode 100644 index 0000000000..d9d733394c --- /dev/null +++ b/tests/test_integration_partition_evolution.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name + +import pytest + +from pyiceberg.catalog import Catalog, load_catalog +from pyiceberg.exceptions import NoSuchTableError +from pyiceberg.partitioning import PartitionField, PartitionSpec +from pyiceberg.schema import Schema +from pyiceberg.table import Table +from pyiceberg.transforms import ( + BucketTransform, + DayTransform, + HourTransform, + IdentityTransform, + MonthTransform, + TruncateTransform, + VoidTransform, + YearTransform, +) +from pyiceberg.types import ( + LongType, + NestedField, + StringType, + TimestampType, +) + + +@pytest.fixture() +def catalog() -> Catalog: + return load_catalog( + "local", + **{ + "type": "rest", + "uri": "http://localhost:8181", + "s3.endpoint": "http://localhost:9000", + "s3.access-key-id": "admin", + "s3.secret-access-key": "password", + }, + ) + + +@pytest.fixture() +def simple_table(catalog: Catalog, table_schema_simple: Schema) -> Table: + return _create_table_with_schema(catalog, table_schema_simple, "1") + + +@pytest.fixture() +def table_with_timestamp(catalog: Catalog) -> Table: + schema_with_timestamp = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + NestedField(3, "str", StringType(), required=False), + ) + return _create_table_with_schema(catalog, schema_with_timestamp, "1") + + +@pytest.fixture() +def table_with_timestamp_v2(catalog: Catalog) -> Table: + schema_with_timestamp = Schema( + NestedField(1, "id", LongType(), required=False), + NestedField(2, "event_ts", TimestampType(), required=False), + NestedField(3, "str", StringType(), required=False), + ) + return _create_table_with_schema(catalog, schema_with_timestamp, "2") + + +def _create_table_with_schema(catalog: Catalog, schema: Schema, format_version: str) -> Table: + tbl_name = "default.test_schema_evolution" + try: + catalog.drop_table(tbl_name) + except NoSuchTableError: + pass + return catalog.create_table(identifier=tbl_name, schema=schema, properties={"format-version": format_version}) + + +@pytest.mark.integration +def test_add_identity_partition(simple_table: Table) -> None: + simple_table.update_spec().add_identity("foo").commit() + specs = simple_table.specs() + assert len(specs) == 2 + spec = simple_table.spec() + assert spec.spec_id == 1 + assert spec.last_assigned_field_id == 1000 + + +@pytest.mark.integration +def test_add_year(table_with_timestamp: Table) -> None: + table_with_timestamp.update_spec().add_field("year_transform", "event_ts", YearTransform()).commit() + _validate_new_partition_fields(table_with_timestamp, 1000, 1, PartitionField(2, 1000, YearTransform(), "year_transform")) + + +@pytest.mark.integration +def test_add_month(table_with_timestamp: Table) -> None: + table_with_timestamp.update_spec().add_field("month_transform", "event_ts", MonthTransform()).commit() + _validate_new_partition_fields(table_with_timestamp, 1000, 1, PartitionField(2, 1000, MonthTransform(), "month_transform")) + + +@pytest.mark.integration +def test_add_day(table_with_timestamp: Table) -> None: + table_with_timestamp.update_spec().add_field("day_transform", "event_ts", DayTransform()).commit() + _validate_new_partition_fields(table_with_timestamp, 1000, 1, PartitionField(2, 1000, DayTransform(), "day_transform")) + + +@pytest.mark.integration +def test_add_hour(table_with_timestamp: Table) -> None: + table_with_timestamp.update_spec().add_field("hour_transform", "event_ts", HourTransform()).commit() + _validate_new_partition_fields(table_with_timestamp, 1000, 1, PartitionField(2, 1000, HourTransform(), "hour_transform")) + + +@pytest.mark.integration +def test_add_bucket(simple_table: Table) -> None: + simple_table.update_spec().add_field("bucket_transform", "foo", BucketTransform(12)).commit() + _validate_new_partition_fields(simple_table, 1000, 1, PartitionField(1, 1000, BucketTransform(12), "bucket_transform")) + + +@pytest.mark.integration +def test_add_truncate(simple_table: Table) -> None: + simple_table.update_spec().add_field("truncate_transform", "foo", TruncateTransform(1)).commit() + _validate_new_partition_fields(simple_table, 1000, 1, PartitionField(1, 1000, TruncateTransform(1), "truncate_transform")) + + +@pytest.mark.integration +def test_multiple_adds(table_with_timestamp: Table) -> None: + table_with_timestamp.update_spec().add_identity("id").add_field("hourly_partitioned", "event_ts", HourTransform()).add_field( + "truncate_str", "str", TruncateTransform(2) + ).commit() + _validate_new_partition_fields( + table_with_timestamp, + 1002, + 1, + PartitionField(1, 1000, IdentityTransform(), "id"), + PartitionField(2, 1001, HourTransform(), "hourly_partitioned"), + PartitionField(3, 1002, TruncateTransform(2), "truncate_str"), + ) + + +@pytest.mark.integration +def test_add_hour_to_day(table_with_timestamp: Table) -> None: + table_with_timestamp.update_spec().add_field("daily_partitioned", "event_ts", DayTransform()).commit() + table_with_timestamp.update_spec().add_field("hourly_partitioned", "event_ts", HourTransform()).commit() + _validate_new_partition_fields( + table_with_timestamp, + 1001, + 2, + PartitionField(2, 1000, DayTransform(), "daily_partitioned"), + PartitionField(2, 1001, HourTransform(), "hourly_partitioned"), + ) + + +@pytest.mark.integration +def test_add_multiple_buckets(table_with_timestamp: Table) -> None: + table_with_timestamp.update_spec().add_field("bucket_16", "id", BucketTransform(16)).add_field( + "bucket_4", "id", BucketTransform(4) + ).commit() + _validate_new_partition_fields( + table_with_timestamp, + 1001, + 1, + PartitionField(1, 1000, BucketTransform(16), "bucket_16"), + PartitionField(1, 1001, BucketTransform(4), "bucket_4"), + ) + + +@pytest.mark.integration +def test_remove_identity(table_with_timestamp: Table) -> None: + table_with_timestamp.update_spec().add_identity("id").commit() + table_with_timestamp.update_spec().remove_field("id").commit() + assert len(table_with_timestamp.specs()) == 3 + assert table_with_timestamp.spec().spec_id == 2 + assert table_with_timestamp.spec() == PartitionSpec( + PartitionField(source_id=1, field_id=1000, transform=VoidTransform(), name='id'), spec_id=2 + ) + + +@pytest.mark.integration +def test_remove_identity_v2(table_with_timestamp_v2: Table) -> None: + table_with_timestamp_v2.update_spec().add_identity("id").commit() + table_with_timestamp_v2.update_spec().remove_field("id").commit() + assert len(table_with_timestamp_v2.specs()) == 2 + assert table_with_timestamp_v2.spec().spec_id == 0 + assert table_with_timestamp_v2.spec() == PartitionSpec(spec_id=0) + + +@pytest.mark.integration +def test_remove_bucket(table_with_timestamp: Table) -> None: + pass + + +@pytest.mark.integration +def test_remove_day(table_with_timestamp: Table) -> None: + pass + + +@pytest.mark.integration +def test_rename(table_with_timestamp: Table) -> None: + table_with_timestamp.update_spec().add_identity("id").commit() + table_with_timestamp.update_spec().rename_field("id", "sharded_id").commit() + assert len(table_with_timestamp.specs()) == 3 + assert table_with_timestamp.spec().spec_id == 2 + _validate_new_partition_fields(table_with_timestamp, 1000, 2, PartitionField(1, 1000, IdentityTransform(), "sharded_id")) + + +@pytest.mark.integration +def test_cannot_add_and_remove(table_with_timestamp: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table_with_timestamp.update_spec().add_identity("id").remove_field("id").commit() + assert "Cannot delete newly added field id" in str(exc_info.value) + + +@pytest.mark.integration +def test_cannot_add_redundant_time_partition(table_with_timestamp: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table_with_timestamp.update_spec().add_field("year_transform", "event_ts", YearTransform()).add_field( + "hour_transform", "event_ts", HourTransform() + ).commit() + assert "Cannot add time partition field: hour_transform conflicts with year_transform" in str(exc_info.value) + + +@pytest.mark.integration +def test_cannot_delete_and_rename(table_with_timestamp: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table_with_timestamp.update_spec().add_identity("id").commit() + table_with_timestamp.update_spec().remove_field("id").rename_field("id", "sharded_id").commit() + assert "Cannot delete and rename partition field id" in str(exc_info.value) + + +@pytest.mark.integration +def test_cannot_rename_and_delete(table_with_timestamp: Table) -> None: + with pytest.raises(ValueError) as exc_info: + table_with_timestamp.update_spec().add_identity("id").commit() + table_with_timestamp.update_spec().rename_field("id", "sharded_id").remove_field("id").commit() + assert "Cannot rename and delete field id" in str(exc_info.value) + + +def _validate_new_partition_fields( + table: Table, expected_last_assigned_field_id: int, expected_spec_id: int, *expected_partition_fields: PartitionField +) -> None: + spec = table.spec() + assert spec.spec_id == expected_spec_id + assert spec.last_assigned_field_id == expected_last_assigned_field_id + assert len(spec.fields) == len(expected_partition_fields) + for i in range(len(spec.fields)): + assert spec.fields[i] == expected_partition_fields[i]