Refactored streaming scheduler and added listener interface.

- Refactored Scheduler + JobManager to JobGenerator + JobScheduler and
  added JobSet for cleaner code. Moved scheduler related code to
  streaming.scheduler package.
- Added StreamingListener trait (similar to SparkListener) to enable
  gathering to streaming stats like processing times and delays.
  StreamingContext.addListener() to added listeners.
- Deduped some code in streaming tests by modifying TestSuiteBase, and
  added StreamingListenerSuite.
This commit is contained in:
Tathagata Das 2013-12-12 20:41:51 -08:00
parent 6169fe14a1
commit 097e120c0c
23 changed files with 500 additions and 199 deletions

View file

@ -63,7 +63,7 @@ trait SparkListener {
* Called when a task begins remotely fetching its result (will not be called for tasks that do
* not need to fetch the result remotely).
*/
def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
/**
* Called when a task ends

View file

@ -40,7 +40,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time)
val graph = ssc.graph
val checkpointDir = ssc.checkpointDir
val checkpointDuration = ssc.checkpointDuration
val pendingTimes = ssc.scheduler.jobManager.getPendingTimes()
val pendingTimes = ssc.scheduler.getPendingTimes()
val delaySeconds = MetadataCleaner.getDelaySeconds
def validate() {

View file

@ -17,23 +17,18 @@
package org.apache.spark.streaming
import org.apache.spark.streaming.dstream._
import StreamingContext._
import org.apache.spark.util.MetadataCleaner
//import Time._
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.scheduler.Job
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.MetadataCleaner
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.conf.Configuration
/**
* A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous

View file

@ -21,6 +21,7 @@ import dstream.InputDStream
import java.io.{ObjectInputStream, IOException, ObjectOutputStream}
import collection.mutable.ArrayBuffer
import org.apache.spark.Logging
import org.apache.spark.streaming.scheduler.Job
final private[streaming] class DStreamGraph extends Serializable with Logging {
initLogging()

View file

@ -1,88 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming
import org.apache.spark.Logging
import org.apache.spark.SparkEnv
import java.util.concurrent.Executors
import collection.mutable.HashMap
import collection.mutable.ArrayBuffer
private[streaming]
class JobManager(ssc: StreamingContext, numThreads: Int = 1) extends Logging {
class JobHandler(ssc: StreamingContext, job: Job) extends Runnable {
def run() {
SparkEnv.set(ssc.env)
try {
val timeTaken = job.run()
logInfo("Total delay: %.5f s for job %s of time %s (execution: %.5f s)".format(
(System.currentTimeMillis() - job.time.milliseconds) / 1000.0, job.id, job.time.milliseconds, timeTaken / 1000.0))
} catch {
case e: Exception =>
logError("Running " + job + " failed", e)
}
clearJob(job)
}
}
initLogging()
val jobExecutor = Executors.newFixedThreadPool(numThreads)
val jobs = new HashMap[Time, ArrayBuffer[Job]]
def runJob(job: Job) {
jobs.synchronized {
jobs.getOrElseUpdate(job.time, new ArrayBuffer[Job]) += job
}
jobExecutor.execute(new JobHandler(ssc, job))
logInfo("Added " + job + " to queue")
}
def stop() {
jobExecutor.shutdown()
}
private def clearJob(job: Job) {
var timeCleared = false
val time = job.time
jobs.synchronized {
val jobsOfTime = jobs.get(time)
if (jobsOfTime.isDefined) {
jobsOfTime.get -= job
if (jobsOfTime.get.isEmpty) {
jobs -= time
timeCleared = true
}
} else {
throw new Exception("Job finished for time " + job.time +
" but time does not exist in jobs")
}
}
if (timeCleared) {
ssc.scheduler.clearOldMetadata(time)
}
}
def getPendingTimes(): Array[Time] = {
jobs.synchronized {
jobs.keySet.toArray
}
}
}

View file

@ -46,6 +46,7 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
import org.apache.hadoop.fs.Path
import twitter4j.Status
import twitter4j.auth.Authorization
import org.apache.spark.streaming.scheduler._
/**
@ -146,9 +147,10 @@ class StreamingContext private (
}
}
protected[streaming] var checkpointDuration: Duration = if (isCheckpointPresent) cp_.checkpointDuration else null
protected[streaming] var receiverJobThread: Thread = null
protected[streaming] var scheduler: Scheduler = null
protected[streaming] val checkpointDuration: Duration = {
if (isCheckpointPresent) cp_.checkpointDuration else graph.batchDuration
}
protected[streaming] val scheduler = new JobScheduler(this)
/**
* Return the associated Spark context
@ -510,6 +512,10 @@ class StreamingContext private (
graph.addOutputStream(outputStream)
}
def addListener(streamingListener: StreamingListener) {
scheduler.listenerBus.addListener(streamingListener)
}
protected def validate() {
assert(graph != null, "Graph is null")
graph.validate()
@ -525,9 +531,6 @@ class StreamingContext private (
* Start the execution of the streams.
*/
def start() {
if (checkpointDir != null && checkpointDuration == null && graph != null) {
checkpointDuration = graph.batchDuration
}
validate()
@ -545,7 +548,6 @@ class StreamingContext private (
Thread.sleep(1000)
// Start the scheduler
scheduler = new Scheduler(this)
scheduler.start()
}
@ -556,7 +558,6 @@ class StreamingContext private (
try {
if (scheduler != null) scheduler.stop()
if (networkInputTracker != null) networkInputTracker.stop()
if (receiverJobThread != null) receiverJobThread.interrupt()
sc.stop()
logInfo("StreamingContext stopped successfully")
} catch {

View file

@ -18,7 +18,8 @@
package org.apache.spark.streaming.dstream
import org.apache.spark.rdd.RDD
import org.apache.spark.streaming.{Duration, DStream, Job, Time}
import org.apache.spark.streaming.{Duration, DStream, Time}
import org.apache.spark.streaming.scheduler.Job
private[streaming]
class ForEachDStream[T: ClassManifest] (

View file

@ -32,6 +32,7 @@ import org.apache.spark.streaming._
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.rdd.{RDD, BlockRDD}
import org.apache.spark.storage.{BlockId, StorageLevel, StreamBlockId}
import org.apache.spark.streaming.scheduler.{DeregisterReceiver, AddBlocks, RegisterReceiver}
/**
* Abstract class for defining any InputDStream that has to start a receiver on worker

View file

@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.scheduler
import org.apache.spark.streaming.Time
case class BatchInfo(
batchTime: Time,
submissionTime: Long,
processingStartTime: Option[Long],
processingEndTime: Option[Long]
) {
def schedulingDelay = processingStartTime.map(_ - submissionTime)
def processingDelay = processingEndTime.zip(processingStartTime).map(x => x._1 - x._2).headOption
def totalDelay = schedulingDelay.zip(processingDelay).map(x => x._1 + x._2).headOption
}

View file

@ -15,13 +15,15 @@
* limitations under the License.
*/
package org.apache.spark.streaming
package org.apache.spark.streaming.scheduler
import java.util.concurrent.atomic.AtomicLong
import org.apache.spark.streaming.Time
private[streaming]
class Job(val time: Time, func: () => _) {
val id = Job.getNewId()
var id: String = _
def run(): Long = {
val startTime = System.currentTimeMillis
func()
@ -29,13 +31,17 @@ class Job(val time: Time, func: () => _) {
(stopTime - startTime)
}
override def toString = "streaming job " + id + " @ " + time
}
def setId(number: Int) {
id = "streaming job " + time + "." + number
}
override def toString = id
}
/*
private[streaming]
object Job {
val id = new AtomicLong(0)
def getNewId() = id.getAndIncrement()
}
*/

View file

@ -15,31 +15,30 @@
* limitations under the License.
*/
package org.apache.spark.streaming
package org.apache.spark.streaming.scheduler
import util.{ManualClock, RecurringTimer, Clock}
import org.apache.spark.SparkEnv
import org.apache.spark.Logging
import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter}
import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
private[streaming]
class Scheduler(ssc: StreamingContext) extends Logging {
class JobGenerator(jobScheduler: JobScheduler) extends Logging {
initLogging()
val concurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt
val jobManager = new JobManager(ssc, concurrentJobs)
val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
new CheckpointWriter(ssc.checkpointDir)
} else {
null
}
val ssc = jobScheduler.ssc
val clockClass = System.getProperty(
"spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
val clock = Class.forName(clockClass).newInstance().asInstanceOf[Clock]
val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
longTime => generateJobs(new Time(longTime)))
val graph = ssc.graph
lazy val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) {
new CheckpointWriter(ssc.checkpointDir)
} else {
null
}
var latestTime: Time = null
def start() = synchronized {
@ -48,26 +47,24 @@ class Scheduler(ssc: StreamingContext) extends Logging {
} else {
startFirstTime()
}
logInfo("Scheduler started")
logInfo("JobGenerator started")
}
def stop() = synchronized {
timer.stop()
jobManager.stop()
if (checkpointWriter != null) checkpointWriter.stop()
ssc.graph.stop()
logInfo("Scheduler stopped")
logInfo("JobGenerator stopped")
}
private def startFirstTime() {
val startTime = new Time(timer.getStartTime())
graph.start(startTime - graph.batchDuration)
timer.start(startTime.milliseconds)
logInfo("Scheduler's timer started at " + startTime)
logInfo("JobGenerator's timer started at " + startTime)
}
private def restart() {
// If manual clock is being used for testing, then
// either set the manual clock to the last checkpointed time,
// or if the property is defined set it to that time
@ -93,35 +90,34 @@ class Scheduler(ssc: StreamingContext) extends Logging {
val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
logInfo("Batches to reschedule: " + timesToReschedule.mkString(", "))
timesToReschedule.foreach(time =>
graph.generateJobs(time).foreach(jobManager.runJob)
jobScheduler.runJobs(time, graph.generateJobs(time))
)
// Restart the timer
timer.start(restartTime.milliseconds)
logInfo("Scheduler's timer restarted at " + restartTime)
logInfo("JobGenerator's timer restarted at " + restartTime)
}
/** Generate jobs and perform checkpoint for the given `time`. */
def generateJobs(time: Time) {
private def generateJobs(time: Time) {
SparkEnv.set(ssc.env)
logInfo("\n-----------------------------------------------------\n")
graph.generateJobs(time).foreach(jobManager.runJob)
jobScheduler.runJobs(time, graph.generateJobs(time))
latestTime = time
doCheckpoint(time)
}
/**
* Clear old metadata assuming jobs of `time` have finished processing.
* And also perform checkpoint.
* On batch completion, clear old metadata and checkpoint computation.
*/
def clearOldMetadata(time: Time) {
private[streaming] def onBatchCompletion(time: Time) {
ssc.graph.clearOldMetadata(time)
doCheckpoint(time)
}
/** Perform checkpoint for the give `time`. */
def doCheckpoint(time: Time) = synchronized {
if (ssc.checkpointDuration != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
private def doCheckpoint(time: Time) = synchronized {
if (checkpointWriter != null && (time - graph.zeroTime).isMultipleOf(ssc.checkpointDuration)) {
logInfo("Checkpointing graph for time " + time)
ssc.graph.updateCheckpointData(time)
checkpointWriter.write(new Checkpoint(ssc, time))

View file

@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.scheduler
import org.apache.spark.Logging
import org.apache.spark.SparkEnv
import java.util.concurrent.{TimeUnit, ConcurrentHashMap, Executors}
import scala.collection.mutable.HashSet
import org.apache.spark.streaming._
private[streaming]
class JobScheduler(val ssc: StreamingContext) extends Logging {
initLogging()
val jobSets = new ConcurrentHashMap[Time, JobSet]
val numConcurrentJobs = System.getProperty("spark.streaming.concurrentJobs", "1").toInt
val executor = Executors.newFixedThreadPool(numConcurrentJobs)
val generator = new JobGenerator(this)
val listenerBus = new StreamingListenerBus()
def clock = generator.clock
def start() {
generator.start()
}
def stop() {
generator.stop()
executor.shutdown()
if (!executor.awaitTermination(5, TimeUnit.SECONDS)) {
executor.shutdownNow()
}
}
def runJobs(time: Time, jobs: Seq[Job]) {
if (jobs.isEmpty) {
logInfo("No jobs added for time " + time)
} else {
val jobSet = new JobSet(time, jobs)
jobSets.put(time, jobSet)
jobSet.jobs.foreach(job => executor.execute(new JobHandler(job)))
logInfo("Added jobs for time " + time)
}
}
def getPendingTimes(): Array[Time] = {
jobSets.keySet.toArray(new Array[Time](0))
}
private def beforeJobStart(job: Job) {
val jobSet = jobSets.get(job.time)
if (!jobSet.hasStarted) {
listenerBus.post(StreamingListenerBatchStarted(jobSet.toBatchInfo()))
}
jobSet.beforeJobStart(job)
logInfo("Starting job " + job.id + " from job set of time " + jobSet.time)
SparkEnv.set(generator.ssc.env)
}
private def afterJobEnd(job: Job) {
val jobSet = jobSets.get(job.time)
jobSet.afterJobStop(job)
logInfo("Finished job " + job.id + " from job set of time " + jobSet.time)
if (jobSet.hasCompleted) {
listenerBus.post(StreamingListenerBatchCompleted(jobSet.toBatchInfo()))
jobSets.remove(jobSet.time)
generator.onBatchCompletion(jobSet.time)
logInfo("Total delay: %.3f s for time %s (execution: %.3f s)".format(
jobSet.totalDelay / 1000.0, jobSet.time.toString,
jobSet.processingDelay / 1000.0
))
}
}
class JobHandler(job: Job) extends Runnable {
def run() {
beforeJobStart(job)
try {
job.run()
} catch {
case e: Exception =>
logError("Running " + job + " failed", e)
}
afterJobEnd(job)
}
}
}

View file

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.scheduler
import scala.collection.mutable.HashSet
import org.apache.spark.streaming.Time
private[streaming]
case class JobSet(time: Time, jobs: Seq[Job]) {
private val incompleteJobs = new HashSet[Job]()
var submissionTime = System.currentTimeMillis()
var processingStartTime = -1L
var processingEndTime = -1L
jobs.zipWithIndex.foreach { case (job, i) => job.setId(i) }
incompleteJobs ++= jobs
def beforeJobStart(job: Job) {
if (processingStartTime < 0) processingStartTime = System.currentTimeMillis()
}
def afterJobStop(job: Job) {
incompleteJobs -= job
if (hasCompleted) processingEndTime = System.currentTimeMillis()
}
def hasStarted() = (processingStartTime > 0)
def hasCompleted() = incompleteJobs.isEmpty
def processingDelay = processingEndTime - processingStartTime
def totalDelay = {
processingEndTime - time.milliseconds
}
def toBatchInfo(): BatchInfo = {
new BatchInfo(
time,
submissionTime,
if (processingStartTime >= 0 ) Some(processingStartTime) else None,
if (processingEndTime >= 0 ) Some(processingEndTime) else None
)
}
}

View file

@ -15,7 +15,7 @@
* limitations under the License.
*/
package org.apache.spark.streaming
package org.apache.spark.streaming.scheduler
import org.apache.spark.streaming.dstream.{NetworkInputDStream, NetworkReceiver}
import org.apache.spark.streaming.dstream.{StopReceiver, ReportBlock, ReportError}
@ -31,6 +31,7 @@ import akka.pattern.ask
import akka.util.duration._
import akka.dispatch._
import org.apache.spark.storage.BlockId
import org.apache.spark.streaming.{Time, StreamingContext}
private[streaming] sealed trait NetworkInputTrackerMessage
private[streaming] case class RegisterReceiver(streamId: Int, receiverActor: ActorRef) extends NetworkInputTrackerMessage

View file

@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.scheduler
sealed trait StreamingListenerEvent
case class StreamingListenerBatchCompleted(batchInfo: BatchInfo) extends StreamingListenerEvent
case class StreamingListenerBatchStarted(batchInfo: BatchInfo) extends StreamingListenerEvent
trait StreamingListener {
/**
* Called when processing of a batch has completed
*/
def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) { }
/**
* Called when processing of a batch has started
*/
def onBatchStarted(batchStarted: StreamingListenerBatchStarted) { }
}

View file

@ -0,0 +1,81 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.scheduler
import org.apache.spark.Logging
import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer}
import java.util.concurrent.LinkedBlockingQueue
/** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */
private[spark] class StreamingListenerBus() extends Logging {
private val listeners = new ArrayBuffer[StreamingListener]() with SynchronizedBuffer[StreamingListener]
/* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
* an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
private val EVENT_QUEUE_CAPACITY = 10000
private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY)
private var queueFullErrorMessageLogged = false
new Thread("StreamingListenerBus") {
setDaemon(true)
override def run() {
while (true) {
val event = eventQueue.take
event match {
case batchStarted: StreamingListenerBatchStarted =>
listeners.foreach(_.onBatchStarted(batchStarted))
case batchCompleted: StreamingListenerBatchCompleted =>
listeners.foreach(_.onBatchCompleted(batchCompleted))
case _ =>
}
}
}
}.start()
def addListener(listener: StreamingListener) {
listeners += listener
}
def post(event: StreamingListenerEvent) {
val eventAdded = eventQueue.offer(event)
if (!eventAdded && !queueFullErrorMessageLogged) {
logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
"This likely means one of the SparkListeners is too slow and cannot keep up with the " +
"rate at which tasks are being started by the scheduler.")
queueFullErrorMessageLogged = true
}
}
/**
* Waits until there are no more events in the queue, or until the specified time has elapsed.
* Used for testing only. Returns true if the queue has emptied and false is the specified time
* elapsed before the queue emptied.
*/
def waitUntilEmpty(timeoutMillis: Int): Boolean = {
val finishTime = System.currentTimeMillis + timeoutMillis
while (!eventQueue.isEmpty()) {
if (System.currentTimeMillis > finishTime) {
return false
}
/* Sleep rather than using wait/notify, because this is used only for testing and wait/notify
* add overhead in the general case. */
Thread.sleep(10)
}
return true
}
}

View file

@ -26,18 +26,6 @@ import util.ManualClock
class BasicOperationsSuite extends TestSuiteBase {
override def framework() = "BasicOperationsSuite"
before {
System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
}
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
}
test("map") {
val input = Seq(1 to 4, 5 to 8, 9 to 12)
testOperation(

View file

@ -34,30 +34,24 @@ import com.google.common.io.Files
* the checkpointing of a DStream's RDDs as well as the checkpointing of
* the whole DStream graph.
*/
class CheckpointSuite extends TestSuiteBase with BeforeAndAfter {
System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
before {
FileUtils.deleteDirectory(new File(checkpointDir))
}
after {
if (ssc != null) ssc.stop()
FileUtils.deleteDirectory(new File(checkpointDir))
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
}
class CheckpointSuite extends TestSuiteBase {
var ssc: StreamingContext = null
override def framework = "CheckpointSuite"
override def batchDuration = Milliseconds(500)
override def actuallyWait = true
override def actuallyWait = true // to allow checkpoints to be written
override def beforeFunction() {
super.beforeFunction()
FileUtils.deleteDirectory(new File(checkpointDir))
}
override def afterFunction() {
super.afterFunction()
if (ssc != null) ssc.stop()
FileUtils.deleteDirectory(new File(checkpointDir))
}
test("basic rdd checkpoints + dstream graph checkpoint recovery") {

View file

@ -32,17 +32,22 @@ import collection.mutable.ArrayBuffer
* This testsuite tests master failures at random times while the stream is running using
* the real clock.
*/
class FailureSuite extends FunSuite with BeforeAndAfter with Logging {
class FailureSuite extends TestSuiteBase with Logging {
var directory = "FailureSuite"
val numBatches = 30
val batchDuration = Milliseconds(1000)
before {
override def batchDuration = Milliseconds(1000)
override def useManualClock = false
override def beforeFunction() {
super.beforeFunction()
FileUtils.deleteDirectory(new File(directory))
}
after {
override def afterFunction() {
super.afterFunction()
FileUtils.deleteDirectory(new File(directory))
}

View file

@ -50,18 +50,6 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val testPort = 9999
override def checkpointDir = "checkpoint"
before {
System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
}
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
}
test("socket input stream") {
// Start the server
val testServer = new TestServer()

View file

@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming
import org.apache.spark.streaming.scheduler._
import scala.collection.mutable.ArrayBuffer
import org.scalatest.matchers.ShouldMatchers
class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers{
val input = (1 to 4).map(Seq(_)).toSeq
val operation = (d: DStream[Int]) => d.map(x => x)
// To make sure that the processing start and end times in collected
// information are different for successive batches
override def batchDuration = Milliseconds(100)
override def actuallyWait = true
test("basic BatchInfo generation") {
val ssc = setupStreams(input, operation)
val collector = new BatchInfoCollector
ssc.addListener(collector)
runStreams(ssc, input.size, input.size)
val batchInfos = collector.batchInfos
batchInfos should have size 4
batchInfos.foreach(info => {
info.schedulingDelay should not be None
info.processingDelay should not be None
info.totalDelay should not be None
info.schedulingDelay.get should be >= 0L
info.processingDelay.get should be >= 0L
info.totalDelay.get should be >= 0L
})
isInIncreasingOrder(batchInfos.map(_.submissionTime)) should be (true)
isInIncreasingOrder(batchInfos.map(_.processingStartTime.get)) should be (true)
isInIncreasingOrder(batchInfos.map(_.processingEndTime.get)) should be (true)
}
/** Check if a sequence of numbers is in increasing order */
def isInIncreasingOrder(seq: Seq[Long]): Boolean = {
for(i <- 1 until seq.size) {
if (seq(i - 1) > seq(i)) return false
}
true
}
/** Listener that collects information on processed batches */
class BatchInfoCollector extends StreamingListener {
val batchInfos = new ArrayBuffer[BatchInfo]
override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted) {
batchInfos += batchCompleted.batchInfo
}
}
}

View file

@ -109,7 +109,7 @@ class TestOutputStreamWithPartitions[T: ClassManifest](parent: DStream[T],
trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Name of the framework for Spark context
def framework = "TestSuiteBase"
def framework = this.getClass.getSimpleName
// Master for Spark context
def master = "local[2]"
@ -126,9 +126,39 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Maximum time to wait before the test times out
def maxWaitTimeMillis = 10000
// Whether to use manual clock or not
def useManualClock = true
// Whether to actually wait in real time before changing manual clock
def actuallyWait = false
// Default before function for any streaming test suite. Override this
// if you want to add your stuff to "before" (i.e., don't call before { } )
def beforeFunction() {
if (useManualClock) {
System.setProperty(
"spark.streaming.clock",
"org.apache.spark.streaming.util.ManualClock"
)
} else {
System.clearProperty("spark.streaming.clock")
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
}
// Default after function for any streaming test suite. Override this
// if you want to add your stuff to "after" (i.e., don't call after { } )
def afterFunction() {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
}
before(beforeFunction)
after(afterFunction)
/**
* Set up required DStreams to test the DStream operation using the two sequences
* of input collections.

View file

@ -22,19 +22,9 @@ import collection.mutable.ArrayBuffer
class WindowOperationsSuite extends TestSuiteBase {
System.setProperty("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
override def maxWaitTimeMillis = 20000 // large window tests can sometimes take longer
override def framework = "WindowOperationsSuite"
override def maxWaitTimeMillis = 20000
override def batchDuration = Seconds(1)
after {
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
}
override def batchDuration = Seconds(1) // making sure its visible in this class
val largerSlideInput = Seq(
Seq(("a", 1)),