From e56326d4a3ba74eb279bf57202a9121c4f2c731d Mon Sep 17 00:00:00 2001 From: Kevin Liu Date: Wed, 6 Mar 2024 16:35:24 -0800 Subject: [PATCH] Disable Spark Catalog caching for integration tests (#501) --- tests/conftest.py | 1 + tests/integration/test_writes.py | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index 21c036ec09..a005966ea5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1965,6 +1965,7 @@ def spark() -> SparkSession: .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions") .config("spark.sql.catalog.integration", "org.apache.iceberg.spark.SparkCatalog") .config("spark.sql.catalog.integration.catalog-impl", "org.apache.iceberg.rest.RESTCatalog") + .config("spark.sql.catalog.integration.cache-enabled", "false") .config("spark.sql.catalog.integration.uri", "http://localhost:8181") .config("spark.sql.catalog.integration.io-impl", "org.apache.iceberg.aws.s3.S3FileIO") .config("spark.sql.catalog.integration.warehouse", "s3://warehouse/wh/") diff --git a/tests/integration/test_writes.py b/tests/integration/test_writes.py index f0d1c85797..3b6d476b74 100644 --- a/tests/integration/test_writes.py +++ b/tests/integration/test_writes.py @@ -355,6 +355,28 @@ def test_data_files(spark: SparkSession, session_catalog: Catalog, arrow_table_w assert [row.deleted_data_files_count for row in rows] == [0, 0, 1, 0, 0] +@pytest.mark.integration +def test_python_writes_with_spark_snapshot_reads( + spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table +) -> None: + identifier = "default.python_writes_with_spark_snapshot_reads" + tbl = _create_table(session_catalog, identifier, {"format-version": "1"}, []) + + def get_current_snapshot_id(identifier: str) -> int: + return ( + spark.sql(f"SELECT snapshot_id FROM {identifier}.snapshots order by committed_at desc limit 1") + .collect()[0] + .snapshot_id + ) + + tbl.overwrite(arrow_table_with_null) + assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore + tbl.overwrite(arrow_table_with_null) + assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore + tbl.append(arrow_table_with_null) + assert tbl.current_snapshot().snapshot_id == get_current_snapshot_id(identifier) # type: ignore + + @pytest.mark.integration @pytest.mark.parametrize("format_version", [1, 2]) @pytest.mark.parametrize(