[SPARK-8581] [SPARK-8584] Simplify checkpointing code + better error message
This patch rewrites the old checkpointing code in a way that is easier to understand. It also adds a guard against an invalid specification of checkpoint directory to provide a clearer error message. Most of the changes here are relatively minor. Author: Andrew Or <andrew@databricks.com> Closes #6968 from andrewor14/checkpoint-cleanup and squashes the following commits: 4ef8263 [Andrew Or] Use global synchronized instead 6f6fd84 [Andrew Or] Merge branch 'master' of github.com:apache/spark into checkpoint-cleanup b1437ad [Andrew Or] Warn instead of throw 5484293 [Andrew Or] Merge branch 'master' of github.com:apache/spark into checkpoint-cleanup 7fb4af5 [Andrew Or] Guard against bad settings of checkpoint directory 691da98 [Andrew Or] Simplify checkpoint code / code style / comments
This commit is contained in:
parent
0e553a3e93
commit
2e2f32603c
|
@ -1906,6 +1906,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
|
|||
* be a HDFS path if running on a cluster.
|
||||
*/
|
||||
def setCheckpointDir(directory: String) {
|
||||
|
||||
// If we are running on a cluster, log a warning if the directory is local.
|
||||
// Otherwise, the driver may attempt to reconstruct the checkpointed RDD from
|
||||
// its own local file system, which is incorrect because the checkpoint files
|
||||
// are actually on the executor machines.
|
||||
if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) {
|
||||
logWarning("Checkpoint directory must be non-local " +
|
||||
"if Spark is running on a cluster: " + directory)
|
||||
}
|
||||
|
||||
checkpointDir = Option(directory).map { dir =>
|
||||
val path = new Path(dir, UUID.randomUUID().toString)
|
||||
val fs = path.getFileSystem(hadoopConfiguration)
|
||||
|
|
|
@ -28,7 +28,7 @@ import org.apache.spark.broadcast.Broadcast
|
|||
import org.apache.spark.deploy.SparkHadoopUtil
|
||||
import org.apache.spark.util.{SerializableConfiguration, Utils}
|
||||
|
||||
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
|
||||
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition
|
||||
|
||||
/**
|
||||
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
|
||||
|
@ -37,9 +37,11 @@ private[spark]
|
|||
class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
|
||||
extends RDD[T](sc, Nil) {
|
||||
|
||||
val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
|
||||
private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
|
||||
|
||||
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
|
||||
@transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
|
||||
|
||||
override def getCheckpointFile: Option[String] = Some(checkpointPath)
|
||||
|
||||
override def getPartitions: Array[Partition] = {
|
||||
val cpath = new Path(checkpointPath)
|
||||
|
@ -59,9 +61,6 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
|
|||
Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
|
||||
}
|
||||
|
||||
checkpointData = Some(new RDDCheckpointData[T](this))
|
||||
checkpointData.get.cpFile = Some(checkpointPath)
|
||||
|
||||
override def getPreferredLocations(split: Partition): Seq[String] = {
|
||||
val status = fs.getFileStatus(new Path(checkpointPath,
|
||||
CheckpointRDD.splitIdToFile(split.index)))
|
||||
|
@ -74,9 +73,9 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
|
|||
CheckpointRDD.readFromFile(file, broadcastedConf, context)
|
||||
}
|
||||
|
||||
override def checkpoint() {
|
||||
// Do nothing. CheckpointRDD should not be checkpointed.
|
||||
}
|
||||
// CheckpointRDD should not be checkpointed again
|
||||
override def checkpoint(): Unit = { }
|
||||
override def doCheckpoint(): Unit = { }
|
||||
}
|
||||
|
||||
private[spark] object CheckpointRDD extends Logging {
|
||||
|
|
|
@ -194,7 +194,7 @@ abstract class RDD[T: ClassTag](
|
|||
@transient private var partitions_ : Array[Partition] = null
|
||||
|
||||
/** An Option holding our checkpoint RDD, if we are checkpointed */
|
||||
private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
|
||||
private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD)
|
||||
|
||||
/**
|
||||
* Get the list of dependencies of this RDD, taking into account whether the
|
||||
|
@ -1451,12 +1451,16 @@ abstract class RDD[T: ClassTag](
|
|||
* executed on this RDD. It is strongly recommended that this RDD is persisted in
|
||||
* memory, otherwise saving it on a file will require recomputation.
|
||||
*/
|
||||
def checkpoint() {
|
||||
def checkpoint(): Unit = {
|
||||
if (context.checkpointDir.isEmpty) {
|
||||
throw new SparkException("Checkpoint directory has not been set in the SparkContext")
|
||||
} else if (checkpointData.isEmpty) {
|
||||
checkpointData = Some(new RDDCheckpointData(this))
|
||||
checkpointData.get.markForCheckpoint()
|
||||
// NOTE: we use a global lock here due to complexities downstream with ensuring
|
||||
// children RDD partitions point to the correct parent partitions. In the future
|
||||
// we should revisit this consideration.
|
||||
RDDCheckpointData.synchronized {
|
||||
checkpointData = Some(new RDDCheckpointData(this))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1497,7 +1501,7 @@ abstract class RDD[T: ClassTag](
|
|||
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
|
||||
|
||||
/** Returns the first parent RDD */
|
||||
protected[spark] def firstParent[U: ClassTag] = {
|
||||
protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
|
||||
dependencies.head.rdd.asInstanceOf[RDD[U]]
|
||||
}
|
||||
|
||||
|
|
|
@ -22,16 +22,15 @@ import scala.reflect.ClassTag
|
|||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask}
|
||||
import org.apache.spark.util.SerializableConfiguration
|
||||
|
||||
/**
|
||||
* Enumeration to manage state transitions of an RDD through checkpointing
|
||||
* [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
|
||||
* [ Initialized --> checkpointing in progress --> checkpointed ].
|
||||
*/
|
||||
private[spark] object CheckpointState extends Enumeration {
|
||||
type CheckpointState = Value
|
||||
val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
|
||||
val Initialized, CheckpointingInProgress, Checkpointed = Value
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -46,37 +45,37 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
|
|||
import CheckpointState._
|
||||
|
||||
// The checkpoint state of the associated RDD.
|
||||
var cpState = Initialized
|
||||
private var cpState = Initialized
|
||||
|
||||
// The file to which the associated RDD has been checkpointed to
|
||||
@transient var cpFile: Option[String] = None
|
||||
private var cpFile: Option[String] = None
|
||||
|
||||
// The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
|
||||
var cpRDD: Option[RDD[T]] = None
|
||||
// This is defined if and only if `cpState` is `Checkpointed`.
|
||||
private var cpRDD: Option[CheckpointRDD[T]] = None
|
||||
|
||||
// Mark the RDD for checkpointing
|
||||
def markForCheckpoint() {
|
||||
RDDCheckpointData.synchronized {
|
||||
if (cpState == Initialized) cpState = MarkedForCheckpoint
|
||||
}
|
||||
}
|
||||
// TODO: are we sure we need to use a global lock in the following methods?
|
||||
|
||||
// Is the RDD already checkpointed
|
||||
def isCheckpointed: Boolean = {
|
||||
RDDCheckpointData.synchronized { cpState == Checkpointed }
|
||||
def isCheckpointed: Boolean = RDDCheckpointData.synchronized {
|
||||
cpState == Checkpointed
|
||||
}
|
||||
|
||||
// Get the file to which this RDD was checkpointed to as an Option
|
||||
def getCheckpointFile: Option[String] = {
|
||||
RDDCheckpointData.synchronized { cpFile }
|
||||
def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized {
|
||||
cpFile
|
||||
}
|
||||
|
||||
// Do the checkpointing of the RDD. Called after the first job using that RDD is over.
|
||||
def doCheckpoint() {
|
||||
// If it is marked for checkpointing AND checkpointing is not already in progress,
|
||||
// then set it to be in progress, else return
|
||||
/**
|
||||
* Materialize this RDD and write its content to a reliable DFS.
|
||||
* This is called immediately after the first action invoked on this RDD has completed.
|
||||
*/
|
||||
def doCheckpoint(): Unit = {
|
||||
|
||||
// Guard against multiple threads checkpointing the same RDD by
|
||||
// atomically flipping the state of this RDDCheckpointData
|
||||
RDDCheckpointData.synchronized {
|
||||
if (cpState == MarkedForCheckpoint) {
|
||||
if (cpState == Initialized) {
|
||||
cpState = CheckpointingInProgress
|
||||
} else {
|
||||
return
|
||||
|
@ -87,7 +86,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
|
|||
val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
|
||||
val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
|
||||
if (!fs.mkdirs(path)) {
|
||||
throw new SparkException("Failed to create checkpoint path " + path)
|
||||
throw new SparkException(s"Failed to create checkpoint path $path")
|
||||
}
|
||||
|
||||
// Save to file, and reload it as an RDD
|
||||
|
@ -99,6 +98,8 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
|
|||
cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
|
||||
rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
|
||||
if (newRDD.partitions.length != rdd.partitions.length) {
|
||||
throw new SparkException(
|
||||
|
@ -113,34 +114,26 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
|
|||
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
|
||||
cpState = Checkpointed
|
||||
}
|
||||
logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
|
||||
logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}")
|
||||
}
|
||||
|
||||
// Get preferred location of a split after checkpointing
|
||||
def getPreferredLocations(split: Partition): Seq[String] = {
|
||||
RDDCheckpointData.synchronized {
|
||||
cpRDD.get.preferredLocations(split)
|
||||
}
|
||||
def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
|
||||
cpRDD.get.partitions
|
||||
}
|
||||
|
||||
def getPartitions: Array[Partition] = {
|
||||
RDDCheckpointData.synchronized {
|
||||
cpRDD.get.partitions
|
||||
}
|
||||
}
|
||||
|
||||
def checkpointRDD: Option[RDD[T]] = {
|
||||
RDDCheckpointData.synchronized {
|
||||
cpRDD
|
||||
}
|
||||
def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized {
|
||||
cpRDD
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] object RDDCheckpointData {
|
||||
|
||||
/** Return the path of the directory to which this RDD's checkpoint data is written. */
|
||||
def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = {
|
||||
sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) }
|
||||
sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") }
|
||||
}
|
||||
|
||||
/** Clean up the files associated with the checkpoint data for this RDD. */
|
||||
def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = {
|
||||
rddCheckpointDataPath(sc, rddId).foreach { path =>
|
||||
val fs = path.getFileSystem(sc.hadoopConfiguration)
|
||||
|
|
|
@ -46,7 +46,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
|
|||
val parCollection = sc.makeRDD(1 to 4)
|
||||
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
|
||||
flatMappedRDD.checkpoint()
|
||||
assert(flatMappedRDD.dependencies.head.rdd == parCollection)
|
||||
assert(flatMappedRDD.dependencies.head.rdd === parCollection)
|
||||
val result = flatMappedRDD.collect()
|
||||
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
|
||||
assert(flatMappedRDD.collect() === result)
|
||||
|
|
Loading…
Reference in a new issue