Skip to content

Commit

Permalink
Fix setting V1 format version for Non-REST catalogs
Browse files Browse the repository at this point in the history
  • Loading branch information
amogh-jahagirdar committed Feb 11, 2024
1 parent a576fc9 commit 8c5f085
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 4 deletions.
55 changes: 52 additions & 3 deletions pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,10 @@ def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]:
The TableMetadata with the defaults applied.
"""
# When the schema doesn't have an ID
if data.get("schema") and "schema_id" not in data["schema"]:
data["schema"]["schema_id"] = DEFAULT_SCHEMA_ID
schema = data.get("schema")
if isinstance(schema, dict):
if "schema_id" not in schema:
schema["schema_id"] = DEFAULT_SCHEMA_ID

return data

Expand Down Expand Up @@ -313,6 +315,34 @@ def construct_partition_specs(cls, data: Dict[str, Any]) -> Dict[str, Any]:

return data

@model_validator(mode="before")
def construct_v1_spec_from_v2_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]:
specs_field = "partition_specs"
default_spec_id_field = "default_spec_id"
if specs_field in data and default_spec_id_field in data:
specs = data[specs_field]
spec_id = data[default_spec_id_field]
for spec in specs:
if spec.spec_id == spec_id:
data["partition_spec"] = [spec.model_dump()]
return data

return data

@model_validator(mode="before")
def construct_v1_schema_from_v2_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]:
schemas_field = "schemas"
current_schema_id_field = "current_schema_id"
if schemas_field in data and current_schema_id_field in data:
schemas = data[schemas_field]
current_schema_id = data[current_schema_id_field]
for schema in schemas:
if schema.schema_id == current_schema_id:
data["schema"] = schema
return data

return data

@model_validator(mode="before")
def set_sort_orders(cls, data: Dict[str, Any]) -> Dict[str, Any]:
"""Set the sort_orders if not provided.
Expand All @@ -335,7 +365,7 @@ def to_v2(self) -> TableMetadataV2:
metadata["format-version"] = 2
return TableMetadataV2.model_validate(metadata)

format_version: Literal[1] = Field(alias="format-version")
format_version: Literal[1] = Field(alias="format-version", default=1)
"""An integer version number for the format. Currently, this can be 1 or 2
based on the spec. Implementations must throw an exception if a table’s
version is higher than the supported version."""
Expand Down Expand Up @@ -394,6 +424,7 @@ def construct_refs(cls, table_metadata: TableMetadata) -> TableMetadata:


TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2], Field(discriminator="format_version")]
DEFAULT_FORMAT_VERSION = "2"


def new_table_metadata(
Expand All @@ -411,6 +442,24 @@ def new_table_metadata(
if table_uuid is None:
table_uuid = uuid.uuid4()

# Remove format-version so it does not get persisted
format_version = properties.pop("format-version", DEFAULT_FORMAT_VERSION)

if int(format_version) == 1:
return TableMetadataV1(
location=location,
schema=fresh_schema,
last_column_id=fresh_schema.highest_field_id,
current_schema_id=fresh_schema.schema_id,
partition_specs=[fresh_partition_spec],
default_spec_id=fresh_partition_spec.spec_id,
sort_orders=[fresh_sort_order],
default_sort_order_id=fresh_sort_order.order_id,
properties=properties,
last_partition_id=fresh_partition_spec.last_assigned_field_id,
table_uuid=table_uuid,
)

return TableMetadataV2(
location=location,
schemas=[fresh_schema],
Expand Down
56 changes: 55 additions & 1 deletion tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV2
from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV1, TableMetadataV2
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import (
MetadataLogEntry,
Expand Down Expand Up @@ -294,6 +294,60 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase,

assert metadata.model_dump() == expected.model_dump()

catalog.create_table(
("default", "table_v1"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"}
)


@patch("time.time", MagicMock(return_value=12345))
def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

catalog._client = MagicMock()
catalog._client.__enter__().create_table.return_value = None
catalog._client.__enter__().get_table.return_value = hive_table
catalog._client.__enter__().get_database.return_value = hive_database
catalog.create_table(
("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"}
)

# Test creating V1 table
called_v1_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0]
metadata_location = called_v1_table.parameters["metadata_location"]
with open(metadata_location, encoding=UTF8) as f:
payload = f.read()

actual_v1_metadata = TableMetadataUtil.parse_raw(payload)
expected_v1_metadata = TableMetadataV1(
location=actual_v1_metadata.location,
table_uuid=actual_v1_metadata.table_uuid,
last_updated_ms=actual_v1_metadata.last_updated_ms,
last_column_id=3,
schemas=[
Schema(
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
schema_id=0,
identifier_field_ids=[2],
)
],
current_schema_id=0,
last_partition_id=1000,
properties={"owner": "javaberg", "write.parquet.compression-codec": "zstd"},
partition_spec=[{'fields': [], 'spec-id': 0}],
current_snapshot_id=None,
snapshots=[],
snapshot_log=[],
metadata_log=[],
sort_orders=[SortOrder(order_id=0)],
default_sort_order_id=0,
refs={},
format_version=1,
)

assert actual_v1_metadata.model_dump() == expected_v1_metadata.model_dump()


def test_load_table(hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
Expand Down
66 changes: 66 additions & 0 deletions tests/table/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,72 @@ def test_migrate_v1_partition_specs(example_table_metadata_v1: Dict[str, Any]) -
]


def test_new_table_metadata_with_explicit_v1_format() -> None:
schema = Schema(
NestedField(field_id=10, name="foo", field_type=StringType(), required=False),
NestedField(field_id=22, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=33, name="baz", field_type=BooleanType(), required=False),
schema_id=10,
identifier_field_ids=[22],
)

partition_spec = PartitionSpec(
PartitionField(source_id=22, field_id=1022, transform=IdentityTransform(), name="bar"), spec_id=10
)

sort_order = SortOrder(
SortField(source_id=10, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST),
order_id=10,
)

actual = new_table_metadata(
schema=schema,
partition_spec=partition_spec,
sort_order=sort_order,
location="s3://some_v1_location/",
properties={'format-version': "1"},
)

expected = TableMetadataV1(
location="s3://some_v1_location/",
table_uuid=actual.table_uuid,
last_updated_ms=actual.last_updated_ms,
last_column_id=3,
schemas=[
Schema(
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
schema_id=0,
identifier_field_ids=[2],
)
],
current_schema_id=0,
partition_specs=[PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="bar"))],
default_spec_id=0,
last_partition_id=1000,
properties={},
current_snapshot_id=None,
snapshots=[],
snapshot_log=[],
metadata_log=[],
sort_orders=[
SortOrder(
SortField(
source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST
),
order_id=1,
)
],
default_sort_order_id=1,
refs={},
format_version=1,
last_sequence_number=0,
)

assert actual.model_dump() == expected.model_dump()


def test_invalid_format_version(example_table_metadata_v1: Dict[str, Any]) -> None:
"""Test the exception when trying to load an unknown version"""

Expand Down

0 comments on commit 8c5f085

Please sign in to comment.