diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 0355618e43..17bc2515f2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -265,6 +265,19 @@ abstract class RDD[T: ClassManifest]( def distinct(): RDD[T] = distinct(partitions.size) + /** + * Return a new RDD that has exactly numPartitions partitions. + * + * Used to increase or decrease the level of parallelism in this RDD. This will use + * a shuffle to redistribute data. + * + * If you are decreasing the number of partitions in this RDD, consider using `coalesce`, + * which can avoid performing a shuffle. + */ + def repartition(numPartitions: Int): RDD[T] = { + coalesce(numPartitions, true) + } + /** * Return a new RDD that is reduced into `numPartitions` partitions. * diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 6d1bc5e296..354ab8ae5d 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -139,6 +139,26 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(rdd.union(emptyKv).collect().size === 2) } + test("repartitioned RDDs") { + val data = sc.parallelize(1 to 1000, 10) + + // Coalesce partitions + val repartitioned1 = data.repartition(2) + assert(repartitioned1.partitions.size == 2) + val partitions1 = repartitioned1.glom().collect() + assert(partitions1(0).length > 0) + assert(partitions1(1).length > 0) + assert(repartitioned1.collect().toSet === (1 to 1000).toSet) + + // Split partitions + val repartitioned2 = data.repartition(20) + assert(repartitioned2.partitions.size == 20) + val partitions2 = repartitioned2.glom().collect() + assert(partitions2(0).length > 0) + assert(partitions2(19).length > 0) + assert(repartitioned2.collect().toSet === (1 to 1000).toSet) + } + test("coalesced RDDs") { val data = sc.parallelize(1 to 10, 10) diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index 835b257238..851e30fe76 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -72,6 +72,10 @@ DStreams support many of the transformations available on normal Spark RDD's: Similar to map, but runs separately on each partition (block) of the DStream, so func must be of type Iterator[T] => Iterator[U] when running on an DStream of type T. + + repartition(numPartitions) + Changes the level of parallelism in this DStream by creating more or fewer partitions. + union(otherStream) Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index 80da6bd30b..6da2261f06 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -438,6 +438,13 @@ abstract class DStream[T: ClassManifest] ( */ def glom(): DStream[Array[T]] = new GlommedDStream(this) + + /** + * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the + * returned DStream has exactly numPartitions partitions. + */ + def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions)) + /** * Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs * of this DStream. Applying mapPartitions() to an RDD applies a function to each partition diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala index 8a6604904d..5344ae7682 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaTestUtils.scala @@ -54,8 +54,7 @@ trait JavaTestBase extends TestSuiteBase { { implicit val cm: ClassManifest[T] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] - val ostream = new TestOutputStream(dstream.dstream, - new ArrayBuffer[Seq[T]] with SynchronizedBuffer[Seq[T]]) + val ostream = new TestOutputStreamWithPartitions(dstream.dstream) dstream.dstream.ssc.registerOutputStream(ostream) } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index 11586f72b6..55cfcb371a 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -82,6 +82,44 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(input, operation, output, true) } + test("repartition (more partitions)") { + val input = Seq(1 to 100, 101 to 200, 201 to 300) + val operation = (r: DStream[Int]) => r.repartition(5) + val ssc = setupStreams(input, operation, 2) + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 5) + assert(second.size === 5) + assert(third.size === 5) + + assert(first.flatten.toSet === (1 to 100).toSet) + assert(second.flatten.toSet === (101 to 200).toSet) + assert(third.flatten.toSet === (201 to 300).toSet) + } + + test("repartition (fewer partitions)") { + val input = Seq(1 to 100, 101 to 200, 201 to 300) + val operation = (r: DStream[Int]) => r.repartition(2) + val ssc = setupStreams(input, operation, 5) + val output = runStreamsWithPartitions(ssc, 3, 3) + assert(output.size === 3) + val first = output(0) + val second = output(1) + val third = output(2) + + assert(first.size === 2) + assert(second.size === 2) + assert(third.size === 2) + + assert(first.flatten.toSet === (1 to 100).toSet) + assert(second.flatten.toSet === (101 to 200).toSet) + assert(third.flatten.toSet === (201 to 300).toSet) + } + test("groupByKey") { testOperation( Seq( Seq("a", "a", "b"), Seq("", ""), Seq() ), diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index a327de80b3..beb20831bd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -366,7 +366,7 @@ class CheckpointSuite extends TestSuiteBase with BeforeAndAfter { logInfo("Manual clock after advancing = " + clock.time) Thread.sleep(batchDuration.milliseconds) - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] - outputStream.output + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]] + outputStream.output.map(_.flatten) } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index 37dd9c4cc6..be140699c2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -60,8 +60,11 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ /** * This is a output stream just for the testsuites. All the output is collected into a * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * + * The buffer contains a sequence of RDD's, each containing a sequence of items */ -class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBuffer[Seq[T]]) +class TestOutputStream[T: ClassManifest](parent: DStream[T], + val output: ArrayBuffer[Seq[T]] = ArrayBuffer[Seq[T]]()) extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { val collected = rdd.collect() output += collected @@ -75,6 +78,30 @@ class TestOutputStream[T: ClassManifest](parent: DStream[T], val output: ArrayBu } } +/** + * This is a output stream just for the testsuites. All the output is collected into a + * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. + * + * The buffer contains a sequence of RDD's, each containing a sequence of partitions, each + * containing a sequence of items. + */ +class TestOutputStreamWithPartitions[T: ClassManifest](parent: DStream[T], + val output: ArrayBuffer[Seq[Seq[T]]] = ArrayBuffer[Seq[Seq[T]]]()) + extends ForEachDStream[T](parent, (rdd: RDD[T], t: Time) => { + val collected = rdd.glom().collect().map(_.toSeq) + output += collected + }) { + + // This is to clear the output buffer every it is read from a checkpoint + @throws(classOf[IOException]) + private def readObject(ois: ObjectInputStream) { + ois.defaultReadObject() + output.clear() + } + + def toTestOutputStream = new TestOutputStream[T](this.parent, this.output.map(_.flatten)) +} + /** * This is the base trait for Spark Streaming testsuites. This provides basic functionality * to run user-defined set of input on user-defined stream operations, and verify the output. @@ -108,7 +135,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { */ def setupStreams[U: ClassManifest, V: ClassManifest]( input: Seq[Seq[U]], - operation: DStream[U] => DStream[V] + operation: DStream[U] => DStream[V], + numPartitions: Int = numInputPartitions ): StreamingContext = { // Create StreamingContext @@ -118,9 +146,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { } // Setup the stream computation - val inputStream = new TestInputStream(ssc, input, numInputPartitions) + val inputStream = new TestInputStream(ssc, input, numPartitions) val operatedStream = operation(inputStream) - val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[V]] with SynchronizedBuffer[Seq[V]]) + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ArrayBuffer[Seq[Seq[V]]] with SynchronizedBuffer[Seq[Seq[V]]]) ssc.registerInputStream(inputStream) ssc.registerOutputStream(outputStream) ssc @@ -146,7 +175,8 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { val inputStream1 = new TestInputStream(ssc, input1, numInputPartitions) val inputStream2 = new TestInputStream(ssc, input2, numInputPartitions) val operatedStream = operation(inputStream1, inputStream2) - val outputStream = new TestOutputStream(operatedStream, new ArrayBuffer[Seq[W]] with SynchronizedBuffer[Seq[W]]) + val outputStream = new TestOutputStreamWithPartitions(operatedStream, + new ArrayBuffer[Seq[Seq[W]]] with SynchronizedBuffer[Seq[Seq[W]]]) ssc.registerInputStream(inputStream1) ssc.registerInputStream(inputStream2) ssc.registerOutputStream(outputStream) @@ -157,18 +187,37 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and * returns the collected output. It will wait until `numExpectedOutput` number of * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. + * + * Returns a sequence of items for each RDD. */ def runStreams[V: ClassManifest]( ssc: StreamingContext, numBatches: Int, numExpectedOutput: Int ): Seq[Seq[V]] = { + // Flatten each RDD into a single Seq + runStreamsWithPartitions(ssc, numBatches, numExpectedOutput).map(_.flatten.toSeq) + } + + /** + * Runs the streams set up in `ssc` on manual clock for `numBatches` batches and + * returns the collected output. It will wait until `numExpectedOutput` number of + * output data has been collected or timeout (set by `maxWaitTimeMillis`) is reached. + * + * Returns a sequence of RDD's. Each RDD is represented as several sequences of items, each + * representing one partition. + */ + def runStreamsWithPartitions[V: ClassManifest]( + ssc: StreamingContext, + numBatches: Int, + numExpectedOutput: Int + ): Seq[Seq[Seq[V]]] = { assert(numBatches > 0, "Number of batches to run stream computation is zero") assert(numExpectedOutput > 0, "Number of expected outputs after " + numBatches + " is zero") logInfo("numBatches = " + numBatches + ", numExpectedOutput = " + numExpectedOutput) // Get the output buffer - val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStream[V]] + val outputStream = ssc.graph.getOutputStreams.head.asInstanceOf[TestOutputStreamWithPartitions[V]] val output = outputStream.output try {