From e3f5dcb81d2426da919ac8526c536448ad801520 Mon Sep 17 00:00:00 2001 From: amogh-jahagirdar Date: Sat, 10 Feb 2024 09:21:11 -0800 Subject: [PATCH] Fix setting V1 format version for Non-REST catalogs --- pyiceberg/table/metadata.py | 55 ++++++++++++++++++++++++++++-- tests/catalog/test_hive.py | 56 +++++++++++++++++++++++++++++- tests/table/test_metadata.py | 66 ++++++++++++++++++++++++++++++++++++ 3 files changed, 173 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 43e29c7b03..c03cf919c6 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -260,8 +260,10 @@ def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]: The TableMetadata with the defaults applied. """ # When the schema doesn't have an ID - if data.get("schema") and "schema_id" not in data["schema"]: - data["schema"]["schema_id"] = DEFAULT_SCHEMA_ID + schema = data.get("schema") + if isinstance(schema, dict): + if "schema_id" not in schema: + schema["schema_id"] = DEFAULT_SCHEMA_ID return data @@ -313,6 +315,34 @@ def construct_partition_specs(cls, data: Dict[str, Any]) -> Dict[str, Any]: return data + @model_validator(mode="before") + def construct_v1_spec_from_v2_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]: + specs_field = "partition_specs" + default_spec_id_field = "default_spec_id" + if specs_field in data and default_spec_id_field in data: + specs = data[specs_field] + spec_id = data[default_spec_id_field] + for spec in specs: + if spec.spec_id == spec_id: + data["partition_spec"] = [spec.model_dump()] + return data + + return data + + @model_validator(mode="before") + def construct_v1_schema_from_v2_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]: + schemas_field = "schemas" + current_schema_id_field = "current_schema_id" + if schemas_field in data and current_schema_id_field in data: + schemas = data[schemas_field] + current_schema_id = data[current_schema_id_field] + for schema in schemas: + if schema.schema_id == current_schema_id: + data["schema"] = schema + return data + + return data + @model_validator(mode="before") def set_sort_orders(cls, data: Dict[str, Any]) -> Dict[str, Any]: """Set the sort_orders if not provided. @@ -335,7 +365,7 @@ def to_v2(self) -> TableMetadataV2: metadata["format-version"] = 2 return TableMetadataV2.model_validate(metadata) - format_version: Literal[1] = Field(alias="format-version") + format_version: Literal[1] = Field(alias="format-version", default=1) """An integer version number for the format. Currently, this can be 1 or 2 based on the spec. Implementations must throw an exception if a table’s version is higher than the supported version.""" @@ -394,6 +424,7 @@ def construct_refs(cls, table_metadata: TableMetadata) -> TableMetadata: TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2], Field(discriminator="format_version")] +DEFAULT_FORMAT_VERSION = "2" def new_table_metadata( @@ -411,6 +442,24 @@ def new_table_metadata( if table_uuid is None: table_uuid = uuid.uuid4() + # Remove format-version so it does not get persisted + format_version = int(properties.pop("format-version", DEFAULT_FORMAT_VERSION)) + + if format_version == 1: + return TableMetadataV1( + location=location, + schema=fresh_schema, + last_column_id=fresh_schema.highest_field_id, + current_schema_id=fresh_schema.schema_id, + partition_specs=[fresh_partition_spec], + default_spec_id=fresh_partition_spec.spec_id, + sort_orders=[fresh_sort_order], + default_sort_order_id=fresh_sort_order.order_id, + properties=properties, + last_partition_id=fresh_partition_spec.last_assigned_field_id, + table_uuid=table_uuid, + ) + return TableMetadataV2( location=location, schemas=[fresh_schema], diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py index f42962f0f3..66ecc9ce92 100644 --- a/tests/catalog/test_hive.py +++ b/tests/catalog/test_hive.py @@ -43,7 +43,7 @@ ) from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV2 +from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV1, TableMetadataV2 from pyiceberg.table.refs import SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( MetadataLogEntry, @@ -294,6 +294,60 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase, assert metadata.model_dump() == expected.model_dump() + catalog.create_table( + ("default", "table_v1"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"} + ) + + +@patch("time.time", MagicMock(return_value=12345)) +def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None: + catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL) + + catalog._client = MagicMock() + catalog._client.__enter__().create_table.return_value = None + catalog._client.__enter__().get_table.return_value = hive_table + catalog._client.__enter__().get_database.return_value = hive_database + catalog.create_table( + ("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"} + ) + + # Test creating V1 table + called_v1_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0] + metadata_location = called_v1_table.parameters["metadata_location"] + with open(metadata_location, encoding=UTF8) as f: + payload = f.read() + + actual_v1_metadata = TableMetadataUtil.parse_raw(payload) + expected_v1_metadata = TableMetadataV1( + location=actual_v1_metadata.location, + table_uuid=actual_v1_metadata.table_uuid, + last_updated_ms=actual_v1_metadata.last_updated_ms, + last_column_id=3, + schemas=[ + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=0, + identifier_field_ids=[2], + ) + ], + current_schema_id=0, + last_partition_id=1000, + properties={"owner": "javaberg", "write.parquet.compression-codec": "zstd"}, + partition_spec=[{'fields': [], 'spec-id': 0}], + current_snapshot_id=None, + snapshots=[], + snapshot_log=[], + metadata_log=[], + sort_orders=[SortOrder(order_id=0)], + default_sort_order_id=0, + refs={}, + format_version=1, + ) + + assert actual_v1_metadata.model_dump() == expected_v1_metadata.model_dump() + def test_load_table(hive_table: HiveTable) -> None: catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL) diff --git a/tests/table/test_metadata.py b/tests/table/test_metadata.py index 2c453b6e03..e4f576982c 100644 --- a/tests/table/test_metadata.py +++ b/tests/table/test_metadata.py @@ -199,6 +199,72 @@ def test_migrate_v1_partition_specs(example_table_metadata_v1: Dict[str, Any]) - ] +def test_new_table_metadata_with_explicit_v1_format() -> None: + schema = Schema( + NestedField(field_id=10, name="foo", field_type=StringType(), required=False), + NestedField(field_id=22, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=33, name="baz", field_type=BooleanType(), required=False), + schema_id=10, + identifier_field_ids=[22], + ) + + partition_spec = PartitionSpec( + PartitionField(source_id=22, field_id=1022, transform=IdentityTransform(), name="bar"), spec_id=10 + ) + + sort_order = SortOrder( + SortField(source_id=10, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST), + order_id=10, + ) + + actual = new_table_metadata( + schema=schema, + partition_spec=partition_spec, + sort_order=sort_order, + location="s3://some_v1_location/", + properties={'format-version': "1"}, + ) + + expected = TableMetadataV1( + location="s3://some_v1_location/", + table_uuid=actual.table_uuid, + last_updated_ms=actual.last_updated_ms, + last_column_id=3, + schemas=[ + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=0, + identifier_field_ids=[2], + ) + ], + current_schema_id=0, + partition_specs=[PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="bar"))], + default_spec_id=0, + last_partition_id=1000, + properties={}, + current_snapshot_id=None, + snapshots=[], + snapshot_log=[], + metadata_log=[], + sort_orders=[ + SortOrder( + SortField( + source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST + ), + order_id=1, + ) + ], + default_sort_order_id=1, + refs={}, + format_version=1, + last_sequence_number=0, + ) + + assert actual.model_dump() == expected.model_dump() + + def test_invalid_format_version(example_table_metadata_v1: Dict[str, Any]) -> None: """Test the exception when trying to load an unknown version"""