Skip to content

Latest commit

 

History

History
78 lines (62 loc) · 2.74 KB

spark-mllib-CrossValidator-example.adoc

File metadata and controls

78 lines (62 loc) · 2.74 KB

CrossValidator with Pipeline Example

Caution
FIXME The example below does NOT work. Being investigated.
Caution
FIXME Can k-means be crossvalidated? Does it make any sense? Does it only applies to supervised learning?
// Let's create a pipeline with transformers and estimator
import org.apache.spark.ml.feature._

val tok = new Tokenizer().setInputCol("text")

val hashTF = new HashingTF()
  .setInputCol(tok.getOutputCol)
  .setOutputCol("features")
  .setNumFeatures(10)

import org.apache.spark.ml.classification.RandomForestClassifier
val rfc = new RandomForestClassifier

import org.apache.spark.ml.Pipeline
val pipeline = new Pipeline()
  .setStages(Array(tok, hashTF, rfc))

// CAUTION: label must be double
// 0 = scientific text
// 1 = non-scientific text
val trainDS = Seq(
  (0L, "[science] hello world", 0d),
  (1L, "long text", 1d),
  (2L, "[science] hello all people", 0d),
  (3L, "[science] hello hello", 0d)).toDF("id", "text", "label").cache

// Check out the train dataset
// Values in label and prediction columns should be alike
val sampleModel = pipeline.fit(trainDS)
sampleModel
  .transform(trainDS)
  .select('text, 'label, 'features, 'prediction)
  .show(truncate = false)

+--------------------------+-----+--------------------------+----------+
|text                      |label|features                  |prediction|
+--------------------------+-----+--------------------------+----------+
|[science] hello world     |0.0  |(10,[0,8],[2.0,1.0])      |0.0       |
|long text                 |1.0  |(10,[4,9],[1.0,1.0])      |1.0       |
|[science] hello all people|0.0  |(10,[0,6,8],[1.0,1.0,2.0])|0.0       |
|[science] hello hello     |0.0  |(10,[0,8],[1.0,2.0])      |0.0       |
+--------------------------+-----+--------------------------+----------+

val input = Seq("Hello ScienCE").toDF("text")
sampleModel
  .transform(input)
  .select('text, 'rawPrediction, 'prediction)
  .show(truncate = false)

+-------------+--------------------------------------+----------+
|text         |rawPrediction                         |prediction|
+-------------+--------------------------------------+----------+
|Hello ScienCE|[12.666666666666668,7.333333333333333]|0.0       |
+-------------+--------------------------------------+----------+

import org.apache.spark.ml.tuning.ParamGridBuilder
val paramGrid = new ParamGridBuilder().build

import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
val binEval = new BinaryClassificationEvaluator

import org.apache.spark.ml.tuning.CrossValidator
val cv = new CrossValidator()
  .setEstimator(pipeline) // <-- pipeline is the estimator
  .setEvaluator(binEval)  // has to match the estimator
  .setEstimatorParamMaps(paramGrid)

// WARNING: It does not work!!!
val cvModel = cv.fit(trainDS)