Skip to content

Latest commit

 

History

History
64 lines (48 loc) · 1.76 KB

spark-mllib-Estimator.adoc

File metadata and controls

64 lines (48 loc) · 1.76 KB

Estimator

Estimator is the contract in Spark MLlib for estimators that fit models to a dataset.

Estimator accepts parameters that you can set through dedicated setter methods upon creating an Estimator. You could also fit a model with extra parameters.

import org.apache.spark.ml.classification.LogisticRegression

// Define parameters upon creating an Estimator
val lr = new LogisticRegression().
  setMaxIter(5).
  setRegParam(0.01)
val training: DataFrame = ...
val model1 = lr.fit(training)

// Define parameters through fit
import org.apache.spark.ml.param.ParamMap
val customParams = ParamMap(
  lr.maxIter -> 10,
  lr.featuresCol -> "custom_features"
)
val model2 = lr.fit(training, customParams)

Estimator is a PipelineStage and so can be a part of a Pipeline.

Estimator Contract

package org.apache.spark.ml

abstract class Estimator[M <: Model[M]] {
  // only required methods that have no implementation
  def fit(dataset: Dataset[_]): M
  def copy(extra: ParamMap): Estimator[M]
}
Table 1. Estimator Contract
Method Description

copy

Used when…​

fit

Used when…​

Fitting Model with Extra Parameters — fit Method

fit(dataset: Dataset[_], paramMap: ParamMap): M

fit copies the extra paramMap and fits a model (of type M).

Note
fit is used mainly for model tuning to find the best model (using CrossValidator and TrainValidationSplit).