Merge branch 'apache-master' into transform
This commit is contained in:
commit
0400aba1c0
|
@ -48,6 +48,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
|
|||
*/
|
||||
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
* This method blocks until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist())
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
*
|
||||
* @param blocking Whether to block until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking))
|
||||
|
||||
// first() has to be overriden here in order for its return type to be Double instead of Object.
|
||||
override def first(): Double = srdd.first()
|
||||
|
||||
|
|
|
@ -65,6 +65,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
|
|||
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
|
||||
new JavaPairRDD[K, V](rdd.persist(newLevel))
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
* This method blocks until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist())
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
*
|
||||
* @param blocking Whether to block until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking))
|
||||
|
||||
// Transformations (return a new RDD)
|
||||
|
||||
/**
|
||||
|
|
|
@ -41,9 +41,17 @@ JavaRDDLike[T, JavaRDD[T]] {
|
|||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
* This method blocks until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
*
|
||||
* @param blocking Whether to block until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking))
|
||||
|
||||
// Transformations (return a new RDD)
|
||||
|
||||
/**
|
||||
|
|
|
@ -68,6 +68,11 @@ class DAGScheduler(
|
|||
eventQueue.put(BeginEvent(task, taskInfo))
|
||||
}
|
||||
|
||||
// Called to report that a task has completed and results are being fetched remotely.
|
||||
def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
|
||||
eventQueue.put(GettingResultEvent(task, taskInfo))
|
||||
}
|
||||
|
||||
// Called by TaskScheduler to report task completions or failures.
|
||||
def taskEnded(
|
||||
task: Task[_],
|
||||
|
@ -415,6 +420,9 @@ class DAGScheduler(
|
|||
case begin: BeginEvent =>
|
||||
listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
|
||||
|
||||
case gettingResult: GettingResultEvent =>
|
||||
listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo))
|
||||
|
||||
case completion: CompletionEvent =>
|
||||
listenerBus.post(SparkListenerTaskEnd(
|
||||
completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
|
||||
|
|
|
@ -53,6 +53,9 @@ private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
|
|||
private[scheduler]
|
||||
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
|
||||
|
||||
private[scheduler]
|
||||
case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
|
||||
|
||||
private[scheduler] case class CompletionEvent(
|
||||
task: Task[_],
|
||||
reason: TaskEndReason,
|
||||
|
|
|
@ -31,6 +31,9 @@ case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
|
|||
|
||||
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
|
||||
|
||||
case class SparkListenerTaskGettingResult(
|
||||
task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
|
||||
|
||||
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
|
||||
taskMetrics: TaskMetrics) extends SparkListenerEvents
|
||||
|
||||
|
@ -56,6 +59,12 @@ trait SparkListener {
|
|||
*/
|
||||
def onTaskStart(taskStart: SparkListenerTaskStart) { }
|
||||
|
||||
/**
|
||||
* Called when a task begins remotely fetching its result (will not be called for tasks that do
|
||||
* not need to fetch the result remotely).
|
||||
*/
|
||||
def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
|
||||
|
||||
/**
|
||||
* Called when a task ends
|
||||
*/
|
||||
|
|
|
@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging {
|
|||
sparkListeners.foreach(_.onJobEnd(jobEnd))
|
||||
case taskStart: SparkListenerTaskStart =>
|
||||
sparkListeners.foreach(_.onTaskStart(taskStart))
|
||||
case taskGettingResult: SparkListenerTaskGettingResult =>
|
||||
sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
|
||||
case taskEnd: SparkListenerTaskEnd =>
|
||||
sparkListeners.foreach(_.onTaskEnd(taskEnd))
|
||||
case _ =>
|
||||
|
|
|
@ -31,9 +31,25 @@ class TaskInfo(
|
|||
val host: String,
|
||||
val taskLocality: TaskLocality.TaskLocality) {
|
||||
|
||||
/**
|
||||
* The time when the task started remotely getting the result. Will not be set if the
|
||||
* task result was sent immediately when the task finished (as opposed to sending an
|
||||
* IndirectTaskResult and later fetching the result from the block manager).
|
||||
*/
|
||||
var gettingResultTime: Long = 0
|
||||
|
||||
/**
|
||||
* The time when the task has completed successfully (including the time to remotely fetch
|
||||
* results, if necessary).
|
||||
*/
|
||||
var finishTime: Long = 0
|
||||
|
||||
var failed = false
|
||||
|
||||
def markGettingResult(time: Long = System.currentTimeMillis) {
|
||||
gettingResultTime = time
|
||||
}
|
||||
|
||||
def markSuccessful(time: Long = System.currentTimeMillis) {
|
||||
finishTime = time
|
||||
}
|
||||
|
@ -43,6 +59,8 @@ class TaskInfo(
|
|||
failed = true
|
||||
}
|
||||
|
||||
def gettingResult: Boolean = gettingResultTime != 0
|
||||
|
||||
def finished: Boolean = finishTime != 0
|
||||
|
||||
def successful: Boolean = finished && !failed
|
||||
|
@ -52,6 +70,8 @@ class TaskInfo(
|
|||
def status: String = {
|
||||
if (running)
|
||||
"RUNNING"
|
||||
else if (gettingResult)
|
||||
"GET RESULT"
|
||||
else if (failed)
|
||||
"FAILED"
|
||||
else if (successful)
|
||||
|
|
|
@ -306,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
}
|
||||
}
|
||||
|
||||
def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
|
||||
taskSetManager.handleTaskGettingResult(tid)
|
||||
}
|
||||
|
||||
def handleSuccessfulTask(
|
||||
taskSetManager: ClusterTaskSetManager,
|
||||
tid: Long,
|
||||
|
|
|
@ -418,6 +418,12 @@ private[spark] class ClusterTaskSetManager(
|
|||
sched.dagScheduler.taskStarted(task, info)
|
||||
}
|
||||
|
||||
def handleTaskGettingResult(tid: Long) = {
|
||||
val info = taskInfos(tid)
|
||||
info.markGettingResult()
|
||||
sched.dagScheduler.taskGettingResult(tasks(info.index), info)
|
||||
}
|
||||
|
||||
/**
|
||||
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
|
||||
*/
|
||||
|
|
|
@ -50,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
|
|||
case directResult: DirectTaskResult[_] => directResult
|
||||
case IndirectTaskResult(blockId) =>
|
||||
logDebug("Fetching indirect task result for TID %s".format(tid))
|
||||
scheduler.handleTaskGettingResult(taskSetManager, tid)
|
||||
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
|
||||
if (!serializedTaskResult.isDefined) {
|
||||
/* We won't be able to get the task result if the machine that ran the task failed
|
||||
|
|
|
@ -115,7 +115,13 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
|
|||
taskList += ((taskStart.taskInfo, None, None))
|
||||
stageIdToTaskInfos(sid) = taskList
|
||||
}
|
||||
|
||||
|
||||
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult)
|
||||
= synchronized {
|
||||
// Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in
|
||||
// stageToTaskInfos already has the updated status.
|
||||
}
|
||||
|
||||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
|
||||
val sid = taskEnd.task.stageId
|
||||
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
|
||||
|
|
|
@ -17,22 +17,25 @@
|
|||
|
||||
package org.apache.spark.scheduler
|
||||
|
||||
import org.scalatest.{BeforeAndAfter, FunSuite}
|
||||
import org.apache.spark.{LocalSparkContext, SparkContext}
|
||||
import scala.collection.mutable
|
||||
import scala.collection.mutable.{Buffer, HashSet}
|
||||
|
||||
import org.scalatest.{BeforeAndAfterAll, FunSuite}
|
||||
import org.scalatest.matchers.ShouldMatchers
|
||||
|
||||
import org.apache.spark.{LocalSparkContext, SparkContext}
|
||||
import org.apache.spark.SparkContext._
|
||||
|
||||
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
|
||||
with BeforeAndAfter {
|
||||
with BeforeAndAfterAll {
|
||||
/** Length of time to wait while draining listener events. */
|
||||
val WAIT_TIMEOUT_MILLIS = 10000
|
||||
|
||||
before {
|
||||
sc = new SparkContext("local", "DAGSchedulerSuite")
|
||||
override def afterAll {
|
||||
System.clearProperty("spark.akka.frameSize")
|
||||
}
|
||||
|
||||
test("basic creation of StageInfo") {
|
||||
sc = new SparkContext("local", "DAGSchedulerSuite")
|
||||
val listener = new SaveStageInfo
|
||||
sc.addSparkListener(listener)
|
||||
val rdd1 = sc.parallelize(1 to 100, 4)
|
||||
|
@ -53,6 +56,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
|
|||
}
|
||||
|
||||
test("StageInfo with fewer tasks than partitions") {
|
||||
sc = new SparkContext("local", "DAGSchedulerSuite")
|
||||
val listener = new SaveStageInfo
|
||||
sc.addSparkListener(listener)
|
||||
val rdd1 = sc.parallelize(1 to 100, 4)
|
||||
|
@ -68,6 +72,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
|
|||
}
|
||||
|
||||
test("local metrics") {
|
||||
sc = new SparkContext("local", "DAGSchedulerSuite")
|
||||
val listener = new SaveStageInfo
|
||||
sc.addSparkListener(listener)
|
||||
sc.addSparkListener(new StatsReportListener)
|
||||
|
@ -129,15 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
|
|||
}
|
||||
}
|
||||
|
||||
test("onTaskGettingResult() called when result fetched remotely") {
|
||||
// Need to use local cluster mode here, because results are not ever returned through the
|
||||
// block manager when using the LocalScheduler.
|
||||
sc = new SparkContext("local-cluster[1,1,512]", "test")
|
||||
|
||||
val listener = new SaveTaskEvents
|
||||
sc.addSparkListener(listener)
|
||||
|
||||
// Make a task whose result is larger than the akka frame size
|
||||
System.setProperty("spark.akka.frameSize", "1")
|
||||
val akkaFrameSize =
|
||||
sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
|
||||
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x)
|
||||
assert(result === 1.to(akkaFrameSize).toArray)
|
||||
|
||||
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
|
||||
val TASK_INDEX = 0
|
||||
assert(listener.startedTasks.contains(TASK_INDEX))
|
||||
assert(listener.startedGettingResultTasks.contains(TASK_INDEX))
|
||||
assert(listener.endedTasks.contains(TASK_INDEX))
|
||||
}
|
||||
|
||||
test("onTaskGettingResult() not called when result sent directly") {
|
||||
// Need to use local cluster mode here, because results are not ever returned through the
|
||||
// block manager when using the LocalScheduler.
|
||||
sc = new SparkContext("local-cluster[1,1,512]", "test")
|
||||
|
||||
val listener = new SaveTaskEvents
|
||||
sc.addSparkListener(listener)
|
||||
|
||||
// Make a task whose result is larger than the akka frame size
|
||||
val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
|
||||
assert(result === 2)
|
||||
|
||||
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
|
||||
val TASK_INDEX = 0
|
||||
assert(listener.startedTasks.contains(TASK_INDEX))
|
||||
assert(listener.startedGettingResultTasks.isEmpty == true)
|
||||
assert(listener.endedTasks.contains(TASK_INDEX))
|
||||
}
|
||||
|
||||
def checkNonZeroAvg(m: Traversable[Long], msg: String) {
|
||||
assert(m.sum / m.size.toDouble > 0.0, msg)
|
||||
}
|
||||
|
||||
class SaveStageInfo extends SparkListener {
|
||||
val stageInfos = mutable.Buffer[StageInfo]()
|
||||
val stageInfos = Buffer[StageInfo]()
|
||||
override def onStageCompleted(stage: StageCompleted) {
|
||||
stageInfos += stage.stage
|
||||
}
|
||||
}
|
||||
|
||||
class SaveTaskEvents extends SparkListener {
|
||||
val startedTasks = new HashSet[Int]()
|
||||
val startedGettingResultTasks = new HashSet[Int]()
|
||||
val endedTasks = new HashSet[Int]()
|
||||
|
||||
override def onTaskStart(taskStart: SparkListenerTaskStart) {
|
||||
startedTasks += taskStart.taskInfo.index
|
||||
}
|
||||
|
||||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
|
||||
endedTasks += taskEnd.taskInfo.index
|
||||
}
|
||||
|
||||
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
|
||||
startedGettingResultTasks += taskGettingResult.taskInfo.index
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.streaming.examples
|
||||
|
||||
import org.apache.spark.streaming.{ Seconds, StreamingContext }
|
||||
import org.apache.spark.streaming.StreamingContext._
|
||||
import org.apache.spark.streaming.dstream.MQTTReceiver
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
import org.eclipse.paho.client.mqttv3.MqttClient
|
||||
import org.eclipse.paho.client.mqttv3.MqttClientPersistence
|
||||
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
|
||||
import org.eclipse.paho.client.mqttv3.MqttException
|
||||
import org.eclipse.paho.client.mqttv3.MqttMessage
|
||||
import org.eclipse.paho.client.mqttv3.MqttTopic
|
||||
|
||||
/**
|
||||
* A simple Mqtt publisher for demonstration purposes, repeatedly publishes
|
||||
* Space separated String Message "hello mqtt demo for spark streaming"
|
||||
*/
|
||||
object MQTTPublisher {
|
||||
|
||||
var client: MqttClient = _
|
||||
|
||||
def main(args: Array[String]) {
|
||||
if (args.length < 2) {
|
||||
System.err.println("Usage: MQTTPublisher <MqttBrokerUrl> <topic>")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
val Seq(brokerUrl, topic) = args.toSeq
|
||||
|
||||
try {
|
||||
var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp")
|
||||
client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance)
|
||||
} catch {
|
||||
case e: MqttException => println("Exception Caught: " + e)
|
||||
}
|
||||
|
||||
client.connect()
|
||||
|
||||
val msgtopic: MqttTopic = client.getTopic(topic);
|
||||
val msg: String = "hello mqtt demo for spark streaming"
|
||||
|
||||
while (true) {
|
||||
val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes())
|
||||
msgtopic.publish(message);
|
||||
println("Published data. topic: " + msgtopic.getName() + " Message: " + message)
|
||||
}
|
||||
client.disconnect()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A sample wordcount with MqttStream stream
|
||||
*
|
||||
* To work with Mqtt, Mqtt Message broker/server required.
|
||||
* Mosquitto (http://mosquitto.org/) is an open source Mqtt Broker
|
||||
* In ubuntu mosquitto can be installed using the command `$ sudo apt-get install mosquitto`
|
||||
* Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/
|
||||
* Example Java code for Mqtt Publisher and Subscriber can be found here https://bitbucket.org/mkjinesh/mqttclient
|
||||
* Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>
|
||||
* In local mode, <master> should be 'local[n]' with n > 1
|
||||
* <MqttbrokerUrl> and <topic> describe where Mqtt publisher is running.
|
||||
*
|
||||
* To run this example locally, you may run publisher as
|
||||
* `$ ./run-example org.apache.spark.streaming.examples.MQTTPublisher tcp://localhost:1883 foo`
|
||||
* and run the example as
|
||||
* `$ ./run-example org.apache.spark.streaming.examples.MQTTWordCount local[2] tcp://localhost:1883 foo`
|
||||
*/
|
||||
object MQTTWordCount {
|
||||
|
||||
def main(args: Array[String]) {
|
||||
if (args.length < 3) {
|
||||
System.err.println(
|
||||
"Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>" +
|
||||
" In local mode, <master> should be 'local[n]' with n > 1")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
val Seq(master, brokerUrl, topic) = args.toSeq
|
||||
|
||||
val ssc = new StreamingContext(master, "MqttWordCount", Seconds(2), System.getenv("SPARK_HOME"),
|
||||
Seq(System.getenv("SPARK_EXAMPLES_JAR")))
|
||||
val lines = ssc.mqttStream(brokerUrl, topic, StorageLevel.MEMORY_ONLY)
|
||||
|
||||
val words = lines.flatMap(x => x.toString.split(" "))
|
||||
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
|
||||
wordCounts.print()
|
||||
ssc.start()
|
||||
}
|
||||
}
|
11
pom.xml
11
pom.xml
|
@ -147,6 +147,17 @@
|
|||
<enabled>false</enabled>
|
||||
</snapshots>
|
||||
</repository>
|
||||
<repository>
|
||||
<id>mqtt-repo</id>
|
||||
<name>MQTT Repository</name>
|
||||
<url>https://repo.eclipse.org/content/repositories/paho-releases/</url>
|
||||
<releases>
|
||||
<enabled>true</enabled>
|
||||
</releases>
|
||||
<snapshots>
|
||||
<enabled>false</enabled>
|
||||
</snapshots>
|
||||
</repository>
|
||||
</repositories>
|
||||
<pluginRepositories>
|
||||
<pluginRepository>
|
||||
|
|
|
@ -108,7 +108,10 @@ object SparkBuild extends Build {
|
|||
// Shared between both core and streaming.
|
||||
resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"),
|
||||
|
||||
// For Sonatype publishing
|
||||
// Shared between both examples and streaming.
|
||||
resolvers ++= Seq("Mqtt Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/"),
|
||||
|
||||
// For Sonatype publishing
|
||||
resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
|
||||
"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"),
|
||||
|
||||
|
@ -282,10 +285,11 @@ object SparkBuild extends Build {
|
|||
"Apache repo" at "https://repository.apache.org/content/repositories/releases"
|
||||
),
|
||||
libraryDependencies ++= Seq(
|
||||
"org.eclipse.paho" % "mqtt-client" % "0.4.0",
|
||||
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy),
|
||||
"org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
|
||||
"com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty),
|
||||
"org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1"
|
||||
"org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1"
|
||||
exclude("com.sun.jdmk", "jmxtools")
|
||||
exclude("com.sun.jmx", "jmxri")
|
||||
)
|
||||
|
|
|
@ -136,6 +136,11 @@
|
|||
<artifactId>slf4j-log4j12</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.eclipse.paho</groupId>
|
||||
<artifactId>mqtt-client</artifactId>
|
||||
<version>0.4.0</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
<outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
|
||||
|
|
|
@ -462,6 +462,21 @@ class StreamingContext private (
|
|||
inputStream
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an input stream that receives messages pushed by a mqtt publisher.
|
||||
* @param brokerUrl Url of remote mqtt publisher
|
||||
* @param topic topic name to subscribe to
|
||||
* @param storageLevel RDD storage level. Defaults to memory-only.
|
||||
*/
|
||||
|
||||
def mqttStream(
|
||||
brokerUrl: String,
|
||||
topic: String,
|
||||
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2): DStream[String] = {
|
||||
val inputStream = new MQTTInputDStream[String](this, brokerUrl, topic, storageLevel)
|
||||
registerInputStream(inputStream)
|
||||
inputStream
|
||||
}
|
||||
/**
|
||||
* Create a unified DStream from multiple DStreams of the same type and same slide duration.
|
||||
*/
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.streaming.dstream
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.streaming.{ Time, DStreamCheckpointData, StreamingContext }
|
||||
|
||||
import java.util.Properties
|
||||
import java.util.concurrent.Executors
|
||||
import java.io.IOException
|
||||
|
||||
import org.eclipse.paho.client.mqttv3.MqttCallback
|
||||
import org.eclipse.paho.client.mqttv3.MqttClient
|
||||
import org.eclipse.paho.client.mqttv3.MqttClientPersistence
|
||||
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
|
||||
import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken
|
||||
import org.eclipse.paho.client.mqttv3.MqttException
|
||||
import org.eclipse.paho.client.mqttv3.MqttMessage
|
||||
import org.eclipse.paho.client.mqttv3.MqttTopic
|
||||
|
||||
import scala.collection.Map
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
/**
|
||||
* Input stream that subscribe messages from a Mqtt Broker.
|
||||
* Uses eclipse paho as MqttClient http://www.eclipse.org/paho/
|
||||
* @param brokerUrl Url of remote mqtt publisher
|
||||
* @param topic topic name to subscribe to
|
||||
* @param storageLevel RDD storage level.
|
||||
*/
|
||||
|
||||
private[streaming]
|
||||
class MQTTInputDStream[T: ClassManifest](
|
||||
@transient ssc_ : StreamingContext,
|
||||
brokerUrl: String,
|
||||
topic: String,
|
||||
storageLevel: StorageLevel
|
||||
) extends NetworkInputDStream[T](ssc_) with Logging {
|
||||
|
||||
def getReceiver(): NetworkReceiver[T] = {
|
||||
new MQTTReceiver(brokerUrl, topic, storageLevel)
|
||||
.asInstanceOf[NetworkReceiver[T]]
|
||||
}
|
||||
}
|
||||
|
||||
private[streaming]
|
||||
class MQTTReceiver(brokerUrl: String,
|
||||
topic: String,
|
||||
storageLevel: StorageLevel
|
||||
) extends NetworkReceiver[Any] {
|
||||
lazy protected val blockGenerator = new BlockGenerator(storageLevel)
|
||||
|
||||
def onStop() {
|
||||
blockGenerator.stop()
|
||||
}
|
||||
|
||||
def onStart() {
|
||||
|
||||
blockGenerator.start()
|
||||
|
||||
// Set up persistence for messages
|
||||
var peristance: MqttClientPersistence = new MemoryPersistence()
|
||||
|
||||
// Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance
|
||||
var client: MqttClient = new MqttClient(brokerUrl, "MQTTSub", peristance)
|
||||
|
||||
// Connect to MqttBroker
|
||||
client.connect()
|
||||
|
||||
// Subscribe to Mqtt topic
|
||||
client.subscribe(topic)
|
||||
|
||||
// Callback automatically triggers as and when new message arrives on specified topic
|
||||
var callback: MqttCallback = new MqttCallback() {
|
||||
|
||||
// Handles Mqtt message
|
||||
override def messageArrived(arg0: String, arg1: MqttMessage) {
|
||||
blockGenerator += new String(arg1.getPayload())
|
||||
}
|
||||
|
||||
override def deliveryComplete(arg0: IMqttDeliveryToken) {
|
||||
}
|
||||
|
||||
override def connectionLost(arg0: Throwable) {
|
||||
logInfo("Connection lost " + arg0)
|
||||
}
|
||||
}
|
||||
|
||||
// Set up callback for MqttClient
|
||||
client.setCallback(callback)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue