Skip to content

Commit

Permalink
Fix CommitTableRequest serialisation (#525)
Browse files Browse the repository at this point in the history
* add failing test

* make requirements a discriminated union

* use discriminated type union

* add return type

* use type annotation

* have requirements inherit from ValidatableTableRequirement

* AddSortOrder filter by type

* lint

---------

Co-authored-by: Kieran Higgins <[email protected]>
  • Loading branch information
kdbhiggins and Kieran Higgins authored Mar 17, 2024
1 parent 7f712fd commit a077c73
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 70 deletions.
148 changes: 78 additions & 70 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
Union,
)

from pydantic import Field, SerializeAsAny, field_validator
from pydantic import Field, field_validator
from sortedcontainers import SortedList
from typing_extensions import Annotated

Expand Down Expand Up @@ -383,77 +383,56 @@ def commit_transaction(self) -> Table:
return self._table


class TableUpdateAction(Enum):
upgrade_format_version = "upgrade-format-version"
add_schema = "add-schema"
set_current_schema = "set-current-schema"
add_spec = "add-spec"
set_default_spec = "set-default-spec"
add_sort_order = "add-sort-order"
set_default_sort_order = "set-default-sort-order"
add_snapshot = "add-snapshot"
set_snapshot_ref = "set-snapshot-ref"
remove_snapshots = "remove-snapshots"
remove_snapshot_ref = "remove-snapshot-ref"
set_location = "set-location"
set_properties = "set-properties"
remove_properties = "remove-properties"


class TableUpdate(IcebergBaseModel):
action: TableUpdateAction


class UpgradeFormatVersionUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.upgrade_format_version
class UpgradeFormatVersionUpdate(IcebergBaseModel):
action: Literal['upgrade-format-version'] = Field(default="upgrade-format-version")
format_version: int = Field(alias="format-version")


class AddSchemaUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.add_schema
class AddSchemaUpdate(IcebergBaseModel):
action: Literal['add-schema'] = Field(default="add-schema")
schema_: Schema = Field(alias="schema")
# This field is required: https://github.com/apache/iceberg/pull/7445
last_column_id: int = Field(alias="last-column-id")


class SetCurrentSchemaUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.set_current_schema
class SetCurrentSchemaUpdate(IcebergBaseModel):
action: Literal['set-current-schema'] = Field(default="set-current-schema")
schema_id: int = Field(
alias="schema-id", description="Schema ID to set as current, or -1 to set last added schema", default=-1
)


class AddPartitionSpecUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.add_spec
class AddPartitionSpecUpdate(IcebergBaseModel):
action: Literal['add-spec'] = Field(default="add-spec")
spec: PartitionSpec


class SetDefaultSpecUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.set_default_spec
class SetDefaultSpecUpdate(IcebergBaseModel):
action: Literal['set-default-spec'] = Field(default="set-default-spec")
spec_id: int = Field(
alias="spec-id", description="Partition spec ID to set as the default, or -1 to set last added spec", default=-1
)


class AddSortOrderUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.add_sort_order
class AddSortOrderUpdate(IcebergBaseModel):
action: Literal['add-sort-order'] = Field(default="add-sort-order")
sort_order: SortOrder = Field(alias="sort-order")


class SetDefaultSortOrderUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.set_default_sort_order
class SetDefaultSortOrderUpdate(IcebergBaseModel):
action: Literal['set-default-sort-order'] = Field(default="set-default-sort-order")
sort_order_id: int = Field(
alias="sort-order-id", description="Sort order ID to set as the default, or -1 to set last added sort order", default=-1
)


class AddSnapshotUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.add_snapshot
class AddSnapshotUpdate(IcebergBaseModel):
action: Literal['add-snapshot'] = Field(default="add-snapshot")
snapshot: Snapshot


class SetSnapshotRefUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.set_snapshot_ref
class SetSnapshotRefUpdate(IcebergBaseModel):
action: Literal['set-snapshot-ref'] = Field(default="set-snapshot-ref")
ref_name: str = Field(alias="ref-name")
type: Literal["tag", "branch"]
snapshot_id: int = Field(alias="snapshot-id")
Expand All @@ -462,35 +441,56 @@ class SetSnapshotRefUpdate(TableUpdate):
min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None)]


class RemoveSnapshotsUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.remove_snapshots
class RemoveSnapshotsUpdate(IcebergBaseModel):
action: Literal['remove-snapshots'] = Field(default="remove-snapshots")
snapshot_ids: List[int] = Field(alias="snapshot-ids")


class RemoveSnapshotRefUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.remove_snapshot_ref
class RemoveSnapshotRefUpdate(IcebergBaseModel):
action: Literal['remove-snapshot-ref'] = Field(default="remove-snapshot-ref")
ref_name: str = Field(alias="ref-name")


class SetLocationUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.set_location
class SetLocationUpdate(IcebergBaseModel):
action: Literal['set-location'] = Field(default="set-location")
location: str


class SetPropertiesUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.set_properties
class SetPropertiesUpdate(IcebergBaseModel):
action: Literal['set-properties'] = Field(default="set-properties")
updates: Dict[str, str]

@field_validator('updates', mode='before')
def transform_properties_dict_value_to_str(cls, properties: Properties) -> Dict[str, str]:
return transform_dict_value_to_str(properties)


class RemovePropertiesUpdate(TableUpdate):
action: TableUpdateAction = TableUpdateAction.remove_properties
class RemovePropertiesUpdate(IcebergBaseModel):
action: Literal['remove-properties'] = Field(default="remove-properties")
removals: List[str]


TableUpdate = Annotated[
Union[
UpgradeFormatVersionUpdate,
AddSchemaUpdate,
SetCurrentSchemaUpdate,
AddPartitionSpecUpdate,
SetDefaultSpecUpdate,
AddSortOrderUpdate,
SetDefaultSortOrderUpdate,
AddSnapshotUpdate,
SetSnapshotRefUpdate,
RemoveSnapshotsUpdate,
RemoveSnapshotRefUpdate,
SetLocationUpdate,
SetPropertiesUpdate,
RemovePropertiesUpdate,
],
Field(discriminator='action'),
]


class _TableMetadataUpdateContext:
_updates: List[TableUpdate]

Expand All @@ -502,21 +502,15 @@ def add_update(self, update: TableUpdate) -> None:

def is_added_snapshot(self, snapshot_id: int) -> bool:
return any(
update.snapshot.snapshot_id == snapshot_id
for update in self._updates
if update.action == TableUpdateAction.add_snapshot
update.snapshot.snapshot_id == snapshot_id for update in self._updates if isinstance(update, AddSnapshotUpdate)
)

def is_added_schema(self, schema_id: int) -> bool:
return any(
update.schema_.schema_id == schema_id for update in self._updates if update.action == TableUpdateAction.add_schema
)
return any(update.schema_.schema_id == schema_id for update in self._updates if isinstance(update, AddSchemaUpdate))

def is_added_sort_order(self, sort_order_id: int) -> bool:
return any(
update.sort_order.order_id == sort_order_id
for update in self._updates
if update.action == TableUpdateAction.add_sort_order
update.sort_order.order_id == sort_order_id for update in self._updates if isinstance(update, AddSortOrderUpdate)
)


Expand Down Expand Up @@ -767,7 +761,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda
return new_metadata.model_copy(deep=True)


class TableRequirement(IcebergBaseModel):
class ValidatableTableRequirement(IcebergBaseModel):
type: str

@abstractmethod
Expand All @@ -783,7 +777,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
...


class AssertCreate(TableRequirement):
class AssertCreate(ValidatableTableRequirement):
"""The table must not already exist; used for create transactions."""

type: Literal["assert-create"] = Field(default="assert-create")
Expand All @@ -793,7 +787,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
raise CommitFailedException("Table already exists")


class AssertTableUUID(TableRequirement):
class AssertTableUUID(ValidatableTableRequirement):
"""The table UUID must match the requirement's `uuid`."""

type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid")
Expand All @@ -806,7 +800,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}")


class AssertRefSnapshotId(TableRequirement):
class AssertRefSnapshotId(ValidatableTableRequirement):
"""The table branch or tag identified by the requirement's `ref` must reference the requirement's `snapshot-id`.
if `snapshot-id` is `null` or missing, the ref must not already exist.
Expand All @@ -831,7 +825,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
raise CommitFailedException(f"Requirement failed: branch or tag {self.ref} is missing, expected {self.snapshot_id}")


class AssertLastAssignedFieldId(TableRequirement):
class AssertLastAssignedFieldId(ValidatableTableRequirement):
"""The table's last assigned column id must match the requirement's `last-assigned-field-id`."""

type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id")
Expand All @@ -846,7 +840,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
)


class AssertCurrentSchemaId(TableRequirement):
class AssertCurrentSchemaId(ValidatableTableRequirement):
"""The table's current schema id must match the requirement's `current-schema-id`."""

type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id")
Expand All @@ -861,7 +855,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
)


class AssertLastAssignedPartitionId(TableRequirement):
class AssertLastAssignedPartitionId(ValidatableTableRequirement):
"""The table's last assigned partition id must match the requirement's `last-assigned-partition-id`."""

type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id")
Expand All @@ -876,7 +870,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
)


class AssertDefaultSpecId(TableRequirement):
class AssertDefaultSpecId(ValidatableTableRequirement):
"""The table's default spec id must match the requirement's `default-spec-id`."""

type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id")
Expand All @@ -891,7 +885,7 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
)


class AssertDefaultSortOrderId(TableRequirement):
class AssertDefaultSortOrderId(ValidatableTableRequirement):
"""The table's default sort order id must match the requirement's `default-sort-order-id`."""

type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id")
Expand All @@ -906,6 +900,20 @@ def validate(self, base_metadata: Optional[TableMetadata]) -> None:
)


TableRequirement = Annotated[
Union[
AssertCreate,
AssertTableUUID,
AssertRefSnapshotId,
AssertLastAssignedFieldId,
AssertCurrentSchemaId,
AssertLastAssignedPartitionId,
AssertDefaultSpecId,
AssertDefaultSortOrderId,
],
Field(discriminator='type'),
]

UpdatesAndRequirements = Tuple[Tuple[TableUpdate, ...], Tuple[TableRequirement, ...]]


Expand All @@ -927,8 +935,8 @@ class TableIdentifier(IcebergBaseModel):

class CommitTableRequest(IcebergBaseModel):
identifier: TableIdentifier = Field()
requirements: Tuple[SerializeAsAny[TableRequirement], ...] = Field(default_factory=tuple)
updates: Tuple[SerializeAsAny[TableUpdate], ...] = Field(default_factory=tuple)
requirements: Tuple[TableRequirement, ...] = Field(default_factory=tuple)
updates: Tuple[TableUpdate, ...] = Field(default_factory=tuple)


class CommitTableResponse(IcebergBaseModel):
Expand Down
12 changes: 12 additions & 0 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@
AssertLastAssignedPartitionId,
AssertRefSnapshotId,
AssertTableUUID,
CommitTableRequest,
RemovePropertiesUpdate,
SetDefaultSortOrderUpdate,
SetPropertiesUpdate,
SetSnapshotRefUpdate,
StaticTable,
Table,
TableIdentifier,
UpdateSchema,
_apply_table_update,
_check_schema,
Expand Down Expand Up @@ -1113,3 +1115,13 @@ def test_table_properties_raise_for_none_value(example_table_metadata_v2: Dict[s
with pytest.raises(ValidationError) as exc_info:
TableMetadataV2(**example_table_metadata_v2)
assert "None type is not a supported value in properties: property_name" in str(exc_info.value)


def test_serialize_commit_table_request() -> None:
request = CommitTableRequest(
requirements=(AssertTableUUID(uuid='4bfd18a3-74c6-478e-98b1-71c4c32f4163'),),
identifier=TableIdentifier(namespace=['a'], name='b'),
)

deserialized_request = CommitTableRequest.model_validate_json(request.model_dump_json())
assert request == deserialized_request

0 comments on commit a077c73

Please sign in to comment.