[SPARK-8581] [SPARK-8584] Simplify checkpointing code + better error message

This patch rewrites the old checkpointing code in a way that is easier to understand. It also adds a guard against an invalid specification of checkpoint directory to provide a clearer error message. Most of the changes here are relatively minor.

Author: Andrew Or <andrew@databricks.com>

Closes #6968 from andrewor14/checkpoint-cleanup and squashes the following commits:

4ef8263 [Andrew Or] Use global synchronized instead
6f6fd84 [Andrew Or] Merge branch 'master' of github.com:apache/spark into checkpoint-cleanup
b1437ad [Andrew Or] Warn instead of throw
5484293 [Andrew Or] Merge branch 'master' of github.com:apache/spark into checkpoint-cleanup
7fb4af5 [Andrew Or] Guard against bad settings of checkpoint directory
691da98 [Andrew Or] Simplify checkpoint code / code style / comments
This commit is contained in:
Andrew Or 2015-07-02 10:57:02 -07:00
parent 0e553a3e93
commit 2e2f32603c
5 changed files with 60 additions and 54 deletions

View file

@ -1906,6 +1906,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* be a HDFS path if running on a cluster.
*/
def setCheckpointDir(directory: String) {
// If we are running on a cluster, log a warning if the directory is local.
// Otherwise, the driver may attempt to reconstruct the checkpointed RDD from
// its own local file system, which is incorrect because the checkpoint files
// are actually on the executor machines.
if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) {
logWarning("Checkpoint directory must be non-local " +
"if Spark is running on a cluster: " + directory)
}
checkpointDir = Option(directory).map { dir =>
val path = new Path(dir, UUID.randomUUID().toString)
val fs = path.getFileSystem(hadoopConfiguration)

View file

@ -28,7 +28,7 @@ import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.{SerializableConfiguration, Utils}
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition
/**
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
@ -37,9 +37,11 @@ private[spark]
class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
extends RDD[T](sc, Nil) {
val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
@transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
override def getCheckpointFile: Option[String] = Some(checkpointPath)
override def getPartitions: Array[Partition] = {
val cpath = new Path(checkpointPath)
@ -59,9 +61,6 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
}
checkpointData = Some(new RDDCheckpointData[T](this))
checkpointData.get.cpFile = Some(checkpointPath)
override def getPreferredLocations(split: Partition): Seq[String] = {
val status = fs.getFileStatus(new Path(checkpointPath,
CheckpointRDD.splitIdToFile(split.index)))
@ -74,9 +73,9 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
CheckpointRDD.readFromFile(file, broadcastedConf, context)
}
override def checkpoint() {
// Do nothing. CheckpointRDD should not be checkpointed.
}
// CheckpointRDD should not be checkpointed again
override def checkpoint(): Unit = { }
override def doCheckpoint(): Unit = { }
}
private[spark] object CheckpointRDD extends Logging {

View file

@ -194,7 +194,7 @@ abstract class RDD[T: ClassTag](
@transient private var partitions_ : Array[Partition] = null
/** An Option holding our checkpoint RDD, if we are checkpointed */
private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD)
/**
* Get the list of dependencies of this RDD, taking into account whether the
@ -1451,12 +1451,16 @@ abstract class RDD[T: ClassTag](
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
def checkpoint() {
def checkpoint(): Unit = {
if (context.checkpointDir.isEmpty) {
throw new SparkException("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
// NOTE: we use a global lock here due to complexities downstream with ensuring
// children RDD partitions point to the correct parent partitions. In the future
// we should revisit this consideration.
RDDCheckpointData.synchronized {
checkpointData = Some(new RDDCheckpointData(this))
checkpointData.get.markForCheckpoint()
}
}
}
@ -1497,7 +1501,7 @@ abstract class RDD[T: ClassTag](
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
/** Returns the first parent RDD */
protected[spark] def firstParent[U: ClassTag] = {
protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
dependencies.head.rdd.asInstanceOf[RDD[U]]
}

View file

@ -22,16 +22,15 @@ import scala.reflect.ClassTag
import org.apache.hadoop.fs.Path
import org.apache.spark._
import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask}
import org.apache.spark.util.SerializableConfiguration
/**
* Enumeration to manage state transitions of an RDD through checkpointing
* [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
* [ Initialized --> checkpointing in progress --> checkpointed ].
*/
private[spark] object CheckpointState extends Enumeration {
type CheckpointState = Value
val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
val Initialized, CheckpointingInProgress, Checkpointed = Value
}
/**
@ -46,37 +45,37 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
import CheckpointState._
// The checkpoint state of the associated RDD.
var cpState = Initialized
private var cpState = Initialized
// The file to which the associated RDD has been checkpointed to
@transient var cpFile: Option[String] = None
private var cpFile: Option[String] = None
// The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
var cpRDD: Option[RDD[T]] = None
// This is defined if and only if `cpState` is `Checkpointed`.
private var cpRDD: Option[CheckpointRDD[T]] = None
// Mark the RDD for checkpointing
def markForCheckpoint() {
RDDCheckpointData.synchronized {
if (cpState == Initialized) cpState = MarkedForCheckpoint
}
}
// TODO: are we sure we need to use a global lock in the following methods?
// Is the RDD already checkpointed
def isCheckpointed: Boolean = {
RDDCheckpointData.synchronized { cpState == Checkpointed }
def isCheckpointed: Boolean = RDDCheckpointData.synchronized {
cpState == Checkpointed
}
// Get the file to which this RDD was checkpointed to as an Option
def getCheckpointFile: Option[String] = {
RDDCheckpointData.synchronized { cpFile }
def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized {
cpFile
}
// Do the checkpointing of the RDD. Called after the first job using that RDD is over.
def doCheckpoint() {
// If it is marked for checkpointing AND checkpointing is not already in progress,
// then set it to be in progress, else return
/**
* Materialize this RDD and write its content to a reliable DFS.
* This is called immediately after the first action invoked on this RDD has completed.
*/
def doCheckpoint(): Unit = {
// Guard against multiple threads checkpointing the same RDD by
// atomically flipping the state of this RDDCheckpointData
RDDCheckpointData.synchronized {
if (cpState == MarkedForCheckpoint) {
if (cpState == Initialized) {
cpState = CheckpointingInProgress
} else {
return
@ -87,7 +86,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
if (!fs.mkdirs(path)) {
throw new SparkException("Failed to create checkpoint path " + path)
throw new SparkException(s"Failed to create checkpoint path $path")
}
// Save to file, and reload it as an RDD
@ -99,6 +98,8 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
}
}
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
if (newRDD.partitions.length != rdd.partitions.length) {
throw new SparkException(
@ -113,34 +114,26 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed
}
logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}")
}
// Get preferred location of a split after checkpointing
def getPreferredLocations(split: Partition): Seq[String] = {
RDDCheckpointData.synchronized {
cpRDD.get.preferredLocations(split)
}
}
def getPartitions: Array[Partition] = {
RDDCheckpointData.synchronized {
def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
cpRDD.get.partitions
}
}
def checkpointRDD: Option[RDD[T]] = {
RDDCheckpointData.synchronized {
def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized {
cpRDD
}
}
}
private[spark] object RDDCheckpointData {
/** Return the path of the directory to which this RDD's checkpoint data is written. */
def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = {
sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) }
sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") }
}
/** Clean up the files associated with the checkpoint data for this RDD. */
def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = {
rddCheckpointDataPath(sc, rddId).foreach { path =>
val fs = path.getFileSystem(sc.hadoopConfiguration)

View file

@ -46,7 +46,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
flatMappedRDD.checkpoint()
assert(flatMappedRDD.dependencies.head.rdd == parCollection)
assert(flatMappedRDD.dependencies.head.rdd === parCollection)
val result = flatMappedRDD.collect()
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)