diff --git a/core/src/main/scala/spark/BlockRDD.scala b/core/src/main/scala/spark/BlockRDD.scala new file mode 100644 index 0000000000..ea009f0f4f --- /dev/null +++ b/core/src/main/scala/spark/BlockRDD.scala @@ -0,0 +1,42 @@ +package spark + +import scala.collection.mutable.HashMap + +class BlockRDDSplit(val blockId: String, idx: Int) extends Split { + val index = idx +} + + +class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) { + + @transient + val splits_ = (0 until blockIds.size).map(i => { + new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split] + }).toArray + + @transient + lazy val locations_ = { + val blockManager = SparkEnv.get.blockManager + /*val locations = blockIds.map(id => blockManager.getLocations(id))*/ + val locations = blockManager.getLocations(blockIds) + HashMap(blockIds.zip(locations):_*) + } + + override def splits = splits_ + + override def compute(split: Split): Iterator[T] = { + val blockManager = SparkEnv.get.blockManager + val blockId = split.asInstanceOf[BlockRDDSplit].blockId + blockManager.get(blockId) match { + case Some(block) => block.asInstanceOf[Iterator[T]] + case None => + throw new Exception("Could not compute split, block " + blockId + " not found") + } + } + + override def preferredLocations(split: Split) = + locations_(split.asInstanceOf[BlockRDDSplit].blockId) + + override val dependencies: List[Dependency[_]] = Nil +} + diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 1191523ccc..1190e64f8f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -94,6 +94,20 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def getStorageLevel = storageLevel + def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = { + if (!level.useDisk && level.replication < 2) { + throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")") + } + + // This is a hack. Ideally this should re-use the code used by the CacheTracker + // to generate the key. + def getSplitKey(split: Split) = "rdd:%d:%d".format(this.id, split.index) + + persist(level) + sc.runJob(this, (iter: Iterator[T]) => {} ) + new BlockRDD[T](sc, splits.map(getSplitKey).toArray) + } + // Read this RDD; will read from cache if applicable, or otherwise compute final def iterator(split: Split): Iterator[T] = { if (storageLevel != StorageLevel.NONE) { diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 7199b634b7..8f39820178 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -42,4 +42,11 @@ class RDDSuite extends FunSuite { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) sc.stop() } + + test("checkpointing") { + val sc = new SparkContext("local", "test") + val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint() + assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) + sc.stop() + } }