From 04836484066c22628fbe53d422d82a1a8f14550e Mon Sep 17 00:00:00 2001 From: Aaron Z-L Date: Sun, 29 May 2022 17:07:56 +0100 Subject: [PATCH] add option to reuse existing spark session --- joblibspark/__init__.py | 7 ++-- joblibspark/backend.py | 22 +++++++----- test/test_spark.py | 75 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 92 insertions(+), 12 deletions(-) diff --git a/joblibspark/__init__.py b/joblibspark/__init__.py index 7553d85..cfd673c 100644 --- a/joblibspark/__init__.py +++ b/joblibspark/__init__.py @@ -21,13 +21,14 @@ __version__ = '0.5.1' -def register_spark(): +def register_spark(spark=None): """ - Register spark backend into joblib. + Register spark backend into joblib. The user can optionally supply an active SparkSession, + otherwise a new one is created by default. """ try: from .backend import register # pylint: disable=C0415 - register() + register(spark) except ImportError: msg = ("To use the spark.distributed backend you must install " "the pyspark and packages.\n\n") diff --git a/joblibspark/backend.py b/joblibspark/backend.py index dc7af28..828ac54 100644 --- a/joblibspark/backend.py +++ b/joblibspark/backend.py @@ -35,9 +35,10 @@ from pyspark.util import VersionUtils -def register(): +def register(spark=None): """ - Register joblib spark backend. + Register joblib spark backend. The user can optionally supply an active SparkSession, + otherwise a new one is created by default. """ try: import sklearn # pylint: disable=C0415 @@ -47,7 +48,7 @@ def register(): "make sklearn use spark backend.") except ImportError: pass - register_parallel_backend('spark', SparkDistributedBackend) + register_parallel_backend('spark', lambda: SparkDistributedBackend(spark=spark)) # pylint: disable=too-many-instance-attributes @@ -61,15 +62,20 @@ class SparkDistributedBackend(ParallelBackendBase, AutoBatchingMixin): by `SequentialBackend` """ - def __init__(self, **backend_args): + def __init__(self, spark=None, **backend_args): # pylint: disable=super-with-arguments super(SparkDistributedBackend, self).__init__(**backend_args) self._pool = None self._n_jobs = None - self._spark = SparkSession \ - .builder \ - .appName("JoblibSparkBackend") \ - .getOrCreate() + + if spark is None: + self._spark = SparkSession \ + .builder \ + .appName("JoblibSparkBackend") \ + .getOrCreate() + else: + self._spark = spark + self._spark_context = self._spark.sparkContext self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4()) self._spark_pinned_threads_enabled = isinstance( diff --git a/test/test_spark.py b/test/test_spark.py index 2e2d24e..f5eba0b 100644 --- a/test/test_spark.py +++ b/test/test_spark.py @@ -26,12 +26,18 @@ from joblibspark import register_spark +from pyspark.sql import SparkSession from sklearn.utils import parallel_backend from sklearn.model_selection import cross_val_score from sklearn import datasets from sklearn import svm -register_spark() + +@pytest.fixture(scope="session") +def existing_spark(): + spark = SparkSession.builder.appName("ExistingSession").getOrCreate() + yield spark + spark.stop() def inc(x): @@ -45,6 +51,8 @@ def slow_raise_value_error(condition, duration=0.05): def test_simple(): + register_spark() + with parallel_backend('spark') as (ba, _): seq = Parallel(n_jobs=5)(delayed(inc)(i) for i in range(10)) assert seq == [inc(i) for i in range(10)] @@ -55,6 +63,8 @@ def test_simple(): def test_sklearn_cv(): + register_spark() + iris = datasets.load_iris() clf = svm.SVC(kernel='linear', C=1) with parallel_backend('spark', n_jobs=3): @@ -79,6 +89,69 @@ def test_job_cancelling(): import tempfile import os + register_spark() + tmp_dir = tempfile.mkdtemp() + + def test_fn(x): + if x == 0: + # make the task-0 fail, then it will cause task 1/2/3 to be canceled. + raise RuntimeError() + else: + time.sleep(15) + # if the task finished successfully, it will write a flag file to tmp dir. + with open(os.path.join(tmp_dir, str(x)), 'w'): + pass + + with pytest.raises(Exception): + with parallel_backend('spark', n_jobs=2): + Parallel()(delayed(test_fn)(i) for i in range(2)) + + time.sleep(30) # wait until we can ensure all task finish or cancelled. + # assert all jobs was cancelled, no flag file will be written to tmp dir. + assert len(os.listdir(tmp_dir)) == 0 + + +def test_simple_reuse_spark(existing_spark): + register_spark(existing_spark) + + with parallel_backend('spark') as (ba, _): + seq = Parallel(n_jobs=5)(delayed(inc)(i) for i in range(10)) + assert seq == [inc(i) for i in range(10)] + + with pytest.raises(BaseException): + Parallel(n_jobs=5)(delayed(slow_raise_value_error)(i == 3) + for i in range(10)) + + +def test_sklearn_cv_reuse_spark(existing_spark): + register_spark(existing_spark) + + iris = datasets.load_iris() + clf = svm.SVC(kernel='linear', C=1) + with parallel_backend('spark', n_jobs=3): + scores = cross_val_score(clf, iris.data, iris.target, cv=5) + + expected = [0.97, 1.0, 0.97, 0.97, 1.0] + + for i in range(5): + assert(pytest.approx(scores[i], 0.01) == expected[i]) + + # test with default n_jobs=-1 + with parallel_backend('spark'): + scores = cross_val_score(clf, iris.data, iris.target, cv=5) + + for i in range(5): + assert(pytest.approx(scores[i], 0.01) == expected[i]) + + +def test_simple_reuse_spark(existing_spark): + register_spark(existing_spark) + + from joblib import Parallel, delayed + import time + import tempfile + import os + tmp_dir = tempfile.mkdtemp() def test_fn(x):