[SPARK-5604][MLLIB] remove checkpointDir from trees

This is the second part of SPARK-5604, which removes checkpointDir from tree strategies. Note that this is a break change. I will mention it in the migration guide.

Author: Xiangrui Meng <meng@databricks.com>

Closes #4407 from mengxr/SPARK-5604-1 and squashes the following commits:

13a276d [Xiangrui Meng] remove checkpointDir from trees
This commit is contained in:
Xiangrui Meng 2015-02-05 23:32:09 -08:00
parent 7dc4965f34
commit 6b88825a25
4 changed files with 6 additions and 20 deletions

View file

@ -272,6 +272,8 @@ object DecisionTreeRunner {
case Variance => impurity.Variance case Variance => impurity.Variance
} }
params.checkpointDir.foreach(sc.setCheckpointDir)
val strategy val strategy
= new Strategy( = new Strategy(
algo = params.algo, algo = params.algo,
@ -282,7 +284,6 @@ object DecisionTreeRunner {
minInstancesPerNode = params.minInstancesPerNode, minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain, minInfoGain = params.minInfoGain,
useNodeIdCache = params.useNodeIdCache, useNodeIdCache = params.useNodeIdCache,
checkpointDir = params.checkpointDir,
checkpointInterval = params.checkpointInterval) checkpointInterval = params.checkpointInterval)
if (params.numTrees == 1) { if (params.numTrees == 1) {
val startTime = System.nanoTime() val startTime = System.nanoTime()

View file

@ -204,7 +204,6 @@ private class RandomForest (
Some(NodeIdCache.init( Some(NodeIdCache.init(
data = baggedInput, data = baggedInput,
numTrees = numTrees, numTrees = numTrees,
checkpointDir = strategy.checkpointDir,
checkpointInterval = strategy.checkpointInterval, checkpointInterval = strategy.checkpointInterval,
initVal = 1)) initVal = 1))
} else { } else {

View file

@ -62,11 +62,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param subsamplingRate Fraction of the training data used for learning decision tree. * @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
* maintain a separate RDD of node Id cache for each row. * maintain a separate RDD of node Id cache for each row.
* @param checkpointDir If the node Id cache is used, it will help to checkpoint
* the node Id cache periodically. This is the checkpoint directory
* to be used for the node Id cache.
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated. * @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
* E.g. 10 means that the cache will get checkpointed every 10 updates. * E.g. 10 means that the cache will get checkpointed every 10 updates. If
* the checkpoint directory is not set in
* [[org.apache.spark.SparkContext]], this setting is ignored.
*/ */
@Experimental @Experimental
class Strategy ( class Strategy (
@ -82,7 +81,6 @@ class Strategy (
@BeanProperty var maxMemoryInMB: Int = 256, @BeanProperty var maxMemoryInMB: Int = 256,
@BeanProperty var subsamplingRate: Double = 1, @BeanProperty var subsamplingRate: Double = 1,
@BeanProperty var useNodeIdCache: Boolean = false, @BeanProperty var useNodeIdCache: Boolean = false,
@BeanProperty var checkpointDir: Option[String] = None,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable { @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
def isMulticlassClassification = def isMulticlassClassification =
@ -165,7 +163,7 @@ class Strategy (
def copy: Strategy = { def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins, new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain, quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval) maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
} }
} }

View file

@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater(
* The nodeIdsForInstances RDD needs to be updated at each iteration. * The nodeIdsForInstances RDD needs to be updated at each iteration.
* @param nodeIdsForInstances The initial values in the cache * @param nodeIdsForInstances The initial values in the cache
* (should be an Array of all 1's (meaning the root nodes)). * (should be an Array of all 1's (meaning the root nodes)).
* @param checkpointDir The checkpoint directory where
* the checkpointed files will be stored.
* @param checkpointInterval The checkpointing interval * @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.). * (how often should the cache be checkpointed.).
*/ */
@DeveloperApi @DeveloperApi
private[tree] class NodeIdCache( private[tree] class NodeIdCache(
var nodeIdsForInstances: RDD[Array[Int]], var nodeIdsForInstances: RDD[Array[Int]],
val checkpointDir: Option[String],
val checkpointInterval: Int) { val checkpointInterval: Int) {
// Keep a reference to a previous node Ids for instances. // Keep a reference to a previous node Ids for instances.
@ -91,12 +88,6 @@ private[tree] class NodeIdCache(
private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
private var rddUpdateCount = 0 private var rddUpdateCount = 0
// If a checkpoint directory is given, and there's no prior checkpoint directory,
// then set the checkpoint directory with the given one.
if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
}
/** /**
* Update the node index values in the cache. * Update the node index values in the cache.
* This updates the RDD and its lineage. * This updates the RDD and its lineage.
@ -184,7 +175,6 @@ private[tree] object NodeIdCache {
* Initialize the node Id cache with initial node Id values. * Initialize the node Id cache with initial node Id values.
* @param data The RDD of training rows. * @param data The RDD of training rows.
* @param numTrees The number of trees that we want to create cache for. * @param numTrees The number of trees that we want to create cache for.
* @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
* @param checkpointInterval The checkpointing interval * @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.). * (how often should the cache be checkpointed.).
* @param initVal The initial values in the cache. * @param initVal The initial values in the cache.
@ -193,12 +183,10 @@ private[tree] object NodeIdCache {
def init( def init(
data: RDD[BaggedPoint[TreePoint]], data: RDD[BaggedPoint[TreePoint]],
numTrees: Int, numTrees: Int,
checkpointDir: Option[String],
checkpointInterval: Int, checkpointInterval: Int,
initVal: Int = 1): NodeIdCache = { initVal: Int = 1): NodeIdCache = {
new NodeIdCache( new NodeIdCache(
data.map(_ => Array.fill[Int](numTrees)(initVal)), data.map(_ => Array.fill[Int](numTrees)(initVal)),
checkpointDir,
checkpointInterval) checkpointInterval)
} }
} }