diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index fa04fb7b24..1a4c39cfd2 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -19,6 +19,7 @@ import datetime import itertools import uuid +import warnings from abc import ABC, abstractmethod from copy import copy from dataclasses import dataclass @@ -942,15 +943,23 @@ def snapshot(self) -> Optional[Snapshot]: return self.table.current_snapshot() def projection(self) -> Schema: - snapshot_schema = self.table.schema() - if snapshot := self.snapshot(): - if snapshot.schema_id is not None: - snapshot_schema = self.table.schemas()[snapshot.schema_id] + current_schema = self.table.schema() + if self.snapshot_id is not None: + snapshot = self.table.snapshot_by_id(self.snapshot_id) + if snapshot is not None: + if snapshot.schema_id is not None: + snapshot_schema = self.table.schemas().get(snapshot.schema_id) + if snapshot_schema is not None: + current_schema = snapshot_schema + else: + warnings.warn(f"Metadata does not contain schema with id: {snapshot.schema_id}") + else: + raise ValueError(f"Snapshot not found: {self.snapshot_id}") if "*" in self.selected_fields: - return snapshot_schema + return current_schema - return snapshot_schema.select(*self.selected_fields, case_sensitive=self.case_sensitive) + return current_schema.select(*self.selected_fields, case_sensitive=self.case_sensitive) @abstractmethod def plan_files(self) -> Iterable[ScanTask]: diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 04d467c318..547990c4c8 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -22,6 +22,7 @@ import pytest from sortedcontainers import SortedList +from pyiceberg.catalog.noop import NoopCatalog from pyiceberg.exceptions import CommitFailedException from pyiceberg.expressions import ( AlwaysTrue, @@ -29,7 +30,7 @@ EqualTo, In, ) -from pyiceberg.io import PY_IO_IMPL +from pyiceberg.io import PY_IO_IMPL, load_file_io from pyiceberg.manifest import ( DataFile, DataFileContent, @@ -848,3 +849,89 @@ def test_assert_default_sort_order_id(table_v2: Table) -> None: match="Requirement failed: default sort order id has changed: expected 1, found 3", ): AssertDefaultSortOrderId(default_sort_order_id=1).validate(base_metadata) + + +def test_correct_schema() -> None: + table_metadata = TableMetadataV2( + **{ + "format-version": 2, + "table-uuid": "9c12d441-03fe-4693-9a96-a0705ddf69c1", + "location": "s3://bucket/test/location", + "last-sequence-number": 34, + "last-updated-ms": 1602638573590, + "last-column-id": 3, + "current-schema-id": 1, + "schemas": [ + {"type": "struct", "schema-id": 0, "fields": [{"id": 1, "name": "x", "required": True, "type": "long"}]}, + { + "type": "struct", + "schema-id": 1, + "identifier-field-ids": [1, 2], + "fields": [ + {"id": 1, "name": "x", "required": True, "type": "long"}, + {"id": 2, "name": "y", "required": True, "type": "long"}, + {"id": 3, "name": "z", "required": True, "type": "long"}, + ], + }, + ], + "default-spec-id": 0, + "partition-specs": [ + {"spec-id": 0, "fields": [{"name": "x", "transform": "identity", "source-id": 1, "field-id": 1000}]} + ], + "last-partition-id": 1000, + "default-sort-order-id": 0, + "sort-orders": [], + "current-snapshot-id": 123, + "snapshots": [ + { + "snapshot-id": 234, + "timestamp-ms": 1515100955770, + "sequence-number": 0, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/b/1.avro", + "schema-id": 10, + }, + { + "snapshot-id": 123, + "timestamp-ms": 1515100955770, + "sequence-number": 0, + "summary": {"operation": "append"}, + "manifest-list": "s3://a/b/1.avro", + "schema-id": 0, + }, + ], + } + ) + + t = Table( + identifier=("default", "t1"), + metadata=table_metadata, + metadata_location="s3://../..", + io=load_file_io(), + catalog=NoopCatalog("NoopCatalog"), + ) + + # Should use the current schema, instead the one from the snapshot + assert t.scan().projection() == Schema( + NestedField(field_id=1, name='x', field_type=LongType(), required=True), + NestedField(field_id=2, name='y', field_type=LongType(), required=True), + NestedField(field_id=3, name='z', field_type=LongType(), required=True), + schema_id=1, + identifier_field_ids=[1, 2], + ) + + # When we explicitly filter on the commit, we want to have the schema that's linked to the snapshot + assert t.scan(snapshot_id=123).projection() == Schema( + NestedField(field_id=1, name='x', field_type=LongType(), required=True), + schema_id=0, + identifier_field_ids=[], + ) + + with pytest.warns(UserWarning, match="Metadata does not contain schema with id: 10"): + t.scan(snapshot_id=234).projection() + + # Invalid snapshot + with pytest.raises(ValueError) as exc_info: + _ = t.scan(snapshot_id=-1).projection() + + assert "Snapshot not found: -1" in str(exc_info.value)