From 83306104a25a4ecd1f2185ec46cd9fda247544f4 Mon Sep 17 00:00:00 2001 From: HonahX <140284484+HonahX@users.noreply.github.com> Date: Mon, 4 Dec 2023 13:30:59 -0800 Subject: [PATCH] Support updating table metadata (#139) * Implement table metadata updater first draft * fix updater error and add tests * implement apply_metadata_update which is simpler * remove old implementation * re-organize method place * fix nit * fix test * add another test * clear TODO * add a combined test * Fix merge conflict * remove table requirement validation for PR simplification * make context private and solve elif issue * remove private field access * push snapshot ref validation to its builder using pydantic * fix comment * remove unnecessary code for AddSchemaUpdate update * replace if with elif * enhance the set current schema update implementation and some other changes * make apply_table_update private * fix an error * remove unnecessary last_added_schema_id --- pyiceberg/table/__init__.py | 203 +++++++++++++++++++++++-- pyiceberg/table/metadata.py | 10 ++ pyiceberg/table/refs.py | 22 ++- tests/conftest.py | 42 +++++- tests/table/test_init.py | 280 +++++++++++++++++++++++++++-------- tests/table/test_metadata.py | 25 ---- tests/table/test_refs.py | 54 +++++++ 7 files changed, 535 insertions(+), 101 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 6fbde32cc7..9aa6c1c9c5 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -16,13 +16,14 @@ # under the License. from __future__ import annotations +import datetime import itertools import uuid from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass from enum import Enum -from functools import cached_property +from functools import cached_property, singledispatch from itertools import chain from typing import ( TYPE_CHECKING, @@ -41,6 +42,7 @@ from pydantic import Field, SerializeAsAny from sortedcontainers import SortedList +from typing_extensions import Annotated from pyiceberg.exceptions import ResolveError, ValidationError from pyiceberg.expressions import ( @@ -69,8 +71,13 @@ promote, visit, ) -from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadata -from pyiceberg.table.refs import SnapshotRef +from pyiceberg.table.metadata import ( + INITIAL_SEQUENCE_NUMBER, + SUPPORTED_TABLE_FORMAT_VERSION, + TableMetadata, + TableMetadataUtil, +) +from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef from pyiceberg.table.snapshots import Snapshot, SnapshotLogEntry from pyiceberg.table.sorting import SortOrder from pyiceberg.typedef import ( @@ -90,6 +97,7 @@ StructType, ) from pyiceberg.utils.concurrent import ExecutorFactory +from pyiceberg.utils.datetime import datetime_to_millis if TYPE_CHECKING: import pandas as pd @@ -320,9 +328,9 @@ class SetSnapshotRefUpdate(TableUpdate): ref_name: str = Field(alias="ref-name") type: Literal["tag", "branch"] snapshot_id: int = Field(alias="snapshot-id") - max_age_ref_ms: int = Field(alias="max-ref-age-ms") - max_snapshot_age_ms: int = Field(alias="max-snapshot-age-ms") - min_snapshots_to_keep: int = Field(alias="min-snapshots-to-keep") + max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None)] + max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None)] + min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None)] class RemoveSnapshotsUpdate(TableUpdate): @@ -350,6 +358,184 @@ class RemovePropertiesUpdate(TableUpdate): removals: List[str] +class _TableMetadataUpdateContext: + _updates: List[TableUpdate] + + def __init__(self) -> None: + self._updates = [] + + def add_update(self, update: TableUpdate) -> None: + self._updates.append(update) + + def is_added_snapshot(self, snapshot_id: int) -> bool: + return any( + update.snapshot.snapshot_id == snapshot_id + for update in self._updates + if update.action == TableUpdateAction.add_snapshot + ) + + def is_added_schema(self, schema_id: int) -> bool: + return any( + update.schema_.schema_id == schema_id for update in self._updates if update.action == TableUpdateAction.add_schema + ) + + +@singledispatch +def _apply_table_update(update: TableUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + """Apply a table update to the table metadata. + + Args: + update: The update to be applied. + base_metadata: The base metadata to be updated. + context: Contains previous updates and other change tracking information in the current transaction. + + Returns: + The updated metadata. + + """ + raise NotImplementedError(f"Unsupported table update: {update}") + + +@_apply_table_update.register(UpgradeFormatVersionUpdate) +def _(update: UpgradeFormatVersionUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + if update.format_version > SUPPORTED_TABLE_FORMAT_VERSION: + raise ValueError(f"Unsupported table format version: {update.format_version}") + elif update.format_version < base_metadata.format_version: + raise ValueError(f"Cannot downgrade v{base_metadata.format_version} table to v{update.format_version}") + elif update.format_version == base_metadata.format_version: + return base_metadata + + updated_metadata_data = copy(base_metadata.model_dump()) + updated_metadata_data["format-version"] = update.format_version + + context.add_update(update) + return TableMetadataUtil.parse_obj(updated_metadata_data) + + +@_apply_table_update.register(AddSchemaUpdate) +def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + if update.last_column_id < base_metadata.last_column_id: + raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}") + + updated_metadata_data = copy(base_metadata.model_dump()) + updated_metadata_data["last-column-id"] = update.last_column_id + updated_metadata_data["schemas"].append(update.schema_.model_dump()) + + context.add_update(update) + return TableMetadataUtil.parse_obj(updated_metadata_data) + + +@_apply_table_update.register(SetCurrentSchemaUpdate) +def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + new_schema_id = update.schema_id + if new_schema_id == -1: + # The last added schema should be in base_metadata.schemas at this point + new_schema_id = max(schema.schema_id for schema in base_metadata.schemas) + if not context.is_added_schema(new_schema_id): + raise ValueError("Cannot set current schema to last added schema when no schema has been added") + + if new_schema_id == base_metadata.current_schema_id: + return base_metadata + + schema = base_metadata.schema_by_id(new_schema_id) + if schema is None: + raise ValueError(f"Schema with id {new_schema_id} does not exist") + + updated_metadata_data = copy(base_metadata.model_dump()) + updated_metadata_data["current-schema-id"] = new_schema_id + + context.add_update(update) + return TableMetadataUtil.parse_obj(updated_metadata_data) + + +@_apply_table_update.register(AddSnapshotUpdate) +def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + if len(base_metadata.schemas) == 0: + raise ValueError("Attempting to add a snapshot before a schema is added") + elif len(base_metadata.partition_specs) == 0: + raise ValueError("Attempting to add a snapshot before a partition spec is added") + elif len(base_metadata.sort_orders) == 0: + raise ValueError("Attempting to add a snapshot before a sort order is added") + elif base_metadata.snapshot_by_id(update.snapshot.snapshot_id) is not None: + raise ValueError(f"Snapshot with id {update.snapshot.snapshot_id} already exists") + elif ( + base_metadata.format_version == 2 + and update.snapshot.sequence_number is not None + and update.snapshot.sequence_number <= base_metadata.last_sequence_number + and update.snapshot.parent_snapshot_id is not None + ): + raise ValueError( + f"Cannot add snapshot with sequence number {update.snapshot.sequence_number} " + f"older than last sequence number {base_metadata.last_sequence_number}" + ) + + updated_metadata_data = copy(base_metadata.model_dump()) + updated_metadata_data["last-updated-ms"] = update.snapshot.timestamp_ms + updated_metadata_data["last-sequence-number"] = update.snapshot.sequence_number + updated_metadata_data["snapshots"].append(update.snapshot.model_dump()) + context.add_update(update) + return TableMetadataUtil.parse_obj(updated_metadata_data) + + +@_apply_table_update.register(SetSnapshotRefUpdate) +def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _TableMetadataUpdateContext) -> TableMetadata: + snapshot_ref = SnapshotRef( + snapshot_id=update.snapshot_id, + snapshot_ref_type=update.type, + min_snapshots_to_keep=update.min_snapshots_to_keep, + max_snapshot_age_ms=update.max_snapshot_age_ms, + max_ref_age_ms=update.max_ref_age_ms, + ) + + existing_ref = base_metadata.refs.get(update.ref_name) + if existing_ref is not None and existing_ref == snapshot_ref: + return base_metadata + + snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id) + if snapshot is None: + raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}") + + update_metadata_data = copy(base_metadata.model_dump()) + update_last_updated_ms = True + if context.is_added_snapshot(snapshot_ref.snapshot_id): + update_metadata_data["last-updated-ms"] = snapshot.timestamp_ms + update_last_updated_ms = False + + if update.ref_name == MAIN_BRANCH: + update_metadata_data["current-snapshot-id"] = snapshot_ref.snapshot_id + if update_last_updated_ms: + update_metadata_data["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone()) + update_metadata_data["snapshot-log"].append( + SnapshotLogEntry( + snapshot_id=snapshot_ref.snapshot_id, + timestamp_ms=update_metadata_data["last-updated-ms"], + ).model_dump() + ) + + update_metadata_data["refs"][update.ref_name] = snapshot_ref.model_dump() + context.add_update(update) + return TableMetadataUtil.parse_obj(update_metadata_data) + + +def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata: + """Update the table metadata with the given updates in one transaction. + + Args: + base_metadata: The base metadata to be updated. + updates: The updates in one transaction. + + Returns: + The metadata with the updates applied. + """ + context = _TableMetadataUpdateContext() + new_metadata = base_metadata + + for update in updates: + new_metadata = _apply_table_update(update, new_metadata, context) + + return new_metadata + + class TableRequirement(IcebergBaseModel): type: str @@ -552,10 +738,7 @@ def current_snapshot(self) -> Optional[Snapshot]: def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: """Get the snapshot of this table with the given id, or None if there is no matching snapshot.""" - try: - return next(snapshot for snapshot in self.metadata.snapshots if snapshot.snapshot_id == snapshot_id) - except StopIteration: - return None + return self.metadata.snapshot_by_id(snapshot_id) def snapshot_by_name(self, name: str) -> Optional[Snapshot]: """Return the snapshot referenced by the given name or null if no such reference exists.""" diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 73d76d8606..43e29c7b03 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -69,6 +69,8 @@ INITIAL_SPEC_ID = 0 DEFAULT_SCHEMA_ID = 0 +SUPPORTED_TABLE_FORMAT_VERSION = 2 + def cleanup_snapshot_id(data: Dict[str, Any]) -> Dict[str, Any]: """Run before validation.""" @@ -216,6 +218,14 @@ class TableMetadataCommonFields(IcebergBaseModel): There is always a main branch reference pointing to the current-snapshot-id even if the refs map is null.""" + def snapshot_by_id(self, snapshot_id: int) -> Optional[Snapshot]: + """Get the snapshot by snapshot_id.""" + return next((snapshot for snapshot in self.snapshots if snapshot.snapshot_id == snapshot_id), None) + + def schema_by_id(self, schema_id: int) -> Optional[Schema]: + """Get the schema by schema_id.""" + return next((schema for schema in self.schemas if schema.schema_id == schema_id), None) + class TableMetadataV1(TableMetadataCommonFields, IcebergBaseModel): """Represents version 1 of the Table Metadata. diff --git a/pyiceberg/table/refs.py b/pyiceberg/table/refs.py index b9692ca975..6f17880cac 100644 --- a/pyiceberg/table/refs.py +++ b/pyiceberg/table/refs.py @@ -17,8 +17,10 @@ from enum import Enum from typing import Optional -from pydantic import Field +from pydantic import Field, model_validator +from typing_extensions import Annotated +from pyiceberg.exceptions import ValidationError from pyiceberg.typedef import IcebergBaseModel MAIN_BRANCH = "main" @@ -36,6 +38,18 @@ def __repr__(self) -> str: class SnapshotRef(IcebergBaseModel): snapshot_id: int = Field(alias="snapshot-id") snapshot_ref_type: SnapshotRefType = Field(alias="type") - min_snapshots_to_keep: Optional[int] = Field(alias="min-snapshots-to-keep", default=None) - max_snapshot_age_ms: Optional[int] = Field(alias="max-snapshot-age-ms", default=None) - max_ref_age_ms: Optional[int] = Field(alias="max-ref-age-ms", default=None) + min_snapshots_to_keep: Annotated[Optional[int], Field(alias="min-snapshots-to-keep", default=None, gt=0)] + max_snapshot_age_ms: Annotated[Optional[int], Field(alias="max-snapshot-age-ms", default=None, gt=0)] + max_ref_age_ms: Annotated[Optional[int], Field(alias="max-ref-age-ms", default=None, gt=0)] + + @model_validator(mode='after') + def check_min_snapshots_to_keep(self) -> 'SnapshotRef': + if self.min_snapshots_to_keep is not None and self.snapshot_ref_type == SnapshotRefType.TAG: + raise ValidationError("Tags do not support setting minSnapshotsToKeep") + return self + + @model_validator(mode='after') + def check_max_snapshot_age_ms(self) -> 'SnapshotRef': + if self.max_snapshot_age_ms is not None and self.snapshot_ref_type == SnapshotRefType.TAG: + raise ValidationError("Tags do not support setting maxSnapshotAgeMs") + return self diff --git a/tests/conftest.py b/tests/conftest.py index 72a7ad0310..367e08151e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -73,7 +73,7 @@ from pyiceberg.schema import Accessor, Schema from pyiceberg.serializers import ToOutputFile from pyiceberg.table import FileScanTask, Table -from pyiceberg.table.metadata import TableMetadataV2 +from pyiceberg.table.metadata import TableMetadataV1, TableMetadataV2 from pyiceberg.typedef import UTF8 from pyiceberg.types import ( BinaryType, @@ -354,6 +354,32 @@ def all_avro_types() -> Dict[str, Any]: } +EXAMPLE_TABLE_METADATA_V1 = { + "format-version": 1, + "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", + "location": "s3://bucket/test/location", + "last-updated-ms": 1602638573874, + "last-column-id": 3, + "schema": { + "type": "struct", + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], + "properties": {}, + "current-snapshot-id": -1, + "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}], +} + + +@pytest.fixture(scope="session") +def example_table_metadata_v1() -> Dict[str, Any]: + return EXAMPLE_TABLE_METADATA_V1 + + EXAMPLE_TABLE_METADATA_WITH_SNAPSHOT_V1 = { "format-version": 1, "table-uuid": "b55d9dda-6561-423a-8bfc-787980ce421f", @@ -1780,7 +1806,19 @@ def example_task(data_file: str) -> FileScanTask: @pytest.fixture -def table(example_table_metadata_v2: Dict[str, Any]) -> Table: +def table_v1(example_table_metadata_v1: Dict[str, Any]) -> Table: + table_metadata = TableMetadataV1(**example_table_metadata_v1) + return Table( + identifier=("database", "table"), + metadata=table_metadata, + metadata_location=f"{table_metadata.location}/uuid.metadata.json", + io=load_file_io(), + catalog=NoopCatalog("NoopCatalog"), + ) + + +@pytest.fixture +def table_v2(example_table_metadata_v2: Dict[str, Any]) -> Table: table_metadata = TableMetadataV2(**example_table_metadata_v2) return Table( identifier=("database", "table"), diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 369df4fa92..6d188befeb 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -37,12 +37,18 @@ from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema from pyiceberg.table import ( + AddSnapshotUpdate, SetPropertiesUpdate, + SetSnapshotRefUpdate, + SnapshotRef, StaticTable, Table, UpdateSchema, + _apply_table_update, _generate_snapshot_id, _match_deletes_to_datafile, + _TableMetadataUpdateContext, + update_table_metadata, ) from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER from pyiceberg.table.snapshots import ( @@ -79,8 +85,8 @@ ) -def test_schema(table: Table) -> None: - assert table.schema() == Schema( +def test_schema(table_v2: Table) -> None: + assert table_v2.schema() == Schema( NestedField(field_id=1, name="x", field_type=LongType(), required=True), NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), NestedField(field_id=3, name="z", field_type=LongType(), required=True), @@ -89,8 +95,8 @@ def test_schema(table: Table) -> None: ) -def test_schemas(table: Table) -> None: - assert table.schemas() == { +def test_schemas(table_v2: Table) -> None: + assert table_v2.schemas() == { 0: Schema( NestedField(field_id=1, name="x", field_type=LongType(), required=True), schema_id=0, @@ -106,20 +112,20 @@ def test_schemas(table: Table) -> None: } -def test_spec(table: Table) -> None: - assert table.spec() == PartitionSpec( +def test_spec(table_v2: Table) -> None: + assert table_v2.spec() == PartitionSpec( PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="x"), spec_id=0 ) -def test_specs(table: Table) -> None: - assert table.specs() == { +def test_specs(table_v2: Table) -> None: + assert table_v2.specs() == { 0: PartitionSpec(PartitionField(source_id=1, field_id=1000, transform=IdentityTransform(), name="x"), spec_id=0) } -def test_sort_order(table: Table) -> None: - assert table.sort_order() == SortOrder( +def test_sort_order(table_v2: Table) -> None: + assert table_v2.sort_order() == SortOrder( SortField(source_id=2, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_FIRST), SortField( source_id=3, @@ -131,8 +137,8 @@ def test_sort_order(table: Table) -> None: ) -def test_sort_orders(table: Table) -> None: - assert table.sort_orders() == { +def test_sort_orders(table_v2: Table) -> None: + assert table_v2.sort_orders() == { 3: SortOrder( SortField(source_id=2, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_FIRST), SortField( @@ -146,12 +152,12 @@ def test_sort_orders(table: Table) -> None: } -def test_location(table: Table) -> None: - assert table.location() == "s3://bucket/test/location" +def test_location(table_v2: Table) -> None: + assert table_v2.location() == "s3://bucket/test/location" -def test_current_snapshot(table: Table) -> None: - assert table.current_snapshot() == Snapshot( +def test_current_snapshot(table_v2: Table) -> None: + assert table_v2.current_snapshot() == Snapshot( snapshot_id=3055729675574597004, parent_snapshot_id=3051729675574597004, sequence_number=1, @@ -162,8 +168,8 @@ def test_current_snapshot(table: Table) -> None: ) -def test_snapshot_by_id(table: Table) -> None: - assert table.snapshot_by_id(3055729675574597004) == Snapshot( +def test_snapshot_by_id(table_v2: Table) -> None: + assert table_v2.snapshot_by_id(3055729675574597004) == Snapshot( snapshot_id=3055729675574597004, parent_snapshot_id=3051729675574597004, sequence_number=1, @@ -174,12 +180,12 @@ def test_snapshot_by_id(table: Table) -> None: ) -def test_snapshot_by_id_does_not_exist(table: Table) -> None: - assert table.snapshot_by_id(-1) is None +def test_snapshot_by_id_does_not_exist(table_v2: Table) -> None: + assert table_v2.snapshot_by_id(-1) is None -def test_snapshot_by_name(table: Table) -> None: - assert table.snapshot_by_name("test") == Snapshot( +def test_snapshot_by_name(table_v2: Table) -> None: + assert table_v2.snapshot_by_name("test") == Snapshot( snapshot_id=3051729675574597004, parent_snapshot_id=None, sequence_number=0, @@ -190,11 +196,11 @@ def test_snapshot_by_name(table: Table) -> None: ) -def test_snapshot_by_name_does_not_exist(table: Table) -> None: - assert table.snapshot_by_name("doesnotexist") is None +def test_snapshot_by_name_does_not_exist(table_v2: Table) -> None: + assert table_v2.snapshot_by_name("doesnotexist") is None -def test_repr(table: Table) -> None: +def test_repr(table_v2: Table) -> None: expected = """table( 1: x: required long, 2: y: required long (comment), @@ -203,37 +209,37 @@ def test_repr(table: Table) -> None: partition by: [x], sort order: [2 ASC NULLS FIRST, bucket[4](3) DESC NULLS LAST], snapshot: Operation.APPEND: id=3055729675574597004, parent_id=3051729675574597004, schema_id=1""" - assert repr(table) == expected + assert repr(table_v2) == expected -def test_history(table: Table) -> None: - assert table.history() == [ +def test_history(table_v2: Table) -> None: + assert table_v2.history() == [ SnapshotLogEntry(snapshot_id=3051729675574597004, timestamp_ms=1515100955770), SnapshotLogEntry(snapshot_id=3055729675574597004, timestamp_ms=1555100955770), ] -def test_table_scan_select(table: Table) -> None: - scan = table.scan() +def test_table_scan_select(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.selected_fields == ("*",) assert scan.select("a", "b").selected_fields == ("a", "b") assert scan.select("a", "c").select("a").selected_fields == ("a",) -def test_table_scan_row_filter(table: Table) -> None: - scan = table.scan() +def test_table_scan_row_filter(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.row_filter == AlwaysTrue() assert scan.filter(EqualTo("x", 10)).row_filter == EqualTo("x", 10) assert scan.filter(EqualTo("x", 10)).filter(In("y", (10, 11))).row_filter == And(EqualTo("x", 10), In("y", (10, 11))) -def test_table_scan_ref(table: Table) -> None: - scan = table.scan() +def test_table_scan_ref(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.use_ref("test").snapshot_id == 3051729675574597004 -def test_table_scan_ref_does_not_exists(table: Table) -> None: - scan = table.scan() +def test_table_scan_ref_does_not_exists(table_v2: Table) -> None: + scan = table_v2.scan() with pytest.raises(ValueError) as exc_info: _ = scan.use_ref("boom") @@ -241,8 +247,8 @@ def test_table_scan_ref_does_not_exists(table: Table) -> None: assert "Cannot scan unknown ref=boom" in str(exc_info.value) -def test_table_scan_projection_full_schema(table: Table) -> None: - scan = table.scan() +def test_table_scan_projection_full_schema(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.select("x", "y", "z").projection() == Schema( NestedField(field_id=1, name="x", field_type=LongType(), required=True), NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), @@ -252,8 +258,8 @@ def test_table_scan_projection_full_schema(table: Table) -> None: ) -def test_table_scan_projection_single_column(table: Table) -> None: - scan = table.scan() +def test_table_scan_projection_single_column(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.select("y").projection() == Schema( NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), schema_id=1, @@ -261,8 +267,8 @@ def test_table_scan_projection_single_column(table: Table) -> None: ) -def test_table_scan_projection_single_column_case_sensitive(table: Table) -> None: - scan = table.scan() +def test_table_scan_projection_single_column_case_sensitive(table_v2: Table) -> None: + scan = table_v2.scan() assert scan.with_case_sensitive(False).select("Y").projection() == Schema( NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), schema_id=1, @@ -270,8 +276,8 @@ def test_table_scan_projection_single_column_case_sensitive(table: Table) -> Non ) -def test_table_scan_projection_unknown_column(table: Table) -> None: - scan = table.scan() +def test_table_scan_projection_unknown_column(table_v2: Table) -> None: + scan = table_v2.scan() with pytest.raises(ValueError) as exc_info: _ = scan.select("a").projection() @@ -279,16 +285,16 @@ def test_table_scan_projection_unknown_column(table: Table) -> None: assert "Could not find column: 'a'" in str(exc_info.value) -def test_static_table_same_as_table(table: Table, metadata_location: str) -> None: +def test_static_table_same_as_table(table_v2: Table, metadata_location: str) -> None: static_table = StaticTable.from_metadata(metadata_location) assert isinstance(static_table, Table) - assert static_table.metadata == table.metadata + assert static_table.metadata == table_v2.metadata -def test_static_table_gz_same_as_table(table: Table, metadata_location_gz: str) -> None: +def test_static_table_gz_same_as_table(table_v2: Table, metadata_location_gz: str) -> None: static_table = StaticTable.from_metadata(metadata_location_gz) assert isinstance(static_table, Table) - assert static_table.metadata == table.metadata + assert static_table.metadata == table_v2.metadata def test_static_table_io_does_not_exist(metadata_location: str) -> None: @@ -409,8 +415,8 @@ def test_serialize_set_properties_updates() -> None: assert SetPropertiesUpdate(updates={"abc": "🤪"}).model_dump_json() == """{"action":"set-properties","updates":{"abc":"🤪"}}""" -def test_add_column(table: Table) -> None: - update = UpdateSchema(table) +def test_add_column(table_v2: Table) -> None: + update = UpdateSchema(table_v2) update.add_column(path="b", field_type=IntegerType()) apply_schema: Schema = update._apply() # pylint: disable=W0212 assert len(apply_schema.fields) == 4 @@ -426,7 +432,7 @@ def test_add_column(table: Table) -> None: assert apply_schema.highest_field_id == 4 -def test_add_primitive_type_column(table: Table) -> None: +def test_add_primitive_type_column(table_v2: Table) -> None: primitive_type: Dict[str, PrimitiveType] = { "boolean": BooleanType(), "int": IntegerType(), @@ -444,7 +450,7 @@ def test_add_primitive_type_column(table: Table) -> None: for name, type_ in primitive_type.items(): field_name = f"new_column_{name}" - update = UpdateSchema(table) + update = UpdateSchema(table_v2) update.add_column(path=field_name, field_type=type_, doc=f"new_column_{name}") new_schema = update._apply() # pylint: disable=W0212 @@ -453,10 +459,10 @@ def test_add_primitive_type_column(table: Table) -> None: assert field.doc == f"new_column_{name}" -def test_add_nested_type_column(table: Table) -> None: +def test_add_nested_type_column(table_v2: Table) -> None: # add struct type column field_name = "new_column_struct" - update = UpdateSchema(table) + update = UpdateSchema(table_v2) struct_ = StructType( NestedField(1, "lat", DoubleType()), NestedField(2, "long", DoubleType()), @@ -471,10 +477,10 @@ def test_add_nested_type_column(table: Table) -> None: assert schema_.highest_field_id == 6 -def test_add_nested_map_type_column(table: Table) -> None: +def test_add_nested_map_type_column(table_v2: Table) -> None: # add map type column field_name = "new_column_map" - update = UpdateSchema(table) + update = UpdateSchema(table_v2) map_ = MapType(1, StringType(), 2, IntegerType(), False) update.add_column(path=field_name, field_type=map_) new_schema = update._apply() # pylint: disable=W0212 @@ -483,10 +489,10 @@ def test_add_nested_map_type_column(table: Table) -> None: assert new_schema.highest_field_id == 6 -def test_add_nested_list_type_column(table: Table) -> None: +def test_add_nested_list_type_column(table_v2: Table) -> None: # add list type column field_name = "new_column_list" - update = UpdateSchema(table) + update = UpdateSchema(table_v2) list_ = ListType( element_id=101, element_type=StructType( @@ -509,6 +515,160 @@ def test_add_nested_list_type_column(table: Table) -> None: assert new_schema.highest_field_id == 7 -def test_generate_snapshot_id(table: Table) -> None: +def test_apply_add_schema_update(table_v2: Table) -> None: + transaction = table_v2.transaction() + update = transaction.update_schema() + update.add_column(path="b", field_type=IntegerType()) + update.commit() + + test_context = _TableMetadataUpdateContext() + + new_table_metadata = _apply_table_update( + transaction._updates[0], base_metadata=table_v2.metadata, context=test_context + ) # pylint: disable=W0212 + assert len(new_table_metadata.schemas) == 3 + assert new_table_metadata.current_schema_id == 1 + assert len(test_context._updates) == 1 + assert test_context._updates[0] == transaction._updates[0] # pylint: disable=W0212 + assert test_context.is_added_schema(2) + + new_table_metadata = _apply_table_update( + transaction._updates[1], base_metadata=new_table_metadata, context=test_context + ) # pylint: disable=W0212 + assert len(new_table_metadata.schemas) == 3 + assert new_table_metadata.current_schema_id == 2 + assert len(test_context._updates) == 2 + assert test_context._updates[1] == transaction._updates[1] # pylint: disable=W0212 + assert test_context.is_added_schema(2) + + +def test_update_metadata_table_schema(table_v2: Table) -> None: + transaction = table_v2.transaction() + update = transaction.update_schema() + update.add_column(path="b", field_type=IntegerType()) + update.commit() + new_metadata = update_table_metadata(table_v2.metadata, transaction._updates) # pylint: disable=W0212 + apply_schema: Schema = next(schema for schema in new_metadata.schemas if schema.schema_id == 2) + assert len(apply_schema.fields) == 4 + + assert apply_schema == Schema( + NestedField(field_id=1, name="x", field_type=LongType(), required=True), + NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"), + NestedField(field_id=3, name="z", field_type=LongType(), required=True), + NestedField(field_id=4, name="b", field_type=IntegerType(), required=False), + identifier_field_ids=[1, 2], + ) + assert apply_schema.schema_id == 2 + assert apply_schema.highest_field_id == 4 + + assert new_metadata.current_schema_id == 2 + + +def test_update_metadata_add_snapshot(table_v2: Table) -> None: + new_snapshot = Snapshot( + snapshot_id=25, + parent_snapshot_id=19, + sequence_number=200, + timestamp_ms=1602638573590, + manifest_list="s3:/a/b/c.avro", + summary=Summary(Operation.APPEND), + schema_id=3, + ) + + new_metadata = update_table_metadata(table_v2.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),)) + assert len(new_metadata.snapshots) == 3 + assert new_metadata.snapshots[-1] == new_snapshot + assert new_metadata.last_sequence_number == new_snapshot.sequence_number + assert new_metadata.last_updated_ms == new_snapshot.timestamp_ms + + +def test_update_metadata_set_snapshot_ref(table_v2: Table) -> None: + update = SetSnapshotRefUpdate( + ref_name="main", + type="branch", + snapshot_id=3051729675574597004, + max_ref_age_ms=123123123, + max_snapshot_age_ms=12312312312, + min_snapshots_to_keep=1, + ) + + new_metadata = update_table_metadata(table_v2.metadata, (update,)) + assert len(new_metadata.snapshot_log) == 3 + assert new_metadata.snapshot_log[2].snapshot_id == 3051729675574597004 + assert new_metadata.current_snapshot_id == 3051729675574597004 + assert new_metadata.last_updated_ms > table_v2.metadata.last_updated_ms + assert new_metadata.refs[update.ref_name] == SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type="branch", + min_snapshots_to_keep=1, + max_snapshot_age_ms=12312312312, + max_ref_age_ms=123123123, + ) + + +def test_update_metadata_with_multiple_updates(table_v1: Table) -> None: + base_metadata = table_v1.metadata + transaction = table_v1.transaction() + transaction.upgrade_table_version(format_version=2) + + schema_update_1 = transaction.update_schema() + schema_update_1.add_column(path="b", field_type=IntegerType()) + schema_update_1.commit() + + test_updates = transaction._updates # pylint: disable=W0212 + + new_snapshot = Snapshot( + snapshot_id=25, + parent_snapshot_id=19, + sequence_number=200, + timestamp_ms=1602638573590, + manifest_list="s3:/a/b/c.avro", + summary=Summary(Operation.APPEND), + schema_id=3, + ) + + test_updates += ( + AddSnapshotUpdate(snapshot=new_snapshot), + SetSnapshotRefUpdate( + ref_name="main", + type="branch", + snapshot_id=25, + max_ref_age_ms=123123123, + max_snapshot_age_ms=12312312312, + min_snapshots_to_keep=1, + ), + ) + + new_metadata = update_table_metadata(base_metadata, test_updates) + + # UpgradeFormatVersionUpdate + assert new_metadata.format_version == 2 + + # UpdateSchema + assert len(new_metadata.schemas) == 2 + assert new_metadata.current_schema_id == 1 + assert new_metadata.schema_by_id(new_metadata.current_schema_id).highest_field_id == 4 # type: ignore + + # AddSchemaUpdate + assert len(new_metadata.snapshots) == 2 + assert new_metadata.snapshots[-1] == new_snapshot + assert new_metadata.last_sequence_number == new_snapshot.sequence_number + assert new_metadata.last_updated_ms == new_snapshot.timestamp_ms + + # SetSnapshotRefUpdate + assert len(new_metadata.snapshot_log) == 1 + assert new_metadata.snapshot_log[0].snapshot_id == 25 + assert new_metadata.current_snapshot_id == 25 + assert new_metadata.last_updated_ms == 1602638573590 + assert new_metadata.refs["main"] == SnapshotRef( + snapshot_id=25, + snapshot_ref_type="branch", + min_snapshots_to_keep=1, + max_snapshot_age_ms=12312312312, + max_ref_age_ms=123123123, + ) + + +def test_generate_snapshot_id(table_v2: Table) -> None: assert isinstance(_generate_snapshot_id(), int) - assert isinstance(table.new_snapshot_id(), int) + assert isinstance(table_v2.new_snapshot_id(), int) diff --git a/tests/table/test_metadata.py b/tests/table/test_metadata.py index 173e93e35d..2c453b6e03 100644 --- a/tests/table/test_metadata.py +++ b/tests/table/test_metadata.py @@ -52,31 +52,6 @@ StructType, ) -EXAMPLE_TABLE_METADATA_V1 = { - "format-version": 1, - "table-uuid": "d20125c8-7284-442c-9aea-15fee620737c", - "location": "s3://bucket/test/location", - "last-updated-ms": 1602638573874, - "last-column-id": 3, - "schema": { - "type": "struct", - "fields": [ - {"id": 1, "name": "x", "required": True, "type": "long"}, - {"id": 2, "name": "y", "required": True, "type": "long", "doc": "comment"}, - {"id": 3, "name": "z", "required": True, "type": "long"}, - ], - }, - "partition-spec": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}], - "properties": {}, - "current-snapshot-id": -1, - "snapshots": [{"snapshot-id": 1925, "timestamp-ms": 1602638573822}], -} - - -@pytest.fixture(scope="session") -def example_table_metadata_v1() -> Dict[str, Any]: - return EXAMPLE_TABLE_METADATA_V1 - def test_from_dict_v1(example_table_metadata_v1: Dict[str, Any]) -> None: """Test initialization of a TableMetadata instance from a dictionary""" diff --git a/tests/table/test_refs.py b/tests/table/test_refs.py index d106f0237a..e6b7006a99 100644 --- a/tests/table/test_refs.py +++ b/tests/table/test_refs.py @@ -15,6 +15,10 @@ # specific language governing permissions and limitations # under the License. # pylint:disable=eval-used +import pytest +from pydantic import ValidationError + +from pyiceberg import exceptions from pyiceberg.table.refs import SnapshotRef, SnapshotRefType @@ -32,3 +36,53 @@ def test_snapshot_with_properties_repr() -> None: == """SnapshotRef(snapshot_id=3051729675574597004, snapshot_ref_type=SnapshotRefType.TAG, min_snapshots_to_keep=None, max_snapshot_age_ms=None, max_ref_age_ms=10000000)""" ) assert snapshot_ref == eval(repr(snapshot_ref)) + + +def test_snapshot_with_invalid_field() -> None: + # min_snapshots_to_keep, if present, must be greater than 0 + with pytest.raises(ValidationError): + SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type=SnapshotRefType.TAG, + min_snapshots_to_keep=-1, + max_snapshot_age_ms=None, + max_ref_age_ms=10000000, + ) + + # max_snapshot_age_ms, if present, must be greater than 0 + with pytest.raises(ValidationError): + SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type=SnapshotRefType.TAG, + min_snapshots_to_keep=1, + max_snapshot_age_ms=-1, + max_ref_age_ms=10000000, + ) + + # max_ref_age_ms, if present, must be greater than 0 + with pytest.raises(ValidationError): + SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type=SnapshotRefType.TAG, + min_snapshots_to_keep=None, + max_snapshot_age_ms=None, + max_ref_age_ms=-1, + ) + + with pytest.raises(exceptions.ValidationError, match="Tags do not support setting minSnapshotsToKeep"): + SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type=SnapshotRefType.TAG, + min_snapshots_to_keep=1, + max_snapshot_age_ms=None, + max_ref_age_ms=10000000, + ) + + with pytest.raises(exceptions.ValidationError, match="Tags do not support setting maxSnapshotAgeMs"): + SnapshotRef( + snapshot_id=3051729675574597004, + snapshot_ref_type=SnapshotRefType.TAG, + min_snapshots_to_keep=None, + max_snapshot_age_ms=1, + max_ref_age_ms=100000, + )