[SPARK-5154] [PySpark] [Streaming] Kafka streaming support in Python

This PR brings the Python API for Spark Streaming Kafka data source.

```
    class KafkaUtils(__builtin__.object)
     |  Static methods defined here:
     |
     |  createStream(ssc, zkQuorum, groupId, topics, storageLevel=StorageLevel(True, True, False, False,
2), keyDecoder=<function utf8_decoder>, valueDecoder=<function utf8_decoder>)
     |      Create an input stream that pulls messages from a Kafka Broker.
     |
     |      :param ssc:  StreamingContext object
     |      :param zkQuorum:  Zookeeper quorum (hostname:port,hostname:port,..).
     |      :param groupId:  The group id for this consumer.
     |      :param topics:  Dict of (topic_name -> numPartitions) to consume.
     |                      Each partition is consumed in its own thread.
     |      :param storageLevel:  RDD storage level.
     |      :param keyDecoder:  A function used to decode key
     |      :param valueDecoder:  A function used to decode value
     |      :return: A DStream object
```
run the example:

```
bin/spark-submit --driver-class-path external/kafka-assembly/target/scala-*/spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py localhost:2181 test
```

Author: Davies Liu <davies@databricks.com>
Author: Tathagata Das <tdas@databricks.com>

Closes #3715 from davies/kafka and squashes the following commits:

d93bfe0 [Davies Liu] Update make-distribution.sh
4280d04 [Davies Liu] address comments
e6d0427 [Davies Liu] Merge branch 'master' of github.com:apache/spark into kafka
f257071 [Davies Liu] add tests for null in RDD
23b039a [Davies Liu] address comments
9af51c4 [Davies Liu] Merge branch 'kafka' of github.com:davies/spark into kafka
a74da87 [Davies Liu] address comments
dc1eed0 [Davies Liu] Update kafka_wordcount.py
31e2317 [Davies Liu] Update kafka_wordcount.py
370ba61 [Davies Liu] Update kafka.py
97386b3 [Davies Liu] address comment
2c567a5 [Davies Liu] update logging and comment
33730d1 [Davies Liu] Merge branch 'master' of github.com:apache/spark into kafka
adeeb38 [Davies Liu] Merge pull request #3 from tdas/kafka-python-api
aea8953 [Tathagata Das] Kafka-assembly for Python API
eea16a7 [Davies Liu] refactor
f6ce899 [Davies Liu] add example and fix bugs
98c8d17 [Davies Liu] fix python style
5697a01 [Davies Liu] bypass decoder in scala
048dbe6 [Davies Liu] fix python style
75d485e [Davies Liu] add mqtt
07923c4 [Davies Liu] support kafka in Python
This commit is contained in:
Davies Liu 2015-02-02 19:16:27 -08:00 committed by Tathagata Das
parent 554403fd91
commit 0561c45449
10 changed files with 313 additions and 58 deletions

View file

@ -316,6 +316,7 @@ private object SpecialLengths {
val PYTHON_EXCEPTION_THROWN = -2 val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3 val TIMING_DATA = -3
val END_OF_STREAM = -4 val END_OF_STREAM = -4
val NULL = -5
} }
private[spark] object PythonRDD extends Logging { private[spark] object PythonRDD extends Logging {
@ -374,54 +375,25 @@ private[spark] object PythonRDD extends Logging {
} }
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
// The right way to implement this would be to use TypeTags to get the full
// type of T. Since I don't want to introduce breaking changes throughout the def write(obj: Any): Unit = obj match {
// entire Spark API, I have to use this hacky approach: case null =>
if (iter.hasNext) { dataOut.writeInt(SpecialLengths.NULL)
val first = iter.next() case arr: Array[Byte] =>
val newIter = Seq(first).iterator ++ iter dataOut.writeInt(arr.length)
first match { dataOut.write(arr)
case arr: Array[Byte] => case str: String =>
newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes => writeUTF(str, dataOut)
dataOut.writeInt(bytes.length) case stream: PortableDataStream =>
dataOut.write(bytes) write(stream.toArray())
} case (key, value) =>
case string: String => write(key)
newIter.asInstanceOf[Iterator[String]].foreach { str => write(value)
writeUTF(str, dataOut) case other =>
} throw new SparkException("Unexpected element type " + other.getClass)
case stream: PortableDataStream =>
newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
val bytes = stream.toArray()
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
case (key: String, stream: PortableDataStream) =>
newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
case (key, stream) =>
writeUTF(key, dataOut)
val bytes = stream.toArray()
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
case (key: String, value: String) =>
newIter.asInstanceOf[Iterator[(String, String)]].foreach {
case (key, value) =>
writeUTF(key, dataOut)
writeUTF(value, dataOut)
}
case (key: Array[Byte], value: Array[Byte]) =>
newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
case (key, value) =>
dataOut.writeInt(key.length)
dataOut.write(key)
dataOut.writeInt(value.length)
dataOut.write(value)
}
case other =>
throw new SparkException("Unexpected element type " + first.getClass)
}
} }
iter.foreach(write)
} }
/** /**

View file

@ -22,6 +22,7 @@ import java.io.{File, InputStream, IOException, OutputStream}
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkContext import org.apache.spark.SparkContext
import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
private[spark] object PythonUtils { private[spark] object PythonUtils {
/** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */ /** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */
@ -39,4 +40,8 @@ private[spark] object PythonUtils {
def mergePythonPaths(paths: String*): String = { def mergePythonPaths(paths: String*): String = {
paths.filter(_ != "").mkString(File.pathSeparator) paths.filter(_ != "").mkString(File.pathSeparator)
} }
def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = {
sc.parallelize(List("a", null, "b"))
}
} }

View file

@ -23,11 +23,22 @@ import org.scalatest.FunSuite
class PythonRDDSuite extends FunSuite { class PythonRDDSuite extends FunSuite {
test("Writing large strings to the worker") { test("Writing large strings to the worker") {
val input: List[String] = List("a"*100000) val input: List[String] = List("a"*100000)
val buffer = new DataOutputStream(new ByteArrayOutputStream) val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(input.iterator, buffer) PythonRDD.writeIteratorToStream(input.iterator, buffer)
} }
test("Handle nulls gracefully") {
val buffer = new DataOutputStream(new ByteArrayOutputStream)
// Should not have NPE when write an Iterator with null in it
// The correctness will be tested in Python
PythonRDD.writeIteratorToStream(Iterator("a", null), buffer)
PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer)
PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer)
PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer)
PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer)
PythonRDD.writeIteratorToStream(
Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer)
}
} }

View file

@ -0,0 +1,54 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
Usage: network_wordcount.py <zk> <topic>
To run this on your local machine, you need to setup Kafka and create a producer first, see
http://kafka.apache.org/documentation.html#quickstart
and then run the example
`$ bin/spark-submit --driver-class-path external/kafka-assembly/target/scala-*/\
spark-streaming-kafka-assembly-*.jar examples/src/main/python/streaming/kafka_wordcount.py \
localhost:2181 test`
"""
import sys
from pyspark import SparkContext
from pyspark.streaming import StreamingContext
from pyspark.streaming.kafka import KafkaUtils
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: kafka_wordcount.py <zk> <topic>"
exit(-1)
sc = SparkContext(appName="PythonStreamingKafkaWordCount")
ssc = StreamingContext(sc, 1)
zkQuorum, topic = sys.argv[1:]
kvs = KafkaUtils.createStream(ssc, zkQuorum, "spark-streaming-consumer", {topic: 1})
lines = kvs.map(lambda x: x[1])
counts = lines.flatMap(lambda line: line.split(" ")) \
.map(lambda word: (word, 1)) \
.reduceByKey(lambda a, b: a+b)
counts.pprint()
ssc.start()
ssc.awaitTermination()

106
external/kafka-assembly/pom.xml vendored Normal file
View file

@ -0,0 +1,106 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Licensed to the Apache Software Foundation (ASF) under one or more
~ contributor license agreements. See the NOTICE file distributed with
~ this work for additional information regarding copyright ownership.
~ The ASF licenses this file to You under the Apache License, Version 2.0
~ (the "License"); you may not use this file except in compliance with
~ the License. You may obtain a copy of the License at
~
~ http://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
<version>1.3.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka-assembly_2.10</artifactId>
<packaging>jar</packaging>
<name>Spark Project External Kafka Assembly</name>
<url>http://spark.apache.org/</url>
<properties>
<sbt.project.name>streaming-kafka-assembly</sbt.project.name>
<spark.jar.dir>scala-${scala.binary.version}</spark.jar.dir>
<spark.jar.basename>spark-streaming-kafka-assembly-${project.version}.jar</spark.jar.basename>
<spark.jar>${project.build.directory}/${spark.jar.dir}/${spark.jar.basename}</spark.jar>
</properties>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming-kafka_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
<testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-shade-plugin</artifactId>
<configuration>
<shadedArtifactAttached>false</shadedArtifactAttached>
<outputFile>${spark.jar}</outputFile>
<artifactSet>
<includes>
<include>*:*</include>
</includes>
</artifactSet>
<filters>
<filter>
<artifact>*:*</artifact>
<excludes>
<exclude>META-INF/*.SF</exclude>
<exclude>META-INF/*.DSA</exclude>
<exclude>META-INF/*.RSA</exclude>
</excludes>
</filter>
</filters>
</configuration>
<executions>
<execution>
<phase>package</phase>
<goals>
<goal>shade</goal>
</goals>
<configuration>
<transformers>
<transformer implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/>
<transformer implementation="org.apache.maven.plugins.shade.resource.AppendingTransformer">
<resource>reference.conf</resource>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.DontIncludeResourceTransformer">
<resource>log4j.properties</resource>
</transformer>
<transformer implementation="org.apache.maven.plugins.shade.resource.ApacheLicenseResourceTransformer"/>
<transformer implementation="org.apache.maven.plugins.shade.resource.ApacheNoticeResourceTransformer"/>
</transformers>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>

View file

@ -1629,6 +1629,7 @@
</properties> </properties>
<modules> <modules>
<module>external/kafka</module> <module>external/kafka</module>
<module>external/kafka-assembly</module>
</modules> </modules>
</profile> </profile>

View file

@ -44,8 +44,9 @@ object BuildCommons {
sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl", sparkKinesisAsl) = Seq("yarn", "yarn-stable", "java8-tests", "ganglia-lgpl",
"kinesis-asl").map(ProjectRef(buildLocation, _)) "kinesis-asl").map(ProjectRef(buildLocation, _))
val assemblyProjects@Seq(assembly, examples, networkYarn) = val assemblyProjects@Seq(assembly, examples, networkYarn, streamingKafkaAssembly) =
Seq("assembly", "examples", "network-yarn").map(ProjectRef(buildLocation, _)) Seq("assembly", "examples", "network-yarn", "streaming-kafka-assembly")
.map(ProjectRef(buildLocation, _))
val tools = ProjectRef(buildLocation, "tools") val tools = ProjectRef(buildLocation, "tools")
// Root project. // Root project.
@ -300,7 +301,14 @@ object Assembly {
sys.props.get("hadoop.version") sys.props.get("hadoop.version")
.getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String]) .getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String])
}, },
jarName in assembly := s"${moduleName.value}-${version.value}-hadoop${hadoopVersion.value}.jar", jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) =>
if (mName.contains("streaming-kafka-assembly")) {
// This must match the same name used in maven (see external/kafka-assembly/pom.xml)
s"${mName}-${v}.jar"
} else {
s"${mName}-${v}-hadoop${hv}.jar"
}
},
mergeStrategy in assembly := { mergeStrategy in assembly := {
case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard

View file

@ -70,6 +70,7 @@ class SpecialLengths(object):
PYTHON_EXCEPTION_THROWN = -2 PYTHON_EXCEPTION_THROWN = -2
TIMING_DATA = -3 TIMING_DATA = -3
END_OF_STREAM = -4 END_OF_STREAM = -4
NULL = -5
class Serializer(object): class Serializer(object):
@ -133,6 +134,8 @@ class FramedSerializer(Serializer):
def _write_with_length(self, obj, stream): def _write_with_length(self, obj, stream):
serialized = self.dumps(obj) serialized = self.dumps(obj)
if serialized is None:
raise ValueError("serialized value should not be None")
if len(serialized) > (1 << 31): if len(serialized) > (1 << 31):
raise ValueError("can not serialize object larger than 2G") raise ValueError("can not serialize object larger than 2G")
write_int(len(serialized), stream) write_int(len(serialized), stream)
@ -145,8 +148,10 @@ class FramedSerializer(Serializer):
length = read_int(stream) length = read_int(stream)
if length == SpecialLengths.END_OF_DATA_SECTION: if length == SpecialLengths.END_OF_DATA_SECTION:
raise EOFError raise EOFError
elif length == SpecialLengths.NULL:
return None
obj = stream.read(length) obj = stream.read(length)
if obj == "": if len(obj) < length:
raise EOFError raise EOFError
return self.loads(obj) return self.loads(obj)
@ -484,6 +489,8 @@ class UTF8Deserializer(Serializer):
length = read_int(stream) length = read_int(stream)
if length == SpecialLengths.END_OF_DATA_SECTION: if length == SpecialLengths.END_OF_DATA_SECTION:
raise EOFError raise EOFError
elif length == SpecialLengths.NULL:
return None
s = stream.read(length) s = stream.read(length)
return s.decode("utf-8") if self.use_unicode else s return s.decode("utf-8") if self.use_unicode else s

View file

@ -0,0 +1,83 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from py4j.java_collections import MapConverter
from py4j.java_gateway import java_import, Py4JError
from pyspark.storagelevel import StorageLevel
from pyspark.serializers import PairDeserializer, NoOpSerializer
from pyspark.streaming import DStream
__all__ = ['KafkaUtils', 'utf8_decoder']
def utf8_decoder(s):
""" Decode the unicode as UTF-8 """
return s and s.decode('utf-8')
class KafkaUtils(object):
@staticmethod
def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
"""
Create an input stream that pulls messages from a Kafka Broker.
:param ssc: StreamingContext object
:param zkQuorum: Zookeeper quorum (hostname:port,hostname:port,..).
:param groupId: The group id for this consumer.
:param topics: Dict of (topic_name -> numPartitions) to consume.
Each partition is consumed in its own thread.
:param kafkaParams: Additional params for Kafka
:param storageLevel: RDD storage level.
:param keyDecoder: A function used to decode key (default is utf8_decoder)
:param valueDecoder: A function used to decode value (default is utf8_decoder)
:return: A DStream object
"""
java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils")
kafkaParams.update({
"zookeeper.connect": zkQuorum,
"group.id": groupId,
"zookeeper.connection.timeout.ms": "10000",
})
if not isinstance(topics, dict):
raise TypeError("topics should be dict")
jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client)
jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client)
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
def getClassByName(name):
return ssc._jvm.org.apache.spark.util.Utils.classForName(name)
try:
array = getClassByName("[B")
decoder = getClassByName("kafka.serializer.DefaultDecoder")
jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, array, array, decoder, decoder,
jparam, jtopics, jlevel)
except Py4JError, e:
# TODO: use --jar once it also work on driver
if not e.message or 'call a package' in e.message:
print "No kafka package, please put the assembly jar into classpath:"
print " $ bin/spark-submit --driver-class-path external/kafka-assembly/target/" + \
"scala-*/spark-streaming-kafka-assembly-*.jar"
raise e
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
stream = DStream(jstream, ssc, ser)
return stream.map(lambda (k, v): (keyDecoder(k), valueDecoder(v)))

View file

@ -47,9 +47,10 @@ else:
from pyspark.conf import SparkConf from pyspark.conf import SparkConf
from pyspark.context import SparkContext from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.files import SparkFiles from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer, CompressedSerializer CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType UserDefinedType, DoubleType
@ -716,6 +717,13 @@ class RDDTests(ReusedPySparkTestCase):
wr_s21 = rdd.sample(True, 0.4, 21).collect() wr_s21 = rdd.sample(True, 0.4, 21).collect()
self.assertNotEqual(set(wr_s11), set(wr_s21)) self.assertNotEqual(set(wr_s11), set(wr_s21))
def test_null_in_rdd(self):
jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc)
rdd = RDD(jrdd, self.sc, UTF8Deserializer())
self.assertEqual([u"a", None, u"b"], rdd.collect())
rdd = RDD(jrdd, self.sc, NoOpSerializer())
self.assertEqual(["a", None, "b"], rdd.collect())
def test_multiple_python_java_RDD_conversions(self): def test_multiple_python_java_RDD_conversions(self):
# Regression test for SPARK-5361 # Regression test for SPARK-5361
data = [ data = [