From dab5d768d17cb60b443e299b7a945dece559f7b1 Mon Sep 17 00:00:00 2001 From: Amogh Jahagirdar Date: Mon, 12 Feb 2024 03:05:30 -0700 Subject: [PATCH] Fix setting V1 format version for Non-REST catalogs (#411) --- pyiceberg/table/__init__.py | 2 ++ pyiceberg/table/metadata.py | 28 +++++++++++++-- tests/catalog/test_glue.py | 32 +++++++++++++++++ tests/catalog/test_hive.py | 55 +++++++++++++++++++++++++++- tests/catalog/test_sql.py | 19 ++++++++++ tests/table/test_metadata.py | 69 ++++++++++++++++++++++++++++++++++++ 6 files changed, 201 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index cc4bbf52a3..a87435fcfb 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -166,6 +166,8 @@ class TableProperties: METRICS_MODE_COLUMN_CONF_PREFIX = "write.metadata.metrics.column" DEFAULT_NAME_MAPPING = "schema.name-mapping.default" + FORMAT_VERSION = "format-version" + DEFAULT_FORMAT_VERSION = 2 class PropertyUtil: diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 43e29c7b03..a5dfb6ce4c 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -260,8 +260,10 @@ def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]: The TableMetadata with the defaults applied. """ # When the schema doesn't have an ID - if data.get("schema") and "schema_id" not in data["schema"]: - data["schema"]["schema_id"] = DEFAULT_SCHEMA_ID + schema = data.get("schema") + if isinstance(schema, dict): + if "schema_id" not in schema and "schema-id" not in schema: + schema["schema_id"] = DEFAULT_SCHEMA_ID return data @@ -335,7 +337,7 @@ def to_v2(self) -> TableMetadataV2: metadata["format-version"] = 2 return TableMetadataV2.model_validate(metadata) - format_version: Literal[1] = Field(alias="format-version") + format_version: Literal[1] = Field(alias="format-version", default=1) """An integer version number for the format. Currently, this can be 1 or 2 based on the spec. Implementations must throw an exception if a table’s version is higher than the supported version.""" @@ -404,6 +406,8 @@ def new_table_metadata( properties: Properties = EMPTY_DICT, table_uuid: Optional[uuid.UUID] = None, ) -> TableMetadata: + from pyiceberg.table import TableProperties + fresh_schema = assign_fresh_schema_ids(schema) fresh_partition_spec = assign_fresh_partition_spec_ids(partition_spec, schema, fresh_schema) fresh_sort_order = assign_fresh_sort_order_ids(sort_order, schema, fresh_schema) @@ -411,6 +415,24 @@ def new_table_metadata( if table_uuid is None: table_uuid = uuid.uuid4() + # Remove format-version so it does not get persisted + format_version = int(properties.pop(TableProperties.FORMAT_VERSION, TableProperties.DEFAULT_FORMAT_VERSION)) + if format_version == 1: + return TableMetadataV1( + location=location, + last_column_id=fresh_schema.highest_field_id, + current_schema_id=fresh_schema.schema_id, + schema=fresh_schema, + partition_spec=[field.model_dump() for field in fresh_partition_spec.fields], + partition_specs=[fresh_partition_spec], + default_spec_id=fresh_partition_spec.spec_id, + sort_orders=[fresh_sort_order], + default_sort_order_id=fresh_sort_order.order_id, + properties=properties, + last_partition_id=fresh_partition_spec.last_assigned_field_id, + table_uuid=table_uuid, + ) + return TableMetadataV2( location=location, schemas=[fresh_schema], diff --git a/tests/catalog/test_glue.py b/tests/catalog/test_glue.py index 270d2251ba..6e0196c1a2 100644 --- a/tests/catalog/test_glue.py +++ b/tests/catalog/test_glue.py @@ -72,6 +72,38 @@ def test_create_table_with_database_location( assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}" +@mock_aws +def test_create_v1_table( + _bucket_initialize: None, + _glue: boto3.client, + moto_endpoint_url: str, + table_schema_nested: Schema, + database_name: str, + table_name: str, +) -> None: + catalog_name = "glue" + test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url}) + test_catalog.create_namespace(namespace=database_name, properties={"location": f"s3://{BUCKET_NAME}/{database_name}.db"}) + table = test_catalog.create_table((database_name, table_name), table_schema_nested, properties={"format-version": "1"}) + assert table.format_version == 1 + + table_info = _glue.get_table( + DatabaseName=database_name, + Name=table_name, + ) + + storage_descriptor = table_info["Table"]["StorageDescriptor"] + columns = storage_descriptor["Columns"] + assert len(columns) == len(table_schema_nested.fields) + assert columns[0] == { + "Name": "foo", + "Type": "string", + "Parameters": {"iceberg.field.id": "1", "iceberg.field.optional": "true", "iceberg.field.current": "true"}, + } + + assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}" + + @mock_aws def test_create_table_with_default_warehouse( _bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str diff --git a/tests/catalog/test_hive.py b/tests/catalog/test_hive.py index f42962f0f3..dc2689e0d8 100644 --- a/tests/catalog/test_hive.py +++ b/tests/catalog/test_hive.py @@ -43,7 +43,7 @@ ) from pyiceberg.partitioning import PartitionField, PartitionSpec from pyiceberg.schema import Schema -from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV2 +from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV1, TableMetadataV2 from pyiceberg.table.refs import SnapshotRef, SnapshotRefType from pyiceberg.table.snapshots import ( MetadataLogEntry, @@ -295,6 +295,59 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase, assert metadata.model_dump() == expected.model_dump() +@patch("time.time", MagicMock(return_value=12345)) +def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None: + catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL) + + catalog._client = MagicMock() + catalog._client.__enter__().create_table.return_value = None + catalog._client.__enter__().get_table.return_value = hive_table + catalog._client.__enter__().get_database.return_value = hive_database + catalog.create_table( + ("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"} + ) + + # Test creating V1 table + called_v1_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0] + metadata_location = called_v1_table.parameters["metadata_location"] + with open(metadata_location, encoding=UTF8) as f: + payload = f.read() + + expected_schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=0, + identifier_field_ids=[2], + ) + actual_v1_metadata = TableMetadataUtil.parse_raw(payload) + expected_spec = PartitionSpec() + expected_v1_metadata = TableMetadataV1( + location=actual_v1_metadata.location, + table_uuid=actual_v1_metadata.table_uuid, + last_updated_ms=actual_v1_metadata.last_updated_ms, + last_column_id=3, + schema=expected_schema, + schemas=[expected_schema], + current_schema_id=0, + last_partition_id=1000, + properties={"owner": "javaberg", "write.parquet.compression-codec": "zstd"}, + partition_spec=[], + partition_specs=[expected_spec], + default_spec_id=0, + current_snapshot_id=None, + snapshots=[], + snapshot_log=[], + metadata_log=[], + sort_orders=[SortOrder(order_id=0)], + default_sort_order_id=0, + refs={}, + format_version=1, + ) + + assert actual_v1_metadata.model_dump() == expected_v1_metadata.model_dump() + + def test_load_table(hive_table: HiveTable) -> None: catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL) diff --git a/tests/catalog/test_sql.py b/tests/catalog/test_sql.py index 1ca8fd16d2..1d127378be 100644 --- a/tests/catalog/test_sql.py +++ b/tests/catalog/test_sql.py @@ -38,6 +38,7 @@ ) from pyiceberg.io import FSSPEC_FILE_IO, PY_IO_IMPL from pyiceberg.io.pyarrow import schema_to_pyarrow +from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC from pyiceberg.schema import Schema from pyiceberg.table.snapshots import Operation from pyiceberg.table.sorting import ( @@ -158,6 +159,24 @@ def test_create_table_default_sort_order(catalog: SqlCatalog, table_schema_neste catalog.drop_table(random_identifier) +@pytest.mark.parametrize( + 'catalog', + [ + lazy_fixture('catalog_memory'), + lazy_fixture('catalog_sqlite'), + ], +) +def test_create_v1_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None: + database_name, _table_name = random_identifier + catalog.create_namespace(database_name) + table = catalog.create_table(random_identifier, table_schema_nested, properties={"format-version": "1"}) + assert table.sort_order().order_id == 0, "Order ID must match" + assert table.sort_order().is_unsorted is True, "Order must be unsorted" + assert table.format_version == 1 + assert table.spec() == UNPARTITIONED_PARTITION_SPEC + catalog.drop_table(random_identifier) + + @pytest.mark.parametrize( 'catalog', [ diff --git a/tests/table/test_metadata.py b/tests/table/test_metadata.py index 2c453b6e03..97a7931cbb 100644 --- a/tests/table/test_metadata.py +++ b/tests/table/test_metadata.py @@ -199,6 +199,75 @@ def test_migrate_v1_partition_specs(example_table_metadata_v1: Dict[str, Any]) - ] +def test_new_table_metadata_with_explicit_v1_format() -> None: + schema = Schema( + NestedField(field_id=10, name="foo", field_type=StringType(), required=False), + NestedField(field_id=22, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=33, name="baz", field_type=BooleanType(), required=False), + schema_id=10, + identifier_field_ids=[22], + ) + + partition_spec = PartitionSpec( + PartitionField(source_id=22, field_id=1022, transform=IdentityTransform(), name="bar"), spec_id=10 + ) + + sort_order = SortOrder( + SortField(source_id=10, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST), + order_id=10, + ) + + actual = new_table_metadata( + schema=schema, + partition_spec=partition_spec, + sort_order=sort_order, + location="s3://some_v1_location/", + properties={'format-version': "1"}, + ) + + expected_schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=0, + identifier_field_ids=[2], + ) + + expected_spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="bar")) + + expected = TableMetadataV1( + location="s3://some_v1_location/", + table_uuid=actual.table_uuid, + last_updated_ms=actual.last_updated_ms, + last_column_id=3, + schemas=[expected_schema], + schema_=expected_schema, + current_schema_id=0, + partition_spec=[field.model_dump() for field in expected_spec.fields], + partition_specs=[expected_spec], + default_spec_id=0, + last_partition_id=1000, + properties={}, + current_snapshot_id=None, + snapshots=[], + snapshot_log=[], + metadata_log=[], + sort_orders=[ + SortOrder( + SortField( + source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST + ), + order_id=1, + ) + ], + default_sort_order_id=1, + refs={}, + format_version=1, + ) + + assert actual.model_dump() == expected.model_dump() + + def test_invalid_format_version(example_table_metadata_v1: Dict[str, Any]) -> None: """Test the exception when trying to load an unknown version"""