Skip to content

Commit

Permalink
fix saving pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
danielenricocahall committed Apr 21, 2021
1 parent 3e5a44a commit 682cce4
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 3 deletions.
4 changes: 3 additions & 1 deletion elephas/ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pyspark import keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.param.shared import HasOutputCol, HasFeaturesCol, HasLabelCol
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.sql import DataFrame
from pyspark.sql.types import DoubleType, StructField, ArrayType
from tensorflow.keras.models import model_from_yaml
Expand All @@ -23,7 +24,8 @@
class ElephasEstimator(Estimator, HasCategoricalLabels, HasValidationSplit, HasKerasModelConfig, HasFeaturesCol,
HasLabelCol, HasMode, HasEpochs, HasBatchSize, HasFrequency, HasVerbosity, HasNumberOfClasses,
HasNumberOfWorkers, HasOutputCol, HasLoss,
HasMetrics, HasKerasOptimizerConfig, HasCustomObjects):
HasMetrics, HasKerasOptimizerConfig, HasCustomObjects, DefaultParamsReadable,
DefaultParamsWritable):
"""
SparkML Estimator implementation of an elephas model. This estimator takes all relevant arguments for model
compilation and training.
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from setuptools import find_packages

setup(name='elephas',
version='1.4.3',
version='2.1.0',
description='Deep learning on Spark with Keras',
url='http://github.com/maxpumperla/elephas',
download_url='https://github.com/maxpumperla/elephas/tarball/1.4.3',
download_url='https://github.com/maxpumperla/elephas/tarball/2.1.0',
author='Daniel Cahall',
author_email='[email protected]',
install_requires=['cython',
Expand Down
22 changes: 22 additions & 0 deletions tests/test_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,3 +339,25 @@ def test_batch_predict_classes_probability(spark_context, classification_model,
assert len(results_np.prediction) == 10
assert len(results_np.prediction_via_batch_inference) == 10
assert np.array_equal(results_np.prediction, results_np.prediction_via_batch_inference)


def test_save_pipeline(spark_context, classification_model):
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
sgd_conf = optimizers.serialize(sgd)

# Initialize Spark ML Estimator
estimator = ElephasEstimator()
estimator.set_keras_model_config(classification_model.to_yaml())
estimator.set_optimizer_config(sgd_conf)
estimator.set_mode("synchronous")
estimator.set_loss("categorical_crossentropy")
estimator.set_metrics(['acc'])
estimator.set_epochs(10)
estimator.set_batch_size(10)
estimator.set_validation_split(0.1)
estimator.set_categorical_labels(True)
estimator.set_nb_classes(10)

# Fitting a model returns a Transformer
pipeline = Pipeline(stages=[estimator])
pipeline.save('tmp')

0 comments on commit 682cce4

Please sign in to comment.