diff --git a/.github/workflows/ibis-backends.yml b/.github/workflows/ibis-backends.yml index a8646280400c8..2fc9d7eb5cf4d 100644 --- a/.github/workflows/ibis-backends.yml +++ b/.github/workflows/ibis-backends.yml @@ -723,7 +723,7 @@ jobs: - name: install iceberg shell: bash if: matrix.pyspark-version == '3.5' - run: pushd "$(poetry run python -c "import pyspark; print(pyspark.__file__.rsplit('/', 1)[0])")/jars" && curl -LO https://search.maven.org/remotecontent?filepath=org/apache/iceberg/iceberg-spark-runtime-3.5_2.12/1.5.2/iceberg-spark-runtime-3.5_2.12-1.5.2.jar + run: just download-iceberg-jar - name: run tests run: just ci-check -m pyspark diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index 063fbc214d0a8..babbd124eeb66 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -11,13 +11,6 @@ import sqlglot.expressions as sge from packaging.version import parse as vparse from pyspark import SparkConf - -try: - from pyspark.errors.exceptions.base import AnalysisException # PySpark 3.5+ -except ImportError: - from pyspark.sql.utils import AnalysisException # PySpark 3.3 - - from pyspark.sql import SparkSession from pyspark.sql.types import BooleanType, DoubleType, LongType, StringType @@ -38,9 +31,9 @@ from ibis.util import deprecated try: - from pyspark.errors import ParseException as PySparkParseException + from pyspark.errors import AnalysisException, ParseException except ImportError: - from pyspark.sql.utils import ParseException as PySparkParseException + from pyspark.sql.utils import AnalysisException, ParseException if TYPE_CHECKING: from collections.abc import Mapping, Sequence @@ -53,8 +46,9 @@ from ibis.expr.api import Watermark -PYSPARK_LT_34 = vparse(pyspark.__version__) < vparse("3.4") -PYSPARK_LT_35 = vparse(pyspark.__version__) < vparse("3.5") +PYSPARK_VERSION = vparse(pyspark.__version__) +PYSPARK_LT_34 = PYSPARK_VERSION < vparse("3.4") +PYSPARK_LT_35 = PYSPARK_VERSION < vparse("3.5") ConnectionMode = Literal["streaming", "batch"] @@ -279,55 +273,89 @@ def _active_catalog_database(self, catalog: str | None, db: str | None): # # We attempt to use the Unity-specific Spark SQL to set CATALOG and DATABASE # and if that causes a parser exception we fall back to using the catalog API. + v = self.compiler.v + quoted = self.compiler.quoted + dialect = self.dialect + catalog_api = self._session.catalog + try: if catalog is not None: + catalog_sql = sge.Use( + kind=v.CATALOG, this=sg.to_identifier(catalog, quoted=quoted) + ).sql(dialect) + try: - catalog_sql = sg.to_identifier(catalog).sql(self.dialect) - self.raw_sql(f"USE CATALOG {catalog_sql}") - except PySparkParseException: - self._session.catalog.setCurrentCatalog(catalog) + self.raw_sql(catalog_sql) + except ParseException: + catalog_api.setCurrentCatalog(catalog) + + db_sql = sge.Use( + kind=v.DATABASE, this=sg.to_identifier(db, quoted=quoted) + ).sql(dialect) + try: - db_sql = sg.to_identifier(db).sql(self.dialect) - self.raw_sql(f"USE DATABASE {db_sql}") - except PySparkParseException: - self._session.catalog.setCurrentDatabase(db) + self.raw_sql(db_sql) + except ParseException: + catalog_api.setCurrentDatabase(db) yield finally: if catalog is not None: + catalog_sql = sge.Use( + kind=v.CATALOG, + this=sg.to_identifier(current_catalog, quoted=quoted), + ).sql(dialect) try: - catalog_sql = sg.to_identifier(current_catalog).sql(self.dialect) - self.raw_sql(f"USE CATALOG {catalog_sql}") - except PySparkParseException: - self._session.catalog.setCurrentCatalog(current_catalog) + self.raw_sql(catalog_sql) + except ParseException: + catalog_api.setCurrentCatalog(current_catalog) + + db_sql = sge.Use( + kind=v.DATABASE, this=sg.to_identifier(current_db, quoted=quoted) + ).sql(dialect) + try: - db_sql = sg.to_identifier(current_db).sql(self.dialect) - self.raw_sql(f"USE DATABASE {db_sql}") - except PySparkParseException: - self._session.catalog.setCurrentDatabase(current_db) + self.raw_sql(db_sql) + except ParseException: + catalog_api.setCurrentDatabase(current_db) @contextlib.contextmanager def _active_catalog(self, name: str | None): if name is None or PYSPARK_LT_34: yield return + prev_catalog = self.current_catalog prev_database = self.current_database + + v = self.compiler.v + quoted = self.compiler.quoted + dialect = self.dialect + + catalog_sql = sge.Use( + kind=v.CATALOG, this=sg.to_identifier(name, quoted=quoted) + ).sql(dialect) + catalog_api = self._session.catalog + try: try: - catalog_sql = sg.to_identifier(name).sql(self.dialect) - self.raw_sql(f"USE CATALOG {catalog_sql};") - except PySparkParseException: - self._session.catalog.setCurrentCatalog(name) + self.raw_sql(catalog_sql) + except ParseException: + catalog_api.setCurrentCatalog(name) yield finally: + catalog_sql = sge.Use( + kind=v.CATALOG, this=sg.to_identifier(prev_catalog, quoted=quoted) + ).sql(dialect) + db_sql = sge.Use( + kind=v.DATABASE, this=sg.to_identifier(prev_database, quoted=quoted) + ).sql(dialect) + try: - catalog_sql = sg.to_identifier(prev_catalog).sql(self.dialect) - db_sql = sg.to_identifier(prev_database).sql(self.dialect) - self.raw_sql(f"USE CATALOG {catalog_sql};") - self.raw_sql(f"USE DATABASE {db_sql};") - except PySparkParseException: - self._session.catalog.setCurrentCatalog(prev_catalog) - self._session.catalog.setCurrentDatabase(prev_database) + self.raw_sql(catalog_sql) + self.raw_sql(db_sql) + except ParseException: + catalog_api.setCurrentCatalog(prev_catalog) + catalog_api.setCurrentDatabase(prev_database) def list_catalogs(self, like: str | None = None) -> list[str]: catalogs = [res.catalog for res in self._session.sql("SHOW CATALOGS").collect()] @@ -491,7 +519,7 @@ def create_database( sql = sge.Create( kind="DATABASE", exist=force, - this=sg.to_identifier(name), + this=sg.to_identifier(name, quoted=self.compiler.quoted), properties=properties, ) with self._active_catalog(catalog): @@ -515,7 +543,10 @@ def drop_database( """ sql = sge.Drop( - kind="DATABASE", exist=force, this=sg.to_identifier(name), cascade=force + kind="DATABASE", + exist=force, + this=sg.to_identifier(name, quoted=self.compiler.quoted), + cascade=force, ) with self._active_catalog(catalog): with self._safe_raw_sql(sql): diff --git a/ibis/backends/pyspark/tests/conftest.py b/ibis/backends/pyspark/tests/conftest.py index 8c7a977d96530..19d3fddd8c5f9 100644 --- a/ibis/backends/pyspark/tests/conftest.py +++ b/ibis/backends/pyspark/tests/conftest.py @@ -149,36 +149,35 @@ def connect(*, tmpdir, worker_id, **kw): config = ( SparkSession.builder.appName("ibis_testing") .master("local[1]") - .config("spark.cores.max", 1) - .config("spark.default.parallelism", 1) - .config("spark.driver.extraJavaOptions", "-Duser.timezone=GMT") - .config("spark.dynamicAllocation.enabled", False) - .config("spark.executor.extraJavaOptions", "-Duser.timezone=GMT") - .config("spark.executor.heartbeatInterval", "3600s") - .config("spark.executor.instances", 1) - .config("spark.network.timeout", "4200s") - .config("spark.rdd.compress", False) - .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - .config("spark.shuffle.compress", False) - .config("spark.shuffle.spill.compress", False) - .config("spark.sql.legacy.timeParserPolicy", "LEGACY") - .config("spark.sql.session.timeZone", "UTC") - .config("spark.sql.shuffle.partitions", 1) - .config("spark.storage.blockManagerSlaveTimeoutMs", "4200s") - .config("spark.ui.enabled", False) - .config("spark.ui.showConsoleProgress", False) - .config("spark.sql.execution.arrow.pyspark.enabled", False) - .config("spark.sql.streaming.schemaInference", True) - ) - - config = ( - config.config( - "spark.sql.extensions", - "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", + .config( + map={ + "spark.cores.max": 1, + "spark.default.parallelism": 1, + "spark.driver.extraJavaOptions": "-Duser.timezone=GMT", + "spark.dynamicAllocation.enabled": False, + "spark.executor.extraJavaOptions": "-Duser.timezone=GMT", + "spark.executor.heartbeatInterval": "3600s", + "spark.executor.instances": 1, + "spark.jars.packages": "org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:1.5.2", + "spark.network.timeout": "4200s", + "spark.rdd.compress": False, + "spark.serializer": "org.apache.spark.serializer.KryoSerializer", + "spark.shuffle.compress": False, + "spark.shuffle.spill.compress": False, + "spark.sql.catalog.local": "org.apache.iceberg.spark.SparkCatalog", + "spark.sql.catalog.local.type": "hadoop", + "spark.sql.catalog.local.warehouse": "icehouse", + "spark.sql.execution.arrow.pyspark.enabled": False, + "spark.sql.extensions": "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions", + "spark.sql.legacy.timeParserPolicy": "LEGACY", + "spark.sql.session.timeZone": "UTC", + "spark.sql.shuffle.partitions": 1, + "spark.sql.streaming.schemaInference": True, + "spark.storage.blockManagerSlaveTimeoutMs": "4200s", + "spark.ui.enabled": False, + "spark.ui.showConsoleProgress": False, + } ) - .config("spark.sql.catalog.local", "org.apache.iceberg.spark.SparkCatalog") - .config("spark.sql.catalog.local.type", "hadoop") - .config("spark.sql.catalog.local.warehouse", "icehouse") ) try: diff --git a/ibis/backends/pyspark/tests/test_client.py b/ibis/backends/pyspark/tests/test_client.py index 6c977118cafd4..474938b880d9c 100644 --- a/ibis/backends/pyspark/tests/test_client.py +++ b/ibis/backends/pyspark/tests/test_client.py @@ -6,8 +6,7 @@ @pytest.mark.xfail_version(pyspark=["pyspark<3.4"], reason="no catalog support") -def test_catalog_db_args(con, monkeypatch): - monkeypatch.setattr(ibis.options, "default_backend", con) +def test_catalog_db_args(con): t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]}) assert con.current_catalog == "spark_catalog" @@ -40,8 +39,7 @@ def test_catalog_db_args(con, monkeypatch): assert con.current_database == "ibis_testing" -def test_create_table_no_catalog(con, monkeypatch): - monkeypatch.setattr(ibis.options, "default_backend", con) +def test_create_table_no_catalog(con): t = ibis.memtable({"epoch": [1712848119, 1712848121, 1712848155]}) assert con.current_database != "default" diff --git a/justfile b/justfile index e55754ef4d3f4..ba0d9aa59577f 100644 --- a/justfile +++ b/justfile @@ -135,6 +135,23 @@ download-data owner="ibis-project" repo="testing-data" rev="master": git -C "${outdir}" checkout "{{ rev }}" fi +# download the iceberg jar used for testing pyspark and iceberg integration +download-iceberg-jar pyspark="3.5" scala="2.12" iceberg="1.5.2": + #!/usr/bin/env bash + set -eo pipefail + + runner=(python) + + if [ -n "${CI}" ]; then + runner=(poetry run python) + fi + pyspark="$("${runner[@]}" -c "import pyspark; print(pyspark.__file__.rsplit('/', 1)[0])")" + pushd "${pyspark}/jars" + jar="iceberg-spark-runtime-{{ pyspark }}_{{ scala }}-{{ iceberg }}.jar" + url="https://search.maven.org/remotecontent?filepath=org/apache/iceberg/iceberg-spark-runtime-{{ pyspark }}_{{ scala }}/{{ iceberg }}/${jar}" + curl -qSsL -o "${jar}" "${url}" + ls "${jar}" + # start backends using docker compose; no arguments starts all backends up *backends: docker compose up --build --wait {{ backends }} diff --git a/poetry-overrides.nix b/poetry-overrides.nix index ee5829919c464..47dd28ab3d607 100644 --- a/poetry-overrides.nix +++ b/poetry-overrides.nix @@ -5,7 +5,7 @@ final: prev: { icebergJar = final.pkgs.fetchurl { name = "iceberg-spark-runtime-3.5_2.12-1.5.2.jar"; url = icebergJarUrl; - sha256 = "12v1704h0bq3qr2fci0mckg9171lyr8v6983wpa83k06v1w4pv1a"; + sha256 = "sha256-KuxLeNgGzIHU5QMls1H2NJyQ3mQVROZExgMvAAk4YYs="; }; in {