Simplify checkpointing code and RDD class a little:

- RDD's getDependencies and getSplits methods are now guaranteed to be
  called only once, so subclasses can safely do computation in there
  without worrying about caching the results.

- The management of a "splits_" variable that is cleared out when we
  checkpoint an RDD is now done in the RDD class.

- A few of the RDD subclasses are simpler.

- CheckpointRDD's compute() method no longer assumes that it is given a
  CheckpointRDDSplit -- it can work just as well on a split from the
  original RDD, because it only looks at its index. This is important
  because things like UnionRDD and ZippedRDD remember the parent's
  splits as part of their own and wouldn't work on checkpointed parents.

- RDD.iterator can now reuse cached data if an RDD is computed before it
  is checkpointed. It seems like it wouldn't do this before (it always
  called iterator() on the CheckpointRDD, which read from HDFS).
This commit is contained in:
Matei Zaharia 2013-01-28 22:30:12 -08:00
parent b29599e5cf
commit 64ba6a8c2c
15 changed files with 153 additions and 168 deletions

View file

@ -10,9 +10,9 @@ import spark.storage.{BlockManager, StorageLevel}
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
private val loading = new HashSet[String]
/** Gets or computes an RDD split. Used by RDD.iterator() when a RDD is cached. */
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Split, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
@ -50,7 +50,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
// If we got here, we have to load the split
val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
elements ++= rdd.compute(split, context)
elements ++= rdd.computeOrReadCheckpoint(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]

View file

@ -649,9 +649,7 @@ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
}
private[spark]
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U)
extends RDD[(K, U)](prev) {
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev) {
override def getSplits = firstParent[(K, V)].splits
override val partitioner = firstParent[(K, V)].partitioner
override def compute(split: Split, context: TaskContext) =

View file

@ -1,27 +1,17 @@
package spark
import java.io.{ObjectOutputStream, IOException, EOFException, ObjectInputStream}
import java.net.URL
import java.util.{Date, Random}
import java.util.{HashMap => JHashMap}
import java.util.concurrent.atomic.AtomicLong
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.mapred.HadoopWriter
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputCommitter
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
@ -30,7 +20,6 @@ import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
import spark.rdd.BlockRDD
import spark.rdd.CartesianRDD
import spark.rdd.FilteredRDD
import spark.rdd.FlatMappedRDD
@ -73,11 +62,11 @@ import SparkContext._
* on RDD internals.
*/
abstract class RDD[T: ClassManifest](
@transient var sc: SparkContext,
var dependencies_ : List[Dependency[_]]
@transient private var sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
@ -85,25 +74,27 @@ abstract class RDD[T: ClassManifest](
// Methods that should be implemented by subclasses of RDD
// =======================================================================
/** Function for computing a given partition. */
/** Implemented by subclasses to compute a given partition. */
def compute(split: Split, context: TaskContext): Iterator[T]
/** Set of partitions in this RDD. */
protected def getSplits(): Array[Split]
/**
* Implemented by subclasses to return the set of partitions in this RDD. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
*/
protected def getSplits: Array[Split]
/** How this RDD depends on any parent RDDs. */
protected def getDependencies(): List[Dependency[_]] = dependencies_
/**
* Implemented by subclasses to return how this RDD depends on parent RDDs. This method will only
* be called once, so it is safe to implement a time-consuming computation in it.
*/
protected def getDependencies: Seq[Dependency[_]] = deps
/** A friendly name for this RDD */
var name: String = null
/** Optionally overridden by subclasses to specify placement preferences. */
protected def getPreferredLocations(split: Split): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
// =======================================================================
// Methods and fields available on all RDDs
// =======================================================================
@ -111,13 +102,16 @@ abstract class RDD[T: ClassManifest](
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
/** A friendly name for this RDD */
var name: String = null
/** Assign a name to this RDD */
def setName(_name: String) = {
name = _name
this
}
/**
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
@ -142,15 +136,24 @@ abstract class RDD[T: ClassManifest](
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
// Our dependencies and splits will be gotten by calling subclass's methods below, and will
// be overwritten when we're checkpointed
private var dependencies_ : Seq[Dependency[_]] = null
@transient private var splits_ : Array[Split] = null
/** An Option holding our checkpoint RDD, if we are checkpointed */
private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
/**
* Get the preferred location of a split, taking into account whether the
* Get the list of dependencies of this RDD, taking into account whether the
* RDD is checkpointed or not.
*/
final def preferredLocations(split: Split): Seq[String] = {
if (isCheckpointed) {
checkpointData.get.getPreferredLocations(split)
} else {
getPreferredLocations(split)
final def dependencies: Seq[Dependency[_]] = {
checkpointRDD.map(r => List(new OneToOneDependency(r))).getOrElse {
if (dependencies_ == null) {
dependencies_ = getDependencies
}
dependencies_
}
}
@ -159,22 +162,21 @@ abstract class RDD[T: ClassManifest](
* RDD is checkpointed or not.
*/
final def splits: Array[Split] = {
if (isCheckpointed) {
checkpointData.get.getSplits
} else {
getSplits
checkpointRDD.map(_.splits).getOrElse {
if (splits_ == null) {
splits_ = getSplits
}
splits_
}
}
/**
* Get the list of dependencies of this RDD, taking into account whether the
* Get the preferred location of a split, taking into account whether the
* RDD is checkpointed or not.
*/
final def dependencies: List[Dependency[_]] = {
if (isCheckpointed) {
dependencies_
} else {
getDependencies
final def preferredLocations(split: Split): Seq[String] = {
checkpointRDD.map(_.getPreferredLocations(split)).getOrElse {
getPreferredLocations(split)
}
}
@ -184,10 +186,19 @@ abstract class RDD[T: ClassManifest](
* subclasses of RDD.
*/
final def iterator(split: Split, context: TaskContext): Iterator[T] = {
if (isCheckpointed) {
checkpointData.get.iterator(split, context)
} else if (storageLevel != StorageLevel.NONE) {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheManager.getOrCompute(this, split, context, storageLevel)
} else {
computeOrReadCheckpoint(split, context)
}
}
/**
* Compute an RDD partition or read it from a checkpoint if the RDD is checkpointing.
*/
private[spark] def computeOrReadCheckpoint(split: Split, context: TaskContext): Iterator[T] = {
if (isCheckpointed) {
firstParent[T].iterator(split, context)
} else {
compute(split, context)
}
@ -578,15 +589,15 @@ abstract class RDD[T: ClassManifest](
/**
* Return whether this RDD has been checkpointed or not
*/
def isCheckpointed(): Boolean = {
if (checkpointData.isDefined) checkpointData.get.isCheckpointed() else false
def isCheckpointed: Boolean = {
checkpointData.map(_.isCheckpointed).getOrElse(false)
}
/**
* Gets the name of the file to which this RDD was checkpointed
*/
def getCheckpointFile(): Option[String] = {
if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None
def getCheckpointFile: Option[String] = {
checkpointData.flatMap(_.getCheckpointFile)
}
// =======================================================================
@ -611,31 +622,36 @@ abstract class RDD[T: ClassManifest](
def context = sc
/**
* Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler
* Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
* after a job using this RDD has completed (therefore the RDD has been materialized and
* potentially stored in memory). doCheckpoint() is called recursively on the parent RDDs.
*/
protected[spark] def doCheckpoint() {
if (checkpointData.isDefined) checkpointData.get.doCheckpoint()
dependencies.foreach(_.rdd.doCheckpoint())
private[spark] def doCheckpoint() {
if (checkpointData.isDefined) {
checkpointData.get.doCheckpoint()
} else {
dependencies.foreach(_.rdd.doCheckpoint())
}
}
/**
* Changes the dependencies of this RDD from its original parents to the new RDD
* (`newRDD`) created from the checkpoint file.
* Changes the dependencies of this RDD from its original parents to a new RDD (`newRDD`)
* created from the checkpoint file, and forget its old dependencies and splits.
*/
protected[spark] def changeDependencies(newRDD: RDD[_]) {
private[spark] def markCheckpointed(checkpointRDD: RDD[_]) {
clearDependencies()
dependencies_ = List(new OneToOneDependency(newRDD))
dependencies_ = null
splits_ = null
deps = null // Forget the constructor argument for dependencies too
}
/**
* Clears the dependencies of this RDD. This method must ensure that all references
* to the original parent RDDs is removed to enable the parent RDDs to be garbage
* collected. Subclasses of RDD may override this method for implementing their own cleaning
* logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea.
* logic. See [[spark.rdd.UnionRDD]] for an example.
*/
protected[spark] def clearDependencies() {
protected def clearDependencies() {
dependencies_ = null
}
}

View file

@ -20,7 +20,7 @@ private[spark] object CheckpointState extends Enumeration {
* of the checkpointed RDD.
*/
private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
extends Logging with Serializable {
extends Logging with Serializable {
import CheckpointState._
@ -31,7 +31,7 @@ extends Logging with Serializable {
@transient var cpFile: Option[String] = None
// The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
@transient var cpRDD: Option[RDD[T]] = None
var cpRDD: Option[RDD[T]] = None
// Mark the RDD for checkpointing
def markForCheckpoint() {
@ -41,12 +41,12 @@ extends Logging with Serializable {
}
// Is the RDD already checkpointed
def isCheckpointed(): Boolean = {
def isCheckpointed: Boolean = {
RDDCheckpointData.synchronized { cpState == Checkpointed }
}
// Get the file to which this RDD was checkpointed to as an Option
def getCheckpointFile(): Option[String] = {
def getCheckpointFile: Option[String] = {
RDDCheckpointData.synchronized { cpFile }
}
@ -71,7 +71,7 @@ extends Logging with Serializable {
RDDCheckpointData.synchronized {
cpFile = Some(path)
cpRDD = Some(newRDD)
rdd.changeDependencies(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and splits
cpState = Checkpointed
RDDCheckpointData.clearTaskCaches()
logInfo("Done checkpointing RDD " + rdd.id + ", new parent is RDD " + newRDD.id)
@ -79,7 +79,7 @@ extends Logging with Serializable {
}
// Get preferred location of a split after checkpointing
def getPreferredLocations(split: Split) = {
def getPreferredLocations(split: Split): Seq[String] = {
RDDCheckpointData.synchronized {
cpRDD.get.preferredLocations(split)
}
@ -91,9 +91,10 @@ extends Logging with Serializable {
}
}
// Get iterator. This is called at the worker nodes.
def iterator(split: Split, context: TaskContext): Iterator[T] = {
rdd.firstParent[T].iterator(split, context)
def checkpointRDD: Option[RDD[T]] = {
RDDCheckpointData.synchronized {
cpRDD
}
}
}

View file

@ -319,7 +319,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
/**
* Return whether this RDD has been checkpointed or not
*/
def isCheckpointed(): Boolean = rdd.isCheckpointed()
def isCheckpointed: Boolean = rdd.isCheckpointed
/**
* Gets the name of the file to which this RDD was checkpointed

View file

@ -1,7 +1,7 @@
package spark.rdd
import java.io.{ObjectOutputStream, IOException}
import spark.{OneToOneDependency, NarrowDependency, RDD, SparkContext, Split, TaskContext}
import spark._
private[spark]
@ -35,7 +35,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
val numSplitsInRdd2 = rdd2.splits.size
@transient var splits_ = {
override def getSplits: Array[Split] = {
// create the cross product split
val array = new Array[Split](rdd1.splits.size * rdd2.splits.size)
for (s1 <- rdd1.splits; s2 <- rdd2.splits) {
@ -45,8 +45,6 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
array
}
override def getSplits = splits_
override def getPreferredLocations(split: Split) = {
val currSplit = split.asInstanceOf[CartesianSplit]
rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2)
@ -58,7 +56,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
var deps_ = List(
override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(rdd1) {
def getParents(id: Int): Seq[Int] = List(id / numSplitsInRdd2)
},
@ -67,11 +65,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
}
)
override def getDependencies = deps_
override def clearDependencies() {
deps_ = Nil
splits_ = null
rdd1 = null
rdd2 = null
}

View file

@ -9,23 +9,26 @@ import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split {
override val index: Int = idx
}
private[spark] class CheckpointRDDSplit(val index: Int) extends Split {}
/**
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
*/
private[spark]
class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
extends RDD[T](sc, Nil) {
@transient val path = new Path(checkpointPath)
@transient val fs = path.getFileSystem(new Configuration())
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
@transient val splits_ : Array[Split] = {
val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted
splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray
val dirContents = fs.listStatus(new Path(checkpointPath))
val splitFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
val numSplits = splitFiles.size
if (!splitFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
!splitFiles(numSplits-1).endsWith(CheckpointRDD.splitIdToFile(numSplits-1))) {
throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
}
Array.tabulate(numSplits)(i => new CheckpointRDDSplit(i))
}
checkpointData = Some(new RDDCheckpointData[T](this))
@ -34,36 +37,34 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
override def getSplits = splits_
override def getPreferredLocations(split: Split): Seq[String] = {
val status = fs.getFileStatus(path)
val status = fs.getFileStatus(new Path(checkpointPath))
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
}
override def compute(split: Split, context: TaskContext): Iterator[T] = {
CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile, context)
val file = new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index))
CheckpointRDD.readFromFile(file, context)
}
override def checkpoint() {
// Do nothing. Hadoop RDD should not be checkpointed.
// Do nothing. CheckpointRDD should not be checkpointed.
}
}
private[spark] object CheckpointRDD extends Logging {
def splitIdToFileName(splitId: Int): String = {
val numfmt = NumberFormat.getInstance()
numfmt.setMinimumIntegerDigits(5)
numfmt.setGroupingUsed(false)
"part-" + numfmt.format(splitId)
def splitIdToFile(splitId: Int): String = {
"part-%05d".format(splitId)
}
def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(new Configuration())
val finalOutputName = splitIdToFileName(context.splitId)
val finalOutputName = splitIdToFile(ctx.splitId)
val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " +
@ -83,22 +84,22 @@ private[spark] object CheckpointRDD extends Logging {
serializeStream.close()
if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.delete(finalOutputPath, true)) {
throw new IOException("Checkpoint failed: failed to delete earlier output of task "
+ context.attemptId)
}
if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.exists(finalOutputPath)) {
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
+ context.attemptId)
+ ctx.attemptId + " and final output path does not exist")
} else {
// Some other copy of this task must've finished before us and renamed it
logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
fs.delete(tempOutputPath, false)
}
}
}
def readFromFile[T](path: String, context: TaskContext): Iterator[T] = {
val inputPath = new Path(path)
val fs = inputPath.getFileSystem(new Configuration())
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
val fs = path.getFileSystem(new Configuration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(inputPath, bufferSize)
val fileInputStream = fs.open(path, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
val deserializeStream = serializer.deserializeStream(fileInputStream)

View file

@ -27,11 +27,11 @@ private[spark] case class CoalescedRDDSplit(
* or to avoid having a large number of small tasks when processing a directory with many files.
*/
class CoalescedRDD[T: ClassManifest](
var prev: RDD[T],
@transient var prev: RDD[T],
maxPartitions: Int)
extends RDD[T](prev.context, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
@transient var splits_ : Array[Split] = {
override def getSplits: Array[Split] = {
val prevSplits = prev.splits
if (prevSplits.length < maxPartitions) {
prevSplits.map(_.index).map{idx => new CoalescedRDDSplit(idx, prev, Array(idx)) }
@ -44,26 +44,20 @@ class CoalescedRDD[T: ClassManifest](
}
}
override def getSplits = splits_
override def compute(split: Split, context: TaskContext): Iterator[T] = {
split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { parentSplit =>
firstParent[T].iterator(parentSplit, context)
}
}
var deps_ : List[Dependency[_]] = List(
override def getDependencies: Seq[Dependency[_]] = List(
new NarrowDependency(prev) {
def getParents(id: Int): Seq[Int] =
splits(id).asInstanceOf[CoalescedRDDSplit].parentsIndices
}
)
override def getDependencies() = deps_
override def clearDependencies() {
deps_ = Nil
splits_ = null
prev = null
}
}

View file

@ -3,13 +3,11 @@ package spark.rdd
import spark.{RDD, Split, TaskContext}
private[spark]
class MappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => U)
class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
extends RDD[U](prev) {
override def getSplits = firstParent[T].splits
override def compute(split: Split, context: TaskContext) =
firstParent[T].iterator(split, context).map(f)
}
}

View file

@ -7,23 +7,18 @@ import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext}
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
*
* TODO: This currently doesn't give partition IDs properly!
*/
class PartitionPruningRDD[T: ClassManifest](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
@transient
var partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].partitions
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context)
override protected def getSplits = partitions_
override protected def getSplits =
getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
override val partitioner = firstParent[T].partitioner
override def clearDependencies() {
super.clearDependencies()
partitions_ = null
}
}

View file

@ -22,16 +22,10 @@ class ShuffledRDD[K, V](
override val partitioner = Some(part)
@transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def getSplits = splits_
override def getSplits = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def compute(split: Split, context: TaskContext): Iterator[(K, V)] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
}
override def clearDependencies() {
splits_ = null
}
}

View file

@ -26,9 +26,9 @@ private[spark] class UnionSplit[T: ClassManifest](idx: Int, rdd: RDD[T], splitIn
class UnionRDD[T: ClassManifest](
sc: SparkContext,
@transient var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil, so the dependencies_ var does not refer to parent RDDs
extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
@transient var splits_ : Array[Split] = {
override def getSplits: Array[Split] = {
val array = new Array[Split](rdds.map(_.splits.size).sum)
var pos = 0
for (rdd <- rdds; split <- rdd.splits) {
@ -38,20 +38,16 @@ class UnionRDD[T: ClassManifest](
array
}
override def getSplits = splits_
@transient var deps_ = {
override def getDependencies: Seq[Dependency[_]] = {
val deps = new ArrayBuffer[Dependency[_]]
var pos = 0
for (rdd <- rdds) {
deps += new RangeDependency(rdd, 0, pos, rdd.splits.size)
pos += rdd.splits.size
}
deps.toList
deps
}
override def getDependencies = deps_
override def compute(s: Split, context: TaskContext): Iterator[T] =
s.asInstanceOf[UnionSplit[T]].iterator(context)
@ -59,8 +55,6 @@ class UnionRDD[T: ClassManifest](
s.asInstanceOf[UnionSplit[T]].preferredLocations()
override def clearDependencies() {
deps_ = null
splits_ = null
rdds = null
}
}

View file

@ -32,9 +32,7 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)))
with Serializable {
// TODO: FIX THIS.
@transient var splits_ : Array[Split] = {
override def getSplits: Array[Split] = {
if (rdd1.splits.size != rdd2.splits.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
@ -45,8 +43,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
array
}
override def getSplits = splits_
override def compute(s: Split, context: TaskContext): Iterator[(T, U)] = {
val (split1, split2) = s.asInstanceOf[ZippedSplit[T, U]].splits
rdd1.iterator(split1, context).zip(rdd2.iterator(split2, context))
@ -58,7 +54,6 @@ class ZippedRDD[T: ClassManifest, U: ClassManifest](
}
override def clearDependencies() {
splits_ = null
rdd1 = null
rdd2 = null
}

View file

@ -26,8 +26,8 @@ class MetadataCleaner(name: String, cleanupFunc: (Long) => Unit) extends Logging
if (delaySeconds > 0) {
logDebug(
"Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds and "
+ "period of " + periodSeconds + " secs")
"Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " +
"and period of " + periodSeconds + " secs")
timer.schedule(task, periodSeconds * 1000, periodSeconds * 1000)
}

View file

@ -99,7 +99,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
// the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
// Note that this test is very specific to the current implementation of CartesianRDD.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
ones.checkpoint // checkpoint that MappedRDD
ones.checkpoint() // checkpoint that MappedRDD
val cartesian = new CartesianRDD(sc, ones, ones)
val splitBeforeCheckpoint =
serializeDeserialize(cartesian.splits.head.asInstanceOf[CartesianSplit])
@ -125,7 +125,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
// the parent RDD has been checkpointed and parent splits have been changed to HadoopSplits.
// Note that this test is very specific to the current implementation of CoalescedRDDSplits
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
ones.checkpoint // checkpoint that MappedRDD
ones.checkpoint() // checkpoint that MappedRDD
val coalesced = new CoalescedRDD(ones, 2)
val splitBeforeCheckpoint =
serializeDeserialize(coalesced.splits.head.asInstanceOf[CoalescedRDDSplit])
@ -160,7 +160,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
// so only the RDD will reduce in serialized size, not the splits.
testParentCheckpointing(
rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
}
/**
@ -176,7 +175,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testRDDSplitSize: Boolean = false
) {
// Generate the final RDD using given RDD operation
val baseRDD = generateLongLineageRDD
val baseRDD = generateLongLineageRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.headOption.orNull
val rddType = operatedRDD.getClass.getSimpleName
@ -245,12 +244,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testRDDSplitSize: Boolean
) {
// Generate the final RDD using given RDD operation
val baseRDD = generateLongLineageRDD
val baseRDD = generateLongLineageRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.head.rdd
val rddType = operatedRDD.getClass.getSimpleName
val parentRDDType = parentRDD.getClass.getSimpleName
// Get the splits and dependencies of the parent in case they're lazily computed
parentRDD.dependencies
parentRDD.splits
// Find serialized sizes before and after the checkpoint
val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
@ -267,7 +270,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
if (testRDDSize) {
assert(
rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
"Size of " + rddType + " did not reduce after parent checkpointing parent " + parentRDDType +
"Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType +
"[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
)
}
@ -318,10 +321,12 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
/**
* Get serialized sizes of the RDD and its splits
* Get serialized sizes of the RDD and its splits, in order to test whether the size shrinks
* upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
*/
def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
(Utils.serialize(rdd).size, Utils.serialize(rdd.splits).size)
(Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
Utils.serialize(rdd.splits).length)
}
/**