[SPARK-14299][EXAMPLES] Remove duplications for scala.examples.ml
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14299 Delete duplications in scala/examples/ml. TrainValidationSplitExample.scala --> ModelSelectionViaTrainValidationSplitExample CrossValidatorExample.scala --> ModelSelectionViaCrossValidationExample ## How was this patch tested? Existing tests passed. (If this patch involves UI changes, please attach a screenshot; otherwise, remove this) Author: Xusen Yin <yinxusen@gmail.com> Closes #12366 from yinxusen/SPARK-14299-2.
This commit is contained in:
parent
f31a62d1b2
commit
8c62edb70f
|
@ -1,114 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// scalastyle:off println
|
||||
package org.apache.spark.examples.ml
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.ml.Pipeline
|
||||
import org.apache.spark.ml.classification.LogisticRegression
|
||||
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
|
||||
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
|
||||
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
|
||||
import org.apache.spark.mllib.linalg.Vector
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
|
||||
/**
|
||||
* A simple example demonstrating model selection using CrossValidator.
|
||||
* This example also demonstrates how Pipelines are Estimators.
|
||||
*
|
||||
* This example uses the [[LabeledDocument]] and [[Document]] case classes from
|
||||
* [[SimpleTextClassificationPipeline]].
|
||||
*
|
||||
* Run with
|
||||
* {{{
|
||||
* bin/run-example ml.CrossValidatorExample
|
||||
* }}}
|
||||
*/
|
||||
object CrossValidatorExample {
|
||||
|
||||
def main(args: Array[String]) {
|
||||
val conf = new SparkConf().setAppName("CrossValidatorExample")
|
||||
val sc = new SparkContext(conf)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
import sqlContext.implicits._
|
||||
|
||||
// Prepare training documents, which are labeled.
|
||||
val training = sc.parallelize(Seq(
|
||||
LabeledDocument(0L, "a b c d e spark", 1.0),
|
||||
LabeledDocument(1L, "b d", 0.0),
|
||||
LabeledDocument(2L, "spark f g h", 1.0),
|
||||
LabeledDocument(3L, "hadoop mapreduce", 0.0),
|
||||
LabeledDocument(4L, "b spark who", 1.0),
|
||||
LabeledDocument(5L, "g d a y", 0.0),
|
||||
LabeledDocument(6L, "spark fly", 1.0),
|
||||
LabeledDocument(7L, "was mapreduce", 0.0),
|
||||
LabeledDocument(8L, "e spark program", 1.0),
|
||||
LabeledDocument(9L, "a e c l", 0.0),
|
||||
LabeledDocument(10L, "spark compile", 1.0),
|
||||
LabeledDocument(11L, "hadoop software", 0.0)))
|
||||
|
||||
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
|
||||
val tokenizer = new Tokenizer()
|
||||
.setInputCol("text")
|
||||
.setOutputCol("words")
|
||||
val hashingTF = new HashingTF()
|
||||
.setInputCol(tokenizer.getOutputCol)
|
||||
.setOutputCol("features")
|
||||
val lr = new LogisticRegression()
|
||||
.setMaxIter(10)
|
||||
val pipeline = new Pipeline()
|
||||
.setStages(Array(tokenizer, hashingTF, lr))
|
||||
|
||||
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
|
||||
// This will allow us to jointly choose parameters for all Pipeline stages.
|
||||
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
|
||||
val crossval = new CrossValidator()
|
||||
.setEstimator(pipeline)
|
||||
.setEvaluator(new BinaryClassificationEvaluator)
|
||||
// We use a ParamGridBuilder to construct a grid of parameters to search over.
|
||||
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
|
||||
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
|
||||
val paramGrid = new ParamGridBuilder()
|
||||
.addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
|
||||
.addGrid(lr.regParam, Array(0.1, 0.01))
|
||||
.build()
|
||||
crossval.setEstimatorParamMaps(paramGrid)
|
||||
crossval.setNumFolds(2) // Use 3+ in practice
|
||||
|
||||
// Run cross-validation, and choose the best set of parameters.
|
||||
val cvModel = crossval.fit(training.toDF())
|
||||
|
||||
// Prepare test documents, which are unlabeled.
|
||||
val test = sc.parallelize(Seq(
|
||||
Document(4L, "spark i j k"),
|
||||
Document(5L, "l m n"),
|
||||
Document(6L, "mapreduce spark"),
|
||||
Document(7L, "apache hadoop")))
|
||||
|
||||
// Make predictions on test documents. cvModel uses the best model found (lrModel).
|
||||
cvModel.transform(test.toDF())
|
||||
.select("id", "text", "probability", "prediction")
|
||||
.collect()
|
||||
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
|
||||
println(s"($id, $text) --> prob=$prob, prediction=$prediction")
|
||||
}
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
}
|
||||
// scalastyle:on println
|
|
@ -30,6 +30,15 @@ import org.apache.spark.sql.Row
|
|||
// $example off$
|
||||
import org.apache.spark.sql.SQLContext
|
||||
|
||||
/**
|
||||
* A simple example demonstrating model selection using CrossValidator.
|
||||
* This example also demonstrates how Pipelines are Estimators.
|
||||
*
|
||||
* Run with
|
||||
* {{{
|
||||
* bin/run-example ml.ModelSelectionViaCrossValidationExample
|
||||
* }}}
|
||||
*/
|
||||
object ModelSelectionViaCrossValidationExample {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
|
|
|
@ -25,6 +25,14 @@ import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
|
|||
// $example off$
|
||||
import org.apache.spark.sql.SQLContext
|
||||
|
||||
/**
|
||||
* A simple example demonstrating model selection using TrainValidationSplit.
|
||||
*
|
||||
* Run with
|
||||
* {{{
|
||||
* bin/run-example ml.ModelSelectionViaTrainValidationSplitExample
|
||||
* }}}
|
||||
*/
|
||||
object ModelSelectionViaTrainValidationSplitExample {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
|
|
|
@ -1,78 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.ml
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.ml.evaluation.RegressionEvaluator
|
||||
import org.apache.spark.ml.regression.LinearRegression
|
||||
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
|
||||
import org.apache.spark.sql.SQLContext
|
||||
|
||||
/**
|
||||
* A simple example demonstrating model selection using TrainValidationSplit.
|
||||
*
|
||||
* The example is based on [[SimpleParamsExample]] using linear regression.
|
||||
* Run with
|
||||
* {{{
|
||||
* bin/run-example ml.TrainValidationSplitExample
|
||||
* }}}
|
||||
*/
|
||||
object TrainValidationSplitExample {
|
||||
|
||||
def main(args: Array[String]): Unit = {
|
||||
val conf = new SparkConf().setAppName("TrainValidationSplitExample")
|
||||
val sc = new SparkContext(conf)
|
||||
val sqlContext = new SQLContext(sc)
|
||||
|
||||
// Prepare training and test data.
|
||||
val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
|
||||
val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
|
||||
|
||||
val lr = new LinearRegression()
|
||||
|
||||
// We use a ParamGridBuilder to construct a grid of parameters to search over.
|
||||
// TrainValidationSplit will try all combinations of values and determine best model using
|
||||
// the evaluator.
|
||||
val paramGrid = new ParamGridBuilder()
|
||||
.addGrid(lr.regParam, Array(0.1, 0.01))
|
||||
.addGrid(lr.fitIntercept, Array(true, false))
|
||||
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
|
||||
.build()
|
||||
|
||||
// In this case the estimator is simply the linear regression.
|
||||
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
|
||||
val trainValidationSplit = new TrainValidationSplit()
|
||||
.setEstimator(lr)
|
||||
.setEvaluator(new RegressionEvaluator)
|
||||
.setEstimatorParamMaps(paramGrid)
|
||||
|
||||
// 80% of the data will be used for training and the remaining 20% for validation.
|
||||
trainValidationSplit.setTrainRatio(0.8)
|
||||
|
||||
// Run train validation split, and choose the best set of parameters.
|
||||
val model = trainValidationSplit.fit(training)
|
||||
|
||||
// Make predictions on test data. model is the model with combination of parameters
|
||||
// that performed best.
|
||||
model.transform(test)
|
||||
.select("features", "label", "prediction")
|
||||
.show()
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue