From d0c964fa9d9785568957d37d7759b0d32e23030c Mon Sep 17 00:00:00 2001 From: amogh-jahagirdar Date: Sat, 10 Feb 2024 09:21:11 -0800 Subject: [PATCH] Fix setting V1 format version for Non-REST catalogs --- pyiceberg/table/metadata.py | 55 ++++++++++++++++++++++++++++-- tests/catalog/test_glue.py | 32 ++++++++++++++++++ tests/catalog/test_hive.py | 55 +++++++++++++++++++++++++++++- tests/catalog/test_sql.py | 19 +++++++++++ tests/table/test_metadata.py | 65 ++++++++++++++++++++++++++++++++++++ 5 files changed, 222 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 43e29c7b03..2b96be2b02 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -260,8 +260,10 @@ def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]: The TableMetadata with the defaults applied. """ # When the schema doesn't have an ID - if data.get("schema") and "schema_id" not in data["schema"]: - data["schema"]["schema_id"] = DEFAULT_SCHEMA_ID + schema = data.get("schema") + if isinstance(schema, dict): + if "schema_id" not in schema: + schema["schema_id"] = DEFAULT_SCHEMA_ID return data @@ -313,6 +315,34 @@ def construct_partition_specs(cls, data: Dict[str, Any]) -> Dict[str, Any]: return data + @model_validator(mode="before") + def construct_v1_spec_from_v2_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]: + specs_field = "partition_specs" + default_spec_id_field = "default_spec_id" + if specs_field in data and default_spec_id_field in data: + found_spec = next((spec for spec in data[specs_field] if spec.spec_id == data[default_spec_id_field]), None) + if found_spec is not None: + spec_dict = found_spec.model_dump() + spec_dict['fields'] = list(spec_dict['fields']) + data["partition_spec"] = [spec_dict] + return data + + return data + + @model_validator(mode="before") + def construct_v1_schema_from_v2_fields(cls, data: Dict[str, Any]) -> Dict[str, Any]: + schemas_field = "schemas" + current_schema_id_field = "current_schema_id" + if schemas_field in data and current_schema_id_field in data: + found_schema = next( + (schema for schema in data[schemas_field] if schema.schema_id == data[current_schema_id_field]), None + ) + if found_schema is not None: + data["schema"] = found_schema + return data + + return data + @model_validator(mode="before") def set_sort_orders(cls, data: Dict[str, Any]) -> Dict[str, Any]: """Set the sort_orders if not provided. @@ -335,7 +365,7 @@ def to_v2(self) -> TableMetadataV2: metadata["format-version"] = 2 return TableMetadataV2.model_validate(metadata) - format_version: Literal[1] = Field(alias="format-version") + format_version: Literal[1] = Field(alias="format-version", default=1) """An integer version number for the format. Currently, this can be 1 or 2 based on the spec. Implementations must throw an exception if a table’s version is higher than the supported version.""" @@ -394,6 +424,7 @@ def construct_refs(cls, table_metadata: TableMetadata) -> TableMetadata: TableMetadata = Annotated[Union[TableMetadataV1, TableMetadataV2], Field(discriminator="format_version")] +DEFAULT_FORMAT_VERSION = "2" def new_table_metadata( @@ -411,6 +442,24 @@ def new_table_metadata( if table_uuid is None: table_uuid = uuid.uuid4() + # Remove format-version so it does not get persisted + format_version = int(properties.pop("format-version", DEFAULT_FORMAT_VERSION)) + + if format_version == 1: + return TableMetadataV1( + location=location, + schema=fresh_schema, + last_column_id=fresh_schema.highest_field_id, + current_schema_id=fresh_schema.schema_id, + partition_specs=[fresh_partition_spec], + default_spec_id=fresh_partition_spec.spec_id, + sort_orders=[fresh_sort_order], + default_sort_order_id=fresh_sort_order.order_id, + properties=properties, + last_partition_id=fresh_partition_spec.last_assigned_field_id, + table_uuid=table_uuid, + ) + return TableMetadataV2( location=location, schemas=[fresh_schema], diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py index 270d2251ba..6e0196c1a2 100644 --- a/tests/catalog/test_glue.py +++ b/tests/catalog/test_glue.py @@ -72,6 +72,38 @@ def test_create_table_with_database_location( assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}" +@mock_aws +def test_create_v1_table( + _bucket_initialize: None, + _glue: boto3.client, + moto_endpoint_url: str, + table_schema_nested: Schema, + database_name: str, + table_name: str, +) -> None: + catalog_name = "glue" + test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url}) + test_catalog.create_namespace(namespace=database_name, properties={"location": f"s3://{BUCKET_NAME}/{database_name}.db"}) + table = test_catalog.create_table((database_name, table_name), table_schema_nested, properties={"format-version": "1"}) + assert table.format_version == 1 + + table_info = _glue.get_table( + DatabaseName=database_name, + Name=table_name, + ) + + storage_descriptor = table_info["Table"]["StorageDescriptor"] + columns = storage_descriptor["Columns"] + assert len(columns) == len(table_schema_nested.fields) + assert columns[0] == { + "Name": "foo", + "Type": "string", + "Parameters": {"iceberg.field.id": "1", "iceberg.field.optional": "true", "iceberg.field.current": "true"}, + } + + assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}" + + @mock_aws def test_create_table_with_default_warehouse( _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py index f42962f0f3..11b0bcbdfe 100644 --- a/tests/catalog/test_hive.py +++ b/tests/catalog/test_hive.py @@ -43,7 +43,7 @@ ) from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV2 +from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV1, TableMetadataV2 from pyiceberg.table.refs import SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( MetadataLogEntry, @@ -295,6 +295,59 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase, assert metadata.model_dump() == expected.model_dump() +@patch("time.time", MagicMock(return_value=12345)) +def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None: + catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL) + + catalog._client = MagicMock() + catalog._client.__enter__().create_table.return_value = None + catalog._client.__enter__().get_table.return_value = hive_table + catalog._client.__enter__().get_database.return_value = hive_database + catalog.create_table( + ("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"} + ) + + # Test creating V1 table + called_v1_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0] + metadata_location = called_v1_table.parameters["metadata_location"] + with open(metadata_location, encoding=UTF8) as f: + payload = f.read() + + actual_v1_metadata = TableMetadataUtil.parse_raw(payload) + expected_v1_metadata = TableMetadataV1( + location=actual_v1_metadata.location, + table_uuid=actual_v1_metadata.table_uuid, + last_updated_ms=actual_v1_metadata.last_updated_ms, + last_column_id=3, + schemas=[ + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=0, + identifier_field_ids=[2], + ) + ], + current_schema_id=0, + last_partition_id=1000, + properties={"owner": "javaberg", "write.parquet.compression-codec": "zstd"}, + partition_specs=[PartitionSpec()], + default_spec_id=0, + current_snapshot_id=None, + snapshots=[], + snapshot_log=[], + metadata_log=[], + sort_orders=[SortOrder(order_id=0)], + default_sort_order_id=0, + refs={}, + format_version=1, + ) + + print(f"Actual is {actual_v1_metadata.model_dump()}\n, Expected {expected_v1_metadata.model_dump()}\n") + + assert actual_v1_metadata.model_dump() == expected_v1_metadata.model_dump() + + def test_load_table(hive_table: HiveTable) -> None: catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL) diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 1ca8fd16d2..1d127378be 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -38,6 +38,7 @@ ) from pyiceberg.io import FSSPEC_FILE_IO, PY_IO_IMPL from pyiceberg.io.pyarrow import schema_to_pyarrow +from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC from pyiceberg.schema import Schema from pyiceberg.table.snapshots import Operation from pyiceberg.table.sorting import ( @@ -158,6 +159,24 @@ def test_create_table_default_sort_order(catalog: SqlCatalog, table_schema_neste catalog.drop_table(random_identifier) +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + lazy_fixture('catalog_sqlite'), + ], +) +def test_create_v1_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None: + database_name, _table_name = random_identifier + catalog.create_namespace(database_name) + table = catalog.create_table(random_identifier, table_schema_nested, properties={"format-version": "1"}) + assert table.sort_order().order_id == 0, "Order ID must match" + assert table.sort_order().is_unsorted is True, "Order must be unsorted" + assert table.format_version == 1 + assert table.spec() == UNPARTITIONED_PARTITION_SPEC + catalog.drop_table(random_identifier) + + @pytest.mark.parametrize( 'catalog', [ diff --git a/tests/table/test_metadata.py b/tests/table/test_metadata.py index 2c453b6e03..98f50e7d7e 100644 --- a/tests/table/test_metadata.py +++ b/tests/table/test_metadata.py @@ -199,6 +199,71 @@ def test_migrate_v1_partition_specs(example_table_metadata_v1: Dict[str, Any]) - ] +def test_new_table_metadata_with_explicit_v1_format() -> None: + schema = Schema( + NestedField(field_id=10, name="foo", field_type=StringType(), required=False), + NestedField(field_id=22, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=33, name="baz", field_type=BooleanType(), required=False), + schema_id=10, + identifier_field_ids=[22], + ) + + partition_spec = PartitionSpec( + PartitionField(source_id=22, field_id=1022, transform=IdentityTransform(), name="bar"), spec_id=10 + ) + + sort_order = SortOrder( + SortField(source_id=10, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST), + order_id=10, + ) + + actual = new_table_metadata( + schema=schema, + partition_spec=partition_spec, + sort_order=sort_order, + location="s3://some_v1_location/", + properties={'format-version': "1"}, + ) + + expected = TableMetadataV1( + location="s3://some_v1_location/", + table_uuid=actual.table_uuid, + last_updated_ms=actual.last_updated_ms, + last_column_id=3, + schemas=[ + Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=0, + identifier_field_ids=[2], + ) + ], + current_schema_id=0, + partition_specs=[PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="bar"))], + default_spec_id=0, + last_partition_id=1000, + properties={}, + current_snapshot_id=None, + snapshots=[], + snapshot_log=[], + metadata_log=[], + sort_orders=[ + SortOrder( + SortField( + source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST + ), + order_id=1, + ) + ], + default_sort_order_id=1, + refs={}, + format_version=1, + ) + + assert actual.model_dump() == expected.model_dump() + + def test_invalid_format_version(example_table_metadata_v1: Dict[str, Any]) -> None: """Test the exception when trying to load an unknown version"""