[SPARK-5604][MLLIB] remove checkpointDir from trees
This is the second part of SPARK-5604, which removes checkpointDir from tree strategies. Note that this is a break change. I will mention it in the migration guide. Author: Xiangrui Meng <meng@databricks.com> Closes #4407 from mengxr/SPARK-5604-1 and squashes the following commits: 13a276d [Xiangrui Meng] remove checkpointDir from trees
This commit is contained in:
parent
7dc4965f34
commit
6b88825a25
|
@ -272,6 +272,8 @@ object DecisionTreeRunner {
|
|||
case Variance => impurity.Variance
|
||||
}
|
||||
|
||||
params.checkpointDir.foreach(sc.setCheckpointDir)
|
||||
|
||||
val strategy
|
||||
= new Strategy(
|
||||
algo = params.algo,
|
||||
|
@ -282,7 +284,6 @@ object DecisionTreeRunner {
|
|||
minInstancesPerNode = params.minInstancesPerNode,
|
||||
minInfoGain = params.minInfoGain,
|
||||
useNodeIdCache = params.useNodeIdCache,
|
||||
checkpointDir = params.checkpointDir,
|
||||
checkpointInterval = params.checkpointInterval)
|
||||
if (params.numTrees == 1) {
|
||||
val startTime = System.nanoTime()
|
||||
|
|
|
@ -204,7 +204,6 @@ private class RandomForest (
|
|||
Some(NodeIdCache.init(
|
||||
data = baggedInput,
|
||||
numTrees = numTrees,
|
||||
checkpointDir = strategy.checkpointDir,
|
||||
checkpointInterval = strategy.checkpointInterval,
|
||||
initVal = 1))
|
||||
} else {
|
||||
|
|
|
@ -62,11 +62,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
|||
* @param subsamplingRate Fraction of the training data used for learning decision tree.
|
||||
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
|
||||
* maintain a separate RDD of node Id cache for each row.
|
||||
* @param checkpointDir If the node Id cache is used, it will help to checkpoint
|
||||
* the node Id cache periodically. This is the checkpoint directory
|
||||
* to be used for the node Id cache.
|
||||
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
|
||||
* E.g. 10 means that the cache will get checkpointed every 10 updates.
|
||||
* E.g. 10 means that the cache will get checkpointed every 10 updates. If
|
||||
* the checkpoint directory is not set in
|
||||
* [[org.apache.spark.SparkContext]], this setting is ignored.
|
||||
*/
|
||||
@Experimental
|
||||
class Strategy (
|
||||
|
@ -82,7 +81,6 @@ class Strategy (
|
|||
@BeanProperty var maxMemoryInMB: Int = 256,
|
||||
@BeanProperty var subsamplingRate: Double = 1,
|
||||
@BeanProperty var useNodeIdCache: Boolean = false,
|
||||
@BeanProperty var checkpointDir: Option[String] = None,
|
||||
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
|
||||
|
||||
def isMulticlassClassification =
|
||||
|
@ -165,7 +163,7 @@ class Strategy (
|
|||
def copy: Strategy = {
|
||||
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
|
||||
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
|
||||
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
|
||||
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater(
|
|||
* The nodeIdsForInstances RDD needs to be updated at each iteration.
|
||||
* @param nodeIdsForInstances The initial values in the cache
|
||||
* (should be an Array of all 1's (meaning the root nodes)).
|
||||
* @param checkpointDir The checkpoint directory where
|
||||
* the checkpointed files will be stored.
|
||||
* @param checkpointInterval The checkpointing interval
|
||||
* (how often should the cache be checkpointed.).
|
||||
*/
|
||||
@DeveloperApi
|
||||
private[tree] class NodeIdCache(
|
||||
var nodeIdsForInstances: RDD[Array[Int]],
|
||||
val checkpointDir: Option[String],
|
||||
val checkpointInterval: Int) {
|
||||
|
||||
// Keep a reference to a previous node Ids for instances.
|
||||
|
@ -91,12 +88,6 @@ private[tree] class NodeIdCache(
|
|||
private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
|
||||
private var rddUpdateCount = 0
|
||||
|
||||
// If a checkpoint directory is given, and there's no prior checkpoint directory,
|
||||
// then set the checkpoint directory with the given one.
|
||||
if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
|
||||
nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
|
||||
}
|
||||
|
||||
/**
|
||||
* Update the node index values in the cache.
|
||||
* This updates the RDD and its lineage.
|
||||
|
@ -184,7 +175,6 @@ private[tree] object NodeIdCache {
|
|||
* Initialize the node Id cache with initial node Id values.
|
||||
* @param data The RDD of training rows.
|
||||
* @param numTrees The number of trees that we want to create cache for.
|
||||
* @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
|
||||
* @param checkpointInterval The checkpointing interval
|
||||
* (how often should the cache be checkpointed.).
|
||||
* @param initVal The initial values in the cache.
|
||||
|
@ -193,12 +183,10 @@ private[tree] object NodeIdCache {
|
|||
def init(
|
||||
data: RDD[BaggedPoint[TreePoint]],
|
||||
numTrees: Int,
|
||||
checkpointDir: Option[String],
|
||||
checkpointInterval: Int,
|
||||
initVal: Int = 1): NodeIdCache = {
|
||||
new NodeIdCache(
|
||||
data.map(_ => Array.fill[Int](numTrees)(initVal)),
|
||||
checkpointDir,
|
||||
checkpointInterval)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue