spark-instrumented-optimizer/core/src/main/scala/spark/RDDCheckpointData.scala
Matei Zaharia 64ba6a8c2c Simplify checkpointing code and RDD class a little:
- RDD's getDependencies and getSplits methods are now guaranteed to be
  called only once, so subclasses can safely do computation in there
  without worrying about caching the results.

- The management of a "splits_" variable that is cleared out when we
  checkpoint an RDD is now done in the RDD class.

- A few of the RDD subclasses are simpler.

- CheckpointRDD's compute() method no longer assumes that it is given a
  CheckpointRDDSplit -- it can work just as well on a split from the
  original RDD, because it only looks at its index. This is important
  because things like UnionRDD and ZippedRDD remember the parent's
  splits as part of their own and wouldn't work on checkpointed parents.

- RDD.iterator can now reuse cached data if an RDD is computed before it
  is checkpointed. It seems like it wouldn't do this before (it always
  called iterator() on the CheckpointRDD, which read from HDFS).
2013-01-28 22:30:12 -08:00

107 lines
3.3 KiB
Scala

package spark
import org.apache.hadoop.fs.Path
import rdd.{CheckpointRDD, CoalescedRDD}
import scheduler.{ResultTask, ShuffleMapTask}
/**
* Enumeration to manage state transitions of an RDD through checkpointing
* [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
*/
private[spark] object CheckpointState extends Enumeration {
type CheckpointState = Value
val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
}
/**
* This class contains all the information related to RDD checkpointing. Each instance of this class
* is associated with a RDD. It manages process of checkpointing of the associated RDD, as well as,
* manages the post-checkpoint state by providing the updated splits, iterator and preferred locations
* of the checkpointed RDD.
*/
private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
extends Logging with Serializable {
import CheckpointState._
// The checkpoint state of the associated RDD.
var cpState = Initialized
// The file to which the associated RDD has been checkpointed to
@transient var cpFile: Option[String] = None
// The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
var cpRDD: Option[RDD[T]] = None
// Mark the RDD for checkpointing
def markForCheckpoint() {
RDDCheckpointData.synchronized {
if (cpState == Initialized) cpState = MarkedForCheckpoint
}
}
// Is the RDD already checkpointed
def isCheckpointed: Boolean = {
RDDCheckpointData.synchronized { cpState == Checkpointed }
}
// Get the file to which this RDD was checkpointed to as an Option
def getCheckpointFile: Option[String] = {
RDDCheckpointData.synchronized { cpFile }
}
// Do the checkpointing of the RDD. Called after the first job using that RDD is over.
def doCheckpoint() {
// If it is marked for checkpointing AND checkpointing is not already in progress,
// then set it to be in progress, else return
RDDCheckpointData.synchronized {
if (cpState == MarkedForCheckpoint) {
cpState = CheckpointingInProgress
} else {
return
}
}
// Save to file, and reload it as an RDD
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString
rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
val newRDD = new CheckpointRDD[T](rdd.context, path)
// Change the dependencies and splits of the RDD
RDDCheckpointData.synchronized {
cpFile = Some(path)
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
}
}
// Get preferred location of a split after checkpointing
def getPreferredLocations(split: Split): Seq[String] = {
RDDCheckpointData.synchronized {
cpRDD.get.preferredLocations(split)
}
}
def getSplits: Array[Split] = {
RDDCheckpointData.synchronized {
cpRDD.get.splits
}
}
def checkpointRDD: Option[RDD[T]] = {
RDDCheckpointData.synchronized {
cpRDD
}
}
}
private[spark] object RDDCheckpointData {
def clearTaskCaches() {
ShuffleMapTask.clearCache()
ResultTask.clearCache()
}
}