Skip to content

Commit

Permalink
add option to reuse existing spark session
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaron Z-L authored and aaronzo committed Mar 31, 2023
1 parent 7d52abb commit 0483648
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 12 deletions.
7 changes: 4 additions & 3 deletions joblibspark/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@
__version__ = '0.5.1'


def register_spark():
def register_spark(spark=None):
"""
Register spark backend into joblib.
Register spark backend into joblib. The user can optionally supply an active SparkSession,
otherwise a new one is created by default.
"""
try:
from .backend import register # pylint: disable=C0415
register()
register(spark)
except ImportError:
msg = ("To use the spark.distributed backend you must install "
"the pyspark and packages.\n\n")
Expand Down
22 changes: 14 additions & 8 deletions joblibspark/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
from pyspark.util import VersionUtils


def register():
def register(spark=None):
"""
Register joblib spark backend.
Register joblib spark backend. The user can optionally supply an active SparkSession,
otherwise a new one is created by default.
"""
try:
import sklearn # pylint: disable=C0415
Expand All @@ -47,7 +48,7 @@ def register():
"make sklearn use spark backend.")
except ImportError:
pass
register_parallel_backend('spark', SparkDistributedBackend)
register_parallel_backend('spark', lambda: SparkDistributedBackend(spark=spark))


# pylint: disable=too-many-instance-attributes
Expand All @@ -61,15 +62,20 @@ class SparkDistributedBackend(ParallelBackendBase, AutoBatchingMixin):
by `SequentialBackend`
"""

def __init__(self, **backend_args):
def __init__(self, spark=None, **backend_args):
# pylint: disable=super-with-arguments
super(SparkDistributedBackend, self).__init__(**backend_args)
self._pool = None
self._n_jobs = None
self._spark = SparkSession \
.builder \
.appName("JoblibSparkBackend") \
.getOrCreate()

if spark is None:
self._spark = SparkSession \
.builder \
.appName("JoblibSparkBackend") \
.getOrCreate()
else:
self._spark = spark

self._spark_context = self._spark.sparkContext
self._job_group = "joblib-spark-job-group-" + str(uuid.uuid4())
self._spark_pinned_threads_enabled = isinstance(
Expand Down
75 changes: 74 additions & 1 deletion test/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,18 @@

from joblibspark import register_spark

from pyspark.sql import SparkSession
from sklearn.utils import parallel_backend
from sklearn.model_selection import cross_val_score
from sklearn import datasets
from sklearn import svm

register_spark()

@pytest.fixture(scope="session")
def existing_spark():
spark = SparkSession.builder.appName("ExistingSession").getOrCreate()
yield spark
spark.stop()


def inc(x):
Expand All @@ -45,6 +51,8 @@ def slow_raise_value_error(condition, duration=0.05):


def test_simple():
register_spark()

with parallel_backend('spark') as (ba, _):
seq = Parallel(n_jobs=5)(delayed(inc)(i) for i in range(10))
assert seq == [inc(i) for i in range(10)]
Expand All @@ -55,6 +63,8 @@ def test_simple():


def test_sklearn_cv():
register_spark()

iris = datasets.load_iris()
clf = svm.SVC(kernel='linear', C=1)
with parallel_backend('spark', n_jobs=3):
Expand All @@ -79,6 +89,69 @@ def test_job_cancelling():
import tempfile
import os

register_spark()
tmp_dir = tempfile.mkdtemp()

def test_fn(x):
if x == 0:
# make the task-0 fail, then it will cause task 1/2/3 to be canceled.
raise RuntimeError()
else:
time.sleep(15)
# if the task finished successfully, it will write a flag file to tmp dir.
with open(os.path.join(tmp_dir, str(x)), 'w'):
pass

with pytest.raises(Exception):
with parallel_backend('spark', n_jobs=2):
Parallel()(delayed(test_fn)(i) for i in range(2))

time.sleep(30) # wait until we can ensure all task finish or cancelled.
# assert all jobs was cancelled, no flag file will be written to tmp dir.
assert len(os.listdir(tmp_dir)) == 0


def test_simple_reuse_spark(existing_spark):
register_spark(existing_spark)

with parallel_backend('spark') as (ba, _):
seq = Parallel(n_jobs=5)(delayed(inc)(i) for i in range(10))
assert seq == [inc(i) for i in range(10)]

with pytest.raises(BaseException):
Parallel(n_jobs=5)(delayed(slow_raise_value_error)(i == 3)
for i in range(10))


def test_sklearn_cv_reuse_spark(existing_spark):
register_spark(existing_spark)

iris = datasets.load_iris()
clf = svm.SVC(kernel='linear', C=1)
with parallel_backend('spark', n_jobs=3):
scores = cross_val_score(clf, iris.data, iris.target, cv=5)

expected = [0.97, 1.0, 0.97, 0.97, 1.0]

for i in range(5):
assert(pytest.approx(scores[i], 0.01) == expected[i])

# test with default n_jobs=-1
with parallel_backend('spark'):
scores = cross_val_score(clf, iris.data, iris.target, cv=5)

for i in range(5):
assert(pytest.approx(scores[i], 0.01) == expected[i])


def test_simple_reuse_spark(existing_spark):
register_spark(existing_spark)

from joblib import Parallel, delayed
import time
import tempfile
import os

tmp_dir = tempfile.mkdtemp()

def test_fn(x):
Expand Down

0 comments on commit 0483648

Please sign in to comment.