MLI-1 Decision Trees

Joint work with @hirakendu, @etrain, @atalwalkar and @harsha2010.

Key features:
+ Supports binary classification and regression
+ Supports gini, entropy and variance for information gain calculation
+ Supports both continuous and categorical features

The algorithm has gone through several development iterations over the last few months leading to a highly optimized implementation. Optimizations include:

1. Level-wise training to reduce passes over the entire dataset.
2. Bin-wise split calculation to reduce computation overhead.
3. Aggregation over partitions before combining to reduce communication overhead.

Author: Manish Amde <manish9ue@gmail.com>
Author: manishamde <manish9ue@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>

Closes #79 from manishamde/tree and squashes the following commits:

1e8c704 [Manish Amde] remove numBins field in the Strategy class
7d54b4f [manishamde] Merge pull request #4 from mengxr/dtree
f536ae9 [Xiangrui Meng] another pass on code style
e1dd86f [Manish Amde] implementing code style suggestions
62dc723 [Manish Amde] updating javadoc and converting helper methods to package private to allow unit testing
201702f [Manish Amde] making some more methods private
f963ef5 [Manish Amde] making methods private
c487e6a [manishamde] Merge pull request #1 from mengxr/dtree
24500c5 [Xiangrui Meng] minor style updates
4576b64 [Manish Amde] documentation and for to while loop conversion
ff363a7 [Manish Amde] binary search for bins and while loop for categorical feature bins
632818f [Manish Amde] removing threshold for classification predict method
2116360 [Manish Amde] removing dummy bin calculation for categorical variables
6068356 [Manish Amde] ensuring num bins is always greater than max number of categories
62c2562 [Manish Amde] fixing comment indentation
ad1fc21 [Manish Amde] incorporated mengxr's code style suggestions
d1ef4f6 [Manish Amde] more documentation
794ff4d [Manish Amde] minor improvements to docs and style
eb8fcbe [Manish Amde] minor code style updates
cd2c2b4 [Manish Amde] fixing code style based on feedback
63e786b [Manish Amde] added multiple train methods for java compatability
d3023b3 [Manish Amde] adding more docs for nested methods
84f85d6 [Manish Amde] code documentation
9372779 [Manish Amde] code style: max line lenght <= 100
dd0c0d7 [Manish Amde] minor: some docs
0dd7659 [manishamde] basic doc
5841c28 [Manish Amde] unit tests for categorical features
f067d68 [Manish Amde] minor cleanup
c0e522b [Manish Amde] updated predict and split threshold logic
b09dc98 [Manish Amde] minor refactoring
6b7de78 [Manish Amde] minor refactoring and tests
d504eb1 [Manish Amde] more tests for categorical features
dbb7ac1 [Manish Amde] categorical feature support
6df35b9 [Manish Amde] regression predict logic
53108ed [Manish Amde] fixing index for highest bin
e23c2e5 [Manish Amde] added regression support
c8f6d60 [Manish Amde] adding enum for feature type
b0e3e76 [Manish Amde] adding enum for feature type
154aa77 [Manish Amde] enums for configurations
733d6dd [Manish Amde] fixed tests
02c595c [Manish Amde] added command line parsing
98ec8d5 [Manish Amde] tree building and prediction logic
b0eb866 [Manish Amde] added logic to handle leaf nodes
80e8c66 [Manish Amde] working version of multi-level split calculation
4798aae [Manish Amde] added gain stats class
dad0afc [Manish Amde] decison stump functionality working
03f534c [Manish Amde] some more tests
0012a77 [Manish Amde] basic stump working
8bca1e2 [Manish Amde] additional code for creating intermediate RDD
92cedce [Manish Amde] basic building blocks for intermediate RDD calculation. untested.
cd53eae [Manish Amde] skeletal framework
This commit is contained in:
Manish Amde 2014-04-01 21:40:49 -07:00 committed by Matei Zaharia
parent 45df912736
commit 8b3045ceab
17 changed files with 2188 additions and 0 deletions

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,17 @@
This package contains the default implementation of the decision tree algorithm.
The decision tree algorithm supports:
+ Binary classification
+ Regression
+ Information loss calculation with entropy and gini for classification and variance for regression
+ Both continuous and categorical features
# Tree improvements
+ Node model pruning
+ Printing to dot files
# Future Ensemble Extensions
+ Random forests
+ Boosting
+ Extremely randomized trees

View file

@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.configuration
/**
* Enum to select the algorithm for the decision tree
*/
object Algo extends Enumeration {
type Algo = Value
val Classification, Regression = Value
}

View file

@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.configuration
/**
* Enum to describe whether a feature is "continuous" or "categorical"
*/
object FeatureType extends Enumeration {
type FeatureType = Value
val Continuous, Categorical = Value
}

View file

@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.configuration
/**
* Enum for selecting the quantile calculation strategy
*/
object QuantileStrategy extends Enumeration {
type QuantileStrategy = Value
val Sort, MinMax, ApproxHist = Value
}

View file

@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.configuration
import org.apache.spark.mllib.tree.impurity.Impurity
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
/**
* Stores all the configuration options for tree construction
* @param algo classification or regression
* @param impurity criterion used for information gain calculation
* @param maxDepth maximum depth of the tree
* @param maxBins maximum number of bins used for splitting features
* @param quantileCalculationStrategy algorithm for calculating quantiles
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. For example, an entry (n ->
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
*/
class Strategy (
val algo: Algo,
val impurity: Impurity,
val maxDepth: Int,
val maxBins: Int = 100,
val quantileCalculationStrategy: QuantileStrategy = Sort,
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable

View file

@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impurity
/**
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
* binary classification.
*/
object Entropy extends Impurity {
def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
/**
* entropy calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return entropy value
*/
def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
-(f0 * log2(f0)) - (f1 * log2(f1))
}
}
def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Entropy.calculate")
}

View file

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impurity
/**
* Class for calculating the
* [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
* during binary classification.
*/
object Gini extends Impurity {
/**
* Gini coefficient calculation
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return Gini coefficient value
*/
override def calculate(c0: Double, c1: Double): Double = {
if (c0 == 0 || c1 == 0) {
0
} else {
val total = c0 + c1
val f0 = c0 / total
val f1 = c1 / total
1 - f0 * f0 - f1 * f1
}
}
def calculate(count: Double, sum: Double, sumSquares: Double): Double =
throw new UnsupportedOperationException("Gini.calculate")
}

View file

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impurity
/**
* Trait for calculating information gain.
*/
trait Impurity extends Serializable {
/**
* information calculation for binary classification
* @param c0 count of instances with label 0
* @param c1 count of instances with label 1
* @return information value
*/
def calculate(c0 : Double, c1 : Double): Double
/**
* information calculation for regression
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
* @return information value
*/
def calculate(count: Double, sum: Double, sumSquares: Double): Double
}

View file

@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impurity
/**
* Class for calculating variance during regression
*/
object Variance extends Impurity {
override def calculate(c0: Double, c1: Double): Double =
throw new UnsupportedOperationException("Variance.calculate")
/**
* variance calculation
* @param count number of instances
* @param sum sum of labels
* @param sumSquares summation of squares of the labels
*/
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
val squaredLoss = sumSquares - (sum * sum) / count
squaredLoss / count
}
}

View file

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.FeatureType._
/**
* Used for "binning" the features bins for faster best split calculation. For a continuous
* feature, a bin is determined by a low and a high "split". For a categorical feature,
* the a bin is determined using a single label value (category).
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin
*/
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)

View file

@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.rdd.RDD
/**
* Model to store the decision tree parameters
* @param topNode root node
* @param algo algorithm type -- classification or regression
*/
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
/**
* Predict values for a single data point using the model trained.
*
* @param features array representing a single data point
* @return Double prediction from the trained model
*/
def predict(features: Array[Double]): Double = {
topNode.predictIfLeaf(features)
}
/**
* Predict values for the given data set using the model trained.
*
* @param features RDD representing data points to be predicted
* @return RDD[Int] where each entry contains the corresponding prediction
*/
def predict(features: RDD[Array[Double]]): RDD[Double] = {
features.map(x => predict(x))
}
}

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.model
/**
* Filter specifying a split and type of comparison to be applied on features
* @param split split specifying the feature index, type and threshold
* @param comparison integer specifying <,=,>
*/
case class Filter(split: Split, comparison: Int) {
// Comparison -1,0,1 signifies <.=,>
override def toString = " split = " + split + "comparison = " + comparison
}

View file

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.model
/**
* Information gain statistics for each split
* @param gain information gain value
* @param impurity current node impurity
* @param leftImpurity left node impurity
* @param rightImpurity right node impurity
* @param predict predicted value
*/
class InformationGainStats(
val gain: Double,
val impurity: Double,
val leftImpurity: Double,
val rightImpurity: Double,
val predict: Double) extends Serializable {
override def toString = {
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
.format(gain, impurity, leftImpurity, rightImpurity, predict)
}
}

View file

@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.model
import org.apache.spark.Logging
import org.apache.spark.mllib.tree.configuration.FeatureType._
/**
* Node in a decision tree
* @param id integer node id
* @param predict predicted value at the node
* @param isLeaf whether the leaf is a node
* @param split split to calculate left and right nodes
* @param leftNode left child
* @param rightNode right child
* @param stats information gain stats
*/
class Node (
val id: Int,
val predict: Double,
val isLeaf: Boolean,
val split: Option[Split],
var leftNode: Option[Node],
var rightNode: Option[Node],
val stats: Option[InformationGainStats]) extends Serializable with Logging {
override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
"split = " + split + ", stats = " + stats
/**
* build the left node and right nodes if not leaf
* @param nodes array of nodes
*/
def build(nodes: Array[Node]): Unit = {
logDebug("building node " + id + " at level " +
(scala.math.log(id + 1)/scala.math.log(2)).toInt )
logDebug("id = " + id + ", split = " + split)
logDebug("stats = " + stats)
logDebug("predict = " + predict)
if (!isLeaf) {
val leftNodeIndex = id*2 + 1
val rightNodeIndex = id*2 + 2
leftNode = Some(nodes(leftNodeIndex))
rightNode = Some(nodes(rightNodeIndex))
leftNode.get.build(nodes)
rightNode.get.build(nodes)
}
}
/**
* predict value if node is not leaf
* @param feature feature value
* @return predicted value
*/
def predictIfLeaf(feature: Array[Double]) : Double = {
if (isLeaf) {
predict
} else{
if (split.get.featureType == Continuous) {
if (feature(split.get.feature) <= split.get.threshold) {
leftNode.get.predictIfLeaf(feature)
} else {
rightNode.get.predictIfLeaf(feature)
}
} else {
if (split.get.categories.contains(feature(split.get.feature))) {
leftNode.get.predictIfLeaf(feature)
} else {
rightNode.get.predictIfLeaf(feature)
}
}
}
}
}

View file

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
/**
* Split applied to a feature
* @param feature feature index
* @param threshold threshold for continuous feature
* @param featureType type of feature -- categorical or continuous
* @param categories accepted values for categorical variables
*/
case class Split(
feature: Int,
threshold: Double,
featureType: FeatureType,
categories: List[Double]){
override def toString =
"Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +
", categories = " + categories
}
/**
* Split with minimum threshold for continuous features. Helps with the smallest bin creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyLowSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MinValue, featureType, List())
/**
* Split with maximum threshold for continuous features. Helps with the highest bin creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyHighSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())
/**
* Split with no acceptable feature values for categorical features. Helps with the first bin
* creation.
* @param feature feature index
* @param featureType type of feature -- categorical or continuous
*/
class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
extends Split(feature, Double.MaxValue, featureType, List())

View file

@ -0,0 +1,425 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.Filter
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
@transient private var sc: SparkContext = _
override def beforeAll() {
sc = new SparkContext("local", "test")
}
override def afterAll() {
sc.stop()
System.clearProperty("spark.driver.port")
}
test("split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 99)
assert(bins(0).length === 100)
}
test("split and bin calculation for categorical variables") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Classification,
Gini,
maxDepth = 3,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 99)
assert(bins(0).length === 100)
// Check splits.
assert(splits(0)(0).feature === 0)
assert(splits(0)(0).threshold === Double.MinValue)
assert(splits(0)(0).featureType === Categorical)
assert(splits(0)(0).categories.length === 1)
assert(splits(0)(0).categories.contains(1.0))
assert(splits(0)(1).feature === 0)
assert(splits(0)(1).threshold === Double.MinValue)
assert(splits(0)(1).featureType === Categorical)
assert(splits(0)(1).categories.length === 2)
assert(splits(0)(1).categories.contains(1.0))
assert(splits(0)(1).categories.contains(0.0))
assert(splits(0)(2) === null)
assert(splits(1)(0).feature === 1)
assert(splits(1)(0).threshold === Double.MinValue)
assert(splits(1)(0).featureType === Categorical)
assert(splits(1)(0).categories.length === 1)
assert(splits(1)(0).categories.contains(0.0))
assert(splits(1)(1).feature === 1)
assert(splits(1)(1).threshold === Double.MinValue)
assert(splits(1)(1).featureType === Categorical)
assert(splits(1)(1).categories.length === 2)
assert(splits(1)(1).categories.contains(1.0))
assert(splits(1)(1).categories.contains(0.0))
assert(splits(1)(2) === null)
// Check bins.
assert(bins(0)(0).category === 1.0)
assert(bins(0)(0).lowSplit.categories.length === 0)
assert(bins(0)(0).highSplit.categories.length === 1)
assert(bins(0)(0).highSplit.categories.contains(1.0))
assert(bins(0)(1).category === 0.0)
assert(bins(0)(1).lowSplit.categories.length === 1)
assert(bins(0)(1).lowSplit.categories.contains(1.0))
assert(bins(0)(1).highSplit.categories.length === 2)
assert(bins(0)(1).highSplit.categories.contains(1.0))
assert(bins(0)(1).highSplit.categories.contains(0.0))
assert(bins(0)(2) === null)
assert(bins(1)(0).category === 0.0)
assert(bins(1)(0).lowSplit.categories.length === 0)
assert(bins(1)(0).highSplit.categories.length === 1)
assert(bins(1)(0).highSplit.categories.contains(0.0))
assert(bins(1)(1).category === 1.0)
assert(bins(1)(1).lowSplit.categories.length === 1)
assert(bins(1)(1).lowSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.length === 2)
assert(bins(1)(1).highSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.contains(1.0))
assert(bins(1)(2) === null)
}
test("split and bin calculations for categorical variables with no sample for one category") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Classification,
Gini,
maxDepth = 3,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
// Check splits.
assert(splits(0)(0).feature === 0)
assert(splits(0)(0).threshold === Double.MinValue)
assert(splits(0)(0).featureType === Categorical)
assert(splits(0)(0).categories.length === 1)
assert(splits(0)(0).categories.contains(1.0))
assert(splits(0)(1).feature === 0)
assert(splits(0)(1).threshold === Double.MinValue)
assert(splits(0)(1).featureType === Categorical)
assert(splits(0)(1).categories.length === 2)
assert(splits(0)(1).categories.contains(1.0))
assert(splits(0)(1).categories.contains(0.0))
assert(splits(0)(2).feature === 0)
assert(splits(0)(2).threshold === Double.MinValue)
assert(splits(0)(2).featureType === Categorical)
assert(splits(0)(2).categories.length === 3)
assert(splits(0)(2).categories.contains(1.0))
assert(splits(0)(2).categories.contains(0.0))
assert(splits(0)(2).categories.contains(2.0))
assert(splits(0)(3) === null)
assert(splits(1)(0).feature === 1)
assert(splits(1)(0).threshold === Double.MinValue)
assert(splits(1)(0).featureType === Categorical)
assert(splits(1)(0).categories.length === 1)
assert(splits(1)(0).categories.contains(0.0))
assert(splits(1)(1).feature === 1)
assert(splits(1)(1).threshold === Double.MinValue)
assert(splits(1)(1).featureType === Categorical)
assert(splits(1)(1).categories.length === 2)
assert(splits(1)(1).categories.contains(1.0))
assert(splits(1)(1).categories.contains(0.0))
assert(splits(1)(2).feature === 1)
assert(splits(1)(2).threshold === Double.MinValue)
assert(splits(1)(2).featureType === Categorical)
assert(splits(1)(2).categories.length === 3)
assert(splits(1)(2).categories.contains(1.0))
assert(splits(1)(2).categories.contains(0.0))
assert(splits(1)(2).categories.contains(2.0))
assert(splits(1)(3) === null)
// Check bins.
assert(bins(0)(0).category === 1.0)
assert(bins(0)(0).lowSplit.categories.length === 0)
assert(bins(0)(0).highSplit.categories.length === 1)
assert(bins(0)(0).highSplit.categories.contains(1.0))
assert(bins(0)(1).category === 0.0)
assert(bins(0)(1).lowSplit.categories.length === 1)
assert(bins(0)(1).lowSplit.categories.contains(1.0))
assert(bins(0)(1).highSplit.categories.length === 2)
assert(bins(0)(1).highSplit.categories.contains(1.0))
assert(bins(0)(1).highSplit.categories.contains(0.0))
assert(bins(0)(2).category === 2.0)
assert(bins(0)(2).lowSplit.categories.length === 2)
assert(bins(0)(2).lowSplit.categories.contains(1.0))
assert(bins(0)(2).lowSplit.categories.contains(0.0))
assert(bins(0)(2).highSplit.categories.length === 3)
assert(bins(0)(2).highSplit.categories.contains(1.0))
assert(bins(0)(2).highSplit.categories.contains(0.0))
assert(bins(0)(2).highSplit.categories.contains(2.0))
assert(bins(0)(3) === null)
assert(bins(1)(0).category === 0.0)
assert(bins(1)(0).lowSplit.categories.length === 0)
assert(bins(1)(0).highSplit.categories.length === 1)
assert(bins(1)(0).highSplit.categories.contains(0.0))
assert(bins(1)(1).category === 1.0)
assert(bins(1)(1).lowSplit.categories.length === 1)
assert(bins(1)(1).lowSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.length === 2)
assert(bins(1)(1).highSplit.categories.contains(0.0))
assert(bins(1)(1).highSplit.categories.contains(1.0))
assert(bins(1)(2).category === 2.0)
assert(bins(1)(2).lowSplit.categories.length === 2)
assert(bins(1)(2).lowSplit.categories.contains(0.0))
assert(bins(1)(2).lowSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.length === 3)
assert(bins(1)(2).highSplit.categories.contains(0.0))
assert(bins(1)(2).highSplit.categories.contains(1.0))
assert(bins(1)(2).highSplit.categories.contains(2.0))
assert(bins(1)(3) === null)
}
test("classification stump with all categorical variables") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Classification,
Gini,
maxDepth = 3,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)
val split = bestSplits(0)._1
assert(split.categories.length === 1)
assert(split.categories.contains(1.0))
assert(split.featureType === Categorical)
assert(split.threshold === Double.MinValue)
val stats = bestSplits(0)._2
assert(stats.gain > 0)
assert(stats.predict > 0.4)
assert(stats.predict < 0.5)
assert(stats.impurity > 0.2)
}
test("regression stump with all categorical variables") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(
Regression,
Variance,
maxDepth = 3,
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)
val split = bestSplits(0)._1
assert(split.categories.length === 1)
assert(split.categories.contains(1.0))
assert(split.featureType === Categorical)
assert(split.threshold === Double.MinValue)
val stats = bestSplits(0)._2
assert(stats.gain > 0)
assert(stats.predict > 0.4)
assert(stats.predict < 0.5)
assert(stats.impurity > 0.2)
}
test("stump with fixed label 0 for Gini") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
}
test("stump with fixed label 1 for Gini") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Gini, 3, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 1)
}
test("stump with fixed label 0 for Entropy") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 0)
}
test("stump with fixed label 1 for Entropy") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
val rdd = sc.parallelize(arr)
val strategy = new Strategy(Classification, Entropy, 3, 100)
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
assert(splits.length === 2)
assert(splits(0).length === 99)
assert(bins.length === 2)
assert(bins(0).length === 100)
assert(splits(0).length === 99)
assert(bins(0).length === 100)
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
assert(bestSplits(0)._1.threshold === 10)
assert(bestSplits(0)._2.gain === 0)
assert(bestSplits(0)._2.leftImpurity === 0)
assert(bestSplits(0)._2.rightImpurity === 0)
assert(bestSplits(0)._2.predict === 1)
}
}
object DecisionTreeSuite {
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i))
arr(i) = lp
}
arr
}
def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i))
arr(i) = lp
}
arr
}
def generateCategoricalDataPoints(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
if (i < 600){
arr(i) = new LabeledPoint(1.0,Array(0.0,1.0))
} else {
arr(i) = new LabeledPoint(0.0,Array(1.0,0.0))
}
}
arr
}
}