diff --git a/mkdocs/docs/configuration.md b/mkdocs/docs/configuration.md index f0eb56ebd5..1a4d96547f 100644 --- a/mkdocs/docs/configuration.md +++ b/mkdocs/docs/configuration.md @@ -263,3 +263,7 @@ catalog: # Concurrency PyIceberg uses multiple threads to parallelize operations. The number of workers can be configured by supplying a `max-workers` entry in the configuration file, or by setting the `PYICEBERG_MAX_WORKERS` environment variable. The default value depends on the system hardware and Python version. See [the Python documentation](https://docs.python.org/3/library/concurrent.futures.html#threadpoolexecutor) for more details. + +# Backward Compatibility + +Previous versions of Java (`<1.4.0`) implementations incorrectly assume the optional attribute `current-snapshot-id` to be a required attribute in TableMetadata. This means that if `current-snapshot-id` is missing in the metadata file (e.g. on table creation), the application will throw an exception without being able to load the table. This assumption has been corrected in more recent Iceberg versions. However, it is possible to force PyIceberg to create a table with a metadata file that will be compatible with previous versions. This can be configured by setting the `legacy-current-snapshot-id` entry as "True" in the configuration file, or by setting the `LEGACY_CURRENT_SNAPSHOT_ID` environment variable. Refer to the [PR discussion](https://github.com/apache/iceberg-python/pull/473) for more details on the issue diff --git a/pyiceberg/serializers.py b/pyiceberg/serializers.py index 6a580ead80..e2994884c6 100644 --- a/pyiceberg/serializers.py +++ b/pyiceberg/serializers.py @@ -24,6 +24,7 @@ from pyiceberg.io import InputFile, InputStream, OutputFile from pyiceberg.table.metadata import TableMetadata, TableMetadataUtil from pyiceberg.typedef import UTF8 +from pyiceberg.utils.config import Config GZIP = "gzip" @@ -127,6 +128,9 @@ def table_metadata(metadata: TableMetadata, output_file: OutputFile, overwrite: overwrite (bool): Where to overwrite the file if it already exists. Defaults to `False`. """ with output_file.create(overwrite=overwrite) as output_stream: - json_bytes = metadata.model_dump_json().encode(UTF8) + # We need to serialize None values, in order to dump `None` current-snapshot-id as `-1` + exclude_none = False if Config().get_bool("legacy-current-snapshot-id") else True + + json_bytes = metadata.model_dump_json(exclude_none=exclude_none).encode(UTF8) json_bytes = Compressor.get_compressor(output_file.location).bytes_compressor()(json_bytes) output_stream.write(json_bytes) diff --git a/pyiceberg/table/metadata.py b/pyiceberg/table/metadata.py index 931b0cfe0a..1e5f0fdcec 100644 --- a/pyiceberg/table/metadata.py +++ b/pyiceberg/table/metadata.py @@ -28,7 +28,7 @@ Union, ) -from pydantic import Field, field_validator, model_validator +from pydantic import Field, field_serializer, field_validator, model_validator from pydantic import ValidationError as PydanticValidationError from typing_extensions import Annotated @@ -50,6 +50,7 @@ Properties, ) from pyiceberg.types import transform_dict_value_to_str +from pyiceberg.utils.config import Config from pyiceberg.utils.datetime import datetime_to_millis CURRENT_SNAPSHOT_ID = "current-snapshot-id" @@ -263,6 +264,12 @@ def sort_order_by_id(self, sort_order_id: int) -> Optional[SortOrder]: """Get the sort order by sort_order_id.""" return next((sort_order for sort_order in self.sort_orders if sort_order.order_id == sort_order_id), None) + @field_serializer('current_snapshot_id') + def serialize_current_snapshot_id(self, current_snapshot_id: Optional[int]) -> Optional[int]: + if current_snapshot_id is None and Config().get_bool("legacy-current-snapshot-id"): + return -1 + return current_snapshot_id + def _generate_snapshot_id() -> int: """Generate a new Snapshot ID from a UUID. diff --git a/pyiceberg/utils/concurrent.py b/pyiceberg/utils/concurrent.py index f6c0a23a9c..805599bf41 100644 --- a/pyiceberg/utils/concurrent.py +++ b/pyiceberg/utils/concurrent.py @@ -37,13 +37,4 @@ def get_or_create() -> Executor: @staticmethod def max_workers() -> Optional[int]: """Return the max number of workers configured.""" - config = Config() - val = config.config.get("max-workers") - - if val is None: - return None - - try: - return int(val) # type: ignore - except ValueError as err: - raise ValueError(f"Max workers should be an integer or left unset. Current value: {val}") from err + return Config().get_int("max-workers") diff --git a/pyiceberg/utils/config.py b/pyiceberg/utils/config.py index e038005469..8b1b81d3a7 100644 --- a/pyiceberg/utils/config.py +++ b/pyiceberg/utils/config.py @@ -16,6 +16,7 @@ # under the License. import logging import os +from distutils.util import strtobool from typing import List, Optional import strictyaml @@ -154,3 +155,19 @@ def get_catalog_config(self, catalog_name: str) -> Optional[RecursiveDict]: assert isinstance(catalog_conf, dict), f"Configuration path catalogs.{catalog_name_lower} needs to be an object" return catalog_conf return None + + def get_int(self, key: str) -> Optional[int]: + if (val := self.config.get(key)) is not None: + try: + return int(val) # type: ignore + except ValueError as err: + raise ValueError(f"{key} should be an integer or left unset. Current value: {val}") from err + return None + + def get_bool(self, key: str) -> Optional[bool]: + if (val := self.config.get(key)) is not None: + try: + return strtobool(val) # type: ignore + except ValueError as err: + raise ValueError(f"{key} should be a boolean or left unset. Current value: {val}") from err + return None diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py index 2b851e14e9..f0d1c85797 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes.py @@ -32,7 +32,7 @@ from pyspark.sql import SparkSession from pytest_mock.plugin import MockerFixture -from pyiceberg.catalog import Catalog, Properties, Table, load_catalog +from pyiceberg.catalog import Catalog, Properties, Table from pyiceberg.catalog.sql import SqlCatalog from pyiceberg.exceptions import NoSuchTableError from pyiceberg.schema import Schema diff --git a/tests/test_serializers.py b/tests/test_serializers.py new file mode 100644 index 0000000000..140db02700 --- /dev/null +++ b/tests/test_serializers.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import json +import os +import uuid +from typing import Any, Dict + +import pytest +from pytest_mock import MockFixture + +from pyiceberg.serializers import ToOutputFile +from pyiceberg.table import StaticTable +from pyiceberg.table.metadata import TableMetadataV1 + + +def test_legacy_current_snapshot_id( + mocker: MockFixture, tmp_path_factory: pytest.TempPathFactory, example_table_metadata_no_snapshot_v1: Dict[str, Any] +) -> None: + from pyiceberg.io.pyarrow import PyArrowFileIO + + metadata_location = str(tmp_path_factory.mktemp("metadata") / f"{uuid.uuid4()}.metadata.json") + metadata = TableMetadataV1(**example_table_metadata_no_snapshot_v1) + ToOutputFile.table_metadata(metadata, PyArrowFileIO().new_output(location=metadata_location), overwrite=True) + static_table = StaticTable.from_metadata(metadata_location) + assert static_table.metadata.current_snapshot_id is None + + mocker.patch.dict(os.environ, values={"PYICEBERG_LEGACY_CURRENT_SNAPSHOT_ID": "True"}) + + ToOutputFile.table_metadata(metadata, PyArrowFileIO().new_output(location=metadata_location), overwrite=True) + with PyArrowFileIO().new_input(location=metadata_location).open() as input_stream: + metadata_json_bytes = input_stream.read() + assert json.loads(metadata_json_bytes)['current-snapshot-id'] == -1 + backwards_compatible_static_table = StaticTable.from_metadata(metadata_location) + assert backwards_compatible_static_table.metadata.current_snapshot_id is None + assert backwards_compatible_static_table.metadata == static_table.metadata diff --git a/tests/utils/test_config.py b/tests/utils/test_config.py index 5e3f72ccc6..2f15bb56d8 100644 --- a/tests/utils/test_config.py +++ b/tests/utils/test_config.py @@ -76,3 +76,20 @@ def test_merge_config() -> None: rhs: RecursiveDict = {"common_key": "xyz789"} result = merge_config(lhs, rhs) assert result["common_key"] == rhs["common_key"] + + +def test_from_configuration_files_get_typed_value(tmp_path_factory: pytest.TempPathFactory) -> None: + config_path = str(tmp_path_factory.mktemp("config")) + with open(f"{config_path}/.pyiceberg.yaml", "w", encoding=UTF8) as file: + yaml_str = as_document({"max-workers": "4", "legacy-current-snapshot-id": "True"}).as_yaml() + file.write(yaml_str) + + os.environ["PYICEBERG_HOME"] = config_path + with pytest.raises(ValueError): + Config().get_bool("max-workers") + + with pytest.raises(ValueError): + Config().get_int("legacy-current-snapshot-id") + + assert Config().get_bool("legacy-current-snapshot-id") + assert Config().get_int("max-workers") == 4