[SPARK-14308][ML][MLLIB] Remove unused mllib tree classes and move private classes to ML

## What changes were proposed in this pull request?

Decision tree helper classes will be migrated to ML. This patch moves those internal classes that are not part of the public API and removes ones that are no longer used, after [SPARK-12183](https://github.com/apache/spark/pull/11855). No functional changes are made.

Details:
* Bin.scala is removed as the ML implementation does not require bins
* mllib NodeIdCache is removed. It was only used by the mllib implementation previously, which no longer exists
* mllib TreePoint is removed. It was only used by the mllib implementation previously, which no longer exists
* BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, BaggedPointSuite and TimeTracker are all moved to ML.

## How was this patch tested?

No functional changes are made. Existing unit tests ensure behavior is unchanged.

Author: sethah <seth.hendrickson16@gmail.com>

Closes #12097 from sethah/cleanup_mllib_tree.
This commit is contained in:
sethah 2016-04-01 21:23:35 -07:00 committed by Joseph K. Bradley
parent 36e8fb8005
commit 4fc35e6f5c
18 changed files with 15 additions and 409 deletions

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impl
package org.apache.spark.ml.tree.impl
import org.apache.commons.math3.distribution.PoissonDistribution

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impl
package org.apache.spark.ml.tree.impl
import org.apache.spark.mllib.tree.impurity._
@ -86,6 +86,7 @@ private[spark] class DTStatsAggregator(
/**
* Get an [[ImpurityCalculator]] for a given (node, feature, bin).
*
* @param featureOffset This is a pre-computed (node, feature) offset
* from [[getFeatureOffset]].
*/
@ -118,6 +119,7 @@ private[spark] class DTStatsAggregator(
/**
* Faster version of [[update]].
* Update the stats for a given (feature, bin), using the given label.
*
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
*/
@ -138,6 +140,7 @@ private[spark] class DTStatsAggregator(
/**
* For a given feature, merge the stats for two bins.
*
* @param featureOffset This is a pre-computed feature offset
* from [[getFeatureOffset]].
* @param binIndex The other bin is merged into this bin.

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impl
package org.apache.spark.ml.tree.impl
import scala.collection.mutable

View file

@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD

View file

@ -26,7 +26,6 @@ import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
import org.apache.spark.ml.tree.{LearningNode, Split}
import org.apache.spark.mllib.tree.impl.BaggedPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel

View file

@ -28,8 +28,6 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator,
TimeTracker}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
@ -330,7 +328,7 @@ private[spark] object RandomForest extends Logging {
/**
* Given a group of nodes, this finds the best split for each node.
*
* @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
* @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]]
* @param metadata Learning and dataset metadata
* @param topNodes Root node for each tree. Used for matching instances with nodes.
* @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impl
package org.apache.spark.ml.tree.impl
import scala.collection.mutable.{HashMap => MutableHashMap}

View file

@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl
import org.apache.spark.ml.tree.{ContinuousSplit, Split}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.rdd.RDD

View file

@ -20,11 +20,11 @@ package org.apache.spark.mllib.tree
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
import org.apache.spark.ml.tree.impl.TimeTracker
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.impl.TimeTracker
import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel}
import org.apache.spark.rdd.RDD
@ -165,6 +165,7 @@ object GradientBoostedTrees extends Logging {
/**
* Internal method for performing regression using trees as base learners.
*
* @param input Training dataset.
* @param validationInput Validation dataset, ignored if validate is set to false.
* @param boostingStrategy Boosting parameters.

View file

@ -1,195 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impl
import scala.collection.mutable
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
/**
* :: DeveloperApi ::
* This is used by the node id cache to find the child id that a data point would belong to.
* @param split Split information.
* @param nodeIndex The current node index of a data point that this will update.
*/
@DeveloperApi
private[tree] case class NodeIndexUpdater(
split: Split,
nodeIndex: Int) {
/**
* Determine a child node index based on the feature value and the split.
* @param binnedFeatures Binned feature values.
* @param bins Bin information to convert the bin indices to approximate feature values.
* @return Child node index to update to.
*/
def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
if (split.featureType == Continuous) {
val featureIndex = split.feature
val binIndex = binnedFeatures(featureIndex)
val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
if (featureValueUpperBound <= split.threshold) {
Node.leftChildIndex(nodeIndex)
} else {
Node.rightChildIndex(nodeIndex)
}
} else {
if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
Node.leftChildIndex(nodeIndex)
} else {
Node.rightChildIndex(nodeIndex)
}
}
}
}
/**
* :: DeveloperApi ::
* A given TreePoint would belong to a particular node per tree.
* Each row in the nodeIdsForInstances RDD is an array over trees of the node index
* in each tree. Initially, values should all be 1 for root node.
* The nodeIdsForInstances RDD needs to be updated at each iteration.
* @param nodeIdsForInstances The initial values in the cache
* (should be an Array of all 1's (meaning the root nodes)).
* @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.).
*/
@DeveloperApi
private[spark] class NodeIdCache(
var nodeIdsForInstances: RDD[Array[Int]],
val checkpointInterval: Int) {
// Keep a reference to a previous node Ids for instances.
// Because we will keep on re-persisting updated node Ids,
// we want to unpersist the previous RDD.
private var prevNodeIdsForInstances: RDD[Array[Int]] = null
// To keep track of the past checkpointed RDDs.
private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
private var rddUpdateCount = 0
/**
* Update the node index values in the cache.
* This updates the RDD and its lineage.
* TODO: Passing bin information to executors seems unnecessary and costly.
* @param data The RDD of training rows.
* @param nodeIdUpdaters A map of node index updaters.
* The key is the indices of nodes that we want to update.
* @param bins Bin information needed to find child node indices.
*/
def updateNodeIndices(
data: RDD[BaggedPoint[TreePoint]],
nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
bins: Array[Array[Bin]]): Unit = {
if (prevNodeIdsForInstances != null) {
// Unpersist the previous one if one exists.
prevNodeIdsForInstances.unpersist()
}
prevNodeIdsForInstances = nodeIdsForInstances
nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
case (point, node) => {
var treeId = 0
while (treeId < nodeIdUpdaters.length) {
val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null)
if (nodeIdUpdater != null) {
val newNodeIndex = nodeIdUpdater.updateNodeIndex(
binnedFeatures = point.datum.binnedFeatures,
bins = bins)
node(treeId) = newNodeIndex
}
treeId += 1
}
node
}
}
// Keep on persisting new ones.
nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
rddUpdateCount += 1
// Handle checkpointing if the directory is not None.
if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
(rddUpdateCount % checkpointInterval) == 0) {
// Let's see if we can delete previous checkpoints.
var canDelete = true
while (checkpointQueue.size > 1 && canDelete) {
// We can delete the oldest checkpoint iff
// the next checkpoint actually exists in the file system.
if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
val old = checkpointQueue.dequeue()
// Since the old checkpoint is not deleted by Spark,
// we'll manually delete it here.
val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
fs.delete(new Path(old.getCheckpointFile.get), true)
} else {
canDelete = false
}
}
nodeIdsForInstances.checkpoint()
checkpointQueue.enqueue(nodeIdsForInstances)
}
}
/**
* Call this after training is finished to delete any remaining checkpoints.
*/
def deleteAllCheckpoints(): Unit = {
while (checkpointQueue.nonEmpty) {
val old = checkpointQueue.dequeue()
for (checkpointFile <- old.getCheckpointFile) {
val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
fs.delete(new Path(checkpointFile), true)
}
}
if (prevNodeIdsForInstances != null) {
// Unpersist the previous one if one exists.
prevNodeIdsForInstances.unpersist()
}
}
}
private[spark] object NodeIdCache {
/**
* Initialize the node Id cache with initial node Id values.
* @param data The RDD of training rows.
* @param numTrees The number of trees that we want to create cache for.
* @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.).
* @param initVal The initial values in the cache.
* @return A node Id cache containing an RDD of initial root node Indices.
*/
def init(
data: RDD[BaggedPoint[TreePoint]],
numTrees: Int,
checkpointInterval: Int,
initVal: Int = 1): NodeIdCache = {
new NodeIdCache(
data.map(_ => Array.fill[Int](numTrees)(initVal)),
checkpointInterval)
}
}

View file

@ -1,150 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impl
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.model.Bin
import org.apache.spark.rdd.RDD
/**
* Internal representation of LabeledPoint for DecisionTree.
* This bins feature values based on a subsampled of data as follows:
* (a) Continuous features are binned into ranges.
* (b) Unordered categorical features are binned based on subsets of feature values.
* "Unordered categorical features" are categorical features with low arity used in
* multiclass classification.
* (c) Ordered categorical features are binned based on feature values.
* "Ordered categorical features" are categorical features with high arity,
* or any categorical feature used in regression or binary classification.
*
* @param label Label from LabeledPoint
* @param binnedFeatures Binned feature values.
* Same length as LabeledPoint.features, but values are bin indices.
*/
private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
extends Serializable {
}
private[spark] object TreePoint {
/**
* Convert an input dataset into its TreePoint representation,
* binning feature values in preparation for DecisionTree training.
* @param input Input dataset.
* @param bins Bins for features, of size (numFeatures, numBins).
* @param metadata Learning and dataset metadata
* @return TreePoint dataset representation
*/
def convertToTreeRDD(
input: RDD[LabeledPoint],
bins: Array[Array[Bin]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
// Construct arrays for featureArity for efficiency in the inner loop.
val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
var featureIndex = 0
while (featureIndex < metadata.numFeatures) {
featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
featureIndex += 1
}
input.map { x =>
TreePoint.labeledPointToTreePoint(x, bins, featureArity)
}
}
/**
* Convert one LabeledPoint into its TreePoint representation.
* @param bins Bins for features, of size (numFeatures, numBins).
* @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
* for categorical features.
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
bins: Array[Array[Bin]],
featureArity: Array[Int]): TreePoint = {
val numFeatures = labeledPoint.features.size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
bins)
featureIndex += 1
}
new TreePoint(labeledPoint.label, arr)
}
/**
* Find bin for one (labeledPoint, feature).
*
* @param featureArity 0 for continuous features; number of categories for categorical features.
* @param bins Bins for features, of size (numFeatures, numBins).
*/
private def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
featureArity: Int,
bins: Array[Array[Bin]]): Int = {
/**
* Binary search helper method for continuous feature.
*/
def binarySearchForBins(): Int = {
val binForFeatures = bins(featureIndex)
val feature = labeledPoint.features(featureIndex)
var left = 0
var right = binForFeatures.length - 1
while (left <= right) {
val mid = left + (right - left) / 2
val bin = binForFeatures(mid)
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
if ((lowThreshold < feature) && (highThreshold >= feature)) {
return mid
} else if (lowThreshold >= feature) {
right = mid - 1
} else {
left = mid + 1
}
}
-1
}
if (featureArity == 0) {
// Perform binary search for finding bin for continuous features.
val binIndex = binarySearchForBins()
if (binIndex == -1) {
throw new RuntimeException("No bin was found for continuous feature." +
" This error can occur when given invalid data values (such as NaN)." +
s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
}
binIndex
} else {
// Categorical feature bins are indexed by feature values.
val featureValue = labeledPoint.features(featureIndex)
if (featureValue < 0 || featureValue >= featureArity) {
throw new IllegalArgumentException(
s"DecisionTree given invalid data:" +
s" Feature $featureIndex is categorical with values in" +
s" {0,...,${featureArity - 1}," +
s" but a data point gives it value $featureValue.\n" +
" Bad data point: " + labeledPoint.toString)
}
featureValue.toInt
}
}
}

View file

@ -85,7 +85,7 @@ object Entropy extends Impurity {
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param numClasses Number of classes for label.
*/
private[tree] class EntropyAggregator(numClasses: Int)
private[spark] class EntropyAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable {
/**

View file

@ -81,7 +81,7 @@ object Gini extends Impurity {
* Note: Instances of this class do not hold the data; they operate on views of the data.
* @param numClasses Number of classes for label.
*/
private[tree] class GiniAggregator(numClasses: Int)
private[spark] class GiniAggregator(numClasses: Int)
extends ImpurityAggregator(numClasses) with Serializable {
/**

View file

@ -71,7 +71,7 @@ object Variance extends Impurity {
* in order to compute impurity from a sample.
* Note: Instances of this class do not hold the data; they operate on views of the data.
*/
private[tree] class VarianceAggregator()
private[spark] class VarianceAggregator()
extends ImpurityAggregator(statsSize = 3) with Serializable {
/**

View file

@ -1,47 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.mllib.tree.model
import org.apache.spark.mllib.tree.configuration.FeatureType._
/**
* Used for "binning" the feature values for faster best split calculation.
*
* For a continuous feature, the bin is determined by a low and a high split,
* where an example with featureValue falls into the bin s.t.
* lowSplit.threshold < featureValue <= highSplit.threshold.
*
* For ordered categorical features, there is a 1-1-1 correspondence between
* bins, splits, and feature values. The bin is determined by category/feature value.
* However, the bins are not necessarily ordered by feature value;
* they are ordered using impurity.
*
* For unordered categorical features, there is a 1-1 correspondence between bins, splits,
* where bins and splits correspond to subsets of feature values (in highSplit.categories).
* An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all
* partitionings of categories into 2 disjoint, non-empty sets.
*
* @param lowSplit signifying the lower threshold for the continuous feature to be
* accepted in the bin
* @param highSplit signifying the upper threshold for the continuous feature to be
* accepted in the bin
* @param featureType type of feature -- categorical or continuous
* @param category categorical label value accepted in the bin for ordered features
*/
private[tree]
case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.mllib.tree.impl
package org.apache.spark.ml.tree.impl
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.tree.EnsembleTestHelper

View file

@ -26,7 +26,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._

View file

@ -20,12 +20,12 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext