[SPARK-9910] [ML] User guide for train validation split
Author: martinzapletal <zapletal-martin@email.cz> Closes #8377 from zapletal-martin/SPARK-9910.
This commit is contained in:
parent
2a4e00ca4d
commit
e8ea5bafee
117
docs/ml-guide.md
117
docs/ml-guide.md
|
@ -872,3 +872,120 @@ jsc.stop();
|
|||
</div>
|
||||
|
||||
</div>
|
||||
|
||||
## Example: Model Selection via Train Validation Split
|
||||
In addition to `CrossValidator` Spark also offers `TrainValidationSplit` for hyper-parameter tuning.
|
||||
`TrainValidationSplit` only evaluates each combination of parameters once as opposed to k times in
|
||||
case of `CrossValidator`. It is therefore less expensive,
|
||||
but will not produce as reliable results when the training dataset is not sufficiently large..
|
||||
|
||||
`TrainValidationSplit` takes an `Estimator`, a set of `ParamMap`s provided in the `estimatorParamMaps` parameter,
|
||||
and an `Evaluator`.
|
||||
It begins by splitting the dataset into two parts using `trainRatio` parameter
|
||||
which are used as separate training and test datasets. For example with `$trainRatio=0.75$` (default),
|
||||
`TrainValidationSplit` will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation.
|
||||
Similar to `CrossValidator`, `TrainValidationSplit` also iterates through the set of `ParamMap`s.
|
||||
For each combination of parameters, it trains the given `Estimator` and evaluates it using the given `Evaluator`.
|
||||
The `ParamMap` which produces the best evaluation metric is selected as the best option.
|
||||
`TrainValidationSplit` finally fits the `Estimator` using the best `ParamMap` and the entire dataset.
|
||||
|
||||
<div class="codetabs">
|
||||
|
||||
<div data-lang="scala" markdown="1">
|
||||
{% highlight scala %}
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
||||
import org.apache.spark.ml.regression.LinearRegression
|
||||
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
|
||||
// Prepare training and test data.
|
||||
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
|
||||
val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
|
||||
|
||||
val lr = new LinearRegression()
|
||||
|
||||
// We use a ParamGridBuilder to construct a grid of parameters to search over.
|
||||
// TrainValidationSplit will try all combinations of values and determine best model using
|
||||
// the evaluator.
|
||||
val paramGrid = new ParamGridBuilder()
|
||||
.addGrid(lr.regParam, Array(0.1, 0.01))
|
||||
.addGrid(lr.fitIntercept, Array(true, false))
|
||||
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
|
||||
.build()
|
||||
|
||||
// In this case the estimator is simply the linear regression.
|
||||
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
|
||||
val trainValidationSplit = new TrainValidationSplit()
|
||||
.setEstimator(lr)
|
||||
.setEvaluator(new RegressionEvaluator)
|
||||
.setEstimatorParamMaps(paramGrid)
|
||||
|
||||
// 80% of the data will be used for training and the remaining 20% for validation.
|
||||
trainValidationSplit.setTrainRatio(0.8)
|
||||
|
||||
// Run train validation split, and choose the best set of parameters.
|
||||
val model = trainValidationSplit.fit(training)
|
||||
|
||||
// Make predictions on test data. model is the model with combination of parameters
|
||||
// that performed best.
|
||||
model.transform(test)
|
||||
.select("features", "label", "prediction")
|
||||
.show()
|
||||
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
<div data-lang="java" markdown="1">
|
||||
{% highlight java %}
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator;
|
||||
import org.apache.spark.ml.param.ParamMap;
|
||||
import org.apache.spark.ml.regression.LinearRegression;
|
||||
import org.apache.spark.ml.tuning.*;
|
||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||
import org.apache.spark.mllib.util.MLUtils;
|
||||
import org.apache.spark.rdd.RDD;
|
||||
import org.apache.spark.sql.DataFrame;
|
||||
|
||||
DataFrame data = jsql.createDataFrame(
|
||||
MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
|
||||
LabeledPoint.class);
|
||||
|
||||
// Prepare training and test data.
|
||||
DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
|
||||
DataFrame training = splits[0];
|
||||
DataFrame test = splits[1];
|
||||
|
||||
LinearRegression lr = new LinearRegression();
|
||||
|
||||
// We use a ParamGridBuilder to construct a grid of parameters to search over.
|
||||
// TrainValidationSplit will try all combinations of values and determine best model using
|
||||
// the evaluator.
|
||||
ParamMap[] paramGrid = new ParamGridBuilder()
|
||||
.addGrid(lr.regParam(), new double[] {0.1, 0.01})
|
||||
.addGrid(lr.fitIntercept())
|
||||
.addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
|
||||
.build();
|
||||
|
||||
// In this case the estimator is simply the linear regression.
|
||||
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
|
||||
TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
|
||||
.setEstimator(lr)
|
||||
.setEvaluator(new RegressionEvaluator())
|
||||
.setEstimatorParamMaps(paramGrid);
|
||||
|
||||
// 80% of the data will be used for training and the remaining 20% for validation.
|
||||
trainValidationSplit.setTrainRatio(0.8);
|
||||
|
||||
// Run train validation split, and choose the best set of parameters.
|
||||
TrainValidationSplitModel model = trainValidationSplit.fit(training);
|
||||
|
||||
// Make predictions on test data. model is the model with combination of parameters
|
||||
// that performed best.
|
||||
model.transform(test)
|
||||
.select("features", "label", "prediction")
|
||||
.show();
|
||||
|
||||
{% endhighlight %}
|
||||
</div>
|
||||
|
||||
</div>
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.ml;
|
||||
|
||||
import org.apache.spark.SparkConf;
|
||||
import org.apache.spark.api.java.JavaSparkContext;
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator;
|
||||
import org.apache.spark.ml.param.ParamMap;
|
||||
import org.apache.spark.ml.regression.LinearRegression;
|
||||
import org.apache.spark.ml.tuning.*;
|
||||
import org.apache.spark.mllib.regression.LabeledPoint;
|
||||
import org.apache.spark.mllib.util.MLUtils;
|
||||
import org.apache.spark.sql.DataFrame;
|
||||
import org.apache.spark.sql.SQLContext;
|
||||
|
||||
/**
|
||||
* A simple example demonstrating model selection using TrainValidationSplit.
|
||||
*
|
||||
* The example is based on {@link org.apache.spark.examples.ml.JavaSimpleParamsExample}
|
||||
* using linear regression.
|
||||
*
|
||||
* Run with
|
||||
* {{{
|
||||
* bin/run-example ml.JavaTrainValidationSplitExample
|
||||
* }}}
|
||||
*/
|
||||
public class JavaTrainValidationSplitExample {
|
||||
|
||||
public static void main(String[] args) {
|
||||
SparkConf conf = new SparkConf().setAppName("JavaTrainValidationSplitExample");
|
||||
JavaSparkContext jsc = new JavaSparkContext(conf);
|
||||
SQLContext jsql = new SQLContext(jsc);
|
||||
|
||||
DataFrame data = jsql.createDataFrame(
|
||||
MLUtils.loadLibSVMFile(jsc.sc(), "data/mllib/sample_libsvm_data.txt"),
|
||||
LabeledPoint.class);
|
||||
|
||||
// Prepare training and test data.
|
||||
DataFrame[] splits = data.randomSplit(new double [] {0.9, 0.1}, 12345);
|
||||
DataFrame training = splits[0];
|
||||
DataFrame test = splits[1];
|
||||
|
||||
LinearRegression lr = new LinearRegression();
|
||||
|
||||
// We use a ParamGridBuilder to construct a grid of parameters to search over.
|
||||
// TrainValidationSplit will try all combinations of values and determine best model using
|
||||
// the evaluator.
|
||||
ParamMap[] paramGrid = new ParamGridBuilder()
|
||||
.addGrid(lr.regParam(), new double[] {0.1, 0.01})
|
||||
.addGrid(lr.fitIntercept())
|
||||
.addGrid(lr.elasticNetParam(), new double[] {0.0, 0.5, 1.0})
|
||||
.build();
|
||||
|
||||
// In this case the estimator is simply the linear regression.
|
||||
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
|
||||
TrainValidationSplit trainValidationSplit = new TrainValidationSplit()
|
||||
.setEstimator(lr)
|
||||
.setEvaluator(new RegressionEvaluator())
|
||||
.setEstimatorParamMaps(paramGrid);
|
||||
|
||||
// 80% of the data will be used for training and the remaining 20% for validation.
|
||||
trainValidationSplit.setTrainRatio(0.8);
|
||||
|
||||
// Run train validation split, and choose the best set of parameters.
|
||||
TrainValidationSplitModel model = trainValidationSplit.fit(training);
|
||||
|
||||
// Make predictions on test data. model is the model with combination of parameters
|
||||
// that performed best.
|
||||
model.transform(test)
|
||||
.select("features", "label", "prediction")
|
||||
.show();
|
||||
|
||||
jsc.stop();
|
||||
}
|
||||
}
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.ml
|
||||
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
||||
import org.apache.spark.ml.regression.LinearRegression
|
||||
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
|
||||
import org.apache.spark.mllib.util.MLUtils
|
||||
import org.apache.spark.sql.SQLContext
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
|
||||
/**
|
||||
* A simple example demonstrating model selection using TrainValidationSplit.
|
||||
*
|
||||
* The example is based on [[SimpleParamsExample]] using linear regression.
|
||||
* Run with
|
||||
* {{{
|
||||
* bin/run-example ml.TrainValidationSplitExample
|
||||
* }}}
|
||||
*/
|
||||
object TrainValidationSplitExample {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val conf = new SparkConf().setAppName("TrainValidationSplitExample")
|
||||
val sc = new SparkContext(conf)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// Prepare training and test data.
|
||||
val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt").toDF()
|
||||
val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
|
||||
|
||||
val lr = new LinearRegression()
|
||||
|
||||
// We use a ParamGridBuilder to construct a grid of parameters to search over.
|
||||
// TrainValidationSplit will try all combinations of values and determine best model using
|
||||
// the evaluator.
|
||||
val paramGrid = new ParamGridBuilder()
|
||||
.addGrid(lr.regParam, Array(0.1, 0.01))
|
||||
.addGrid(lr.fitIntercept, Array(true, false))
|
||||
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
|
||||
.build()
|
||||
|
||||
// In this case the estimator is simply the linear regression.
|
||||
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
|
||||
val trainValidationSplit = new TrainValidationSplit()
|
||||
.setEstimator(lr)
|
||||
.setEvaluator(new RegressionEvaluator)
|
||||
.setEstimatorParamMaps(paramGrid)
|
||||
|
||||
// 80% of the data will be used for training and the remaining 20% for validation.
|
||||
trainValidationSplit.setTrainRatio(0.8)
|
||||
|
||||
// Run train validation split, and choose the best set of parameters.
|
||||
val model = trainValidationSplit.fit(training)
|
||||
|
||||
// Make predictions on test data. model is the model with combination of parameters
|
||||
// that performed best.
|
||||
model.transform(test)
|
||||
.select("features", "label", "prediction")
|
||||
.show()
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue