diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala rename to mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala index 572815df0b..4e372702f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.impl +package org.apache.spark.ml.tree.impl import org.apache.commons.math3.distribution.PoissonDistribution diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala rename to mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala index c745e9f8db..61091bb803 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.impl +package org.apache.spark.ml.tree.impl import org.apache.spark.mllib.tree.impurity._ @@ -86,6 +86,7 @@ private[spark] class DTStatsAggregator( /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). + * * @param featureOffset This is a pre-computed (node, feature) offset * from [[getFeatureOffset]]. */ @@ -118,6 +119,7 @@ private[spark] class DTStatsAggregator( /** * Faster version of [[update]]. * Update the stats for a given (feature, bin), using the given label. + * * @param featureOffset This is a pre-computed feature offset * from [[getFeatureOffset]]. */ @@ -138,6 +140,7 @@ private[spark] class DTStatsAggregator( /** * For a given feature, merge the stats for two bins. + * * @param featureOffset This is a pre-computed feature offset * from [[getFeatureOffset]]. * @param binIndex The other bin is merged into this bin. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala rename to mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala index 4f27dc44ef..df8eb5d1f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.impl +package org.apache.spark.ml.tree.impl import scala.collection.mutable diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index b37f4e891e..0749d93b7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} -import org.apache.spark.mllib.tree.impl.TimeTracker import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 2c8286766f..9d697a36b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.ml.tree.{LearningNode, Split} -import org.apache.spark.mllib.tree.impl.BaggedPoint import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index cccf052b3e..7b1fd089f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -28,8 +28,6 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator, - TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD @@ -330,7 +328,7 @@ private[spark] object RandomForest extends Logging { /** * Given a group of nodes, this finds the best split for each node. * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] + * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]] * @param metadata Learning and dataset metadata * @param topNodes Root node for each tree. Used for matching instances with nodes. * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala similarity index 98% rename from mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala rename to mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala index 70afaa162b..4cc250aa46 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.impl +package org.apache.spark.ml.tree.impl import scala.collection.mutable.{HashMap => MutableHashMap} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala index 9fa27e5e1f..3a2bf3c725 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.ml.tree.{ContinuousSplit, Split} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index d166dc7905..0f0c6b466d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,11 +20,11 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.impl.TimeTracker import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.impl.TimeTracker import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} import org.apache.spark.rdd.RDD @@ -165,6 +165,7 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. + * * @param input Training dataset. * @param validationInput Validation dataset, ignored if validate is set to false. * @param boostingStrategy Boosting parameters. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala deleted file mode 100644 index dc7e969f7b..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import scala.collection.mutable - -import org.apache.hadoop.fs.{FileSystem, Path} - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.model.{Bin, Node, Split} -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - -/** - * :: DeveloperApi :: - * This is used by the node id cache to find the child id that a data point would belong to. - * @param split Split information. - * @param nodeIndex The current node index of a data point that this will update. - */ -@DeveloperApi -private[tree] case class NodeIndexUpdater( - split: Split, - nodeIndex: Int) { - /** - * Determine a child node index based on the feature value and the split. - * @param binnedFeatures Binned feature values. - * @param bins Bin information to convert the bin indices to approximate feature values. - * @return Child node index to update to. - */ - def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = { - if (split.featureType == Continuous) { - val featureIndex = split.feature - val binIndex = binnedFeatures(featureIndex) - val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold - if (featureValueUpperBound <= split.threshold) { - Node.leftChildIndex(nodeIndex) - } else { - Node.rightChildIndex(nodeIndex) - } - } else { - if (split.categories.contains(binnedFeatures(split.feature).toDouble)) { - Node.leftChildIndex(nodeIndex) - } else { - Node.rightChildIndex(nodeIndex) - } - } - } -} - -/** - * :: DeveloperApi :: - * A given TreePoint would belong to a particular node per tree. - * Each row in the nodeIdsForInstances RDD is an array over trees of the node index - * in each tree. Initially, values should all be 1 for root node. - * The nodeIdsForInstances RDD needs to be updated at each iteration. - * @param nodeIdsForInstances The initial values in the cache - * (should be an Array of all 1's (meaning the root nodes)). - * @param checkpointInterval The checkpointing interval - * (how often should the cache be checkpointed.). - */ -@DeveloperApi -private[spark] class NodeIdCache( - var nodeIdsForInstances: RDD[Array[Int]], - val checkpointInterval: Int) { - - // Keep a reference to a previous node Ids for instances. - // Because we will keep on re-persisting updated node Ids, - // we want to unpersist the previous RDD. - private var prevNodeIdsForInstances: RDD[Array[Int]] = null - - // To keep track of the past checkpointed RDDs. - private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() - private var rddUpdateCount = 0 - - /** - * Update the node index values in the cache. - * This updates the RDD and its lineage. - * TODO: Passing bin information to executors seems unnecessary and costly. - * @param data The RDD of training rows. - * @param nodeIdUpdaters A map of node index updaters. - * The key is the indices of nodes that we want to update. - * @param bins Bin information needed to find child node indices. - */ - def updateNodeIndices( - data: RDD[BaggedPoint[TreePoint]], - nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], - bins: Array[Array[Bin]]): Unit = { - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() - } - - prevNodeIdsForInstances = nodeIdsForInstances - nodeIdsForInstances = data.zip(nodeIdsForInstances).map { - case (point, node) => { - var treeId = 0 - while (treeId < nodeIdUpdaters.length) { - val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null) - if (nodeIdUpdater != null) { - val newNodeIndex = nodeIdUpdater.updateNodeIndex( - binnedFeatures = point.datum.binnedFeatures, - bins = bins) - node(treeId) = newNodeIndex - } - - treeId += 1 - } - - node - } - } - - // Keep on persisting new ones. - nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) - rddUpdateCount += 1 - - // Handle checkpointing if the directory is not None. - if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty && - (rddUpdateCount % checkpointInterval) == 0) { - // Let's see if we can delete previous checkpoints. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // We can delete the oldest checkpoint iff - // the next checkpoint actually exists in the file system. - if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) { - val old = checkpointQueue.dequeue() - - // Since the old checkpoint is not deleted by Spark, - // we'll manually delete it here. - val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) - fs.delete(new Path(old.getCheckpointFile.get), true) - } else { - canDelete = false - } - } - - nodeIdsForInstances.checkpoint() - checkpointQueue.enqueue(nodeIdsForInstances) - } - } - - /** - * Call this after training is finished to delete any remaining checkpoints. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.nonEmpty) { - val old = checkpointQueue.dequeue() - for (checkpointFile <- old.getCheckpointFile) { - val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) - fs.delete(new Path(checkpointFile), true) - } - } - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() - } - } -} - -private[spark] object NodeIdCache { - /** - * Initialize the node Id cache with initial node Id values. - * @param data The RDD of training rows. - * @param numTrees The number of trees that we want to create cache for. - * @param checkpointInterval The checkpointing interval - * (how often should the cache be checkpointed.). - * @param initVal The initial values in the cache. - * @return A node Id cache containing an RDD of initial root node Indices. - */ - def init( - data: RDD[BaggedPoint[TreePoint]], - numTrees: Int, - checkpointInterval: Int, - initVal: Int = 1): NodeIdCache = { - new NodeIdCache( - data.map(_ => Array.fill[Int](numTrees)(initVal)), - checkpointInterval) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala deleted file mode 100644 index 21919d69a3..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.Bin -import org.apache.spark.rdd.RDD - - -/** - * Internal representation of LabeledPoint for DecisionTree. - * This bins feature values based on a subsampled of data as follows: - * (a) Continuous features are binned into ranges. - * (b) Unordered categorical features are binned based on subsets of feature values. - * "Unordered categorical features" are categorical features with low arity used in - * multiclass classification. - * (c) Ordered categorical features are binned based on feature values. - * "Ordered categorical features" are categorical features with high arity, - * or any categorical feature used in regression or binary classification. - * - * @param label Label from LabeledPoint - * @param binnedFeatures Binned feature values. - * Same length as LabeledPoint.features, but values are bin indices. - */ -private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) - extends Serializable { -} - -private[spark] object TreePoint { - - /** - * Convert an input dataset into its TreePoint representation, - * binning feature values in preparation for DecisionTree training. - * @param input Input dataset. - * @param bins Bins for features, of size (numFeatures, numBins). - * @param metadata Learning and dataset metadata - * @return TreePoint dataset representation - */ - def convertToTreeRDD( - input: RDD[LabeledPoint], - bins: Array[Array[Bin]], - metadata: DecisionTreeMetadata): RDD[TreePoint] = { - // Construct arrays for featureArity for efficiency in the inner loop. - val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) - var featureIndex = 0 - while (featureIndex < metadata.numFeatures) { - featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) - featureIndex += 1 - } - input.map { x => - TreePoint.labeledPointToTreePoint(x, bins, featureArity) - } - } - - /** - * Convert one LabeledPoint into its TreePoint representation. - * @param bins Bins for features, of size (numFeatures, numBins). - * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories - * for categorical features. - */ - private def labeledPointToTreePoint( - labeledPoint: LabeledPoint, - bins: Array[Array[Bin]], - featureArity: Array[Int]): TreePoint = { - val numFeatures = labeledPoint.features.size - val arr = new Array[Int](numFeatures) - var featureIndex = 0 - while (featureIndex < numFeatures) { - arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), - bins) - featureIndex += 1 - } - new TreePoint(labeledPoint.label, arr) - } - - /** - * Find bin for one (labeledPoint, feature). - * - * @param featureArity 0 for continuous features; number of categories for categorical features. - * @param bins Bins for features, of size (numFeatures, numBins). - */ - private def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - featureArity: Int, - bins: Array[Array[Bin]]): Int = { - - /** - * Binary search helper method for continuous feature. - */ - def binarySearchForBins(): Int = { - val binForFeatures = bins(featureIndex) - val feature = labeledPoint.features(featureIndex) - var left = 0 - var right = binForFeatures.length - 1 - while (left <= right) { - val mid = left + (right - left) / 2 - val bin = binForFeatures(mid) - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)) { - return mid - } else if (lowThreshold >= feature) { - right = mid - 1 - } else { - left = mid + 1 - } - } - -1 - } - - if (featureArity == 0) { - // Perform binary search for finding bin for continuous features. - val binIndex = binarySearchForBins() - if (binIndex == -1) { - throw new RuntimeException("No bin was found for continuous feature." + - " This error can occur when given invalid data values (such as NaN)." + - s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") - } - binIndex - } else { - // Categorical feature bins are indexed by feature values. - val featureValue = labeledPoint.features(featureIndex) - if (featureValue < 0 || featureValue >= featureArity) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureArity - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - featureValue.toInt - } - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 13aff11007..ff7700d2d1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -85,7 +85,7 @@ object Entropy extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[tree] class EntropyAggregator(numClasses: Int) +private[spark] class EntropyAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 39c7f9c3be..58dc79b739 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -81,7 +81,7 @@ object Gini extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[tree] class GiniAggregator(numClasses: Int) +private[spark] class GiniAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 92d74a1b83..2423516123 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -71,7 +71,7 @@ object Variance extends Impurity { * in order to compute impurity from a sample. * Note: Instances of this class do not hold the data; they operate on views of the data. */ -private[tree] class VarianceAggregator() +private[spark] class VarianceAggregator() extends ImpurityAggregator(statsSize = 3) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala deleted file mode 100644 index 0cad473782..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import org.apache.spark.mllib.tree.configuration.FeatureType._ - -/** - * Used for "binning" the feature values for faster best split calculation. - * - * For a continuous feature, the bin is determined by a low and a high split, - * where an example with featureValue falls into the bin s.t. - * lowSplit.threshold < featureValue <= highSplit.threshold. - * - * For ordered categorical features, there is a 1-1-1 correspondence between - * bins, splits, and feature values. The bin is determined by category/feature value. - * However, the bins are not necessarily ordered by feature value; - * they are ordered using impurity. - * - * For unordered categorical features, there is a 1-1 correspondence between bins, splits, - * where bins and splits correspond to subsets of feature values (in highSplit.categories). - * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all - * partitionings of categories into 2 disjoint, non-empty sets. - * - * @param lowSplit signifying the lower threshold for the continuous feature to be - * accepted in the bin - * @param highSplit signifying the upper threshold for the continuous feature to be - * accepted in the bin - * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin for ordered features - */ -private[tree] -case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala similarity index 99% rename from mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala rename to mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala index 9d756da410..77ab3d8bb7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.tree.impl +package org.apache.spark.ml.tree.impl import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.tree.EnsembleTestHelper diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 441338e74e..e64551f03c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index bb1041b109..49cb7e1f24 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext