[SPARK-6113] [ML] Small cleanups after original tree API PR

This does a few clean-ups.  With this PR, all spark.ml tree components have ```private[ml]``` constructors.

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #5567 from jkbradley/dt-api-dt2 and squashes the following commits:

2263b5b [Joseph K. Bradley] Added note about tree example issue.
bb9f610 [Joseph K. Bradley] Small cleanups after original tree API PR
This commit is contained in:
Joseph K. Bradley 2015-04-21 21:44:44 -07:00 committed by Xiangrui Meng
parent 70f9f8ff38
commit 607eff0edf
3 changed files with 25 additions and 11 deletions

View file

@ -44,6 +44,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame}
* {{{
* ./bin/run-example ml.DecisionTreeExample [options]
* }}}
* Note that Decision Trees can take a large amount of memory. If the run-example command above
* fails, try running via spark-submit and specifying the amount of memory as at least 1g.
* For local mode, run
* {{{
* ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g
* [examples JAR path] [options]
* }}}
* If you use it as a template to create your own app, please use `spark-submit` to submit your app.
*/
object DecisionTreeExample {
@ -70,7 +77,7 @@ object DecisionTreeExample {
val parser = new OptionParser[Params]("DecisionTreeExample") {
head("DecisionTreeExample: an example decision tree app.")
opt[String]("algo")
.text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}")
.text(s"algorithm (classification, regression), default: ${defaultParams.algo}")
.action((x, c) => c.copy(algo = x))
opt[Int]("maxDepth")
.text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
@ -222,18 +229,23 @@ object DecisionTreeExample {
// (1) For classification, re-index classes.
val labelColName = if (algo == "classification") "indexedLabel" else "label"
if (algo == "classification") {
val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName)
val labelIndexer = new StringIndexer()
.setInputCol("labelString")
.setOutputCol(labelColName)
stages += labelIndexer
}
// (2) Identify categorical features using VectorIndexer.
// Features with more than maxCategories values will be treated as continuous.
val featuresIndexer = new VectorIndexer().setInputCol("features")
.setOutputCol("indexedFeatures").setMaxCategories(10)
val featuresIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(10)
stages += featuresIndexer
// (3) Learn DecisionTree
val dt = algo match {
case "classification" =>
new DecisionTreeClassifier().setFeaturesCol("indexedFeatures")
new DecisionTreeClassifier()
.setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)
@ -242,7 +254,8 @@ object DecisionTreeExample {
.setCacheNodeIds(params.cacheNodeIds)
.setCheckpointInterval(params.checkpointInterval)
case "regression" =>
new DecisionTreeRegressor().setFeaturesCol("indexedFeatures")
new DecisionTreeRegressor()
.setFeaturesCol("indexedFeatures")
.setLabelCol(labelColName)
.setMaxDepth(params.maxDepth)
.setMaxBins(params.maxBins)

View file

@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
def setMaxDepth(value: Int): this.type = {
require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value")
set(maxDepth, value)
this.asInstanceOf[this.type]
this
}
/** @group getParam */
@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params {
def getImpurity: String = getOrDefault(impurity)
/** Convert new impurity to old impurity. */
protected def getOldImpurity: OldImpurity = {
private[ml] def getOldImpurity: OldImpurity = {
getImpurity match {
case "variance" => OldVariance
case _ =>

View file

@ -38,7 +38,7 @@ sealed trait Split extends Serializable {
private[tree] def toOld: OldSplit
}
private[ml] object Split {
private[tree] object Split {
def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = {
oldSplit.featureType match {
@ -58,7 +58,7 @@ private[ml] object Split {
* left. Otherwise, it goes right.
* @param numCategories Number of categories for this feature.
*/
final class CategoricalSplit(
final class CategoricalSplit private[ml] (
override val featureIndex: Int,
leftCategories: Array[Double],
private val numCategories: Int)
@ -130,7 +130,8 @@ final class CategoricalSplit(
* @param threshold If the feature value is <= this threshold, then the split goes left.
* Otherwise, it goes right.
*/
final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split {
final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double)
extends Split {
override private[ml] def shouldGoLeft(features: Vector): Boolean = {
features(featureIndex) <= threshold