From 2ec4b2e38d432ef4f21b725c2fceac863d5f9ea1 Mon Sep 17 00:00:00 2001 From: Tathagata Das Date: Wed, 20 Nov 2013 23:49:30 -0800 Subject: [PATCH] Added partition aware union to improve reduceByKeyAndWindow --- .../streaming/dstream/WindowedDStream.scala | 51 ++++++++++++++++++- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 3c57294269..03f522e581 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -20,7 +20,12 @@ package org.apache.spark.streaming.dstream import org.apache.spark.rdd.RDD import org.apache.spark.rdd.UnionRDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Duration, Interval, Time, DStream} +import org.apache.spark.streaming._ +import org.apache.spark._ +import scala.Some +import scala.Some +import scala.Some +import org.apache.spark.streaming.Duration private[streaming] class WindowedDStream[T: ClassManifest]( @@ -49,9 +54,51 @@ class WindowedDStream[T: ClassManifest]( override def compute(validTime: Time): Option[RDD[T]] = { val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime) - Some(new UnionRDD(ssc.sc, parent.slice(currentWindow))) + val rddsInWindow = parent.slice(currentWindow) + val windowRDD = if (rddsInWindow.flatMap(_.partitioner).distinct.length == 1) { + logInfo("Using partition aware union") + new PartitionAwareUnionRDD(ssc.sc, rddsInWindow) + } else { + logInfo("Using normal union") + new UnionRDD(ssc.sc,rddsInWindow) + } + Some(windowRDD) + } +} + +private[streaming] +class PartitionAwareUnionRDDPartition(val idx: Int, val partitions: Array[Partition]) + extends Partition { + override val index = idx + override def hashCode(): Int = idx +} + +private[streaming] +class PartitionAwareUnionRDD[T: ClassManifest]( + sc: SparkContext, + var rdds: Seq[RDD[T]]) + extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { + require(rdds.length > 0) + require(rdds.flatMap(_.partitioner).distinct.length == 1, "Parent RDDs have different partitioners") + + override val partitioner = rdds.head.partitioner + + override def getPartitions: Array[Partition] = { + val numPartitions = rdds.head.partitions.length + (0 until numPartitions).map(index => { + val parentPartitions = rdds.map(_.partitions(index)).toArray + new PartitionAwareUnionRDDPartition(index, parentPartitions) + }).toArray + } + + override def compute(s: Partition, context: TaskContext): Iterator[T] = { + val parentPartitions = s.asInstanceOf[PartitionAwareUnionRDDPartition].partitions + rdds.zip(parentPartitions).iterator.flatMap { + case (rdd, p) => rdd.iterator(p, context) + } } } +