Added BlockRDD and a first-cut version of checkpoint() to RDD class.

This commit is contained in:
Tathagata Das 2012-07-27 12:00:49 -07:00
parent d1b7f41671
commit 024905f682
3 changed files with 63 additions and 0 deletions

View file

@ -0,0 +1,42 @@
package spark
import scala.collection.mutable.HashMap
class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
val index = idx
}
class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) {
@transient
val splits_ = (0 until blockIds.size).map(i => {
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
}).toArray
@transient
lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
val locations = blockManager.getLocations(blockIds)
HashMap(blockIds.zip(locations):_*)
}
override def splits = splits_
override def compute(split: Split): Iterator[T] = {
val blockManager = SparkEnv.get.blockManager
val blockId = split.asInstanceOf[BlockRDDSplit].blockId
blockManager.get(blockId) match {
case Some(block) => block.asInstanceOf[Iterator[T]]
case None =>
throw new Exception("Could not compute split, block " + blockId + " not found")
}
}
override def preferredLocations(split: Split) =
locations_(split.asInstanceOf[BlockRDDSplit].blockId)
override val dependencies: List[Dependency[_]] = Nil
}

View file

@ -94,6 +94,20 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def getStorageLevel = storageLevel
def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = {
if (!level.useDisk && level.replication < 2) {
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
}
// This is a hack. Ideally this should re-use the code used by the CacheTracker
// to generate the key.
def getSplitKey(split: Split) = "rdd:%d:%d".format(this.id, split.index)
persist(level)
sc.runJob(this, (iter: Iterator[T]) => {} )
new BlockRDD[T](sc, splits.map(getSplitKey).toArray)
}
// Read this RDD; will read from cache if applicable, or otherwise compute
final def iterator(split: Split): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {

View file

@ -42,4 +42,11 @@ class RDDSuite extends FunSuite {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
sc.stop()
}
test("checkpointing") {
val sc = new SparkContext("local", "test")
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint()
assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
sc.stop()
}
}