Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix setting V1 format version for Non-REST catalogs #411

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class TableProperties:
METRICS_MODE_COLUMN_CONF_PREFIX = "write.metadata.metrics.column"

DEFAULT_NAME_MAPPING = "schema.name-mapping.default"
FORMAT_VERSION = "format-version"
DEFAULT_FORMAT_VERSION = 2


class PropertyUtil:
Expand Down
28 changes: 25 additions & 3 deletions pyiceberg/table/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,10 @@ def set_v2_compatible_defaults(cls, data: Dict[str, Any]) -> Dict[str, Any]:
The TableMetadata with the defaults applied.
"""
# When the schema doesn't have an ID
if data.get("schema") and "schema_id" not in data["schema"]:
data["schema"]["schema_id"] = DEFAULT_SCHEMA_ID
schema = data.get("schema")
if isinstance(schema, dict):
if "schema_id" not in schema and "schema-id" not in schema:
schema["schema_id"] = DEFAULT_SCHEMA_ID

return data

Expand Down Expand Up @@ -335,7 +337,7 @@ def to_v2(self) -> TableMetadataV2:
metadata["format-version"] = 2
return TableMetadataV2.model_validate(metadata)

format_version: Literal[1] = Field(alias="format-version")
format_version: Literal[1] = Field(alias="format-version", default=1)
"""An integer version number for the format. Currently, this can be 1 or 2
based on the spec. Implementations must throw an exception if a table’s
version is higher than the supported version."""
Expand Down Expand Up @@ -404,13 +406,33 @@ def new_table_metadata(
properties: Properties = EMPTY_DICT,
table_uuid: Optional[uuid.UUID] = None,
) -> TableMetadata:
from pyiceberg.table import TableProperties

fresh_schema = assign_fresh_schema_ids(schema)
fresh_partition_spec = assign_fresh_partition_spec_ids(partition_spec, schema, fresh_schema)
fresh_sort_order = assign_fresh_sort_order_ids(sort_order, schema, fresh_schema)

if table_uuid is None:
table_uuid = uuid.uuid4()

# Remove format-version so it does not get persisted
format_version = int(properties.pop(TableProperties.FORMAT_VERSION, TableProperties.DEFAULT_FORMAT_VERSION))
if format_version == 1:
return TableMetadataV1(
location=location,
last_column_id=fresh_schema.highest_field_id,
current_schema_id=fresh_schema.schema_id,
schema=fresh_schema,
partition_spec=[field.model_dump() for field in fresh_partition_spec.fields],
partition_specs=[fresh_partition_spec],
default_spec_id=fresh_partition_spec.spec_id,
sort_orders=[fresh_sort_order],
default_sort_order_id=fresh_sort_order.order_id,
properties=properties,
last_partition_id=fresh_partition_spec.last_assigned_field_id,
table_uuid=table_uuid,
)

return TableMetadataV2(
location=location,
schemas=[fresh_schema],
Expand Down
32 changes: 32 additions & 0 deletions tests/catalog/test_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,38 @@ def test_create_table_with_database_location(
assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"


@mock_aws
def test_create_v1_table(
_bucket_initialize: None,
_glue: boto3.client,
moto_endpoint_url: str,
table_schema_nested: Schema,
database_name: str,
table_name: str,
) -> None:
catalog_name = "glue"
test_catalog = GlueCatalog(catalog_name, **{"s3.endpoint": moto_endpoint_url})
test_catalog.create_namespace(namespace=database_name, properties={"location": f"s3://{BUCKET_NAME}/{database_name}.db"})
table = test_catalog.create_table((database_name, table_name), table_schema_nested, properties={"format-version": "1"})
assert table.format_version == 1

table_info = _glue.get_table(
DatabaseName=database_name,
Name=table_name,
)

storage_descriptor = table_info["Table"]["StorageDescriptor"]
columns = storage_descriptor["Columns"]
assert len(columns) == len(table_schema_nested.fields)
assert columns[0] == {
"Name": "foo",
"Type": "string",
"Parameters": {"iceberg.field.id": "1", "iceberg.field.optional": "true", "iceberg.field.current": "true"},
}

assert storage_descriptor["Location"] == f"s3://{BUCKET_NAME}/{database_name}.db/{table_name}"


@mock_aws
def test_create_table_with_default_warehouse(
_bucket_initialize: None, moto_endpoint_url: str, table_schema_nested: Schema, database_name: str, table_name: str
Expand Down
55 changes: 54 additions & 1 deletion tests/catalog/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from pyiceberg.partitioning import PartitionField, PartitionSpec
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV2
from pyiceberg.table.metadata import TableMetadataUtil, TableMetadataV1, TableMetadataV2
from pyiceberg.table.refs import SnapshotRef, SnapshotRefType
from pyiceberg.table.snapshots import (
MetadataLogEntry,
Expand Down Expand Up @@ -295,6 +295,59 @@ def test_create_table(table_schema_simple: Schema, hive_database: HiveDatabase,
assert metadata.model_dump() == expected.model_dump()


@patch("time.time", MagicMock(return_value=12345))
def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabase, hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

catalog._client = MagicMock()
catalog._client.__enter__().create_table.return_value = None
catalog._client.__enter__().get_table.return_value = hive_table
catalog._client.__enter__().get_database.return_value = hive_database
catalog.create_table(
("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"}
)

# Test creating V1 table
called_v1_table: HiveTable = catalog._client.__enter__().create_table.call_args[0][0]
metadata_location = called_v1_table.parameters["metadata_location"]
with open(metadata_location, encoding=UTF8) as f:
payload = f.read()

expected_schema = Schema(
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
schema_id=0,
identifier_field_ids=[2],
)
actual_v1_metadata = TableMetadataUtil.parse_raw(payload)
expected_spec = PartitionSpec()
expected_v1_metadata = TableMetadataV1(
location=actual_v1_metadata.location,
table_uuid=actual_v1_metadata.table_uuid,
last_updated_ms=actual_v1_metadata.last_updated_ms,
last_column_id=3,
schema=expected_schema,
schemas=[expected_schema],
current_schema_id=0,
last_partition_id=1000,
properties={"owner": "javaberg", "write.parquet.compression-codec": "zstd"},
partition_spec=[],
partition_specs=[expected_spec],
default_spec_id=0,
current_snapshot_id=None,
snapshots=[],
snapshot_log=[],
metadata_log=[],
sort_orders=[SortOrder(order_id=0)],
default_sort_order_id=0,
refs={},
format_version=1,
)

assert actual_v1_metadata.model_dump() == expected_v1_metadata.model_dump()


def test_load_table(hive_table: HiveTable) -> None:
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)

Expand Down
19 changes: 19 additions & 0 deletions tests/catalog/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from pyiceberg.io import FSSPEC_FILE_IO, PY_IO_IMPL
from pyiceberg.io.pyarrow import schema_to_pyarrow
from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC
from pyiceberg.schema import Schema
from pyiceberg.table.snapshots import Operation
from pyiceberg.table.sorting import (
Expand Down Expand Up @@ -158,6 +159,24 @@ def test_create_table_default_sort_order(catalog: SqlCatalog, table_schema_neste
catalog.drop_table(random_identifier)


@pytest.mark.parametrize(
'catalog',
[
lazy_fixture('catalog_memory'),
lazy_fixture('catalog_sqlite'),
],
)
def test_create_v1_table(catalog: SqlCatalog, table_schema_nested: Schema, random_identifier: Identifier) -> None:
database_name, _table_name = random_identifier
catalog.create_namespace(database_name)
table = catalog.create_table(random_identifier, table_schema_nested, properties={"format-version": "1"})
assert table.sort_order().order_id == 0, "Order ID must match"
assert table.sort_order().is_unsorted is True, "Order must be unsorted"
assert table.format_version == 1
assert table.spec() == UNPARTITIONED_PARTITION_SPEC
catalog.drop_table(random_identifier)


@pytest.mark.parametrize(
'catalog',
[
Expand Down
69 changes: 69 additions & 0 deletions tests/table/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,75 @@ def test_migrate_v1_partition_specs(example_table_metadata_v1: Dict[str, Any]) -
]


def test_new_table_metadata_with_explicit_v1_format() -> None:
schema = Schema(
NestedField(field_id=10, name="foo", field_type=StringType(), required=False),
NestedField(field_id=22, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=33, name="baz", field_type=BooleanType(), required=False),
schema_id=10,
identifier_field_ids=[22],
)

partition_spec = PartitionSpec(
PartitionField(source_id=22, field_id=1022, transform=IdentityTransform(), name="bar"), spec_id=10
)

sort_order = SortOrder(
SortField(source_id=10, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST),
order_id=10,
)

actual = new_table_metadata(
schema=schema,
partition_spec=partition_spec,
sort_order=sort_order,
location="s3://some_v1_location/",
properties={'format-version': "1"},
)

expected_schema = Schema(
NestedField(field_id=1, name="foo", field_type=StringType(), required=False),
NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True),
NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False),
schema_id=0,
identifier_field_ids=[2],
)

expected_spec = PartitionSpec(PartitionField(source_id=2, field_id=1000, transform=IdentityTransform(), name="bar"))

expected = TableMetadataV1(
location="s3://some_v1_location/",
table_uuid=actual.table_uuid,
last_updated_ms=actual.last_updated_ms,
last_column_id=3,
schemas=[expected_schema],
schema_=expected_schema,
current_schema_id=0,
partition_spec=[field.model_dump() for field in expected_spec.fields],
partition_specs=[expected_spec],
default_spec_id=0,
last_partition_id=1000,
properties={},
current_snapshot_id=None,
snapshots=[],
snapshot_log=[],
metadata_log=[],
sort_orders=[
SortOrder(
SortField(
source_id=1, transform=IdentityTransform(), direction=SortDirection.ASC, null_order=NullOrder.NULLS_LAST
),
order_id=1,
)
],
default_sort_order_id=1,
refs={},
format_version=1,
)

assert actual.model_dump() == expected.model_dump()


def test_invalid_format_version(example_table_metadata_v1: Dict[str, Any]) -> None:
"""Test the exception when trying to load an unknown version"""

Expand Down