Added BlockRDD and a first-cut version of checkpoint() to RDD class.
This commit is contained in:
parent
d1b7f41671
commit
024905f682
42
core/src/main/scala/spark/BlockRDD.scala
Normal file
42
core/src/main/scala/spark/BlockRDD.scala
Normal file
|
@ -0,0 +1,42 @@
|
|||
package spark
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
class BlockRDDSplit(val blockId: String, idx: Int) extends Split {
|
||||
val index = idx
|
||||
}
|
||||
|
||||
|
||||
class BlockRDD[T: ClassManifest](sc: SparkContext, blockIds: Array[String]) extends RDD[T](sc) {
|
||||
|
||||
@transient
|
||||
val splits_ = (0 until blockIds.size).map(i => {
|
||||
new BlockRDDSplit(blockIds(i), i).asInstanceOf[Split]
|
||||
}).toArray
|
||||
|
||||
@transient
|
||||
lazy val locations_ = {
|
||||
val blockManager = SparkEnv.get.blockManager
|
||||
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
|
||||
val locations = blockManager.getLocations(blockIds)
|
||||
HashMap(blockIds.zip(locations):_*)
|
||||
}
|
||||
|
||||
override def splits = splits_
|
||||
|
||||
override def compute(split: Split): Iterator[T] = {
|
||||
val blockManager = SparkEnv.get.blockManager
|
||||
val blockId = split.asInstanceOf[BlockRDDSplit].blockId
|
||||
blockManager.get(blockId) match {
|
||||
case Some(block) => block.asInstanceOf[Iterator[T]]
|
||||
case None =>
|
||||
throw new Exception("Could not compute split, block " + blockId + " not found")
|
||||
}
|
||||
}
|
||||
|
||||
override def preferredLocations(split: Split) =
|
||||
locations_(split.asInstanceOf[BlockRDDSplit].blockId)
|
||||
|
||||
override val dependencies: List[Dependency[_]] = Nil
|
||||
}
|
||||
|
|
@ -94,6 +94,20 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
|
|||
|
||||
def getStorageLevel = storageLevel
|
||||
|
||||
def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER): RDD[T] = {
|
||||
if (!level.useDisk && level.replication < 2) {
|
||||
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
|
||||
}
|
||||
|
||||
// This is a hack. Ideally this should re-use the code used by the CacheTracker
|
||||
// to generate the key.
|
||||
def getSplitKey(split: Split) = "rdd:%d:%d".format(this.id, split.index)
|
||||
|
||||
persist(level)
|
||||
sc.runJob(this, (iter: Iterator[T]) => {} )
|
||||
new BlockRDD[T](sc, splits.map(getSplitKey).toArray)
|
||||
}
|
||||
|
||||
// Read this RDD; will read from cache if applicable, or otherwise compute
|
||||
final def iterator(split: Split): Iterator[T] = {
|
||||
if (storageLevel != StorageLevel.NONE) {
|
||||
|
|
|
@ -42,4 +42,11 @@ class RDDSuite extends FunSuite {
|
|||
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
|
||||
sc.stop()
|
||||
}
|
||||
|
||||
test("checkpointing") {
|
||||
val sc = new SparkContext("local", "test")
|
||||
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).flatMap(x => 1 to x).checkpoint()
|
||||
assert(rdd.collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4))
|
||||
sc.stop()
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue