diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9725017b61..4bd9e1bd54 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -2,6 +2,7 @@ package spark import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer +import scala.collection.Map private[spark] class ParallelCollectionSplit[T: ClassManifest]( val rddId: Long, @@ -24,7 +25,8 @@ private[spark] class ParallelCollectionSplit[T: ClassManifest]( private[spark] class ParallelCollection[T: ClassManifest]( @transient sc : SparkContext, @transient data: Seq[T], - numSlices: Int) + numSlices: Int, + locationPrefs : Map[Int,Seq[String]]) extends RDD[T](sc, Nil) { // TODO: Right now, each split sends along its full data, even if later down the RDD chain it gets // cached. It might be worthwhile to write the data to a file in the DFS and read it in the split @@ -40,7 +42,12 @@ private[spark] class ParallelCollection[T: ClassManifest]( override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - override def preferredLocations(s: Split): Seq[String] = Nil + override def preferredLocations(s: Split): Seq[String] = { + locationPrefs.get(splits_.indexOf(s)) match { + case Some(s) => s + case _ => Nil + } + } } private object ParallelCollection { diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index d7b46bee38..7ae1aea993 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -194,7 +194,7 @@ class SparkContext( /** Distribute a local Scala collection to form an RDD. */ def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { - new ParallelCollection[T](this, seq, numSlices) + new ParallelCollection[T](this, seq, numSlices, Map[Int, Seq[String]]()) } /** Distribute a local Scala collection to form an RDD. */ @@ -202,6 +202,14 @@ class SparkContext( parallelize(seq, numSlices) } + /** Distribute a local Scala collection to form an RDD, with one or more + * location preferences for each object. Create a new partition for each + * collection item. */ + def makeLocalityConstrainedRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap + new ParallelCollection[T](this, seq.map(_._1), seq.size, indexToPrefs) + } + /** * Read a text file from HDFS, a local file system (available on all nodes), or any * Hadoop-supported file system URI, and return it as an RDD of Strings.