[SPARK-5604][MLLIB] remove checkpointDir from trees
This is the second part of SPARK-5604, which removes checkpointDir from tree strategies. Note that this is a break change. I will mention it in the migration guide. Author: Xiangrui Meng <meng@databricks.com> Closes #4407 from mengxr/SPARK-5604-1 and squashes the following commits: 13a276d [Xiangrui Meng] remove checkpointDir from trees
This commit is contained in:
parent
7dc4965f34
commit
6b88825a25
|
@ -272,6 +272,8 @@ object DecisionTreeRunner {
|
||||||
case Variance => impurity.Variance
|
case Variance => impurity.Variance
|
||||||
}
|
}
|
||||||
|
|
||||||
|
params.checkpointDir.foreach(sc.setCheckpointDir)
|
||||||
|
|
||||||
val strategy
|
val strategy
|
||||||
= new Strategy(
|
= new Strategy(
|
||||||
algo = params.algo,
|
algo = params.algo,
|
||||||
|
@ -282,7 +284,6 @@ object DecisionTreeRunner {
|
||||||
minInstancesPerNode = params.minInstancesPerNode,
|
minInstancesPerNode = params.minInstancesPerNode,
|
||||||
minInfoGain = params.minInfoGain,
|
minInfoGain = params.minInfoGain,
|
||||||
useNodeIdCache = params.useNodeIdCache,
|
useNodeIdCache = params.useNodeIdCache,
|
||||||
checkpointDir = params.checkpointDir,
|
|
||||||
checkpointInterval = params.checkpointInterval)
|
checkpointInterval = params.checkpointInterval)
|
||||||
if (params.numTrees == 1) {
|
if (params.numTrees == 1) {
|
||||||
val startTime = System.nanoTime()
|
val startTime = System.nanoTime()
|
||||||
|
|
|
@ -204,7 +204,6 @@ private class RandomForest (
|
||||||
Some(NodeIdCache.init(
|
Some(NodeIdCache.init(
|
||||||
data = baggedInput,
|
data = baggedInput,
|
||||||
numTrees = numTrees,
|
numTrees = numTrees,
|
||||||
checkpointDir = strategy.checkpointDir,
|
|
||||||
checkpointInterval = strategy.checkpointInterval,
|
checkpointInterval = strategy.checkpointInterval,
|
||||||
initVal = 1))
|
initVal = 1))
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -62,11 +62,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
|
||||||
* @param subsamplingRate Fraction of the training data used for learning decision tree.
|
* @param subsamplingRate Fraction of the training data used for learning decision tree.
|
||||||
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
|
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
|
||||||
* maintain a separate RDD of node Id cache for each row.
|
* maintain a separate RDD of node Id cache for each row.
|
||||||
* @param checkpointDir If the node Id cache is used, it will help to checkpoint
|
|
||||||
* the node Id cache periodically. This is the checkpoint directory
|
|
||||||
* to be used for the node Id cache.
|
|
||||||
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
|
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
|
||||||
* E.g. 10 means that the cache will get checkpointed every 10 updates.
|
* E.g. 10 means that the cache will get checkpointed every 10 updates. If
|
||||||
|
* the checkpoint directory is not set in
|
||||||
|
* [[org.apache.spark.SparkContext]], this setting is ignored.
|
||||||
*/
|
*/
|
||||||
@Experimental
|
@Experimental
|
||||||
class Strategy (
|
class Strategy (
|
||||||
|
@ -82,7 +81,6 @@ class Strategy (
|
||||||
@BeanProperty var maxMemoryInMB: Int = 256,
|
@BeanProperty var maxMemoryInMB: Int = 256,
|
||||||
@BeanProperty var subsamplingRate: Double = 1,
|
@BeanProperty var subsamplingRate: Double = 1,
|
||||||
@BeanProperty var useNodeIdCache: Boolean = false,
|
@BeanProperty var useNodeIdCache: Boolean = false,
|
||||||
@BeanProperty var checkpointDir: Option[String] = None,
|
|
||||||
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
|
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
|
||||||
|
|
||||||
def isMulticlassClassification =
|
def isMulticlassClassification =
|
||||||
|
@ -165,7 +163,7 @@ class Strategy (
|
||||||
def copy: Strategy = {
|
def copy: Strategy = {
|
||||||
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
|
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
|
||||||
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
|
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
|
||||||
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
|
maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater(
|
||||||
* The nodeIdsForInstances RDD needs to be updated at each iteration.
|
* The nodeIdsForInstances RDD needs to be updated at each iteration.
|
||||||
* @param nodeIdsForInstances The initial values in the cache
|
* @param nodeIdsForInstances The initial values in the cache
|
||||||
* (should be an Array of all 1's (meaning the root nodes)).
|
* (should be an Array of all 1's (meaning the root nodes)).
|
||||||
* @param checkpointDir The checkpoint directory where
|
|
||||||
* the checkpointed files will be stored.
|
|
||||||
* @param checkpointInterval The checkpointing interval
|
* @param checkpointInterval The checkpointing interval
|
||||||
* (how often should the cache be checkpointed.).
|
* (how often should the cache be checkpointed.).
|
||||||
*/
|
*/
|
||||||
@DeveloperApi
|
@DeveloperApi
|
||||||
private[tree] class NodeIdCache(
|
private[tree] class NodeIdCache(
|
||||||
var nodeIdsForInstances: RDD[Array[Int]],
|
var nodeIdsForInstances: RDD[Array[Int]],
|
||||||
val checkpointDir: Option[String],
|
|
||||||
val checkpointInterval: Int) {
|
val checkpointInterval: Int) {
|
||||||
|
|
||||||
// Keep a reference to a previous node Ids for instances.
|
// Keep a reference to a previous node Ids for instances.
|
||||||
|
@ -91,12 +88,6 @@ private[tree] class NodeIdCache(
|
||||||
private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
|
private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
|
||||||
private var rddUpdateCount = 0
|
private var rddUpdateCount = 0
|
||||||
|
|
||||||
// If a checkpoint directory is given, and there's no prior checkpoint directory,
|
|
||||||
// then set the checkpoint directory with the given one.
|
|
||||||
if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
|
|
||||||
nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Update the node index values in the cache.
|
* Update the node index values in the cache.
|
||||||
* This updates the RDD and its lineage.
|
* This updates the RDD and its lineage.
|
||||||
|
@ -184,7 +175,6 @@ private[tree] object NodeIdCache {
|
||||||
* Initialize the node Id cache with initial node Id values.
|
* Initialize the node Id cache with initial node Id values.
|
||||||
* @param data The RDD of training rows.
|
* @param data The RDD of training rows.
|
||||||
* @param numTrees The number of trees that we want to create cache for.
|
* @param numTrees The number of trees that we want to create cache for.
|
||||||
* @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
|
|
||||||
* @param checkpointInterval The checkpointing interval
|
* @param checkpointInterval The checkpointing interval
|
||||||
* (how often should the cache be checkpointed.).
|
* (how often should the cache be checkpointed.).
|
||||||
* @param initVal The initial values in the cache.
|
* @param initVal The initial values in the cache.
|
||||||
|
@ -193,12 +183,10 @@ private[tree] object NodeIdCache {
|
||||||
def init(
|
def init(
|
||||||
data: RDD[BaggedPoint[TreePoint]],
|
data: RDD[BaggedPoint[TreePoint]],
|
||||||
numTrees: Int,
|
numTrees: Int,
|
||||||
checkpointDir: Option[String],
|
|
||||||
checkpointInterval: Int,
|
checkpointInterval: Int,
|
||||||
initVal: Int = 1): NodeIdCache = {
|
initVal: Int = 1): NodeIdCache = {
|
||||||
new NodeIdCache(
|
new NodeIdCache(
|
||||||
data.map(_ => Array.fill[Int](numTrees)(initVal)),
|
data.map(_ => Array.fill[Int](numTrees)(initVal)),
|
||||||
checkpointDir,
|
|
||||||
checkpointInterval)
|
checkpointInterval)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue