[SPARK-6113] [ML] Small cleanups after original tree API PR
This does a few clean-ups. With this PR, all spark.ml tree components have ```private[ml]``` constructors. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #5567 from jkbradley/dt-api-dt2 and squashes the following commits: 2263b5b [Joseph K. Bradley] Added note about tree example issue. bb9f610 [Joseph K. Bradley] Small cleanups after original tree API PR
This commit is contained in:
parent
70f9f8ff38
commit
607eff0edf
|
@ -44,6 +44,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame}
|
|||
* {{{
|
||||
* ./bin/run-example ml.DecisionTreeExample [options]
|
||||
* }}}
|
||||
* Note that Decision Trees can take a large amount of memory. If the run-example command above
|
||||
* fails, try running via spark-submit and specifying the amount of memory as at least 1g.
|
||||
* For local mode, run
|
||||
* {{{
|
||||
* ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g
|
||||
* [examples JAR path] [options]
|
||||
* }}}
|
||||
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
|
||||
*/
|
||||
object DecisionTreeExample {
|
||||
|
@ -70,7 +77,7 @@ object DecisionTreeExample {
|
|||
val parser = new OptionParser[Params]("DecisionTreeExample") {
|
||||
head("DecisionTreeExample: an example decision tree app.")
|
||||
opt[String]("algo")
|
||||
.text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
|
||||
.text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
|
||||
.action((x, c) => c.copy(algo = x))
|
||||
opt[Int]("maxDepth")
|
||||
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
|
||||
|
@ -222,18 +229,23 @@ object DecisionTreeExample {
|
|||
// (1) For classification, re-index classes.
|
||||
val labelColName = if (algo == "classification") "indexedLabel" else "label"
|
||||
if (algo == "classification") {
|
||||
val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
|
||||
val labelIndexer = new StringIndexer()
|
||||
.setInputCol("labelString")
|
||||
.setOutputCol(labelColName)
|
||||
stages += labelIndexer
|
||||
}
|
||||
// (2) Identify categorical features using VectorIndexer.
|
||||
// Features with more than maxCategories values will be treated as continuous.
|
||||
val featuresIndexer = new VectorIndexer().setInputCol("features")
|
||||
.setOutputCol("indexedFeatures").setMaxCategories(10)
|
||||
val featuresIndexer = new VectorIndexer()
|
||||
.setInputCol("features")
|
||||
.setOutputCol("indexedFeatures")
|
||||
.setMaxCategories(10)
|
||||
stages += featuresIndexer
|
||||
// (3) Learn DecisionTree
|
||||
val dt = algo match {
|
||||
case "classification" =>
|
||||
new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
|
||||
new DecisionTreeClassifier()
|
||||
.setFeaturesCol("indexedFeatures")
|
||||
.setLabelCol(labelColName)
|
||||
.setMaxDepth(params.maxDepth)
|
||||
.setMaxBins(params.maxBins)
|
||||
|
@ -242,7 +254,8 @@ object DecisionTreeExample {
|
|||
.setCacheNodeIds(params.cacheNodeIds)
|
||||
.setCheckpointInterval(params.checkpointInterval)
|
||||
case "regression" =>
|
||||
new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
|
||||
new DecisionTreeRegressor()
|
||||
.setFeaturesCol("indexedFeatures")
|
||||
.setLabelCol(labelColName)
|
||||
.setMaxDepth(params.maxDepth)
|
||||
.setMaxBins(params.maxBins)
|
||||
|
|
|
@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
|
|||
def setMaxDepth(value: Int): this.type = {
|
||||
require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
|
||||
set(maxDepth, value)
|
||||
this.asInstanceOf[this.type]
|
||||
this
|
||||
}
|
||||
|
||||
/** @group getParam */
|
||||
|
@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params {
|
|||
def getImpurity: String = getOrDefault(impurity)
|
||||
|
||||
/** Convert new impurity to old impurity. */
|
||||
protected def getOldImpurity: OldImpurity = {
|
||||
private[ml] def getOldImpurity: OldImpurity = {
|
||||
getImpurity match {
|
||||
case "variance" => OldVariance
|
||||
case _ =>
|
||||
|
|
|
@ -38,7 +38,7 @@ sealed trait Split extends Serializable {
|
|||
private[tree] def toOld: OldSplit
|
||||
}
|
||||
|
||||
private[ml] object Split {
|
||||
private[tree] object Split {
|
||||
|
||||
def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
|
||||
oldSplit.featureType match {
|
||||
|
@ -58,7 +58,7 @@ private[ml] object Split {
|
|||
* left. Otherwise, it goes right.
|
||||
* @param numCategories Number of categories for this feature.
|
||||
*/
|
||||
final class CategoricalSplit(
|
||||
final class CategoricalSplit private[ml] (
|
||||
override val featureIndex: Int,
|
||||
leftCategories: Array[Double],
|
||||
private val numCategories: Int)
|
||||
|
@ -130,7 +130,8 @@ final class CategoricalSplit(
|
|||
* @param threshold If the feature value is <= this threshold, then the split goes left.
|
||||
* Otherwise, it goes right.
|
||||
*/
|
||||
final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
|
||||
final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
|
||||
extends Split {
|
||||
|
||||
override private[ml] def shouldGoLeft(features: Vector): Boolean = {
|
||||
features(featureIndex) <= threshold
|
||||
|
|
Loading…
Reference in a new issue