Estimator
is the contract in Spark MLlib for estimators that fit models to a dataset.
Estimator
accepts parameters that you can set through dedicated setter methods upon creating an Estimator
. You could also fit a model with extra parameters.
import org.apache.spark.ml.classification.LogisticRegression
// Define parameters upon creating an Estimator
val lr = new LogisticRegression().
setMaxIter(5).
setRegParam(0.01)
val training: DataFrame = ...
val model1 = lr.fit(training)
// Define parameters through fit
import org.apache.spark.ml.param.ParamMap
val customParams = ParamMap(
lr.maxIter -> 10,
lr.featuresCol -> "custom_features"
)
val model2 = lr.fit(training, customParams)
Estimator
is a PipelineStage and so can be a part of a Pipeline.
package org.apache.spark.ml
abstract class Estimator[M <: Model[M]] {
// only required methods that have no implementation
def fit(dataset: Dataset[_]): M
def copy(extra: ParamMap): Estimator[M]
}
Method | Description |
---|---|
Used when… |
|
Used when… |
fit(dataset: Dataset[_], paramMap: ParamMap): M
fit
copies the extra paramMap
and fits a model (of type M
).
Note
|
fit is used mainly for model tuning to find the best model (using CrossValidator and TrainValidationSplit).
|