Skip to content

Commit

Permalink
make base_metadata optional and add null check
Browse files Browse the repository at this point in the history
  • Loading branch information
HonahX committed Dec 11, 2023
1 parent 413935e commit 52ceaf8
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 16 deletions.
45 changes: 29 additions & 16 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ class TableRequirement(IcebergBaseModel):
type: str

@abstractmethod
def validate(self, base_metadata: TableMetadata) -> None:
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
"""Validate the requirement against the base metadata.
Args:
Expand Down Expand Up @@ -569,8 +569,10 @@ class AssertTableUUID(TableRequirement):
type: Literal["assert-table-uuid"] = Field(default="assert-table-uuid")
uuid: uuid.UUID

def validate(self, base_metadata: TableMetadata) -> None:
if self.uuid != base_metadata.table_uuid:
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif self.uuid != base_metadata.table_uuid:
raise CommitFailedException(f"Table UUID does not match: {self.uuid} != {base_metadata.table_uuid}")


Expand All @@ -584,9 +586,10 @@ class AssertRefSnapshotId(TableRequirement):
ref: str
snapshot_id: Optional[int] = Field(default=None, alias="snapshot-id")

def validate(self, base_metadata: TableMetadata) -> None:
snapshot_ref = base_metadata.refs.get(self.ref)
if snapshot_ref is not None:
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif snapshot_ref := base_metadata.refs.get(self.ref):
ref_type = snapshot_ref.snapshot_ref_type
if self.snapshot_id is None:
raise CommitFailedException(f"Requirement failed: {ref_type} {self.ref} was created concurrently")
Expand All @@ -604,8 +607,10 @@ class AssertLastAssignedFieldId(TableRequirement):
type: Literal["assert-last-assigned-field-id"] = Field(default="assert-last-assigned-field-id")
last_assigned_field_id: int = Field(..., alias="last-assigned-field-id")

def validate(self, base_metadata: TableMetadata) -> None:
if base_metadata.last_column_id != self.last_assigned_field_id:
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif base_metadata.last_column_id != self.last_assigned_field_id:
raise CommitFailedException(
f"Requirement failed: last assigned field id has changed: expected {self.last_assigned_field_id}, found {base_metadata.last_column_id}"
)
Expand All @@ -617,8 +622,10 @@ class AssertCurrentSchemaId(TableRequirement):
type: Literal["assert-current-schema-id"] = Field(default="assert-current-schema-id")
current_schema_id: int = Field(..., alias="current-schema-id")

def validate(self, base_metadata: TableMetadata) -> None:
if self.current_schema_id != base_metadata.current_schema_id:
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif self.current_schema_id != base_metadata.current_schema_id:
raise CommitFailedException(
f"Requirement failed: current schema id has changed: expected {self.current_schema_id}, found {base_metadata.current_schema_id}"
)
Expand All @@ -630,8 +637,10 @@ class AssertLastAssignedPartitionId(TableRequirement):
type: Literal["assert-last-assigned-partition-id"] = Field(default="assert-last-assigned-partition-id")
last_assigned_partition_id: int = Field(..., alias="last-assigned-partition-id")

def validate(self, base_metadata: TableMetadata) -> None:
if base_metadata.last_partition_id != self.last_assigned_partition_id:
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif base_metadata.last_partition_id != self.last_assigned_partition_id:
raise CommitFailedException(
f"Requirement failed: last assigned partition id has changed: expected {self.last_assigned_partition_id}, found {base_metadata.last_partition_id}"
)
Expand All @@ -643,8 +652,10 @@ class AssertDefaultSpecId(TableRequirement):
type: Literal["assert-default-spec-id"] = Field(default="assert-default-spec-id")
default_spec_id: int = Field(..., alias="default-spec-id")

def validate(self, base_metadata: TableMetadata) -> None:
if self.default_spec_id != base_metadata.default_spec_id:
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif self.default_spec_id != base_metadata.default_spec_id:
raise CommitFailedException(
f"Requirement failed: default spec id has changed: expected {self.default_spec_id}, found {base_metadata.default_spec_id}"
)
Expand All @@ -656,8 +667,10 @@ class AssertDefaultSortOrderId(TableRequirement):
type: Literal["assert-default-sort-order-id"] = Field(default="assert-default-sort-order-id")
default_sort_order_id: int = Field(..., alias="default-sort-order-id")

def validate(self, base_metadata: TableMetadata) -> None:
if self.default_sort_order_id != base_metadata.default_sort_order_id:
def validate(self, base_metadata: Optional[TableMetadata]) -> None:
if base_metadata is None:
raise CommitFailedException("Requirement failed: current table metadata is missing")
elif self.default_sort_order_id != base_metadata.default_sort_order_id:
raise CommitFailedException(
f"Requirement failed: default sort order id has changed: expected {self.default_sort_order_id}, found {base_metadata.default_sort_order_id}"
)
Expand Down
21 changes: 21 additions & 0 deletions tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,9 @@ def test_assert_table_uuid(table_v2: Table) -> None:
base_metadata = table_v2.metadata
AssertTableUUID(uuid=base_metadata.table_uuid).validate(base_metadata)

with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"):
AssertTableUUID(uuid=uuid.UUID("9c12d441-03fe-4693-9a96-a0705ddf69c2")).validate(None)

with pytest.raises(
CommitFailedException,
match="Table UUID does not match: 9c12d441-03fe-4693-9a96-a0705ddf69c2 != 9c12d441-03fe-4693-9a96-a0705ddf69c1",
Expand All @@ -755,6 +758,9 @@ def test_assert_ref_snapshot_id(table_v2: Table) -> None:
base_metadata = table_v2.metadata
AssertRefSnapshotId(ref="main", snapshot_id=base_metadata.current_snapshot_id).validate(base_metadata)

with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"):
AssertRefSnapshotId(ref="main", snapshot_id=1).validate(None)

with pytest.raises(
CommitFailedException,
match="Requirement failed: branch main was created concurrently",
Expand All @@ -778,6 +784,9 @@ def test_assert_last_assigned_field_id(table_v2: Table) -> None:
base_metadata = table_v2.metadata
AssertLastAssignedFieldId(last_assigned_field_id=base_metadata.last_column_id).validate(base_metadata)

with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"):
AssertLastAssignedFieldId(last_assigned_field_id=1).validate(None)

with pytest.raises(
CommitFailedException,
match="Requirement failed: last assigned field id has changed: expected 1, found 3",
Expand All @@ -789,6 +798,9 @@ def test_assert_current_schema_id(table_v2: Table) -> None:
base_metadata = table_v2.metadata
AssertCurrentSchemaId(current_schema_id=base_metadata.current_schema_id).validate(base_metadata)

with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"):
AssertCurrentSchemaId(current_schema_id=1).validate(None)

with pytest.raises(
CommitFailedException,
match="Requirement failed: current schema id has changed: expected 2, found 1",
Expand All @@ -800,6 +812,9 @@ def test_last_assigned_partition_id(table_v2: Table) -> None:
base_metadata = table_v2.metadata
AssertLastAssignedPartitionId(last_assigned_partition_id=base_metadata.last_partition_id).validate(base_metadata)

with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"):
AssertLastAssignedPartitionId(last_assigned_partition_id=1).validate(None)

with pytest.raises(
CommitFailedException,
match="Requirement failed: last assigned partition id has changed: expected 1, found 1000",
Expand All @@ -811,6 +826,9 @@ def test_assert_default_spec_id(table_v2: Table) -> None:
base_metadata = table_v2.metadata
AssertDefaultSpecId(default_spec_id=base_metadata.default_spec_id).validate(base_metadata)

with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"):
AssertDefaultSpecId(default_spec_id=1).validate(None)

with pytest.raises(
CommitFailedException,
match="Requirement failed: default spec id has changed: expected 1, found 0",
Expand All @@ -822,6 +840,9 @@ def test_assert_default_sort_order_id(table_v2: Table) -> None:
base_metadata = table_v2.metadata
AssertDefaultSortOrderId(default_sort_order_id=base_metadata.default_sort_order_id).validate(base_metadata)

with pytest.raises(CommitFailedException, match="Requirement failed: current table metadata is missing"):
AssertDefaultSortOrderId(default_sort_order_id=1).validate(None)

with pytest.raises(
CommitFailedException,
match="Requirement failed: default sort order id has changed: expected 1, found 3",
Expand Down

0 comments on commit 52ceaf8

Please sign in to comment.