From e92e10ae4924d4a2e51a83acc3321c0896f9fa54 Mon Sep 17 00:00:00 2001 From: HonahX <140284484+HonahX@users.noreply.github.com> Date: Mon, 11 Dec 2023 08:02:54 -0800 Subject: [PATCH] Table Requirements Validation (#200) * implement requirements validation * change the exception to CommitFailedException * add docstring * fix CI issue * make base_metadata optional and add null check --- pyiceberg/exceptions.py | 2 +- pyiceberg/table/__init__.py | 80 ++++++++++++++++++++++- pyiceberg/table/refs.py | 4 ++ tests/table/test_init.py | 127 ++++++++++++++++++++++++++++++++++++ 4 files changed, 210 insertions(+), 3 deletions(-) diff --git a/pyiceberg/exceptions.py b/pyiceberg/exceptions.py index f555543723..64356b11a4 100644 --- a/pyiceberg/exceptions.py +++ b/pyiceberg/exceptions.py @@ -104,7 +104,7 @@ class GenericDynamoDbError(DynamoDbError): pass -class CommitFailedException(RESTError): +class CommitFailedException(Exception): """Commit failed, refresh and try again.""" diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 4768706d1d..e4ca71f0a1 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -44,7 +44,7 @@ from sortedcontainers import SortedList from typing_extensions import Annotated -from pyiceberg.exceptions import ResolveError, ValidationError +from pyiceberg.exceptions import CommitFailedException, ResolveError, ValidationError from pyiceberg.expressions import ( AlwaysTrue, And, @@ -540,18 +540,40 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda class TableRequirement(IcebergBaseModel): type: str + @abstractmethod + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + """Validate the requirement against the base metadata. + + Args: + base_metadata: The base metadata to be validated against. + + Raises: + CommitFailedException: When the requirement is not met. + """ + ... + class AssertCreate(TableRequirement): """The table must not already exist; used for create transactions.""" type: Literal["assert-create"] = Field(default="assert-create") + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is not None: + raise CommitFailedException("Table already exists") + class AssertTableUUID(TableRequirement): """The table UUID must match the requirement's `uuid`.""" type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid") - uuid: str + uuid: uuid.UUID + + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif self.uuid != base_metadata.table_uuid: + raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}") class AssertRefSnapshotId(TableRequirement): @@ -564,6 +586,20 @@ class AssertRefSnapshotId(TableRequirement): ref: str snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id") + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif snapshot_ref := base_metadata.refs.get(self.ref): + ref_type = snapshot_ref.snapshot_ref_type + if self.snapshot_id is None: + raise CommitFailedException(f"Requirement failed: {ref_type} {self.ref} was created concurrently") + elif self.snapshot_id != snapshot_ref.snapshot_id: + raise CommitFailedException( + f"Requirement failed: {ref_type} {self.ref} has changed: expected id {self.snapshot_id}, found {snapshot_ref.snapshot_id}" + ) + elif self.snapshot_id is not None: + raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}") + class AssertLastAssignedFieldId(TableRequirement): """The table's last assigned column id must match the requirement's `last-assigned-field-id`.""" @@ -571,6 +607,14 @@ class AssertLastAssignedFieldId(TableRequirement): type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id") last_assigned_field_id: int = Field(..., alias="last-assigned-field-id") + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif base_metadata.last_column_id != self.last_assigned_field_id: + raise CommitFailedException( + f"Requirement failed: last assigned field id has changed: expected {self.last_assigned_field_id}, found {base_metadata.last_column_id}" + ) + class AssertCurrentSchemaId(TableRequirement): """The table's current schema id must match the requirement's `current-schema-id`.""" @@ -578,6 +622,14 @@ class AssertCurrentSchemaId(TableRequirement): type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id") current_schema_id: int = Field(..., alias="current-schema-id") + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif self.current_schema_id != base_metadata.current_schema_id: + raise CommitFailedException( + f"Requirement failed: current schema id has changed: expected {self.current_schema_id}, found {base_metadata.current_schema_id}" + ) + class AssertLastAssignedPartitionId(TableRequirement): """The table's last assigned partition id must match the requirement's `last-assigned-partition-id`.""" @@ -585,6 +637,14 @@ class AssertLastAssignedPartitionId(TableRequirement): type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id") last_assigned_partition_id: int = Field(..., alias="last-assigned-partition-id") + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif base_metadata.last_partition_id != self.last_assigned_partition_id: + raise CommitFailedException( + f"Requirement failed: last assigned partition id has changed: expected {self.last_assigned_partition_id}, found {base_metadata.last_partition_id}" + ) + class AssertDefaultSpecId(TableRequirement): """The table's default spec id must match the requirement's `default-spec-id`.""" @@ -592,6 +652,14 @@ class AssertDefaultSpecId(TableRequirement): type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id") default_spec_id: int = Field(..., alias="default-spec-id") + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif self.default_spec_id != base_metadata.default_spec_id: + raise CommitFailedException( + f"Requirement failed: default spec id has changed: expected {self.default_spec_id}, found {base_metadata.default_spec_id}" + ) + class AssertDefaultSortOrderId(TableRequirement): """The table's default sort order id must match the requirement's `default-sort-order-id`.""" @@ -599,6 +667,14 @@ class AssertDefaultSortOrderId(TableRequirement): type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id") default_sort_order_id: int = Field(..., alias="default-sort-order-id") + def validate(self, base_metadata: Optional[TableMetadata]) -> None: + if base_metadata is None: + raise CommitFailedException("Requirement failed: current table metadata is missing") + elif self.default_sort_order_id != base_metadata.default_sort_order_id: + raise CommitFailedException( + f"Requirement failed: default sort order id has changed: expected {self.default_sort_order_id}, found {base_metadata.default_sort_order_id}" + ) + class Namespace(IcebergRootModel[List[str]]): """Reference to one or more levels of a namespace.""" diff --git a/pyiceberg/table/refs.py b/pyiceberg/table/refs.py index 6f17880cac..df18fadd31 100644 --- a/pyiceberg/table/refs.py +++ b/pyiceberg/table/refs.py @@ -34,6 +34,10 @@ def __repr__(self) -> str: """Return the string representation of the SnapshotRefType class.""" return f"SnapshotRefType.{self.name}" + def __str__(self) -> str: + """Return the string representation of the SnapshotRefType class.""" + return self.value + class SnapshotRef(IcebergBaseModel): snapshot_id: int = Field(alias="snapshot-id") diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 8d13a82f3a..04d467c318 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -15,12 +15,14 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=redefined-outer-name +import uuid from copy import copy from typing import Dict import pytest from sortedcontainers import SortedList +from pyiceberg.exceptions import CommitFailedException from pyiceberg.expressions import ( AlwaysTrue, And, @@ -39,6 +41,14 @@ from pyiceberg.schema import Schema from pyiceberg.table import ( AddSnapshotUpdate, + AssertCreate, + AssertCurrentSchemaId, + AssertDefaultSortOrderId, + AssertDefaultSpecId, + AssertLastAssignedFieldId, + AssertLastAssignedPartitionId, + AssertRefSnapshotId, + AssertTableUUID, SetPropertiesUpdate, SetSnapshotRefUpdate, SnapshotRef, @@ -721,3 +731,120 @@ def test_metadata_isolation_from_illegal_updates(table_v1: Table) -> None: def test_generate_snapshot_id(table_v2: Table) -> None: assert isinstance(_generate_snapshot_id(), int) assert isinstance(table_v2.new_snapshot_id(), int) + + +def test_assert_create(table_v2: Table) -> None: + AssertCreate().validate(None) + + with pytest.raises(CommitFailedException, match="Table already exists"): + AssertCreate().validate(table_v2.metadata) + + +def test_assert_table_uuid(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertTableUUID(uuid=base_metadata.table_uuid).validate(base_metadata) + + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertTableUUID(uuid=uuid.UUID("9c12d441-03fe-4693-9a96-a0705ddf69c2")).validate(None) + + with pytest.raises( + CommitFailedException, + match="Table UUID does not match: 9c12d441-03fe-4693-9a96-a0705ddf69c2 != 9c12d441-03fe-4693-9a96-a0705ddf69c1", + ): + AssertTableUUID(uuid=uuid.UUID("9c12d441-03fe-4693-9a96-a0705ddf69c2")).validate(base_metadata) + + +def test_assert_ref_snapshot_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertRefSnapshotId(ref="main", snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata) + + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertRefSnapshotId(ref="main", snapshot_id=1).validate(None) + + with pytest.raises( + CommitFailedException, + match="Requirement failed: branch main was created concurrently", + ): + AssertRefSnapshotId(ref="main", snapshot_id=None).validate(base_metadata) + + with pytest.raises( + CommitFailedException, + match="Requirement failed: branch main has changed: expected id 1, found 3055729675574597004", + ): + AssertRefSnapshotId(ref="main", snapshot_id=1).validate(base_metadata) + + with pytest.raises( + CommitFailedException, + match="Requirement failed: branch or tag not_exist is missing, expected 1", + ): + AssertRefSnapshotId(ref="not_exist", snapshot_id=1).validate(base_metadata) + + +def test_assert_last_assigned_field_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertLastAssignedFieldId(last_assigned_field_id=base_metadata.last_column_id).validate(base_metadata) + + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertLastAssignedFieldId(last_assigned_field_id=1).validate(None) + + with pytest.raises( + CommitFailedException, + match="Requirement failed: last assigned field id has changed: expected 1, found 3", + ): + AssertLastAssignedFieldId(last_assigned_field_id=1).validate(base_metadata) + + +def test_assert_current_schema_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertCurrentSchemaId(current_schema_id=base_metadata.current_schema_id).validate(base_metadata) + + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertCurrentSchemaId(current_schema_id=1).validate(None) + + with pytest.raises( + CommitFailedException, + match="Requirement failed: current schema id has changed: expected 2, found 1", + ): + AssertCurrentSchemaId(current_schema_id=2).validate(base_metadata) + + +def test_last_assigned_partition_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertLastAssignedPartitionId(last_assigned_partition_id=base_metadata.last_partition_id).validate(base_metadata) + + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertLastAssignedPartitionId(last_assigned_partition_id=1).validate(None) + + with pytest.raises( + CommitFailedException, + match="Requirement failed: last assigned partition id has changed: expected 1, found 1000", + ): + AssertLastAssignedPartitionId(last_assigned_partition_id=1).validate(base_metadata) + + +def test_assert_default_spec_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertDefaultSpecId(default_spec_id=base_metadata.default_spec_id).validate(base_metadata) + + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertDefaultSpecId(default_spec_id=1).validate(None) + + with pytest.raises( + CommitFailedException, + match="Requirement failed: default spec id has changed: expected 1, found 0", + ): + AssertDefaultSpecId(default_spec_id=1).validate(base_metadata) + + +def test_assert_default_sort_order_id(table_v2: Table) -> None: + base_metadata = table_v2.metadata + AssertDefaultSortOrderId(default_sort_order_id=base_metadata.default_sort_order_id).validate(base_metadata) + + with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"): + AssertDefaultSortOrderId(default_sort_order_id=1).validate(None) + + with pytest.raises( + CommitFailedException, + match="Requirement failed: default sort order id has changed: expected 1, found 3", + ): + AssertDefaultSortOrderId(default_sort_order_id=1).validate(base_metadata)