[SPARK-7547] [ML] Scala Example code for ElasticNet
This is scala example code for both linear and logistic regression. Python and Java versions are to be added.
Author: DB Tsai <dbt@netflix.com>
Closes #6576 from dbtsai/elasticNetExample and squashes the following commits:
e7ca406 [DB Tsai] fix test
6bb6d77 [DB Tsai] fix suite and remove duplicated setMaxIter
136e0dd [DB Tsai] address feedback
1ec29d4 [DB Tsai] fix style
9462f5f [DB Tsai] add example
(cherry picked from commit a86b3e9b9b
)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
This commit is contained in:
parent
6a3e32ad1e
commit
6391be872d
|
@ -0,0 +1,142 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.ml
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.language.reflectiveCalls
|
||||
|
||||
import scopt.OptionParser
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.examples.mllib.AbstractParams
|
||||
import org.apache.spark.ml.{Pipeline, PipelineStage}
|
||||
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
/**
|
||||
* An example runner for linear regression with elastic-net (mixing L1/L2) regularization.
|
||||
* Run with
|
||||
* {{{
|
||||
* bin/run-example ml.LinearRegressionExample [options]
|
||||
* }}}
|
||||
* A synthetic dataset can be found at `data/mllib/sample_linear_regression_data.txt` which can be
|
||||
* trained by
|
||||
* {{{
|
||||
* bin/run-example ml.LinearRegressionExample --regParam 0.15 --elasticNetParam 1.0 \
|
||||
* data/mllib/sample_linear_regression_data.txt
|
||||
* }}}
|
||||
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
|
||||
*/
|
||||
object LinearRegressionExample {
|
||||
|
||||
case class Params(
|
||||
input: String = null,
|
||||
testInput: String = "",
|
||||
dataFormat: String = "libsvm",
|
||||
regParam: Double = 0.0,
|
||||
elasticNetParam: Double = 0.0,
|
||||
maxIter: Int = 100,
|
||||
tol: Double = 1E-6,
|
||||
fracTest: Double = 0.2) extends AbstractParams[Params]
|
||||
|
||||
def main(args: Array[String]) {
|
||||
val defaultParams = Params()
|
||||
|
||||
val parser = new OptionParser[Params]("LinearRegressionExample") {
|
||||
head("LinearRegressionExample: an example Linear Regression with Elastic-Net app.")
|
||||
opt[Double]("regParam")
|
||||
.text(s"regularization parameter, default: ${defaultParams.regParam}")
|
||||
.action((x, c) => c.copy(regParam = x))
|
||||
opt[Double]("elasticNetParam")
|
||||
.text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " +
|
||||
s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " +
|
||||
s"L1 and L2, default: ${defaultParams.elasticNetParam}")
|
||||
.action((x, c) => c.copy(elasticNetParam = x))
|
||||
opt[Int]("maxIter")
|
||||
.text(s"maximum number of iterations, default: ${defaultParams.maxIter}")
|
||||
.action((x, c) => c.copy(maxIter = x))
|
||||
opt[Double]("tol")
|
||||
.text(s"the convergence tolerance of iterations, Smaller value will lead " +
|
||||
s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}")
|
||||
.action((x, c) => c.copy(tol = x))
|
||||
opt[Double]("fracTest")
|
||||
.text(s"fraction of data to hold out for testing. If given option testInput, " +
|
||||
s"this option is ignored. default: ${defaultParams.fracTest}")
|
||||
.action((x, c) => c.copy(fracTest = x))
|
||||
opt[String]("testInput")
|
||||
.text(s"input path to test dataset. If given, option fracTest is ignored." +
|
||||
s" default: ${defaultParams.testInput}")
|
||||
.action((x, c) => c.copy(testInput = x))
|
||||
opt[String]("dataFormat")
|
||||
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
|
||||
.action((x, c) => c.copy(dataFormat = x))
|
||||
arg[String]("<input>")
|
||||
.text("input path to labeled examples")
|
||||
.required()
|
||||
.action((x, c) => c.copy(input = x))
|
||||
checkConfig { params =>
|
||||
if (params.fracTest < 0 || params.fracTest >= 1) {
|
||||
failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
|
||||
} else {
|
||||
success
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parser.parse(args, defaultParams).map { params =>
|
||||
run(params)
|
||||
}.getOrElse {
|
||||
sys.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
def run(params: Params) {
|
||||
val conf = new SparkConf().setAppName(s"LinearRegressionExample with $params")
|
||||
val sc = new SparkContext(conf)
|
||||
|
||||
println(s"LinearRegressionExample with parameters:\n$params")
|
||||
|
||||
// Load training and test data and cache it.
|
||||
val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
|
||||
params.dataFormat, params.testInput, "regression", params.fracTest)
|
||||
|
||||
val lir = new LinearRegression()
|
||||
.setFeaturesCol("features")
|
||||
.setLabelCol("label")
|
||||
.setRegParam(params.regParam)
|
||||
.setElasticNetParam(params.elasticNetParam)
|
||||
.setMaxIter(params.maxIter)
|
||||
.setTol(params.tol)
|
||||
|
||||
// Train the model
|
||||
val startTime = System.nanoTime()
|
||||
val lirModel = lir.fit(training)
|
||||
val elapsedTime = (System.nanoTime() - startTime) / 1e9
|
||||
println(s"Training time: $elapsedTime seconds")
|
||||
|
||||
// Print the weights and intercept for linear regression.
|
||||
println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}")
|
||||
|
||||
println("Training data results:")
|
||||
DecisionTreeExample.evaluateRegressionModel(lirModel, training, "label")
|
||||
println("Test data results:")
|
||||
DecisionTreeExample.evaluateRegressionModel(lirModel, test, "label")
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,159 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.examples.ml
|
||||
|
||||
import scala.collection.mutable
|
||||
import scala.language.reflectiveCalls
|
||||
|
||||
import scopt.OptionParser
|
||||
|
||||
import org.apache.spark.{SparkConf, SparkContext}
|
||||
import org.apache.spark.examples.mllib.AbstractParams
|
||||
import org.apache.spark.ml.{Pipeline, PipelineStage}
|
||||
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
|
||||
import org.apache.spark.ml.feature.StringIndexer
|
||||
import org.apache.spark.sql.DataFrame
|
||||
|
||||
/**
|
||||
* An example runner for logistic regression with elastic-net (mixing L1/L2) regularization.
|
||||
* Run with
|
||||
* {{{
|
||||
* bin/run-example ml.LogisticRegressionExample [options]
|
||||
* }}}
|
||||
* A synthetic dataset can be found at `data/mllib/sample_libsvm_data.txt` which can be
|
||||
* trained by
|
||||
* {{{
|
||||
* bin/run-example ml.LogisticRegressionExample --regParam 0.3 --elasticNetParam 0.8 \
|
||||
* data/mllib/sample_libsvm_data.txt
|
||||
* }}}
|
||||
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
|
||||
*/
|
||||
object LogisticRegressionExample {
|
||||
|
||||
case class Params(
|
||||
input: String = null,
|
||||
testInput: String = "",
|
||||
dataFormat: String = "libsvm",
|
||||
regParam: Double = 0.0,
|
||||
elasticNetParam: Double = 0.0,
|
||||
maxIter: Int = 100,
|
||||
fitIntercept: Boolean = true,
|
||||
tol: Double = 1E-6,
|
||||
fracTest: Double = 0.2) extends AbstractParams[Params]
|
||||
|
||||
def main(args: Array[String]) {
|
||||
val defaultParams = Params()
|
||||
|
||||
val parser = new OptionParser[Params]("LogisticRegressionExample") {
|
||||
head("LogisticRegressionExample: an example Logistic Regression with Elastic-Net app.")
|
||||
opt[Double]("regParam")
|
||||
.text(s"regularization parameter, default: ${defaultParams.regParam}")
|
||||
.action((x, c) => c.copy(regParam = x))
|
||||
opt[Double]("elasticNetParam")
|
||||
.text(s"ElasticNet mixing parameter. For alpha = 0, the penalty is an L2 penalty. " +
|
||||
s"For alpha = 1, it is an L1 penalty. For 0 < alpha < 1, the penalty is a combination of " +
|
||||
s"L1 and L2, default: ${defaultParams.elasticNetParam}")
|
||||
.action((x, c) => c.copy(elasticNetParam = x))
|
||||
opt[Int]("maxIter")
|
||||
.text(s"maximum number of iterations, default: ${defaultParams.maxIter}")
|
||||
.action((x, c) => c.copy(maxIter = x))
|
||||
opt[Boolean]("fitIntercept")
|
||||
.text(s"whether to fit an intercept term, default: ${defaultParams.fitIntercept}")
|
||||
.action((x, c) => c.copy(fitIntercept = x))
|
||||
opt[Double]("tol")
|
||||
.text(s"the convergence tolerance of iterations, Smaller value will lead " +
|
||||
s"to higher accuracy with the cost of more iterations, default: ${defaultParams.tol}")
|
||||
.action((x, c) => c.copy(tol = x))
|
||||
opt[Double]("fracTest")
|
||||
.text(s"fraction of data to hold out for testing. If given option testInput, " +
|
||||
s"this option is ignored. default: ${defaultParams.fracTest}")
|
||||
.action((x, c) => c.copy(fracTest = x))
|
||||
opt[String]("testInput")
|
||||
.text(s"input path to test dataset. If given, option fracTest is ignored." +
|
||||
s" default: ${defaultParams.testInput}")
|
||||
.action((x, c) => c.copy(testInput = x))
|
||||
opt[String]("dataFormat")
|
||||
.text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
|
||||
.action((x, c) => c.copy(dataFormat = x))
|
||||
arg[String]("<input>")
|
||||
.text("input path to labeled examples")
|
||||
.required()
|
||||
.action((x, c) => c.copy(input = x))
|
||||
checkConfig { params =>
|
||||
if (params.fracTest < 0 || params.fracTest >= 1) {
|
||||
failure(s"fracTest ${params.fracTest} value incorrect; should be in [0,1).")
|
||||
} else {
|
||||
success
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
parser.parse(args, defaultParams).map { params =>
|
||||
run(params)
|
||||
}.getOrElse {
|
||||
sys.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
def run(params: Params) {
|
||||
val conf = new SparkConf().setAppName(s"LogisticRegressionExample with $params")
|
||||
val sc = new SparkContext(conf)
|
||||
|
||||
println(s"LogisticRegressionExample with parameters:\n$params")
|
||||
|
||||
// Load training and test data and cache it.
|
||||
val (training: DataFrame, test: DataFrame) = DecisionTreeExample.loadDatasets(sc, params.input,
|
||||
params.dataFormat, params.testInput, "classification", params.fracTest)
|
||||
|
||||
// Set up Pipeline
|
||||
val stages = new mutable.ArrayBuffer[PipelineStage]()
|
||||
|
||||
val labelIndexer = new StringIndexer()
|
||||
.setInputCol("labelString")
|
||||
.setOutputCol("indexedLabel")
|
||||
stages += labelIndexer
|
||||
|
||||
val lor = new LogisticRegression()
|
||||
.setFeaturesCol("features")
|
||||
.setLabelCol("indexedLabel")
|
||||
.setRegParam(params.regParam)
|
||||
.setElasticNetParam(params.elasticNetParam)
|
||||
.setMaxIter(params.maxIter)
|
||||
.setTol(params.tol)
|
||||
|
||||
stages += lor
|
||||
val pipeline = new Pipeline().setStages(stages.toArray)
|
||||
|
||||
// Fit the Pipeline
|
||||
val startTime = System.nanoTime()
|
||||
val pipelineModel = pipeline.fit(training)
|
||||
val elapsedTime = (System.nanoTime() - startTime) / 1e9
|
||||
println(s"Training time: $elapsedTime seconds")
|
||||
|
||||
val lirModel = pipelineModel.stages.last.asInstanceOf[LogisticRegressionModel]
|
||||
// Print the weights and intercept for logistic regression.
|
||||
println(s"Weights: ${lirModel.weights} Intercept: ${lirModel.intercept}")
|
||||
|
||||
println("Training data results:")
|
||||
DecisionTreeExample.evaluateClassificationModel(pipelineModel, training, "indexedLabel")
|
||||
println("Test data results:")
|
||||
DecisionTreeExample.evaluateClassificationModel(pipelineModel, test, "indexedLabel")
|
||||
|
||||
sc.stop()
|
||||
}
|
||||
}
|
|
@ -74,7 +74,7 @@ class LogisticRegression(override val uid: String)
|
|||
setDefault(elasticNetParam -> 0.0)
|
||||
|
||||
/**
|
||||
* Set the maximal number of iterations.
|
||||
* Set the maximum number of iterations.
|
||||
* Default is 100.
|
||||
* @group setParam
|
||||
*/
|
||||
|
@ -90,7 +90,11 @@ class LogisticRegression(override val uid: String)
|
|||
def setTol(value: Double): this.type = set(tol, value)
|
||||
setDefault(tol -> 1E-6)
|
||||
|
||||
/** @group setParam */
|
||||
/**
|
||||
* Whether to fit an intercept term.
|
||||
* Default is true.
|
||||
* @group setParam
|
||||
* */
|
||||
def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
|
||||
setDefault(fitIntercept -> true)
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ private[shared] object SharedParamsCodeGen {
|
|||
val params = Seq(
|
||||
ParamDesc[Double]("regParam", "regularization parameter (>= 0)",
|
||||
isValid = "ParamValidators.gtEq(0)"),
|
||||
ParamDesc[Int]("maxIter", "max number of iterations (>= 0)",
|
||||
ParamDesc[Int]("maxIter", "maximum number of iterations (>= 0)",
|
||||
isValid = "ParamValidators.gtEq(0)"),
|
||||
ParamDesc[String]("featuresCol", "features column name", Some("\"features\"")),
|
||||
ParamDesc[String]("labelCol", "label column name", Some("\"label\"")),
|
||||
|
|
|
@ -45,10 +45,10 @@ private[ml] trait HasRegParam extends Params {
|
|||
private[ml] trait HasMaxIter extends Params {
|
||||
|
||||
/**
|
||||
* Param for max number of iterations (>= 0).
|
||||
* Param for maximum number of iterations (>= 0).
|
||||
* @group param
|
||||
*/
|
||||
final val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations (>= 0)", ParamValidators.gtEq(0))
|
||||
final val maxIter: IntParam = new IntParam(this, "maxIter", "maximum number of iterations (>= 0)", ParamValidators.gtEq(0))
|
||||
|
||||
/** @group getParam */
|
||||
final def getMaxIter: Int = $(maxIter)
|
||||
|
|
|
@ -83,7 +83,7 @@ class LinearRegression(override val uid: String)
|
|||
setDefault(elasticNetParam -> 0.0)
|
||||
|
||||
/**
|
||||
* Set the maximal number of iterations.
|
||||
* Set the maximum number of iterations.
|
||||
* Default is 100.
|
||||
* @group setParam
|
||||
*/
|
||||
|
|
|
@ -27,7 +27,7 @@ class ParamsSuite extends FunSuite {
|
|||
import solver.{maxIter, inputCol}
|
||||
|
||||
assert(maxIter.name === "maxIter")
|
||||
assert(maxIter.doc === "max number of iterations (>= 0)")
|
||||
assert(maxIter.doc === "maximum number of iterations (>= 0)")
|
||||
assert(maxIter.parent === uid)
|
||||
assert(maxIter.toString === s"${uid}__maxIter")
|
||||
assert(!maxIter.isValid(-1))
|
||||
|
@ -36,7 +36,7 @@ class ParamsSuite extends FunSuite {
|
|||
|
||||
solver.setMaxIter(5)
|
||||
assert(solver.explainParam(maxIter) ===
|
||||
"maxIter: max number of iterations (>= 0) (default: 10, current: 5)")
|
||||
"maxIter: maximum number of iterations (>= 0) (default: 10, current: 5)")
|
||||
|
||||
assert(inputCol.toString === s"${uid}__inputCol")
|
||||
|
||||
|
@ -120,7 +120,7 @@ class ParamsSuite extends FunSuite {
|
|||
intercept[NoSuchElementException](solver.getInputCol)
|
||||
|
||||
assert(solver.explainParam(maxIter) ===
|
||||
"maxIter: max number of iterations (>= 0) (default: 10, current: 100)")
|
||||
"maxIter: maximum number of iterations (>= 0) (default: 10, current: 100)")
|
||||
assert(solver.explainParams() ===
|
||||
Seq(inputCol, maxIter).map(solver.explainParam).mkString("\n"))
|
||||
|
||||
|
|
Loading…
Reference in a new issue