From ea739251eb763b756a282534268e765b8d4b70f0 Mon Sep 17 00:00:00 2001 From: seanm Date: Sun, 20 Jan 2013 11:29:21 -0700 Subject: [PATCH] adding updateStateByKey object lifecycle test --- .../streaming/BasicOperationsSuite.scala | 45 +++++++++++++++++++ .../scala/spark/streaming/TestSuiteBase.scala | 5 +++ 2 files changed, 50 insertions(+) diff --git a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala index f73f9b1823..2bc94463b1 100644 --- a/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/spark/streaming/BasicOperationsSuite.scala @@ -160,6 +160,51 @@ class BasicOperationsSuite extends TestSuiteBase { testOperation(inputData, updateStateOperation, outputData, true) } + test("updateStateByKey - object lifecycle") { + val inputData = + Seq( + Seq("a","b"), + null, + Seq("a","c","a"), + Seq("c"), + null, + null + ) + + val outputData = + Seq( + Seq(("a", 1), ("b", 1)), + Seq(("a", 1), ("b", 1)), + Seq(("a", 3), ("c", 1)), + Seq(("a", 3), ("c", 2)), + Seq(("c", 2)), + Seq() + ) + + val updateStateOperation = (s: DStream[String]) => { + class StateObject(var counter: Int = 0, var expireCounter: Int = 0) extends Serializable + + // updateFunc clears a state when a StateObject is seen without new values twice in a row + val updateFunc = (values: Seq[Int], state: Option[StateObject]) => { + val stateObj = state.getOrElse(new StateObject) + values.foldLeft(0)(_ + _) match { + case 0 => stateObj.expireCounter += 1 // no new values + case n => { // has new values, increment and reset expireCounter + stateObj.counter += n + stateObj.expireCounter = 0 + } + } + stateObj.expireCounter match { + case 2 => None // seen twice with no new values, give it the boot + case _ => Option(stateObj) + } + } + s.map(_ -> 1).updateStateByKey[StateObject](updateFunc).mapValues(_.counter) + } + + testOperation(inputData, updateStateOperation, outputData, true) + } + test("forgetting of RDDs - map and window operations") { assert(batchDuration === Seconds(1), "Batch duration has changed from 1 second") diff --git a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala index a76f61d4ad..11cfcba827 100644 --- a/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/spark/streaming/TestSuiteBase.scala @@ -28,6 +28,11 @@ class TestInputStream[T: ClassManifest](ssc_ : StreamingContext, input: Seq[Seq[ logInfo("Computing RDD for time " + validTime) val index = ((validTime - zeroTime) / slideDuration - 1).toInt val selectedInput = if (index < input.size) input(index) else Seq[T]() + + // lets us test cases where RDDs are not created + if (selectedInput == null) + return None + val rdd = ssc.sc.makeRDD(selectedInput, numPartitions) logInfo("Created RDD " + rdd.id + " with " + selectedInput) Some(rdd)