Merge branch 'apache-master' into transform

This commit is contained in:
Tathagata Das 2013-10-24 11:05:00 -07:00
commit 0400aba1c0
19 changed files with 417 additions and 10 deletions

View file

@ -48,6 +48,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
*/
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
* This method blocks until all blocks are deleted.
*/
def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist())
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*
* @param blocking Whether to block until all blocks are deleted.
*/
def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking))
// first() has to be overriden here in order for its return type to be Double instead of Object.
override def first(): Double = srdd.first()

View file

@ -65,6 +65,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
* This method blocks until all blocks are deleted.
*/
def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist())
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*
* @param blocking Whether to block until all blocks are deleted.
*/
def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking))
// Transformations (return a new RDD)
/**

View file

@ -41,9 +41,17 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
* This method blocks until all blocks are deleted.
*/
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*
* @param blocking Whether to block until all blocks are deleted.
*/
def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking))
// Transformations (return a new RDD)
/**

View file

@ -68,6 +68,11 @@ class DAGScheduler(
eventQueue.put(BeginEvent(task, taskInfo))
}
// Called to report that a task has completed and results are being fetched remotely.
def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
eventQueue.put(GettingResultEvent(task, taskInfo))
}
// Called by TaskScheduler to report task completions or failures.
def taskEnded(
task: Task[_],
@ -415,6 +420,9 @@ class DAGScheduler(
case begin: BeginEvent =>
listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
case gettingResult: GettingResultEvent =>
listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo))
case completion: CompletionEvent =>
listenerBus.post(SparkListenerTaskEnd(
completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))

View file

@ -53,6 +53,9 @@ private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
private[scheduler]
case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,

View file

@ -31,6 +31,9 @@ case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
case class SparkListenerTaskGettingResult(
task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
taskMetrics: TaskMetrics) extends SparkListenerEvents
@ -56,6 +59,12 @@ trait SparkListener {
*/
def onTaskStart(taskStart: SparkListenerTaskStart) { }
/**
* Called when a task begins remotely fetching its result (will not be called for tasks that do
* not need to fetch the result remotely).
*/
def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
/**
* Called when a task ends
*/

View file

@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging {
sparkListeners.foreach(_.onJobEnd(jobEnd))
case taskStart: SparkListenerTaskStart =>
sparkListeners.foreach(_.onTaskStart(taskStart))
case taskGettingResult: SparkListenerTaskGettingResult =>
sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
case _ =>

View file

@ -31,9 +31,25 @@ class TaskInfo(
val host: String,
val taskLocality: TaskLocality.TaskLocality) {
/**
* The time when the task started remotely getting the result. Will not be set if the
* task result was sent immediately when the task finished (as opposed to sending an
* IndirectTaskResult and later fetching the result from the block manager).
*/
var gettingResultTime: Long = 0
/**
* The time when the task has completed successfully (including the time to remotely fetch
* results, if necessary).
*/
var finishTime: Long = 0
var failed = false
def markGettingResult(time: Long = System.currentTimeMillis) {
gettingResultTime = time
}
def markSuccessful(time: Long = System.currentTimeMillis) {
finishTime = time
}
@ -43,6 +59,8 @@ class TaskInfo(
failed = true
}
def gettingResult: Boolean = gettingResultTime != 0
def finished: Boolean = finishTime != 0
def successful: Boolean = finished && !failed
@ -52,6 +70,8 @@ class TaskInfo(
def status: String = {
if (running)
"RUNNING"
else if (gettingResult)
"GET RESULT"
else if (failed)
"FAILED"
else if (successful)

View file

@ -306,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
taskSetManager.handleTaskGettingResult(tid)
}
def handleSuccessfulTask(
taskSetManager: ClusterTaskSetManager,
tid: Long,

View file

@ -418,6 +418,12 @@ private[spark] class ClusterTaskSetManager(
sched.dagScheduler.taskStarted(task, info)
}
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
sched.dagScheduler.taskGettingResult(tasks(info.index), info)
}
/**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/

View file

@ -50,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
case directResult: DirectTaskResult[_] => directResult
case IndirectTaskResult(blockId) =>
logDebug("Fetching indirect task result for TID %s".format(tid))
scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
if (!serializedTaskResult.isDefined) {
/* We won't be able to get the task result if the machine that ran the task failed

View file

@ -115,7 +115,13 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
taskList += ((taskStart.taskInfo, None, None))
stageIdToTaskInfos(sid) = taskList
}
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult)
= synchronized {
// Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in
// stageToTaskInfos already has the updated status.
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())

View file

@ -17,22 +17,25 @@
package org.apache.spark.scheduler
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.apache.spark.{LocalSparkContext, SparkContext}
import scala.collection.mutable
import scala.collection.mutable.{Buffer, HashSet}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
with BeforeAndAfter {
with BeforeAndAfterAll {
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
before {
sc = new SparkContext("local", "DAGSchedulerSuite")
override def afterAll {
System.clearProperty("spark.akka.frameSize")
}
test("basic creation of StageInfo") {
sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@ -53,6 +56,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("StageInfo with fewer tasks than partitions") {
sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
@ -68,6 +72,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
test("local metrics") {
sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@ -129,15 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
}
test("onTaskGettingResult() called when result fetched remotely") {
// Need to use local cluster mode here, because results are not ever returned through the
// block manager when using the LocalScheduler.
sc = new SparkContext("local-cluster[1,1,512]", "test")
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
// Make a task whose result is larger than the akka frame size
System.setProperty("spark.akka.frameSize", "1")
val akkaFrameSize =
sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
val TASK_INDEX = 0
assert(listener.startedTasks.contains(TASK_INDEX))
assert(listener.startedGettingResultTasks.contains(TASK_INDEX))
assert(listener.endedTasks.contains(TASK_INDEX))
}
test("onTaskGettingResult() not called when result sent directly") {
// Need to use local cluster mode here, because results are not ever returned through the
// block manager when using the LocalScheduler.
sc = new SparkContext("local-cluster[1,1,512]", "test")
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
// Make a task whose result is larger than the akka frame size
val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
assert(result === 2)
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
val TASK_INDEX = 0
assert(listener.startedTasks.contains(TASK_INDEX))
assert(listener.startedGettingResultTasks.isEmpty == true)
assert(listener.endedTasks.contains(TASK_INDEX))
}
def checkNonZeroAvg(m: Traversable[Long], msg: String) {
assert(m.sum / m.size.toDouble > 0.0, msg)
}
class SaveStageInfo extends SparkListener {
val stageInfos = mutable.Buffer[StageInfo]()
val stageInfos = Buffer[StageInfo]()
override def onStageCompleted(stage: StageCompleted) {
stageInfos += stage.stage
}
}
class SaveTaskEvents extends SparkListener {
val startedTasks = new HashSet[Int]()
val startedGettingResultTasks = new HashSet[Int]()
val endedTasks = new HashSet[Int]()
override def onTaskStart(taskStart: SparkListenerTaskStart) {
startedTasks += taskStart.taskInfo.index
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
endedTasks += taskEnd.taskInfo.index
}
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
startedGettingResultTasks += taskGettingResult.taskInfo.index
}
}
}

View file

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.examples
import org.apache.spark.streaming.{ Seconds, StreamingContext }
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.MQTTReceiver
import org.apache.spark.storage.StorageLevel
import org.eclipse.paho.client.mqttv3.MqttClient
import org.eclipse.paho.client.mqttv3.MqttClientPersistence
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
import org.eclipse.paho.client.mqttv3.MqttException
import org.eclipse.paho.client.mqttv3.MqttMessage
import org.eclipse.paho.client.mqttv3.MqttTopic
/**
* A simple Mqtt publisher for demonstration purposes, repeatedly publishes
* Space separated String Message "hello mqtt demo for spark streaming"
*/
object MQTTPublisher {
var client: MqttClient = _
def main(args: Array[String]) {
if (args.length < 2) {
System.err.println("Usage: MQTTPublisher <MqttBrokerUrl> <topic>")
System.exit(1)
}
val Seq(brokerUrl, topic) = args.toSeq
try {
var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp")
client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance)
} catch {
case e: MqttException => println("Exception Caught: " + e)
}
client.connect()
val msgtopic: MqttTopic = client.getTopic(topic);
val msg: String = "hello mqtt demo for spark streaming"
while (true) {
val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes())
msgtopic.publish(message);
println("Published data. topic: " + msgtopic.getName() + " Message: " + message)
}
client.disconnect()
}
}
/**
* A sample wordcount with MqttStream stream
*
* To work with Mqtt, Mqtt Message broker/server required.
* Mosquitto (http://mosquitto.org/) is an open source Mqtt Broker
* In ubuntu mosquitto can be installed using the command `$ sudo apt-get install mosquitto`
* Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/
* Example Java code for Mqtt Publisher and Subscriber can be found here https://bitbucket.org/mkjinesh/mqttclient
* Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>
* In local mode, <master> should be 'local[n]' with n > 1
* <MqttbrokerUrl> and <topic> describe where Mqtt publisher is running.
*
* To run this example locally, you may run publisher as
* `$ ./run-example org.apache.spark.streaming.examples.MQTTPublisher tcp://localhost:1883 foo`
* and run the example as
* `$ ./run-example org.apache.spark.streaming.examples.MQTTWordCount local[2] tcp://localhost:1883 foo`
*/
object MQTTWordCount {
def main(args: Array[String]) {
if (args.length < 3) {
System.err.println(
"Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>" +
" In local mode, <master> should be 'local[n]' with n > 1")
System.exit(1)
}
val Seq(master, brokerUrl, topic) = args.toSeq
val ssc = new StreamingContext(master, "MqttWordCount", Seconds(2), System.getenv("SPARK_HOME"),
Seq(System.getenv("SPARK_EXAMPLES_JAR")))
val lines = ssc.mqttStream(brokerUrl, topic, StorageLevel.MEMORY_ONLY)
val words = lines.flatMap(x => x.toString.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
wordCounts.print()
ssc.start()
}
}

11
pom.xml
View file

@ -147,6 +147,17 @@
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>mqtt-repo</id>
<name>MQTT Repository</name>
<url>https://repo.eclipse.org/content/repositories/paho-releases/</url>
<releases>
<enabled>true</enabled>
</releases>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
</repositories>
<pluginRepositories>
<pluginRepository>

View file

@ -108,7 +108,10 @@ object SparkBuild extends Build {
// Shared between both core and streaming.
resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"),
// For Sonatype publishing
// Shared between both examples and streaming.
resolvers ++= Seq("Mqtt Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/"),
// For Sonatype publishing
resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"),
@ -282,10 +285,11 @@ object SparkBuild extends Build {
"Apache repo" at "https://repository.apache.org/content/repositories/releases"
),
libraryDependencies ++= Seq(
"org.eclipse.paho" % "mqtt-client" % "0.4.0",
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy),
"org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
"com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty),
"org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1"
"org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1"
exclude("com.sun.jdmk", "jmxtools")
exclude("com.sun.jmx", "jmxri")
)

View file

@ -136,6 +136,11 @@
<artifactId>slf4j-log4j12</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.paho</groupId>
<artifactId>mqtt-client</artifactId>
<version>0.4.0</version>
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.version}/classes</outputDirectory>

View file

@ -462,6 +462,21 @@ class StreamingContext private (
inputStream
}
/**
* Create an input stream that receives messages pushed by a mqtt publisher.
* @param brokerUrl Url of remote mqtt publisher
* @param topic topic name to subscribe to
* @param storageLevel RDD storage level. Defaults to memory-only.
*/
def mqttStream(
brokerUrl: String,
topic: String,
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2): DStream[String] = {
val inputStream = new MQTTInputDStream[String](this, brokerUrl, topic, storageLevel)
registerInputStream(inputStream)
inputStream
}
/**
* Create a unified DStream from multiple DStreams of the same type and same slide duration.
*/

View file

@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.dstream
import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{ Time, DStreamCheckpointData, StreamingContext }
import java.util.Properties
import java.util.concurrent.Executors
import java.io.IOException
import org.eclipse.paho.client.mqttv3.MqttCallback
import org.eclipse.paho.client.mqttv3.MqttClient
import org.eclipse.paho.client.mqttv3.MqttClientPersistence
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken
import org.eclipse.paho.client.mqttv3.MqttException
import org.eclipse.paho.client.mqttv3.MqttMessage
import org.eclipse.paho.client.mqttv3.MqttTopic
import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
/**
* Input stream that subscribe messages from a Mqtt Broker.
* Uses eclipse paho as MqttClient http://www.eclipse.org/paho/
* @param brokerUrl Url of remote mqtt publisher
* @param topic topic name to subscribe to
* @param storageLevel RDD storage level.
*/
private[streaming]
class MQTTInputDStream[T: ClassManifest](
@transient ssc_ : StreamingContext,
brokerUrl: String,
topic: String,
storageLevel: StorageLevel
) extends NetworkInputDStream[T](ssc_) with Logging {
def getReceiver(): NetworkReceiver[T] = {
new MQTTReceiver(brokerUrl, topic, storageLevel)
.asInstanceOf[NetworkReceiver[T]]
}
}
private[streaming]
class MQTTReceiver(brokerUrl: String,
topic: String,
storageLevel: StorageLevel
) extends NetworkReceiver[Any] {
lazy protected val blockGenerator = new BlockGenerator(storageLevel)
def onStop() {
blockGenerator.stop()
}
def onStart() {
blockGenerator.start()
// Set up persistence for messages
var peristance: MqttClientPersistence = new MemoryPersistence()
// Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance
var client: MqttClient = new MqttClient(brokerUrl, "MQTTSub", peristance)
// Connect to MqttBroker
client.connect()
// Subscribe to Mqtt topic
client.subscribe(topic)
// Callback automatically triggers as and when new message arrives on specified topic
var callback: MqttCallback = new MqttCallback() {
// Handles Mqtt message
override def messageArrived(arg0: String, arg1: MqttMessage) {
blockGenerator += new String(arg1.getPayload())
}
override def deliveryComplete(arg0: IMqttDeliveryToken) {
}
override def connectionLost(arg0: Throwable) {
logInfo("Connection lost " + arg0)
}
}
// Set up callback for MqttClient
client.setCallback(callback)
}
}