diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 4fb14e7d05..c75a0a5983 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -43,7 +43,7 @@ Union, ) -from pydantic import Field, SerializeAsAny, field_validator +from pydantic import Field, field_validator from sortedcontainers import SortedList from typing_extensions import Annotated @@ -383,77 +383,56 @@ def commit_transaction(self) -> Table: return self._table -class TableUpdateAction(Enum): - upgrade_format_version = "upgrade-format-version" - add_schema = "add-schema" - set_current_schema = "set-current-schema" - add_spec = "add-spec" - set_default_spec = "set-default-spec" - add_sort_order = "add-sort-order" - set_default_sort_order = "set-default-sort-order" - add_snapshot = "add-snapshot" - set_snapshot_ref = "set-snapshot-ref" - remove_snapshots = "remove-snapshots" - remove_snapshot_ref = "remove-snapshot-ref" - set_location = "set-location" - set_properties = "set-properties" - remove_properties = "remove-properties" - - -class TableUpdate(IcebergBaseModel): - action: TableUpdateAction - - -class UpgradeFormatVersionUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.upgrade_format_version +class UpgradeFormatVersionUpdate(IcebergBaseModel): + action: Literal['upgrade-format-version'] = Field(default="upgrade-format-version") format_version: int = Field(alias="format-version") -class AddSchemaUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.add_schema +class AddSchemaUpdate(IcebergBaseModel): + action: Literal['add-schema'] = Field(default="add-schema") schema_: Schema = Field(alias="schema") # This field is required: https://github.com/apache/iceberg/pull/7445 last_column_id: int = Field(alias="last-column-id") -class SetCurrentSchemaUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.set_current_schema +class SetCurrentSchemaUpdate(IcebergBaseModel): + action: Literal['set-current-schema'] = Field(default="set-current-schema") schema_id: int = Field( alias="schema-id", description="Schema ID to set as current, or -1 to set last added schema", default=-1 ) -class AddPartitionSpecUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.add_spec +class AddPartitionSpecUpdate(IcebergBaseModel): + action: Literal['add-spec'] = Field(default="add-spec") spec: PartitionSpec -class SetDefaultSpecUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.set_default_spec +class SetDefaultSpecUpdate(IcebergBaseModel): + action: Literal['set-default-spec'] = Field(default="set-default-spec") spec_id: int = Field( alias="spec-id", description="Partition spec ID to set as the default, or -1 to set last added spec", default=-1 ) -class AddSortOrderUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.add_sort_order +class AddSortOrderUpdate(IcebergBaseModel): + action: Literal['add-sort-order'] = Field(default="add-sort-order") sort_order: SortOrder = Field(alias="sort-order") -class SetDefaultSortOrderUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.set_default_sort_order +class SetDefaultSortOrderUpdate(IcebergBaseModel): + action: Literal['set-default-sort-order'] = Field(default="set-default-sort-order") sort_order_id: int = Field( alias="sort-order-id", description="Sort order ID to set as the default, or -1 to set last added sort order", default=-1 ) -class AddSnapshotUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.add_snapshot +class AddSnapshotUpdate(IcebergBaseModel): + action: Literal['add-snapshot'] = Field(default="add-snapshot") snapshot: Snapshot -class SetSnapshotRefUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.set_snapshot_ref +class SetSnapshotRefUpdate(IcebergBaseModel): + action: Literal['set-snapshot-ref'] = Field(default="set-snapshot-ref") ref_name: str = Field(alias="ref-name") type: Literal["tag", "branch"] snapshot_id: int = Field(alias="snapshot-id") @@ -462,23 +441,23 @@ class SetSnapshotRefUpdate(TableUpdate): min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None)] -class RemoveSnapshotsUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.remove_snapshots +class RemoveSnapshotsUpdate(IcebergBaseModel): + action: Literal['remove-snapshots'] = Field(default="remove-snapshots") snapshot_ids: List[int] = Field(alias="snapshot-ids") -class RemoveSnapshotRefUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.remove_snapshot_ref +class RemoveSnapshotRefUpdate(IcebergBaseModel): + action: Literal['remove-snapshot-ref'] = Field(default="remove-snapshot-ref") ref_name: str = Field(alias="ref-name") -class SetLocationUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.set_location +class SetLocationUpdate(IcebergBaseModel): + action: Literal['set-location'] = Field(default="set-location") location: str -class SetPropertiesUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.set_properties +class SetPropertiesUpdate(IcebergBaseModel): + action: Literal['set-properties'] = Field(default="set-properties") updates: Dict[str, str] @field_validator('updates', mode='before') @@ -486,11 +465,32 @@ def transform_properties_dict_value_to_str(cls, properties: Properties) -> Dict[ return transform_dict_value_to_str(properties) -class RemovePropertiesUpdate(TableUpdate): - action: TableUpdateAction = TableUpdateAction.remove_properties +class RemovePropertiesUpdate(IcebergBaseModel): + action: Literal['remove-properties'] = Field(default="remove-properties") removals: List[str] +TableUpdate = Annotated[ + Union[ + UpgradeFormatVersionUpdate, + AddSchemaUpdate, + SetCurrentSchemaUpdate, + AddPartitionSpecUpdate, + SetDefaultSpecUpdate, + AddSortOrderUpdate, + SetDefaultSortOrderUpdate, + AddSnapshotUpdate, + SetSnapshotRefUpdate, + RemoveSnapshotsUpdate, + RemoveSnapshotRefUpdate, + SetLocationUpdate, + SetPropertiesUpdate, + RemovePropertiesUpdate, + ], + Field(discriminator='action'), +] + + class _TableMetadataUpdateContext: _updates: List[TableUpdate] @@ -502,21 +502,15 @@ def add_update(self, update: TableUpdate) -> None: def is_added_snapshot(self, snapshot_id: int) -> bool: return any( - update.snapshot.snapshot_id == snapshot_id - for update in self._updates - if update.action == TableUpdateAction.add_snapshot + update.snapshot.snapshot_id == snapshot_id for update in self._updates if isinstance(update, AddSnapshotUpdate) ) def is_added_schema(self, schema_id: int) -> bool: - return any( - update.schema_.schema_id == schema_id for update in self._updates if update.action == TableUpdateAction.add_schema - ) + return any(update.schema_.schema_id == schema_id for update in self._updates if isinstance(update, AddSchemaUpdate)) def is_added_sort_order(self, sort_order_id: int) -> bool: return any( - update.sort_order.order_id == sort_order_id - for update in self._updates - if update.action == TableUpdateAction.add_sort_order + update.sort_order.order_id == sort_order_id for update in self._updates if isinstance(update, AddSortOrderUpdate) ) @@ -767,7 +761,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda return new_metadata.model_copy(deep=True) -class TableRequirement(IcebergBaseModel): +class ValidatableTableRequirement(IcebergBaseModel): type: str @abstractmethod @@ -783,7 +777,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ... -class AssertCreate(TableRequirement): +class AssertCreate(ValidatableTableRequirement): """The table must not already exist; used for create transactions.""" type: Literal["assert-create"] = Field(default="assert-create") @@ -793,7 +787,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: raise CommitFailedException("Table already exists") -class AssertTableUUID(TableRequirement): +class AssertTableUUID(ValidatableTableRequirement): """The table UUID must match the requirement's `uuid`.""" type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid") @@ -806,7 +800,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}") -class AssertRefSnapshotId(TableRequirement): +class AssertRefSnapshotId(ValidatableTableRequirement): """The table branch or tag identified by the requirement's `ref` must reference the requirement's `snapshot-id`. if `snapshot-id` is `null` or missing, the ref must not already exist. @@ -831,7 +825,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}") -class AssertLastAssignedFieldId(TableRequirement): +class AssertLastAssignedFieldId(ValidatableTableRequirement): """The table's last assigned column id must match the requirement's `last-assigned-field-id`.""" type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id") @@ -846,7 +840,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ) -class AssertCurrentSchemaId(TableRequirement): +class AssertCurrentSchemaId(ValidatableTableRequirement): """The table's current schema id must match the requirement's `current-schema-id`.""" type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id") @@ -861,7 +855,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ) -class AssertLastAssignedPartitionId(TableRequirement): +class AssertLastAssignedPartitionId(ValidatableTableRequirement): """The table's last assigned partition id must match the requirement's `last-assigned-partition-id`.""" type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id") @@ -876,7 +870,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ) -class AssertDefaultSpecId(TableRequirement): +class AssertDefaultSpecId(ValidatableTableRequirement): """The table's default spec id must match the requirement's `default-spec-id`.""" type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id") @@ -891,7 +885,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ) -class AssertDefaultSortOrderId(TableRequirement): +class AssertDefaultSortOrderId(ValidatableTableRequirement): """The table's default sort order id must match the requirement's `default-sort-order-id`.""" type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id") @@ -906,6 +900,20 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None: ) +TableRequirement = Annotated[ + Union[ + AssertCreate, + AssertTableUUID, + AssertRefSnapshotId, + AssertLastAssignedFieldId, + AssertCurrentSchemaId, + AssertLastAssignedPartitionId, + AssertDefaultSpecId, + AssertDefaultSortOrderId, + ], + Field(discriminator='type'), +] + UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]] @@ -927,8 +935,8 @@ class TableIdentifier(IcebergBaseModel): class CommitTableRequest(IcebergBaseModel): identifier: TableIdentifier = Field() - requirements: Tuple[SerializeAsAny[TableRequirement], ...] = Field(default_factory=tuple) - updates: Tuple[SerializeAsAny[TableUpdate], ...] = Field(default_factory=tuple) + requirements: Tuple[TableRequirement, ...] = Field(default_factory=tuple) + updates: Tuple[TableUpdate, ...] = Field(default_factory=tuple) class CommitTableResponse(IcebergBaseModel): diff --git a/tests/table/test_init.py b/tests/table/test_init.py index f734211510..bb212d696e 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -53,12 +53,14 @@ AssertLastAssignedPartitionId, AssertRefSnapshotId, AssertTableUUID, + CommitTableRequest, RemovePropertiesUpdate, SetDefaultSortOrderUpdate, SetPropertiesUpdate, SetSnapshotRefUpdate, StaticTable, Table, + TableIdentifier, UpdateSchema, _apply_table_update, _check_schema, @@ -1113,3 +1115,13 @@ def test_table_properties_raise_for_none_value(example_table_metadata_v2: Dict[s with pytest.raises(ValidationError) as exc_info: TableMetadataV2(**example_table_metadata_v2) assert "None type is not a supported value in properties: property_name" in str(exc_info.value) + + +def test_serialize_commit_table_request() -> None: + request = CommitTableRequest( + requirements=(AssertTableUUID(uuid='4bfd18a3-74c6-478e-98b1-71c4c32f4163'),), + identifier=TableIdentifier(namespace=['a'], name='b'), + ) + + deserialized_request = CommitTableRequest.model_validate_json(request.model_dump_json()) + assert request == deserialized_request