From 607eff0edfc10a1473fa9713a0500bf09f105c13 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 21 Apr 2015 21:44:44 -0700 Subject: [PATCH] [SPARK-6113] [ML] Small cleanups after original tree API PR This does a few clean-ups. With this PR, all spark.ml tree components have ```private[ml]``` constructors. CC: mengxr Author: Joseph K. Bradley Closes #5567 from jkbradley/dt-api-dt2 and squashes the following commits: 2263b5b [Joseph K. Bradley] Added note about tree example issue. bb9f610 [Joseph K. Bradley] Small cleanups after original tree API PR --- .../examples/ml/DecisionTreeExample.scala | 25 ++++++++++++++----- .../spark/ml/impl/tree/treeParams.scala | 4 +-- .../org/apache/spark/ml/tree/Split.scala | 7 +++--- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index 921b396e79..2cd515c89d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -44,6 +44,13 @@ import org.apache.spark.sql.{SQLContext, DataFrame} * {{{ * ./bin/run-example ml.DecisionTreeExample [options] * }}} + * Note that Decision Trees can take a large amount of memory. If the run-example command above + * fails, try running via spark-submit and specifying the amount of memory as at least 1g. + * For local mode, run + * {{{ + * ./bin/spark-submit --class org.apache.spark.examples.ml.DecisionTreeExample --driver-memory 1g + * [examples JAR path] [options] + * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ object DecisionTreeExample { @@ -70,7 +77,7 @@ object DecisionTreeExample { val parser = new OptionParser[Params]("DecisionTreeExample") { head("DecisionTreeExample: an example decision tree app.") opt[String]("algo") - .text(s"algorithm (Classification, Regression), default: ${defaultParams.algo}") + .text(s"algorithm (classification, regression), default: ${defaultParams.algo}") .action((x, c) => c.copy(algo = x)) opt[Int]("maxDepth") .text(s"max depth of the tree, default: ${defaultParams.maxDepth}") @@ -222,18 +229,23 @@ object DecisionTreeExample { // (1) For classification, re-index classes. val labelColName = if (algo == "classification") "indexedLabel" else "label" if (algo == "classification") { - val labelIndexer = new StringIndexer().setInputCol("labelString").setOutputCol(labelColName) + val labelIndexer = new StringIndexer() + .setInputCol("labelString") + .setOutputCol(labelColName) stages += labelIndexer } // (2) Identify categorical features using VectorIndexer. // Features with more than maxCategories values will be treated as continuous. - val featuresIndexer = new VectorIndexer().setInputCol("features") - .setOutputCol("indexedFeatures").setMaxCategories(10) + val featuresIndexer = new VectorIndexer() + .setInputCol("features") + .setOutputCol("indexedFeatures") + .setMaxCategories(10) stages += featuresIndexer // (3) Learn DecisionTree val dt = algo match { case "classification" => - new DecisionTreeClassifier().setFeaturesCol("indexedFeatures") + new DecisionTreeClassifier() + .setFeaturesCol("indexedFeatures") .setLabelCol(labelColName) .setMaxDepth(params.maxDepth) .setMaxBins(params.maxBins) @@ -242,7 +254,8 @@ object DecisionTreeExample { .setCacheNodeIds(params.cacheNodeIds) .setCheckpointInterval(params.checkpointInterval) case "regression" => - new DecisionTreeRegressor().setFeaturesCol("indexedFeatures") + new DecisionTreeRegressor() + .setFeaturesCol("indexedFeatures") .setLabelCol(labelColName) .setMaxDepth(params.maxDepth) .setMaxBins(params.maxBins) diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala index 6f4509f03d..eb2609faef 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/impl/tree/treeParams.scala @@ -117,7 +117,7 @@ private[ml] trait DecisionTreeParams extends PredictorParams { def setMaxDepth(value: Int): this.type = { require(value >= 0, s"maxDepth parameter must be >= 0. Given bad value: $value") set(maxDepth, value) - this.asInstanceOf[this.type] + this } /** @group getParam */ @@ -283,7 +283,7 @@ private[ml] trait TreeRegressorParams extends Params { def getImpurity: String = getOrDefault(impurity) /** Convert new impurity to old impurity. */ - protected def getOldImpurity: OldImpurity = { + private[ml] def getOldImpurity: OldImpurity = { getImpurity match { case "variance" => OldVariance case _ => diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala index cb940f6299..708c769087 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Split.scala @@ -38,7 +38,7 @@ sealed trait Split extends Serializable { private[tree] def toOld: OldSplit } -private[ml] object Split { +private[tree] object Split { def fromOld(oldSplit: OldSplit, categoricalFeatures: Map[Int, Int]): Split = { oldSplit.featureType match { @@ -58,7 +58,7 @@ private[ml] object Split { * left. Otherwise, it goes right. * @param numCategories Number of categories for this feature. */ -final class CategoricalSplit( +final class CategoricalSplit private[ml] ( override val featureIndex: Int, leftCategories: Array[Double], private val numCategories: Int) @@ -130,7 +130,8 @@ final class CategoricalSplit( * @param threshold If the feature value is <= this threshold, then the split goes left. * Otherwise, it goes right. */ -final class ContinuousSplit(override val featureIndex: Int, val threshold: Double) extends Split { +final class ContinuousSplit private[ml] (override val featureIndex: Int, val threshold: Double) + extends Split { override private[ml] def shouldGoLeft(features: Vector): Boolean = { features(featureIndex) <= threshold