MLI-1 Decision Trees
Joint work with @hirakendu, @etrain, @atalwalkar and @harsha2010. Key features: + Supports binary classification and regression + Supports gini, entropy and variance for information gain calculation + Supports both continuous and categorical features The algorithm has gone through several development iterations over the last few months leading to a highly optimized implementation. Optimizations include: 1. Level-wise training to reduce passes over the entire dataset. 2. Bin-wise split calculation to reduce computation overhead. 3. Aggregation over partitions before combining to reduce communication overhead. Author: Manish Amde <manish9ue@gmail.com> Author: manishamde <manish9ue@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #79 from manishamde/tree and squashes the following commits: 1e8c704 [Manish Amde] remove numBins field in the Strategy class 7d54b4f [manishamde] Merge pull request #4 from mengxr/dtree f536ae9 [Xiangrui Meng] another pass on code style e1dd86f [Manish Amde] implementing code style suggestions 62dc723 [Manish Amde] updating javadoc and converting helper methods to package private to allow unit testing 201702f [Manish Amde] making some more methods private f963ef5 [Manish Amde] making methods private c487e6a [manishamde] Merge pull request #1 from mengxr/dtree 24500c5 [Xiangrui Meng] minor style updates 4576b64 [Manish Amde] documentation and for to while loop conversion ff363a7 [Manish Amde] binary search for bins and while loop for categorical feature bins 632818f [Manish Amde] removing threshold for classification predict method 2116360 [Manish Amde] removing dummy bin calculation for categorical variables 6068356 [Manish Amde] ensuring num bins is always greater than max number of categories 62c2562 [Manish Amde] fixing comment indentation ad1fc21 [Manish Amde] incorporated mengxr's code style suggestions d1ef4f6 [Manish Amde] more documentation 794ff4d [Manish Amde] minor improvements to docs and style eb8fcbe [Manish Amde] minor code style updates cd2c2b4 [Manish Amde] fixing code style based on feedback 63e786b [Manish Amde] added multiple train methods for java compatability d3023b3 [Manish Amde] adding more docs for nested methods 84f85d6 [Manish Amde] code documentation 9372779 [Manish Amde] code style: max line lenght <= 100 dd0c0d7 [Manish Amde] minor: some docs 0dd7659 [manishamde] basic doc 5841c28 [Manish Amde] unit tests for categorical features f067d68 [Manish Amde] minor cleanup c0e522b [Manish Amde] updated predict and split threshold logic b09dc98 [Manish Amde] minor refactoring 6b7de78 [Manish Amde] minor refactoring and tests d504eb1 [Manish Amde] more tests for categorical features dbb7ac1 [Manish Amde] categorical feature support 6df35b9 [Manish Amde] regression predict logic 53108ed [Manish Amde] fixing index for highest bin e23c2e5 [Manish Amde] added regression support c8f6d60 [Manish Amde] adding enum for feature type b0e3e76 [Manish Amde] adding enum for feature type 154aa77 [Manish Amde] enums for configurations 733d6dd [Manish Amde] fixed tests 02c595c [Manish Amde] added command line parsing 98ec8d5 [Manish Amde] tree building and prediction logic b0eb866 [Manish Amde] added logic to handle leaf nodes 80e8c66 [Manish Amde] working version of multi-level split calculation 4798aae [Manish Amde] added gain stats class dad0afc [Manish Amde] decison stump functionality working 03f534c [Manish Amde] some more tests 0012a77 [Manish Amde] basic stump working 8bca1e2 [Manish Amde] additional code for creating intermediate RDD 92cedce [Manish Amde] basic building blocks for intermediate RDD calculation. untested. cd53eae [Manish Amde] skeletal framework
This commit is contained in:
parent
45df912736
commit
8b3045ceab
1150
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Normal file
1150
mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
Normal file
File diff suppressed because it is too large
Load diff
17
mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
Normal file
17
mllib/src/main/scala/org/apache/spark/mllib/tree/README.md
Normal file
|
@ -0,0 +1,17 @@
|
|||
This package contains the default implementation of the decision tree algorithm.
|
||||
|
||||
The decision tree algorithm supports:
|
||||
+ Binary classification
|
||||
+ Regression
|
||||
+ Information loss calculation with entropy and gini for classification and variance for regression
|
||||
+ Both continuous and categorical features
|
||||
|
||||
# Tree improvements
|
||||
+ Node model pruning
|
||||
+ Printing to dot files
|
||||
|
||||
# Future Ensemble Extensions
|
||||
|
||||
+ Random forests
|
||||
+ Boosting
|
||||
+ Extremely randomized trees
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.configuration
|
||||
|
||||
/**
|
||||
* Enum to select the algorithm for the decision tree
|
||||
*/
|
||||
object Algo extends Enumeration {
|
||||
type Algo = Value
|
||||
val Classification, Regression = Value
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.configuration
|
||||
|
||||
/**
|
||||
* Enum to describe whether a feature is "continuous" or "categorical"
|
||||
*/
|
||||
object FeatureType extends Enumeration {
|
||||
type FeatureType = Value
|
||||
val Continuous, Categorical = Value
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.configuration
|
||||
|
||||
/**
|
||||
* Enum for selecting the quantile calculation strategy
|
||||
*/
|
||||
object QuantileStrategy extends Enumeration {
|
||||
type QuantileStrategy = Value
|
||||
val Sort, MinMax, ApproxHist = Value
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.configuration
|
||||
|
||||
import org.apache.spark.mllib.tree.impurity.Impurity
|
||||
import org.apache.spark.mllib.tree.configuration.Algo._
|
||||
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
||||
|
||||
/**
|
||||
* Stores all the configuration options for tree construction
|
||||
* @param algo classification or regression
|
||||
* @param impurity criterion used for information gain calculation
|
||||
* @param maxDepth maximum depth of the tree
|
||||
* @param maxBins maximum number of bins used for splitting features
|
||||
* @param quantileCalculationStrategy algorithm for calculating quantiles
|
||||
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
|
||||
* number of discrete values they take. For example, an entry (n ->
|
||||
* k) implies the feature n is categorical with k categories 0,
|
||||
* 1, 2, ... , k-1. It's important to note that features are
|
||||
* zero-indexed.
|
||||
*/
|
||||
class Strategy (
|
||||
val algo: Algo,
|
||||
val impurity: Impurity,
|
||||
val maxDepth: Int,
|
||||
val maxBins: Int = 100,
|
||||
val quantileCalculationStrategy: QuantileStrategy = Sort,
|
||||
val categoricalFeaturesInfo: Map[Int,Int] = Map[Int,Int]()) extends Serializable
|
|
@ -0,0 +1,47 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.impurity
|
||||
|
||||
/**
|
||||
* Class for calculating [[http://en.wikipedia.org/wiki/Binary_entropy_function entropy]] during
|
||||
* binary classification.
|
||||
*/
|
||||
object Entropy extends Impurity {
|
||||
|
||||
def log2(x: Double) = scala.math.log(x) / scala.math.log(2)
|
||||
|
||||
/**
|
||||
* entropy calculation
|
||||
* @param c0 count of instances with label 0
|
||||
* @param c1 count of instances with label 1
|
||||
* @return entropy value
|
||||
*/
|
||||
def calculate(c0: Double, c1: Double): Double = {
|
||||
if (c0 == 0 || c1 == 0) {
|
||||
0
|
||||
} else {
|
||||
val total = c0 + c1
|
||||
val f0 = c0 / total
|
||||
val f1 = c1 / total
|
||||
-(f0 * log2(f0)) - (f1 * log2(f1))
|
||||
}
|
||||
}
|
||||
|
||||
def calculate(count: Double, sum: Double, sumSquares: Double): Double =
|
||||
throw new UnsupportedOperationException("Entropy.calculate")
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.impurity
|
||||
|
||||
/**
|
||||
* Class for calculating the
|
||||
* [[http://en.wikipedia.org/wiki/Decision_tree_learning#Gini_impurity Gini impurity]]
|
||||
* during binary classification.
|
||||
*/
|
||||
object Gini extends Impurity {
|
||||
|
||||
/**
|
||||
* Gini coefficient calculation
|
||||
* @param c0 count of instances with label 0
|
||||
* @param c1 count of instances with label 1
|
||||
* @return Gini coefficient value
|
||||
*/
|
||||
override def calculate(c0: Double, c1: Double): Double = {
|
||||
if (c0 == 0 || c1 == 0) {
|
||||
0
|
||||
} else {
|
||||
val total = c0 + c1
|
||||
val f0 = c0 / total
|
||||
val f1 = c1 / total
|
||||
1 - f0 * f0 - f1 * f1
|
||||
}
|
||||
}
|
||||
|
||||
def calculate(count: Double, sum: Double, sumSquares: Double): Double =
|
||||
throw new UnsupportedOperationException("Gini.calculate")
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.impurity
|
||||
|
||||
/**
|
||||
* Trait for calculating information gain.
|
||||
*/
|
||||
trait Impurity extends Serializable {
|
||||
|
||||
/**
|
||||
* information calculation for binary classification
|
||||
* @param c0 count of instances with label 0
|
||||
* @param c1 count of instances with label 1
|
||||
* @return information value
|
||||
*/
|
||||
def calculate(c0 : Double, c1 : Double): Double
|
||||
|
||||
/**
|
||||
* information calculation for regression
|
||||
* @param count number of instances
|
||||
* @param sum sum of labels
|
||||
* @param sumSquares summation of squares of the labels
|
||||
* @return information value
|
||||
*/
|
||||
def calculate(count: Double, sum: Double, sumSquares: Double): Double
|
||||
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.impurity
|
||||
|
||||
/**
|
||||
* Class for calculating variance during regression
|
||||
*/
|
||||
object Variance extends Impurity {
|
||||
override def calculate(c0: Double, c1: Double): Double =
|
||||
throw new UnsupportedOperationException("Variance.calculate")
|
||||
|
||||
/**
|
||||
* variance calculation
|
||||
* @param count number of instances
|
||||
* @param sum sum of labels
|
||||
* @param sumSquares summation of squares of the labels
|
||||
*/
|
||||
override def calculate(count: Double, sum: Double, sumSquares: Double): Double = {
|
||||
val squaredLoss = sumSquares - (sum * sum) / count
|
||||
squaredLoss / count
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.model
|
||||
|
||||
import org.apache.spark.mllib.tree.configuration.FeatureType._
|
||||
|
||||
/**
|
||||
* Used for "binning" the features bins for faster best split calculation. For a continuous
|
||||
* feature, a bin is determined by a low and a high "split". For a categorical feature,
|
||||
* the a bin is determined using a single label value (category).
|
||||
* @param lowSplit signifying the lower threshold for the continuous feature to be
|
||||
* accepted in the bin
|
||||
* @param highSplit signifying the upper threshold for the continuous feature to be
|
||||
* accepted in the bin
|
||||
* @param featureType type of feature -- categorical or continuous
|
||||
* @param category categorical label value accepted in the bin
|
||||
*/
|
||||
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.model
|
||||
|
||||
import org.apache.spark.mllib.tree.configuration.Algo._
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
/**
|
||||
* Model to store the decision tree parameters
|
||||
* @param topNode root node
|
||||
* @param algo algorithm type -- classification or regression
|
||||
*/
|
||||
class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
|
||||
|
||||
/**
|
||||
* Predict values for a single data point using the model trained.
|
||||
*
|
||||
* @param features array representing a single data point
|
||||
* @return Double prediction from the trained model
|
||||
*/
|
||||
def predict(features: Array[Double]): Double = {
|
||||
topNode.predictIfLeaf(features)
|
||||
}
|
||||
|
||||
/**
|
||||
* Predict values for the given data set using the model trained.
|
||||
*
|
||||
* @param features RDD representing data points to be predicted
|
||||
* @return RDD[Int] where each entry contains the corresponding prediction
|
||||
*/
|
||||
def predict(features: RDD[Array[Double]]): RDD[Double] = {
|
||||
features.map(x => predict(x))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,28 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.model
|
||||
|
||||
/**
|
||||
* Filter specifying a split and type of comparison to be applied on features
|
||||
* @param split split specifying the feature index, type and threshold
|
||||
* @param comparison integer specifying <,=,>
|
||||
*/
|
||||
case class Filter(split: Split, comparison: Int) {
|
||||
// Comparison -1,0,1 signifies <.=,>
|
||||
override def toString = " split = " + split + "comparison = " + comparison
|
||||
}
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.model
|
||||
|
||||
/**
|
||||
* Information gain statistics for each split
|
||||
* @param gain information gain value
|
||||
* @param impurity current node impurity
|
||||
* @param leftImpurity left node impurity
|
||||
* @param rightImpurity right node impurity
|
||||
* @param predict predicted value
|
||||
*/
|
||||
class InformationGainStats(
|
||||
val gain: Double,
|
||||
val impurity: Double,
|
||||
val leftImpurity: Double,
|
||||
val rightImpurity: Double,
|
||||
val predict: Double) extends Serializable {
|
||||
|
||||
override def toString = {
|
||||
"gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
|
||||
.format(gain, impurity, leftImpurity, rightImpurity, predict)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.model
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.mllib.tree.configuration.FeatureType._
|
||||
|
||||
/**
|
||||
* Node in a decision tree
|
||||
* @param id integer node id
|
||||
* @param predict predicted value at the node
|
||||
* @param isLeaf whether the leaf is a node
|
||||
* @param split split to calculate left and right nodes
|
||||
* @param leftNode left child
|
||||
* @param rightNode right child
|
||||
* @param stats information gain stats
|
||||
*/
|
||||
class Node (
|
||||
val id: Int,
|
||||
val predict: Double,
|
||||
val isLeaf: Boolean,
|
||||
val split: Option[Split],
|
||||
var leftNode: Option[Node],
|
||||
var rightNode: Option[Node],
|
||||
val stats: Option[InformationGainStats]) extends Serializable with Logging {
|
||||
|
||||
override def toString = "id = " + id + ", isLeaf = " + isLeaf + ", predict = " + predict + ", " +
|
||||
"split = " + split + ", stats = " + stats
|
||||
|
||||
/**
|
||||
* build the left node and right nodes if not leaf
|
||||
* @param nodes array of nodes
|
||||
*/
|
||||
def build(nodes: Array[Node]): Unit = {
|
||||
|
||||
logDebug("building node " + id + " at level " +
|
||||
(scala.math.log(id + 1)/scala.math.log(2)).toInt )
|
||||
logDebug("id = " + id + ", split = " + split)
|
||||
logDebug("stats = " + stats)
|
||||
logDebug("predict = " + predict)
|
||||
if (!isLeaf) {
|
||||
val leftNodeIndex = id*2 + 1
|
||||
val rightNodeIndex = id*2 + 2
|
||||
leftNode = Some(nodes(leftNodeIndex))
|
||||
rightNode = Some(nodes(rightNodeIndex))
|
||||
leftNode.get.build(nodes)
|
||||
rightNode.get.build(nodes)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* predict value if node is not leaf
|
||||
* @param feature feature value
|
||||
* @return predicted value
|
||||
*/
|
||||
def predictIfLeaf(feature: Array[Double]) : Double = {
|
||||
if (isLeaf) {
|
||||
predict
|
||||
} else{
|
||||
if (split.get.featureType == Continuous) {
|
||||
if (feature(split.get.feature) <= split.get.threshold) {
|
||||
leftNode.get.predictIfLeaf(feature)
|
||||
} else {
|
||||
rightNode.get.predictIfLeaf(feature)
|
||||
}
|
||||
} else {
|
||||
if (split.get.categories.contains(feature(split.get.feature))) {
|
||||
leftNode.get.predictIfLeaf(feature)
|
||||
} else {
|
||||
rightNode.get.predictIfLeaf(feature)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree.model
|
||||
|
||||
import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
|
||||
|
||||
/**
|
||||
* Split applied to a feature
|
||||
* @param feature feature index
|
||||
* @param threshold threshold for continuous feature
|
||||
* @param featureType type of feature -- categorical or continuous
|
||||
* @param categories accepted values for categorical variables
|
||||
*/
|
||||
case class Split(
|
||||
feature: Int,
|
||||
threshold: Double,
|
||||
featureType: FeatureType,
|
||||
categories: List[Double]){
|
||||
|
||||
override def toString =
|
||||
"Feature = " + feature + ", threshold = " + threshold + ", featureType = " + featureType +
|
||||
", categories = " + categories
|
||||
}
|
||||
|
||||
/**
|
||||
* Split with minimum threshold for continuous features. Helps with the smallest bin creation.
|
||||
* @param feature feature index
|
||||
* @param featureType type of feature -- categorical or continuous
|
||||
*/
|
||||
class DummyLowSplit(feature: Int, featureType: FeatureType)
|
||||
extends Split(feature, Double.MinValue, featureType, List())
|
||||
|
||||
/**
|
||||
* Split with maximum threshold for continuous features. Helps with the highest bin creation.
|
||||
* @param feature feature index
|
||||
* @param featureType type of feature -- categorical or continuous
|
||||
*/
|
||||
class DummyHighSplit(feature: Int, featureType: FeatureType)
|
||||
extends Split(feature, Double.MaxValue, featureType, List())
|
||||
|
||||
/**
|
||||
* Split with no acceptable feature values for categorical features. Helps with the first bin
|
||||
* creation.
|
||||
* @param feature feature index
|
||||
* @param featureType type of feature -- categorical or continuous
|
||||
*/
|
||||
class DummyCategoricalSplit(feature: Int, featureType: FeatureType)
|
||||
extends Split(feature, Double.MaxValue, featureType, List())
|
||||
|
|
@ -0,0 +1,425 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.mllib.tree
|
||||
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import org.apache.spark.SparkContext
|
||||
import org.apache.spark.mllib.regression.LabeledPoint
|
||||
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
|
||||
import org.apache.spark.mllib.tree.model.Filter
|
||||
import org.apache.spark.mllib.tree.configuration.Strategy
|
||||
import org.apache.spark.mllib.tree.configuration.Algo._
|
||||
import org.apache.spark.mllib.tree.configuration.FeatureType._
|
||||
|
||||
class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
|
||||
|
||||
@transient private var sc: SparkContext = _
|
||||
|
||||
override def beforeAll() {
|
||||
sc = new SparkContext("local", "test")
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
sc.stop()
|
||||
System.clearProperty("spark.driver.port")
|
||||
}
|
||||
|
||||
test("split and bin calculation") {
|
||||
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(Classification, Gini, 3, 100)
|
||||
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
|
||||
assert(splits.length === 2)
|
||||
assert(bins.length === 2)
|
||||
assert(splits(0).length === 99)
|
||||
assert(bins(0).length === 100)
|
||||
}
|
||||
|
||||
test("split and bin calculation for categorical variables") {
|
||||
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(
|
||||
Classification,
|
||||
Gini,
|
||||
maxDepth = 3,
|
||||
maxBins = 100,
|
||||
categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
|
||||
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
|
||||
assert(splits.length === 2)
|
||||
assert(bins.length === 2)
|
||||
assert(splits(0).length === 99)
|
||||
assert(bins(0).length === 100)
|
||||
|
||||
// Check splits.
|
||||
|
||||
assert(splits(0)(0).feature === 0)
|
||||
assert(splits(0)(0).threshold === Double.MinValue)
|
||||
assert(splits(0)(0).featureType === Categorical)
|
||||
assert(splits(0)(0).categories.length === 1)
|
||||
assert(splits(0)(0).categories.contains(1.0))
|
||||
|
||||
assert(splits(0)(1).feature === 0)
|
||||
assert(splits(0)(1).threshold === Double.MinValue)
|
||||
assert(splits(0)(1).featureType === Categorical)
|
||||
assert(splits(0)(1).categories.length === 2)
|
||||
assert(splits(0)(1).categories.contains(1.0))
|
||||
assert(splits(0)(1).categories.contains(0.0))
|
||||
|
||||
assert(splits(0)(2) === null)
|
||||
|
||||
assert(splits(1)(0).feature === 1)
|
||||
assert(splits(1)(0).threshold === Double.MinValue)
|
||||
assert(splits(1)(0).featureType === Categorical)
|
||||
assert(splits(1)(0).categories.length === 1)
|
||||
assert(splits(1)(0).categories.contains(0.0))
|
||||
|
||||
assert(splits(1)(1).feature === 1)
|
||||
assert(splits(1)(1).threshold === Double.MinValue)
|
||||
assert(splits(1)(1).featureType === Categorical)
|
||||
assert(splits(1)(1).categories.length === 2)
|
||||
assert(splits(1)(1).categories.contains(1.0))
|
||||
assert(splits(1)(1).categories.contains(0.0))
|
||||
|
||||
assert(splits(1)(2) === null)
|
||||
|
||||
// Check bins.
|
||||
|
||||
assert(bins(0)(0).category === 1.0)
|
||||
assert(bins(0)(0).lowSplit.categories.length === 0)
|
||||
assert(bins(0)(0).highSplit.categories.length === 1)
|
||||
assert(bins(0)(0).highSplit.categories.contains(1.0))
|
||||
|
||||
assert(bins(0)(1).category === 0.0)
|
||||
assert(bins(0)(1).lowSplit.categories.length === 1)
|
||||
assert(bins(0)(1).lowSplit.categories.contains(1.0))
|
||||
assert(bins(0)(1).highSplit.categories.length === 2)
|
||||
assert(bins(0)(1).highSplit.categories.contains(1.0))
|
||||
assert(bins(0)(1).highSplit.categories.contains(0.0))
|
||||
|
||||
assert(bins(0)(2) === null)
|
||||
|
||||
assert(bins(1)(0).category === 0.0)
|
||||
assert(bins(1)(0).lowSplit.categories.length === 0)
|
||||
assert(bins(1)(0).highSplit.categories.length === 1)
|
||||
assert(bins(1)(0).highSplit.categories.contains(0.0))
|
||||
|
||||
assert(bins(1)(1).category === 1.0)
|
||||
assert(bins(1)(1).lowSplit.categories.length === 1)
|
||||
assert(bins(1)(1).lowSplit.categories.contains(0.0))
|
||||
assert(bins(1)(1).highSplit.categories.length === 2)
|
||||
assert(bins(1)(1).highSplit.categories.contains(0.0))
|
||||
assert(bins(1)(1).highSplit.categories.contains(1.0))
|
||||
|
||||
assert(bins(1)(2) === null)
|
||||
}
|
||||
|
||||
test("split and bin calculations for categorical variables with no sample for one category") {
|
||||
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(
|
||||
Classification,
|
||||
Gini,
|
||||
maxDepth = 3,
|
||||
maxBins = 100,
|
||||
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
|
||||
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
|
||||
|
||||
// Check splits.
|
||||
|
||||
assert(splits(0)(0).feature === 0)
|
||||
assert(splits(0)(0).threshold === Double.MinValue)
|
||||
assert(splits(0)(0).featureType === Categorical)
|
||||
assert(splits(0)(0).categories.length === 1)
|
||||
assert(splits(0)(0).categories.contains(1.0))
|
||||
|
||||
assert(splits(0)(1).feature === 0)
|
||||
assert(splits(0)(1).threshold === Double.MinValue)
|
||||
assert(splits(0)(1).featureType === Categorical)
|
||||
assert(splits(0)(1).categories.length === 2)
|
||||
assert(splits(0)(1).categories.contains(1.0))
|
||||
assert(splits(0)(1).categories.contains(0.0))
|
||||
|
||||
assert(splits(0)(2).feature === 0)
|
||||
assert(splits(0)(2).threshold === Double.MinValue)
|
||||
assert(splits(0)(2).featureType === Categorical)
|
||||
assert(splits(0)(2).categories.length === 3)
|
||||
assert(splits(0)(2).categories.contains(1.0))
|
||||
assert(splits(0)(2).categories.contains(0.0))
|
||||
assert(splits(0)(2).categories.contains(2.0))
|
||||
|
||||
assert(splits(0)(3) === null)
|
||||
|
||||
assert(splits(1)(0).feature === 1)
|
||||
assert(splits(1)(0).threshold === Double.MinValue)
|
||||
assert(splits(1)(0).featureType === Categorical)
|
||||
assert(splits(1)(0).categories.length === 1)
|
||||
assert(splits(1)(0).categories.contains(0.0))
|
||||
|
||||
assert(splits(1)(1).feature === 1)
|
||||
assert(splits(1)(1).threshold === Double.MinValue)
|
||||
assert(splits(1)(1).featureType === Categorical)
|
||||
assert(splits(1)(1).categories.length === 2)
|
||||
assert(splits(1)(1).categories.contains(1.0))
|
||||
assert(splits(1)(1).categories.contains(0.0))
|
||||
|
||||
assert(splits(1)(2).feature === 1)
|
||||
assert(splits(1)(2).threshold === Double.MinValue)
|
||||
assert(splits(1)(2).featureType === Categorical)
|
||||
assert(splits(1)(2).categories.length === 3)
|
||||
assert(splits(1)(2).categories.contains(1.0))
|
||||
assert(splits(1)(2).categories.contains(0.0))
|
||||
assert(splits(1)(2).categories.contains(2.0))
|
||||
|
||||
assert(splits(1)(3) === null)
|
||||
|
||||
// Check bins.
|
||||
|
||||
assert(bins(0)(0).category === 1.0)
|
||||
assert(bins(0)(0).lowSplit.categories.length === 0)
|
||||
assert(bins(0)(0).highSplit.categories.length === 1)
|
||||
assert(bins(0)(0).highSplit.categories.contains(1.0))
|
||||
|
||||
assert(bins(0)(1).category === 0.0)
|
||||
assert(bins(0)(1).lowSplit.categories.length === 1)
|
||||
assert(bins(0)(1).lowSplit.categories.contains(1.0))
|
||||
assert(bins(0)(1).highSplit.categories.length === 2)
|
||||
assert(bins(0)(1).highSplit.categories.contains(1.0))
|
||||
assert(bins(0)(1).highSplit.categories.contains(0.0))
|
||||
|
||||
assert(bins(0)(2).category === 2.0)
|
||||
assert(bins(0)(2).lowSplit.categories.length === 2)
|
||||
assert(bins(0)(2).lowSplit.categories.contains(1.0))
|
||||
assert(bins(0)(2).lowSplit.categories.contains(0.0))
|
||||
assert(bins(0)(2).highSplit.categories.length === 3)
|
||||
assert(bins(0)(2).highSplit.categories.contains(1.0))
|
||||
assert(bins(0)(2).highSplit.categories.contains(0.0))
|
||||
assert(bins(0)(2).highSplit.categories.contains(2.0))
|
||||
|
||||
assert(bins(0)(3) === null)
|
||||
|
||||
assert(bins(1)(0).category === 0.0)
|
||||
assert(bins(1)(0).lowSplit.categories.length === 0)
|
||||
assert(bins(1)(0).highSplit.categories.length === 1)
|
||||
assert(bins(1)(0).highSplit.categories.contains(0.0))
|
||||
|
||||
assert(bins(1)(1).category === 1.0)
|
||||
assert(bins(1)(1).lowSplit.categories.length === 1)
|
||||
assert(bins(1)(1).lowSplit.categories.contains(0.0))
|
||||
assert(bins(1)(1).highSplit.categories.length === 2)
|
||||
assert(bins(1)(1).highSplit.categories.contains(0.0))
|
||||
assert(bins(1)(1).highSplit.categories.contains(1.0))
|
||||
|
||||
assert(bins(1)(2).category === 2.0)
|
||||
assert(bins(1)(2).lowSplit.categories.length === 2)
|
||||
assert(bins(1)(2).lowSplit.categories.contains(0.0))
|
||||
assert(bins(1)(2).lowSplit.categories.contains(1.0))
|
||||
assert(bins(1)(2).highSplit.categories.length === 3)
|
||||
assert(bins(1)(2).highSplit.categories.contains(0.0))
|
||||
assert(bins(1)(2).highSplit.categories.contains(1.0))
|
||||
assert(bins(1)(2).highSplit.categories.contains(2.0))
|
||||
|
||||
assert(bins(1)(3) === null)
|
||||
}
|
||||
|
||||
test("classification stump with all categorical variables") {
|
||||
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(
|
||||
Classification,
|
||||
Gini,
|
||||
maxDepth = 3,
|
||||
maxBins = 100,
|
||||
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
|
||||
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
|
||||
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
|
||||
Array[List[Filter]](), splits, bins)
|
||||
|
||||
val split = bestSplits(0)._1
|
||||
assert(split.categories.length === 1)
|
||||
assert(split.categories.contains(1.0))
|
||||
assert(split.featureType === Categorical)
|
||||
assert(split.threshold === Double.MinValue)
|
||||
|
||||
val stats = bestSplits(0)._2
|
||||
assert(stats.gain > 0)
|
||||
assert(stats.predict > 0.4)
|
||||
assert(stats.predict < 0.5)
|
||||
assert(stats.impurity > 0.2)
|
||||
}
|
||||
|
||||
test("regression stump with all categorical variables") {
|
||||
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(
|
||||
Regression,
|
||||
Variance,
|
||||
maxDepth = 3,
|
||||
maxBins = 100,
|
||||
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
|
||||
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
|
||||
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
|
||||
Array[List[Filter]](), splits, bins)
|
||||
|
||||
val split = bestSplits(0)._1
|
||||
assert(split.categories.length === 1)
|
||||
assert(split.categories.contains(1.0))
|
||||
assert(split.featureType === Categorical)
|
||||
assert(split.threshold === Double.MinValue)
|
||||
|
||||
val stats = bestSplits(0)._2
|
||||
assert(stats.gain > 0)
|
||||
assert(stats.predict > 0.4)
|
||||
assert(stats.predict < 0.5)
|
||||
assert(stats.impurity > 0.2)
|
||||
}
|
||||
|
||||
test("stump with fixed label 0 for Gini") {
|
||||
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(Classification, Gini, 3, 100)
|
||||
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
|
||||
assert(splits.length === 2)
|
||||
assert(splits(0).length === 99)
|
||||
assert(bins.length === 2)
|
||||
assert(bins(0).length === 100)
|
||||
assert(splits(0).length === 99)
|
||||
assert(bins(0).length === 100)
|
||||
|
||||
val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
|
||||
Array[List[Filter]](), splits, bins)
|
||||
assert(bestSplits.length === 1)
|
||||
assert(bestSplits(0)._1.feature === 0)
|
||||
assert(bestSplits(0)._1.threshold === 10)
|
||||
assert(bestSplits(0)._2.gain === 0)
|
||||
assert(bestSplits(0)._2.leftImpurity === 0)
|
||||
assert(bestSplits(0)._2.rightImpurity === 0)
|
||||
}
|
||||
|
||||
test("stump with fixed label 1 for Gini") {
|
||||
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(Classification, Gini, 3, 100)
|
||||
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
|
||||
assert(splits.length === 2)
|
||||
assert(splits(0).length === 99)
|
||||
assert(bins.length === 2)
|
||||
assert(bins(0).length === 100)
|
||||
assert(splits(0).length === 99)
|
||||
assert(bins(0).length === 100)
|
||||
|
||||
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
|
||||
Array[List[Filter]](), splits, bins)
|
||||
assert(bestSplits.length === 1)
|
||||
assert(bestSplits(0)._1.feature === 0)
|
||||
assert(bestSplits(0)._1.threshold === 10)
|
||||
assert(bestSplits(0)._2.gain === 0)
|
||||
assert(bestSplits(0)._2.leftImpurity === 0)
|
||||
assert(bestSplits(0)._2.rightImpurity === 0)
|
||||
assert(bestSplits(0)._2.predict === 1)
|
||||
}
|
||||
|
||||
test("stump with fixed label 0 for Entropy") {
|
||||
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(Classification, Entropy, 3, 100)
|
||||
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
|
||||
assert(splits.length === 2)
|
||||
assert(splits(0).length === 99)
|
||||
assert(bins.length === 2)
|
||||
assert(bins(0).length === 100)
|
||||
assert(splits(0).length === 99)
|
||||
assert(bins(0).length === 100)
|
||||
|
||||
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
|
||||
Array[List[Filter]](), splits, bins)
|
||||
assert(bestSplits.length === 1)
|
||||
assert(bestSplits(0)._1.feature === 0)
|
||||
assert(bestSplits(0)._1.threshold === 10)
|
||||
assert(bestSplits(0)._2.gain === 0)
|
||||
assert(bestSplits(0)._2.leftImpurity === 0)
|
||||
assert(bestSplits(0)._2.rightImpurity === 0)
|
||||
assert(bestSplits(0)._2.predict === 0)
|
||||
}
|
||||
|
||||
test("stump with fixed label 1 for Entropy") {
|
||||
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
|
||||
assert(arr.length === 1000)
|
||||
val rdd = sc.parallelize(arr)
|
||||
val strategy = new Strategy(Classification, Entropy, 3, 100)
|
||||
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
|
||||
assert(splits.length === 2)
|
||||
assert(splits(0).length === 99)
|
||||
assert(bins.length === 2)
|
||||
assert(bins(0).length === 100)
|
||||
assert(splits(0).length === 99)
|
||||
assert(bins(0).length === 100)
|
||||
|
||||
val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
|
||||
Array[List[Filter]](), splits, bins)
|
||||
assert(bestSplits.length === 1)
|
||||
assert(bestSplits(0)._1.feature === 0)
|
||||
assert(bestSplits(0)._1.threshold === 10)
|
||||
assert(bestSplits(0)._2.gain === 0)
|
||||
assert(bestSplits(0)._2.leftImpurity === 0)
|
||||
assert(bestSplits(0)._2.rightImpurity === 0)
|
||||
assert(bestSplits(0)._2.predict === 1)
|
||||
}
|
||||
}
|
||||
|
||||
object DecisionTreeSuite {
|
||||
|
||||
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
|
||||
val arr = new Array[LabeledPoint](1000)
|
||||
for (i <- 0 until 1000){
|
||||
val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i))
|
||||
arr(i) = lp
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
|
||||
val arr = new Array[LabeledPoint](1000)
|
||||
for (i <- 0 until 1000){
|
||||
val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i))
|
||||
arr(i) = lp
|
||||
}
|
||||
arr
|
||||
}
|
||||
|
||||
def generateCategoricalDataPoints(): Array[LabeledPoint] = {
|
||||
val arr = new Array[LabeledPoint](1000)
|
||||
for (i <- 0 until 1000){
|
||||
if (i < 600){
|
||||
arr(i) = new LabeledPoint(1.0,Array(0.0,1.0))
|
||||
} else {
|
||||
arr(i) = new LabeledPoint(0.0,Array(1.0,0.0))
|
||||
}
|
||||
}
|
||||
arr
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue