[SPARK-2377] Python API for Streaming

This patch brings Python API for Streaming.

This patch is based on work from @giwa

Author: giwa <ugw.gi.world@gmail.com>
Author: Ken Takagiwa <ken@Kens-MacBook-Pro.local>
Author: Davies Liu <davies.liu@gmail.com>
Author: Ken Takagiwa <ken@kens-mbp.gateway.sonic.net>
Author: Tathagata Das <tathagata.das1565@gmail.com>
Author: Ken <ugw.gi.world@gmail.com>
Author: Ken Takagiwa <ugw.gi.world@gmail.com>
Author: Matthew Farrellee <matt@redhat.com>

Closes #2538 from davies/streaming and squashes the following commits:

64561e4 [Davies Liu] fix tests
331ecce [Davies Liu] fix example
3e2492b [Davies Liu] change updateStateByKey() to easy API
182be73 [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming
02d0575 [Davies Liu] add wrapper for foreachRDD()
bebeb4a [Davies Liu] address all comments
6db00da [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming
8380064 [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming
52c535b [Davies Liu] remove fix for sum()
e108ec1 [Davies Liu]  address comments
37fe06f [Davies Liu] use random port for callback server
d05871e [Davies Liu] remove reuse of PythonRDD
be5e5ff [Davies Liu] merge branch of env, make tests stable.
8071541 [Davies Liu] Merge branch 'env' into streaming
c7bbbce [Davies Liu] fix sphinx docs
6bb9d91 [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming
4d0ea8b [Davies Liu] clear reference of SparkEnv after stop
54bd92b [Davies Liu] improve tests
c2b31cb [Davies Liu] Merge branch 'master' of github.com:apache/spark into streaming
7a88f9f [Davies Liu] rollback RDD.setContext(), use textFileStream() to test checkpointing
bd8a4c2 [Davies Liu] fix scala style
7797c70 [Davies Liu] refactor
ff88bec [Davies Liu] rename RDDFunction to TransformFunction
d328aca [Davies Liu] fix serializer in queueStream
6f0da2f [Davies Liu] recover from checkpoint
fa7261b [Davies Liu] refactor
a13ff34 [Davies Liu] address comments
8466916 [Davies Liu] support checkpoint
9a16bd1 [Davies Liu] change number of partitions during tests
b98d63f [Davies Liu] change private[spark] to private[python]
eed6e2a [Davies Liu] rollback not needed changes
e00136b [Davies Liu] address comments
069a94c [Davies Liu] fix the number of partitions during window()
338580a [Davies Liu] change _first(), _take(), _collect() as private API
19797f9 [Davies Liu] clean up
6ebceca [Davies Liu] add more tests
c40c52d [Davies Liu] change first(), take(n) to has the same behavior as RDD
98ac6c2 [Davies Liu] support ssc.transform()
b983f0f [Davies Liu] address comments
847f9b9 [Davies Liu] add more docs, add first(), take()
e059ca2 [Davies Liu] move check of window into Python
fce0ef5 [Davies Liu] rafactor of foreachRDD()
7001b51 [Davies Liu] refactor of queueStream()
26ea396 [Davies Liu] refactor
74df565 [Davies Liu] fix print and docs
b32774c [Davies Liu] move java_import into streaming
604323f [Davies Liu] enable streaming tests
c499ba0 [Davies Liu] remove Time and Duration
3f0fb4b [Davies Liu] refactor fix tests
c28f520 [Davies Liu] support updateStateByKey
d357b70 [Davies Liu] support windowed dstream
bd13026 [Davies Liu] fix examples
eec401e [Davies Liu] refactor, combine TransformedRDD, fix reuse PythonRDD, fix union
9a57685 [Davies Liu] fix python style
bd27874 [Davies Liu] fix scala style
7339be0 [Davies Liu] delete tests
7f53086 [Davies Liu] support transform(), refactor and cleanup
df098fc [Davies Liu] Merge branch 'master' into giwa
550dfd9 [giwa] WIP fixing 1.1 merge
5cdb6fa [giwa] changed for SCCallSiteSync
e685853 [giwa] meged with rebased 1.1 branch
2d32a74 [giwa] added some StreamingContextTestSuite
4a59e1e [giwa] WIP:added more test for StreamingContext
8ffdbf1 [giwa] added atexit to handle callback server
d5f5fcb [giwa] added comment for StreamingContext.sparkContext
63c881a [giwa] added StreamingContext.sparkContext
d39f102 [giwa] added StreamingContext.remember
d542743 [giwa] clean up code
2fdf0de [Matthew Farrellee] Fix scalastyle errors
c0a06bc [giwa] delete not implemented functions
f385976 [giwa] delete inproper comments
b0f2015 [giwa] added comment in dstream._test_output
bebb3f3 [giwa] remove the last brank line
fbed8da [giwa] revert pom.xml
8ed93af [giwa] fixed explanaiton
066ba90 [giwa] revert pom.xml
fa4af88 [giwa] remove duplicated import
6ae3caa [giwa] revert pom.xml
7dc7391 [giwa] fixed typo
62dc7a3 [giwa] clean up exmples
f04882c [giwa] clen up examples
b171ec3 [giwa] fixed pep8 violation
f198d14 [giwa] clean up code
3166d31 [giwa] clean up
c00e091 [giwa] change test case not to use awaitTermination
e80647e [giwa] adopted the latest compression way of python command
58e41ff [giwa] merge with master
455e5af [giwa] removed wasted print in DStream
af336b7 [giwa] add comments
ddd4ee1 [giwa] added TODO coments
99ce042 [giwa] added saveAsTextFiles and saveAsPickledFiles
2a06cdb [giwa] remove waste duplicated code
c5ecfc1 [giwa] basic function test cases are passed
8dcda84 [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4
795b2cd [giwa] broke something
1e126bf [giwa] WIP: solved partitioned and None is not recognized
f67cf57 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test
953deb0 [giwa] edited the comment to add more precise description
af610d3 [giwa] removed unnesessary changes
c1d546e [giwa] fixed PEP-008 violation
99410be [giwa] delete waste file
b3b0362 [giwa] added basic operation test cases
9cde7c9 [giwa] WIP added test case
bd3ba53 [giwa] WIP
5c04a5f [giwa] WIP: added PythonTestInputStream
019ef38 [giwa] WIP
1934726 [giwa] update comment
376e3ac [giwa] WIP
932372a [giwa] clean up dstream.py
0b09cff [giwa] added stop in StreamingContext
92e333e [giwa] implemented reduce and count function in Dstream
1b83354 [giwa] Removed the waste line
88f7506 [Ken Takagiwa] Kill py4j callback server properly
54b5358 [Ken Takagiwa] tried to restart callback server
4f07163 [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server.
fe02547 [Ken Takagiwa] remove waste file
2ad7bd3 [Ken Takagiwa] clean up codes
6197a11 [Ken Takagiwa] clean up code
eb4bf48 [Ken Takagiwa] fix map function
98c2a00 [Ken Takagiwa] added count operation but this implementation need double check
58591d2 [Ken Takagiwa] reduceByKey is working
0df7111 [Ken Takagiwa] delete old file
f485b1d [Ken Takagiwa] fied input of socketTextDStream
dd6de81 [Ken Takagiwa] initial commit for socketTextStream
247fd74 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10
4bcb318 [Ken Takagiwa] implementing transform function in Python
38adf95 [Ken Takagiwa] added reducedByKey not working yet
66fcfff [Ken Takagiwa] modify dstream.py to fix indent error
41886c2 [Ken Takagiwa] comment PythonDStream.PairwiseDStream
0b99bec [Ken] initial commit for pySparkStreaming
c214199 [giwa] added testcase for combineByKey
5625bdc [giwa] added gorupByKey testcase
10ab87b [giwa] added sparkContext as input parameter in StreamingContext
10b5b04 [giwa] removed wasted print in DStream
e54f986 [giwa] add comments
16aa64f [giwa] added TODO coments
74535d4 [giwa] added saveAsTextFiles and saveAsPickledFiles
f76c182 [giwa] remove waste duplicated code
18c8723 [giwa] modified streaming test case to add coment
13fb44c [giwa] basic function test cases are passed
3000b2b [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4
ff14070 [giwa] broke something
bcdec33 [giwa] WIP: solved partitioned and None is not recognized
270a9e1 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test
bb10956 [giwa] edited the comment to add more precise description
253a863 [giwa] removed unnesessary changes
3d37822 [giwa] fixed PEP-008 violation
f21cab3 [giwa] delete waste file
878bad7 [giwa] added basic operation test cases
ce2acd2 [giwa] WIP added test case
9ad6855 [giwa] WIP
1df77f5 [giwa] WIP: added PythonTestInputStream
1523b66 [giwa] WIP
8a0fbbc [giwa] update comment
fe648e3 [giwa] WIP
29c2bc5 [giwa] initial commit for testcase
4d40d63 [giwa] clean up dstream.py
c462bb3 [giwa] added stop in StreamingContext
d2c01ba [giwa] clean up examples
3c45cd2 [giwa] implemented reduce and count function in Dstream
b349649 [giwa] Removed the waste line
3b498e1 [Ken Takagiwa] Kill py4j callback server properly
84a9668 [Ken Takagiwa] tried to restart callback server
9ab8952 [Tathagata Das] Added extra line.
05e991b [Tathagata Das] Added missing file
b1d2a30 [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server.
678e854 [Ken Takagiwa] remove waste file
0a8bbbb [Ken Takagiwa] clean up codes
bab31c1 [Ken Takagiwa] clean up code
72b9738 [Ken Takagiwa] fix map function
d3ee86a [Ken Takagiwa] added count operation but this implementation need double check
15feea9 [Ken Takagiwa] edit python sparkstreaming example
6f98e50 [Ken Takagiwa] reduceByKey is working
c455c8d [Ken Takagiwa] added reducedByKey not working yet
dc6995d [Ken Takagiwa] delete old file
b31446a [Ken Takagiwa] fixed typo of network_workdcount.py
ccfd214 [Ken Takagiwa] added doctest for pyspark.streaming.duration
0d1b954 [Ken Takagiwa] fied input of socketTextDStream
f746109 [Ken Takagiwa] initial commit for socketTextStream
bb7ccf3 [Ken Takagiwa] remove unused import in python
224fc5e [Ken Takagiwa] add empty line
d2099d8 [Ken Takagiwa] sorted the import following Spark coding convention
5bac7ec [Ken Takagiwa] revert streaming/pom.xml
e1df940 [Ken Takagiwa] revert pom.xml
494cae5 [Ken Takagiwa] remove not implemented DStream functions in python
17a74c6 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10
1a0f065 [Ken Takagiwa] implementing transform function in Python
d7b4d6f [Ken Takagiwa] added reducedByKey not working yet
87438e2 [Ken Takagiwa] modify dstream.py to fix indent error
b406252 [Ken Takagiwa] comment PythonDStream.PairwiseDStream
454981d [Ken] initial commit for pySparkStreaming
150b94c [giwa] added some StreamingContextTestSuite
f7bc8f9 [giwa] WIP:added more test for StreamingContext
ee50c5a [giwa] added atexit to handle callback server
fdc9125 [giwa] added comment for StreamingContext.sparkContext
f5bfb70 [giwa] added StreamingContext.sparkContext
da09768 [giwa] added StreamingContext.remember
d68b568 [giwa] clean up code
4afa390 [giwa] clean up code
1fd6bc7 [Ken Takagiwa] Merge pull request #2 from mattf/giwa-master
d9d59fe [Matthew Farrellee] Fix scalastyle errors
67473a9 [giwa] delete not implemented functions
c97377c [giwa] delete inproper comments
2ea769e [giwa] added comment in dstream._test_output
3b27bd4 [giwa] remove the last brank line
acfcaeb [giwa] revert pom.xml
93f7637 [giwa] fixed explanaiton
50fd6f9 [giwa] revert pom.xml
4f82c89 [giwa] remove duplicated import
9d1de23 [giwa] revert pom.xml
7339df2 [giwa] fixed typo
9c85e48 [giwa] clean up exmples
24f95db [giwa] clen up examples
0d30109 [giwa] fixed pep8 violation
b7dab85 [giwa] improve test case
583e66d [giwa] move tests for streaming inside streaming directory
1d84142 [giwa] remove unimplement test
f0ea311 [giwa] clean up code
171edeb [giwa] clean up
4dedd2d [giwa] change test case not to use awaitTermination
268a6a5 [giwa] Changed awaitTermination not to call awaitTermincation in Scala. Just use time.sleep instread
09a28bf [giwa] improve testcases
58150f5 [giwa] Changed the test case to focus the test operation
199e37f [giwa] adopted the latest compression way of python command
185fdbf [giwa] merge with master
f1798c4 [giwa] merge with master
e70f706 [giwa] added testcase for combineByKey
e162822 [giwa] added gorupByKey testcase
97742fe [giwa] added sparkContext as input parameter in StreamingContext
14d4c0e [giwa] removed wasted print in DStream
6d8190a [giwa] add comments
4aa99e4 [giwa] added TODO coments
e9fab72 [giwa] added saveAsTextFiles and saveAsPickledFiles
94f2b65 [giwa] remove waste duplicated code
580fbc2 [giwa] modified streaming test case to add coment
99e4bb3 [giwa] basic function test cases are passed
7051a84 [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4
35933e1 [giwa] broke something
9767712 [giwa] WIP: solved partitioned and None is not recognized
4f2d7e6 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test
33c0f94d [giwa] edited the comment to add more precise description
774f18d [giwa] removed unnesessary changes
3a671cc [giwa] remove export PYSPARK_PYTHON in spark submit
8efa266 [giwa] fixed PEP-008 violation
fa75d71 [giwa] delete waste file
7f96294 [giwa] added basic operation test cases
3dda31a [giwa] WIP added test case
1f68b78 [giwa] WIP
c05922c [giwa] WIP: added PythonTestInputStream
1fd12ae [giwa] WIP
c880a33 [giwa] update comment
5d22c92 [giwa] WIP
ea4b06b [giwa] initial commit for testcase
5a9b525 [giwa] clean up dstream.py
79c5809 [giwa] added stop in StreamingContext
189dcea [giwa] clean up examples
b8d7d24 [giwa] implemented reduce and count function in Dstream
b6468e6 [giwa] Removed the waste line
b47b5fd [Ken Takagiwa] Kill py4j callback server properly
19ddcdd [Ken Takagiwa] tried to restart callback server
c9fc124 [Tathagata Das] Added extra line.
4caae3f [Tathagata Das] Added missing file
4eff053 [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server.
5e822d4 [Ken Takagiwa] remove waste file
aeaf8a5 [Ken Takagiwa] clean up codes
9fa249b [Ken Takagiwa] clean up code
05459c6 [Ken Takagiwa] fix map function
a9f4ecb [Ken Takagiwa] added count operation but this implementation need double check
d1ee6ca [Ken Takagiwa] edit python sparkstreaming example
0b8b7d0 [Ken Takagiwa] reduceByKey is working
d25d5cf [Ken Takagiwa] added reducedByKey not working yet
7f7c5d1 [Ken Takagiwa] delete old file
967dc26 [Ken Takagiwa] fixed typo of network_workdcount.py
57fb740 [Ken Takagiwa] added doctest for pyspark.streaming.duration
4b69fb1 [Ken Takagiwa] fied input of socketTextDStream
02f618a [Ken Takagiwa] initial commit for socketTextStream
4ce4058 [Ken Takagiwa] remove unused import in python
856d98e [Ken Takagiwa] add empty line
490e338 [Ken Takagiwa] sorted the import following Spark coding convention
5594bd4 [Ken Takagiwa] revert pom.xml
2adca84 [Ken Takagiwa] remove not implemented DStream functions in python
e551e13 [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit
3758175 [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit
c5518b4 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10
dcf243f [Ken Takagiwa] implementing transform function in Python
9af03f4 [Ken Takagiwa] added reducedByKey not working yet
6e0d9c7 [Ken Takagiwa] modify dstream.py to fix indent error
e497b9b [Ken Takagiwa] comment PythonDStream.PairwiseDStream
5c3a683 [Ken] initial commit for pySparkStreaming
665bfdb [giwa] added testcase for combineByKey
a3d2379 [giwa] added gorupByKey testcase
636090a [giwa] added sparkContext as input parameter in StreamingContext
e7ebb08 [giwa] removed wasted print in DStream
d8b593b [giwa] add comments
ea9c873 [giwa] added TODO coments
89ae38a [giwa] added saveAsTextFiles and saveAsPickledFiles
e3033fc [giwa] remove waste duplicated code
a14c7e1 [giwa] modified streaming test case to add coment
536def4 [giwa] basic function test cases are passed
2112638 [giwa] all tests are passed if numSlice is 2 and the numver of each input is over 4
080541a [giwa] broke something
0704b86 [giwa] WIP: solved partitioned and None is not recognized
90a6484 [giwa] added mapValues and flatMapVaules WIP for glom and mapPartitions test
a65f302 [giwa] edited the comment to add more precise description
bdde697 [giwa] removed unnesessary changes
e8c7bfc [giwa] remove export PYSPARK_PYTHON in spark submit
3334169 [giwa] fixed PEP-008 violation
db0a303 [giwa] delete waste file
2cfd3a0 [giwa] added basic operation test cases
90ae568 [giwa] WIP added test case
a120d07 [giwa] WIP
f671cdb [giwa] WIP: added PythonTestInputStream
56fae45 [giwa] WIP
e35e101 [giwa] Merge branch 'master' into testcase
ba5112d [giwa] update comment
28aa56d [giwa] WIP
fb08559 [giwa] initial commit for testcase
a613b85 [giwa] clean up dstream.py
c40c0ef [giwa] added stop in StreamingContext
31e4260 [giwa] clean up examples
d2127d6 [giwa] implemented reduce and count function in Dstream
48f7746 [giwa] Removed the waste line
0f83eaa [Ken Takagiwa] delete py4j 0.8.1
1679808 [Ken Takagiwa] Kill py4j callback server properly
f96cd4e [Ken Takagiwa] tried to restart callback server
fe86198 [Ken Takagiwa] add py4j 0.8.2.1 but server is not launched
1064fe0 [Ken Takagiwa] Merge branch 'master' of https://github.com/giwa/spark
28c6620 [Ken Takagiwa] Implemented DStream.foreachRDD in the Python API using Py4J callback server
85b0fe1 [Ken Takagiwa] Merge pull request #1 from tdas/python-foreach
54e2e8c [Tathagata Das] Added extra line.
e185338 [Tathagata Das] Added missing file
a778d4b [Tathagata Das] Implemented DStream.foreachRDD in the Python API using Py4J callback server.
cc2092b [Ken Takagiwa] remove waste file
d042ac6 [Ken Takagiwa] clean up codes
84a021f [Ken Takagiwa] clean up code
bd20e17 [Ken Takagiwa] fix map function
d01a125 [Ken Takagiwa] added count operation but this implementation need double check
7d05109 [Ken Takagiwa] merge with remote branch
ae464e0 [Ken Takagiwa] edit python sparkstreaming example
04af046 [Ken Takagiwa] reduceByKey is working
3b6d7b0 [Ken Takagiwa] implementing transform function in Python
571d52d [Ken Takagiwa] added reducedByKey not working yet
5720979 [Ken Takagiwa] delete old file
e604fcb [Ken Takagiwa] fixed typo of network_workdcount.py
4b7c08b [Ken Takagiwa] Merge branch 'master' of https://github.com/giwa/spark
ce7d426 [Ken Takagiwa] added doctest for pyspark.streaming.duration
a8c9fd5 [Ken Takagiwa] fixed for socketTextStream
a61fa9e [Ken Takagiwa] fied input of socketTextDStream
1e84f41 [Ken Takagiwa] initial commit for socketTextStream
6d012f7 [Ken Takagiwa] remove unused import in python
25d30d5 [Ken Takagiwa] add empty line
6e0a64a [Ken Takagiwa] sorted the import following Spark coding convention
fa4a7fc [Ken Takagiwa] revert streaming/pom.xml
8f8202b [Ken Takagiwa] revert streaming pom.xml
c9d79dd [Ken Takagiwa] revert pom.xml
57e3e52 [Ken Takagiwa] remove not implemented DStream functions in python
0a516f5 [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit
a7a0b5c [Ken Takagiwa] add coment for hack why PYSPARK_PYTHON is needed in spark-submit
72bfc66 [Ken Takagiwa] modified the code base on comment in https://github.com/tdas/spark/pull/10
69e9cd3 [Ken Takagiwa] implementing transform function in Python
94a0787 [Ken Takagiwa] added reducedByKey not working yet
88068cf [Ken Takagiwa] modify dstream.py to fix indent error
1367be5 [Ken Takagiwa] comment PythonDStream.PairwiseDStream
eb2b3ba [Ken] Merge remote-tracking branch 'upstream/master'
d8e51f9 [Ken] initial commit for pySparkStreaming
This commit is contained in:
giwa 2014-10-12 02:46:56 -07:00 committed by Tathagata Das
parent 7a3f589ef8
commit 69c67abaa9
17 changed files with 2133 additions and 13 deletions

View file

@ -25,8 +25,6 @@ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collectio
import scala.collection.JavaConversions._
import scala.collection.mutable
import scala.language.existentials
import scala.reflect.ClassTag
import scala.util.{Try, Success, Failure}
import net.razorvine.pickle.{Pickler, Unpickler}
@ -42,7 +40,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
private[spark] class PythonRDD(
parent: RDD[_],
@transient parent: RDD[_],
command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
@ -55,9 +53,9 @@ private[spark] class PythonRDD(
val bufferSize = conf.getInt("spark.buffer.size", 65536)
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
override def getPartitions = parent.partitions
override def getPartitions = firstParent.partitions
override val partitioner = if (preservePartitoning) parent.partitioner else None
override val partitioner = if (preservePartitoning) firstParent.partitioner else None
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
@ -234,7 +232,7 @@ private[spark] class PythonRDD(
dataOut.writeInt(command.length)
dataOut.write(command)
// Data values
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
dataOut.flush()
} catch {

View file

@ -0,0 +1,49 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Counts words in new text files created in the given directory
Usage: hdfs_wordcount.py <directory>
<directory> is the directory that Spark Streaming will use to find and read new text files.
To run this on your local machine on directory `localdir`, run this example
$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir
Then create a text file in `localdir` and the words in the file will get counted.
"""
import sys
from pyspark import SparkContext
from pyspark.streaming import StreamingContext
if __name__ == "__main__":
if len(sys.argv) != 2:
print >> sys.stderr, "Usage: hdfs_wordcount.py <directory>"
exit(-1)
sc = SparkContext(appName="PythonStreamingHDFSWordCount")
ssc = StreamingContext(sc, 1)
lines = ssc.textFileStream(sys.argv[1])
counts = lines.flatMap(lambda line: line.split(" "))\
.map(lambda x: (x, 1))\
.reduceByKey(lambda a, b: a+b)
counts.pprint()
ssc.start()
ssc.awaitTermination()

View file

@ -0,0 +1,48 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
Usage: network_wordcount.py <hostname> <port>
<hostname> and <port> describe the TCP server that Spark Streaming would connect to receive data.
To run this on your local machine, you need to first run a Netcat server
`$ nc -lk 9999`
and then run the example
`$ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localhost 9999`
"""
import sys
from pyspark import SparkContext
from pyspark.streaming import StreamingContext
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: network_wordcount.py <hostname> <port>"
exit(-1)
sc = SparkContext(appName="PythonStreamingNetworkWordCount")
ssc = StreamingContext(sc, 1)
lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
counts = lines.flatMap(lambda line: line.split(" "))\
.map(lambda word: (word, 1))\
.reduceByKey(lambda a, b: a+b)
counts.pprint()
ssc.start()
ssc.awaitTermination()

View file

@ -0,0 +1,57 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Counts words in UTF8 encoded, '\n' delimited text received from the
network every second.
Usage: stateful_network_wordcount.py <hostname> <port>
<hostname> and <port> describe the TCP server that Spark Streaming
would connect to receive data.
To run this on your local machine, you need to first run a Netcat server
`$ nc -lk 9999`
and then run the example
`$ bin/spark-submit examples/src/main/python/streaming/stateful_network_wordcount.py \
localhost 9999`
"""
import sys
from pyspark import SparkContext
from pyspark.streaming import StreamingContext
if __name__ == "__main__":
if len(sys.argv) != 3:
print >> sys.stderr, "Usage: stateful_network_wordcount.py <hostname> <port>"
exit(-1)
sc = SparkContext(appName="PythonStreamingStatefulNetworkWordCount")
ssc = StreamingContext(sc, 1)
ssc.checkpoint("checkpoint")
def updateFunc(new_values, last_sum):
return sum(new_values) + (last_sum or 0)
lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2]))
running_counts = lines.flatMap(lambda line: line.split(" "))\
.map(lambda word: (word, 1))\
.updateStateByKey(updateFunc)
running_counts.pprint()
ssc.start()
ssc.awaitTermination()

View file

@ -5,7 +5,7 @@ RULES = (
(r"L{([\w.()]+)}", r":class:`\1`"),
(r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"),
(r"C{([\w.()]+)}", r":class:`\1`"),
(r"[IBCM]{(.+)}", r"`\1`"),
(r"[IBCM]{([^}]+)}", r"`\1`"),
('pyspark.rdd.RDD', 'RDD'),
)

View file

@ -13,6 +13,7 @@ Contents:
pyspark
pyspark.sql
pyspark.streaming
pyspark.mllib

View file

@ -7,8 +7,9 @@ Subpackages
.. toctree::
:maxdepth: 1
pyspark.mllib
pyspark.sql
pyspark.streaming
pyspark.mllib
Contents
--------

View file

@ -68,7 +68,7 @@ class SparkContext(object):
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
gateway=None):
gateway=None, jsc=None):
"""
Create a new SparkContext. At least the master and app name should be set,
either through the named parameters here or through C{conf}.
@ -104,14 +104,14 @@ class SparkContext(object):
SparkContext._ensure_initialized(self, gateway=gateway)
try:
self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf)
conf, jsc)
except:
# If an error occurs, clean up in order to allow future SparkContext creation:
self.stop()
raise
def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer,
conf):
conf, jsc):
self.environment = environment or {}
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
@ -154,7 +154,7 @@ class SparkContext(object):
self.environment[varName] = v
# Create the Java SparkContext through Py4J
self._jsc = self._initialize_context(self._conf._jconf)
self._jsc = jsc or self._initialize_context(self._conf._jconf)
# Create a single Accumulator in Java that we'll send all our updates through;
# they will be passed back to us through a TCP server

View file

@ -114,6 +114,9 @@ class Serializer(object):
def __repr__(self):
return "<%s object>" % self.__class__.__name__
def __hash__(self):
return hash(str(self))
class FramedSerializer(Serializer):

View file

@ -0,0 +1,21 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark.streaming.context import StreamingContext
from pyspark.streaming.dstream import DStream
__all__ = ['StreamingContext', 'DStream']

View file

@ -0,0 +1,325 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import sys
from py4j.java_collections import ListConverter
from py4j.java_gateway import java_import, JavaObject
from pyspark import RDD, SparkConf
from pyspark.serializers import UTF8Deserializer, CloudPickleSerializer
from pyspark.context import SparkContext
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.dstream import DStream
from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer
__all__ = ["StreamingContext"]
def _daemonize_callback_server():
"""
Hack Py4J to daemonize callback server
The thread of callback server has daemon=False, it will block the driver
from exiting if it's not shutdown. The following code replace `start()`
of CallbackServer with a new version, which set daemon=True for this
thread.
Also, it will update the port number (0) with real port
"""
# TODO: create a patch for Py4J
import socket
import py4j.java_gateway
logger = py4j.java_gateway.logger
from py4j.java_gateway import Py4JNetworkError
from threading import Thread
def start(self):
"""Starts the CallbackServer. This method should be called by the
client instead of run()."""
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,
1)
try:
self.server_socket.bind((self.address, self.port))
if not self.port:
# update port with real port
self.port = self.server_socket.getsockname()[1]
except Exception as e:
msg = 'An error occurred while trying to start the callback server: %s' % e
logger.exception(msg)
raise Py4JNetworkError(msg)
# Maybe thread needs to be cleanup up?
self.thread = Thread(target=self.run)
self.thread.daemon = True
self.thread.start()
py4j.java_gateway.CallbackServer.start = start
class StreamingContext(object):
"""
Main entry point for Spark Streaming functionality. A StreamingContext
represents the connection to a Spark cluster, and can be used to create
L{DStream} various input sources. It can be from an existing L{SparkContext}.
After creating and transforming DStreams, the streaming computation can
be started and stopped using `context.start()` and `context.stop()`,
respectively. `context.awaitTransformation()` allows the current thread
to wait for the termination of the context by `stop()` or by an exception.
"""
_transformerSerializer = None
def __init__(self, sparkContext, batchDuration=None, jssc=None):
"""
Create a new StreamingContext.
@param sparkContext: L{SparkContext} object.
@param batchDuration: the time interval (in seconds) at which streaming
data will be divided into batches
"""
self._sc = sparkContext
self._jvm = self._sc._jvm
self._jssc = jssc or self._initialize_context(self._sc, batchDuration)
def _initialize_context(self, sc, duration):
self._ensure_initialized()
return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
def _jduration(self, seconds):
"""
Create Duration object given number of seconds
"""
return self._jvm.Duration(int(seconds * 1000))
@classmethod
def _ensure_initialized(cls):
SparkContext._ensure_initialized()
gw = SparkContext._gateway
java_import(gw.jvm, "org.apache.spark.streaming.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
# start callback server
# getattr will fallback to JVM, so we cannot test by hasattr()
if "_callback_server" not in gw.__dict__:
_daemonize_callback_server()
# use random port
gw._start_callback_server(0)
# gateway with real port
gw._python_proxy_port = gw._callback_server.port
# get the GatewayServer object in JVM by ID
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
# update the port of CallbackClient with real port
gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port)
# register serializer for TransformFunction
# it happens before creating SparkContext when loading from checkpointing
cls._transformerSerializer = TransformFunctionSerializer(
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
@classmethod
def getOrCreate(cls, checkpointPath, setupFunc):
"""
Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
will be used to create a JavaStreamingContext.
@param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
@param setupFunc Function to create a new JavaStreamingContext and setup DStreams
"""
# TODO: support checkpoint in HDFS
if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
ssc = setupFunc()
ssc.checkpoint(checkpointPath)
return ssc
cls._ensure_initialized()
gw = SparkContext._gateway
try:
jssc = gw.jvm.JavaStreamingContext(checkpointPath)
except Exception:
print >>sys.stderr, "failed to load StreamingContext from checkpoint"
raise
jsc = jssc.sparkContext()
conf = SparkConf(_jconf=jsc.getConf())
sc = SparkContext(conf=conf, gateway=gw, jsc=jsc)
# update ctx in serializer
SparkContext._active_spark_context = sc
cls._transformerSerializer.ctx = sc
return StreamingContext(sc, None, jssc)
@property
def sparkContext(self):
"""
Return SparkContext which is associated with this StreamingContext.
"""
return self._sc
def start(self):
"""
Start the execution of the streams.
"""
self._jssc.start()
def awaitTermination(self, timeout=None):
"""
Wait for the execution to stop.
@param timeout: time to wait in seconds
"""
if timeout is None:
self._jssc.awaitTermination()
else:
self._jssc.awaitTermination(int(timeout * 1000))
def stop(self, stopSparkContext=True, stopGraceFully=False):
"""
Stop the execution of the streams, with option of ensuring all
received data has been processed.
@param stopSparkContext: Stop the associated SparkContext or not
@param stopGracefully: Stop gracefully by waiting for the processing
of all received data to be completed
"""
self._jssc.stop(stopSparkContext, stopGraceFully)
if stopSparkContext:
self._sc.stop()
def remember(self, duration):
"""
Set each DStreams in this context to remember RDDs it generated
in the last given duration. DStreams remember RDDs only for a
limited duration of time and releases them for garbage collection.
This method allows the developer to specify how to long to remember
the RDDs (if the developer wishes to query old data outside the
DStream computation).
@param duration: Minimum duration (in seconds) that each DStream
should remember its RDDs
"""
self._jssc.remember(self._jduration(duration))
def checkpoint(self, directory):
"""
Sets the context to periodically checkpoint the DStream operations for master
fault-tolerance. The graph will be checkpointed every batch interval.
@param directory: HDFS-compatible directory where the checkpoint data
will be reliably stored
"""
self._jssc.checkpoint(directory)
def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2):
"""
Create an input from TCP source hostname:port. Data is received using
a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited
lines.
@param hostname: Hostname to connect to for receiving data
@param port: Port to connect to for receiving data
@param storageLevel: Storage level to use for storing the received objects
"""
jlevel = self._sc._getJavaStorageLevel(storageLevel)
return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self,
UTF8Deserializer())
def textFileStream(self, directory):
"""
Create an input stream that monitors a Hadoop-compatible file system
for new files and reads them as text files. Files must be wrriten to the
monitored directory by "moving" them from another location within the same
file system. File names starting with . are ignored.
"""
return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
def _check_serializers(self, rdds):
# make sure they have same serializer
if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
for i in range(len(rdds)):
# reset them to sc.serializer
rdds[i] = rdds[i]._reserialize()
def queueStream(self, rdds, oneAtATime=True, default=None):
"""
Create an input stream from an queue of RDDs or list. In each batch,
it will process either one or all of the RDDs returned by the queue.
NOTE: changes to the queue after the stream is created will not be recognized.
@param rdds: Queue of RDDs
@param oneAtATime: pick one rdd each time or pick all of them once.
@param default: The default rdd if no more in rdds
"""
if default and not isinstance(default, RDD):
default = self._sc.parallelize(default)
if not rdds and default:
rdds = [rdds]
if rdds and not isinstance(rdds[0], RDD):
rdds = [self._sc.parallelize(input) for input in rdds]
self._check_serializers(rdds)
jrdds = ListConverter().convert([r._jrdd for r in rdds],
SparkContext._gateway._gateway_client)
queue = self._jvm.PythonDStream.toRDDQueue(jrdds)
if default:
default = default._reserialize(rdds[0]._jrdd_deserializer)
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
else:
jdstream = self._jssc.queueStream(queue, oneAtATime)
return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
def transform(self, dstreams, transformFunc):
"""
Create a new DStream in which each RDD is generated by applying
a function on RDDs of the DStreams. The order of the JavaRDDs in
the transform function parameter will be the same as the order
of corresponding DStreams in the list.
"""
jdstreams = ListConverter().convert([d._jdstream for d in dstreams],
SparkContext._gateway._gateway_client)
# change the final serializer to sc.serializer
func = TransformFunction(self._sc,
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
*[d._jrdd_deserializer for d in dstreams])
jfunc = self._jvm.TransformFunction(func)
jdstream = self._jssc.transform(jdstreams, jfunc)
return DStream(jdstream, self, self._sc.serializer)
def union(self, *dstreams):
"""
Create a unified DStream from multiple DStreams of the same
type and same slide duration.
"""
if not dstreams:
raise ValueError("should have at least one DStream to union")
if len(dstreams) == 1:
return dstreams[0]
if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
raise ValueError("All DStreams should have same serializer")
if len(set(s._slideDuration for s in dstreams)) > 1:
raise ValueError("All DStreams should have same slide duration")
first = dstreams[0]
jrest = ListConverter().convert([d._jdstream for d in dstreams[1:]],
SparkContext._gateway._gateway_client)
return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)

View file

@ -0,0 +1,621 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from itertools import chain, ifilter, imap
import operator
import time
from datetime import datetime
from py4j.protocol import Py4JJavaError
from pyspark import RDD
from pyspark.storagelevel import StorageLevel
from pyspark.streaming.util import rddToFileName, TransformFunction
from pyspark.rdd import portable_hash
from pyspark.resultiterable import ResultIterable
__all__ = ["DStream"]
class DStream(object):
"""
A Discretized Stream (DStream), the basic abstraction in Spark Streaming,
is a continuous sequence of RDDs (of the same type) representing a
continuous stream of data (see L{RDD} in the Spark core documentation
for more details on RDDs).
DStreams can either be created from live data (such as, data from TCP
sockets, Kafka, Flume, etc.) using a L{StreamingContext} or it can be
generated by transforming existing DStreams using operations such as
`map`, `window` and `reduceByKeyAndWindow`. While a Spark Streaming
program is running, each DStream periodically generates a RDD, either
from live data or by transforming the RDD generated by a parent DStream.
DStreams internally is characterized by a few basic properties:
- A list of other DStreams that the DStream depends on
- A time interval at which the DStream generates an RDD
- A function that is used to generate an RDD after each time interval
"""
def __init__(self, jdstream, ssc, jrdd_deserializer):
self._jdstream = jdstream
self._ssc = ssc
self._sc = ssc._sc
self._jrdd_deserializer = jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
def context(self):
"""
Return the StreamingContext associated with this DStream
"""
return self._ssc
def count(self):
"""
Return a new DStream in which each RDD has a single element
generated by counting each RDD of this DStream.
"""
return self.mapPartitions(lambda i: [sum(1 for _ in i)]).reduce(operator.add)
def filter(self, f):
"""
Return a new DStream containing only the elements that satisfy predicate.
"""
def func(iterator):
return ifilter(f, iterator)
return self.mapPartitions(func, True)
def flatMap(self, f, preservesPartitioning=False):
"""
Return a new DStream by applying a function to all elements of
this DStream, and then flattening the results
"""
def func(s, iterator):
return chain.from_iterable(imap(f, iterator))
return self.mapPartitionsWithIndex(func, preservesPartitioning)
def map(self, f, preservesPartitioning=False):
"""
Return a new DStream by applying a function to each element of DStream.
"""
def func(iterator):
return imap(f, iterator)
return self.mapPartitions(func, preservesPartitioning)
def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new DStream in which each RDD is generated by applying
mapPartitions() to each RDDs of this DStream.
"""
def func(s, iterator):
return f(iterator)
return self.mapPartitionsWithIndex(func, preservesPartitioning)
def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
"""
Return a new DStream in which each RDD is generated by applying
mapPartitionsWithIndex() to each RDDs of this DStream.
"""
return self.transform(lambda rdd: rdd.mapPartitionsWithIndex(f, preservesPartitioning))
def reduce(self, func):
"""
Return a new DStream in which each RDD has a single element
generated by reducing each RDD of this DStream.
"""
return self.map(lambda x: (None, x)).reduceByKey(func, 1).map(lambda x: x[1])
def reduceByKey(self, func, numPartitions=None):
"""
Return a new DStream by applying reduceByKey to each RDD.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
return self.combineByKey(lambda x: x, func, func, numPartitions)
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
numPartitions=None):
"""
Return a new DStream by applying combineByKey to each RDD.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
def func(rdd):
return rdd.combineByKey(createCombiner, mergeValue, mergeCombiners, numPartitions)
return self.transform(func)
def partitionBy(self, numPartitions, partitionFunc=portable_hash):
"""
Return a copy of the DStream in which each RDD are partitioned
using the specified partitioner.
"""
return self.transform(lambda rdd: rdd.partitionBy(numPartitions, partitionFunc))
def foreachRDD(self, func):
"""
Apply a function to each RDD in this DStream.
"""
if func.func_code.co_argcount == 1:
old_func = func
func = lambda t, rdd: old_func(rdd)
jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer)
api = self._ssc._jvm.PythonDStream
api.callForeachRDD(self._jdstream, jfunc)
def pprint(self):
"""
Print the first ten elements of each RDD generated in this DStream.
"""
def takeAndPrint(time, rdd):
taken = rdd.take(11)
print "-------------------------------------------"
print "Time: %s" % time
print "-------------------------------------------"
for record in taken[:10]:
print record
if len(taken) > 10:
print "..."
print
self.foreachRDD(takeAndPrint)
def mapValues(self, f):
"""
Return a new DStream by applying a map function to the value of
each key-value pairs in this DStream without changing the key.
"""
map_values_fn = lambda (k, v): (k, f(v))
return self.map(map_values_fn, preservesPartitioning=True)
def flatMapValues(self, f):
"""
Return a new DStream by applying a flatmap function to the value
of each key-value pairs in this DStream without changing the key.
"""
flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
return self.flatMap(flat_map_fn, preservesPartitioning=True)
def glom(self):
"""
Return a new DStream in which RDD is generated by applying glom()
to RDD of this DStream.
"""
def func(iterator):
yield list(iterator)
return self.mapPartitions(func)
def cache(self):
"""
Persist the RDDs of this DStream with the default storage level
(C{MEMORY_ONLY_SER}).
"""
self.is_cached = True
self.persist(StorageLevel.MEMORY_ONLY_SER)
return self
def persist(self, storageLevel):
"""
Persist the RDDs of this DStream with the given storage level
"""
self.is_cached = True
javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
self._jdstream.persist(javaStorageLevel)
return self
def checkpoint(self, interval):
"""
Enable periodic checkpointing of RDDs of this DStream
@param interval: time in seconds, after each period of that, generated
RDD will be checkpointed
"""
self.is_checkpointed = True
self._jdstream.checkpoint(self._ssc._jduration(interval))
return self
def groupByKey(self, numPartitions=None):
"""
Return a new DStream by applying groupByKey on each RDD.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
return self.transform(lambda rdd: rdd.groupByKey(numPartitions))
def countByValue(self):
"""
Return a new DStream in which each RDD contains the counts of each
distinct value in each RDD of this DStream.
"""
return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count()
def saveAsTextFiles(self, prefix, suffix=None):
"""
Save each RDD in this DStream as at text file, using string
representation of elements.
"""
def saveAsTextFile(t, rdd):
path = rddToFileName(prefix, suffix, t)
try:
rdd.saveAsTextFile(path)
except Py4JJavaError as e:
# after recovered from checkpointing, the foreachRDD may
# be called twice
if 'FileAlreadyExistsException' not in str(e):
raise
return self.foreachRDD(saveAsTextFile)
# TODO: uncomment this until we have ssc.pickleFileStream()
# def saveAsPickleFiles(self, prefix, suffix=None):
# """
# Save each RDD in this DStream as at binary file, the elements are
# serialized by pickle.
# """
# def saveAsPickleFile(t, rdd):
# path = rddToFileName(prefix, suffix, t)
# try:
# rdd.saveAsPickleFile(path)
# except Py4JJavaError as e:
# # after recovered from checkpointing, the foreachRDD may
# # be called twice
# if 'FileAlreadyExistsException' not in str(e):
# raise
# return self.foreachRDD(saveAsPickleFile)
def transform(self, func):
"""
Return a new DStream in which each RDD is generated by applying a function
on each RDD of this DStream.
`func` can have one argument of `rdd`, or have two arguments of
(`time`, `rdd`)
"""
if func.func_code.co_argcount == 1:
oldfunc = func
func = lambda t, rdd: oldfunc(rdd)
assert func.func_code.co_argcount == 2, "func should take one or two arguments"
return TransformedDStream(self, func)
def transformWith(self, func, other, keepSerializer=False):
"""
Return a new DStream in which each RDD is generated by applying a function
on each RDD of this DStream and 'other' DStream.
`func` can have two arguments of (`rdd_a`, `rdd_b`) or have three
arguments of (`time`, `rdd_a`, `rdd_b`)
"""
if func.func_code.co_argcount == 2:
oldfunc = func
func = lambda t, a, b: oldfunc(a, b)
assert func.func_code.co_argcount == 3, "func should take two or three arguments"
jfunc = TransformFunction(self._sc, func, self._jrdd_deserializer, other._jrdd_deserializer)
dstream = self._sc._jvm.PythonTransformed2DStream(self._jdstream.dstream(),
other._jdstream.dstream(), jfunc)
jrdd_serializer = self._jrdd_deserializer if keepSerializer else self._sc.serializer
return DStream(dstream.asJavaDStream(), self._ssc, jrdd_serializer)
def repartition(self, numPartitions):
"""
Return a new DStream with an increased or decreased level of parallelism.
"""
return self.transform(lambda rdd: rdd.repartition(numPartitions))
@property
def _slideDuration(self):
"""
Return the slideDuration in seconds of this DStream
"""
return self._jdstream.dstream().slideDuration().milliseconds() / 1000.0
def union(self, other):
"""
Return a new DStream by unifying data of another DStream with this DStream.
@param other: Another DStream having the same interval (i.e., slideDuration)
as this DStream.
"""
if self._slideDuration != other._slideDuration:
raise ValueError("the two DStream should have same slide duration")
return self.transformWith(lambda a, b: a.union(b), other, True)
def cogroup(self, other, numPartitions=None):
"""
Return a new DStream by applying 'cogroup' between RDDs of this
DStream and `other` DStream.
Hash partitioning is used to generate the RDDs with `numPartitions` partitions.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
return self.transformWith(lambda a, b: a.cogroup(b, numPartitions), other)
def join(self, other, numPartitions=None):
"""
Return a new DStream by applying 'join' between RDDs of this DStream and
`other` DStream.
Hash partitioning is used to generate the RDDs with `numPartitions`
partitions.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
return self.transformWith(lambda a, b: a.join(b, numPartitions), other)
def leftOuterJoin(self, other, numPartitions=None):
"""
Return a new DStream by applying 'left outer join' between RDDs of this DStream and
`other` DStream.
Hash partitioning is used to generate the RDDs with `numPartitions`
partitions.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
return self.transformWith(lambda a, b: a.leftOuterJoin(b, numPartitions), other)
def rightOuterJoin(self, other, numPartitions=None):
"""
Return a new DStream by applying 'right outer join' between RDDs of this DStream and
`other` DStream.
Hash partitioning is used to generate the RDDs with `numPartitions`
partitions.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
return self.transformWith(lambda a, b: a.rightOuterJoin(b, numPartitions), other)
def fullOuterJoin(self, other, numPartitions=None):
"""
Return a new DStream by applying 'full outer join' between RDDs of this DStream and
`other` DStream.
Hash partitioning is used to generate the RDDs with `numPartitions`
partitions.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
return self.transformWith(lambda a, b: a.fullOuterJoin(b, numPartitions), other)
def _jtime(self, timestamp):
""" Convert datetime or unix_timestamp into Time
"""
if isinstance(timestamp, datetime):
timestamp = time.mktime(timestamp.timetuple())
return self._sc._jvm.Time(long(timestamp * 1000))
def slice(self, begin, end):
"""
Return all the RDDs between 'begin' to 'end' (both included)
`begin`, `end` could be datetime.datetime() or unix_timestamp
"""
jrdds = self._jdstream.slice(self._jtime(begin), self._jtime(end))
return [RDD(jrdd, self._sc, self._jrdd_deserializer) for jrdd in jrdds]
def _validate_window_param(self, window, slide):
duration = self._jdstream.dstream().slideDuration().milliseconds()
if int(window * 1000) % duration != 0:
raise ValueError("windowDuration must be multiple of the slide duration (%d ms)"
% duration)
if slide and int(slide * 1000) % duration != 0:
raise ValueError("slideDuration must be multiple of the slide duration (%d ms)"
% duration)
def window(self, windowDuration, slideDuration=None):
"""
Return a new DStream in which each RDD contains all the elements in seen in a
sliding window of time over this DStream.
@param windowDuration: width of the window; must be a multiple of this DStream's
batching interval
@param slideDuration: sliding interval of the window (i.e., the interval after which
the new DStream will generate RDDs); must be a multiple of this
DStream's batching interval
"""
self._validate_window_param(windowDuration, slideDuration)
d = self._ssc._jduration(windowDuration)
if slideDuration is None:
return DStream(self._jdstream.window(d), self._ssc, self._jrdd_deserializer)
s = self._ssc._jduration(slideDuration)
return DStream(self._jdstream.window(d, s), self._ssc, self._jrdd_deserializer)
def reduceByWindow(self, reduceFunc, invReduceFunc, windowDuration, slideDuration):
"""
Return a new DStream in which each RDD has a single element generated by reducing all
elements in a sliding window over this DStream.
if `invReduceFunc` is not None, the reduction is done incrementally
using the old window's reduced value :
1. reduce the new values that entered the window (e.g., adding new counts)
2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
This is more efficient than `invReduceFunc` is None.
@param reduceFunc: associative reduce function
@param invReduceFunc: inverse reduce function of `reduceFunc`
@param windowDuration: width of the window; must be a multiple of this DStream's
batching interval
@param slideDuration: sliding interval of the window (i.e., the interval after which
the new DStream will generate RDDs); must be a multiple of this
DStream's batching interval
"""
keyed = self.map(lambda x: (1, x))
reduced = keyed.reduceByKeyAndWindow(reduceFunc, invReduceFunc,
windowDuration, slideDuration, 1)
return reduced.map(lambda (k, v): v)
def countByWindow(self, windowDuration, slideDuration):
"""
Return a new DStream in which each RDD has a single element generated
by counting the number of elements in a window over this DStream.
windowDuration and slideDuration are as defined in the window() operation.
This is equivalent to window(windowDuration, slideDuration).count(),
but will be more efficient if window is large.
"""
return self.map(lambda x: 1).reduceByWindow(operator.add, operator.sub,
windowDuration, slideDuration)
def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=None):
"""
Return a new DStream in which each RDD contains the count of distinct elements in
RDDs in a sliding window over this DStream.
@param windowDuration: width of the window; must be a multiple of this DStream's
batching interval
@param slideDuration: sliding interval of the window (i.e., the interval after which
the new DStream will generate RDDs); must be a multiple of this
DStream's batching interval
@param numPartitions: number of partitions of each RDD in the new DStream.
"""
keyed = self.map(lambda x: (x, 1))
counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub,
windowDuration, slideDuration, numPartitions)
return counted.filter(lambda (k, v): v > 0).count()
def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None):
"""
Return a new DStream by applying `groupByKey` over a sliding window.
Similar to `DStream.groupByKey()`, but applies it over a sliding window.
@param windowDuration: width of the window; must be a multiple of this DStream's
batching interval
@param slideDuration: sliding interval of the window (i.e., the interval after which
the new DStream will generate RDDs); must be a multiple of this
DStream's batching interval
@param numPartitions: Number of partitions of each RDD in the new DStream.
"""
ls = self.mapValues(lambda x: [x])
grouped = ls.reduceByKeyAndWindow(lambda a, b: a.extend(b) or a, lambda a, b: a[len(b):],
windowDuration, slideDuration, numPartitions)
return grouped.mapValues(ResultIterable)
def reduceByKeyAndWindow(self, func, invFunc, windowDuration, slideDuration=None,
numPartitions=None, filterFunc=None):
"""
Return a new DStream by applying incremental `reduceByKey` over a sliding window.
The reduced value of over a new window is calculated using the old window's reduce value :
1. reduce the new values that entered the window (e.g., adding new counts)
2. "inverse reduce" the old values that left the window (e.g., subtracting old counts)
`invFunc` can be None, then it will reduce all the RDDs in window, could be slower
than having `invFunc`.
@param reduceFunc: associative reduce function
@param invReduceFunc: inverse function of `reduceFunc`
@param windowDuration: width of the window; must be a multiple of this DStream's
batching interval
@param slideDuration: sliding interval of the window (i.e., the interval after which
the new DStream will generate RDDs); must be a multiple of this
DStream's batching interval
@param numPartitions: number of partitions of each RDD in the new DStream.
@param filterFunc: function to filter expired key-value pairs;
only pairs that satisfy the function are retained
set this to null if you do not want to filter
"""
self._validate_window_param(windowDuration, slideDuration)
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
reduced = self.reduceByKey(func, numPartitions)
def reduceFunc(t, a, b):
b = b.reduceByKey(func, numPartitions)
r = a.union(b).reduceByKey(func, numPartitions) if a else b
if filterFunc:
r = r.filter(filterFunc)
return r
def invReduceFunc(t, a, b):
b = b.reduceByKey(func, numPartitions)
joined = a.leftOuterJoin(b, numPartitions)
return joined.mapValues(lambda (v1, v2): invFunc(v1, v2) if v2 is not None else v1)
jreduceFunc = TransformFunction(self._sc, reduceFunc, reduced._jrdd_deserializer)
if invReduceFunc:
jinvReduceFunc = TransformFunction(self._sc, invReduceFunc, reduced._jrdd_deserializer)
else:
jinvReduceFunc = None
if slideDuration is None:
slideDuration = self._slideDuration
dstream = self._sc._jvm.PythonReducedWindowedDStream(reduced._jdstream.dstream(),
jreduceFunc, jinvReduceFunc,
self._ssc._jduration(windowDuration),
self._ssc._jduration(slideDuration))
return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
def updateStateByKey(self, updateFunc, numPartitions=None):
"""
Return a new "state" DStream where the state for each key is updated by applying
the given function on the previous state of the key and the new values of the key.
@param updateFunc: State update function. If this function returns None, then
corresponding state key-value pair will be eliminated.
"""
if numPartitions is None:
numPartitions = self._sc.defaultParallelism
def reduceFunc(t, a, b):
if a is None:
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
else:
g = a.cogroup(b, numPartitions)
g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
return state.filter(lambda (k, v): v is not None)
jreduceFunc = TransformFunction(self._sc, reduceFunc,
self._sc.serializer, self._jrdd_deserializer)
dstream = self._sc._jvm.PythonStateDStream(self._jdstream.dstream(), jreduceFunc)
return DStream(dstream.asJavaDStream(), self._ssc, self._sc.serializer)
class TransformedDStream(DStream):
"""
TransformedDStream is an DStream generated by an Python function
transforming each RDD of an DStream to another RDDs.
Multiple continuous transformations of DStream can be combined into
one transformation.
"""
def __init__(self, prev, func):
self._ssc = prev._ssc
self._sc = self._ssc._sc
self._jrdd_deserializer = self._sc.serializer
self.is_cached = False
self.is_checkpointed = False
self._jdstream_val = None
if (isinstance(prev, TransformedDStream) and
not prev.is_cached and not prev.is_checkpointed):
prev_func = prev.func
self.func = lambda t, rdd: func(t, prev_func(t, rdd))
self.prev = prev.prev
else:
self.prev = prev
self.func = func
@property
def _jdstream(self):
if self._jdstream_val is not None:
return self._jdstream_val
jfunc = TransformFunction(self._sc, self.func, self.prev._jrdd_deserializer)
dstream = self._sc._jvm.PythonTransformedDStream(self.prev._jdstream.dstream(), jfunc)
self._jdstream_val = dstream.asJavaDStream()
return self._jdstream_val

View file

@ -0,0 +1,545 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
from itertools import chain
import time
import operator
import unittest
import tempfile
from pyspark.context import SparkConf, SparkContext, RDD
from pyspark.streaming.context import StreamingContext
class PySparkStreamingTestCase(unittest.TestCase):
timeout = 10 # seconds
duration = 1
def setUp(self):
class_name = self.__class__.__name__
conf = SparkConf().set("spark.default.parallelism", 1)
self.sc = SparkContext(appName=class_name, conf=conf)
self.sc.setCheckpointDir("/tmp")
# TODO: decrease duration to speed up tests
self.ssc = StreamingContext(self.sc, self.duration)
def tearDown(self):
self.ssc.stop()
def wait_for(self, result, n):
start_time = time.time()
while len(result) < n and time.time() - start_time < self.timeout:
time.sleep(0.01)
if len(result) < n:
print "timeout after", self.timeout
def _take(self, dstream, n):
"""
Return the first `n` elements in the stream (will start and stop).
"""
results = []
def take(_, rdd):
if rdd and len(results) < n:
results.extend(rdd.take(n - len(results)))
dstream.foreachRDD(take)
self.ssc.start()
self.wait_for(results, n)
return results
def _collect(self, dstream, n, block=True):
"""
Collect each RDDs into the returned list.
:return: list, which will have the collected items.
"""
result = []
def get_output(_, rdd):
if rdd and len(result) < n:
r = rdd.collect()
if r:
result.append(r)
dstream.foreachRDD(get_output)
if not block:
return result
self.ssc.start()
self.wait_for(result, n)
return result
def _test_func(self, input, func, expected, sort=False, input2=None):
"""
@param input: dataset for the test. This should be list of lists.
@param func: wrapped function. This function should return PythonDStream object.
@param expected: expected output for this testcase.
"""
if not isinstance(input[0], RDD):
input = [self.sc.parallelize(d, 1) for d in input]
input_stream = self.ssc.queueStream(input)
if input2 and not isinstance(input2[0], RDD):
input2 = [self.sc.parallelize(d, 1) for d in input2]
input_stream2 = self.ssc.queueStream(input2) if input2 is not None else None
# Apply test function to stream.
if input2:
stream = func(input_stream, input_stream2)
else:
stream = func(input_stream)
result = self._collect(stream, len(expected))
if sort:
self._sort_result_based_on_key(result)
self._sort_result_based_on_key(expected)
self.assertEqual(expected, result)
def _sort_result_based_on_key(self, outputs):
"""Sort the list based on first value."""
for output in outputs:
output.sort(key=lambda x: x[0])
class BasicOperationTests(PySparkStreamingTestCase):
def test_map(self):
"""Basic operation test for DStream.map."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.map(str)
expected = map(lambda x: map(str, x), input)
self._test_func(input, func, expected)
def test_flatMap(self):
"""Basic operation test for DStream.faltMap."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.flatMap(lambda x: (x, x * 2))
expected = map(lambda x: list(chain.from_iterable((map(lambda y: [y, y * 2], x)))),
input)
self._test_func(input, func, expected)
def test_filter(self):
"""Basic operation test for DStream.filter."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.filter(lambda x: x % 2 == 0)
expected = map(lambda x: filter(lambda y: y % 2 == 0, x), input)
self._test_func(input, func, expected)
def test_count(self):
"""Basic operation test for DStream.count."""
input = [range(5), range(10), range(20)]
def func(dstream):
return dstream.count()
expected = map(lambda x: [len(x)], input)
self._test_func(input, func, expected)
def test_reduce(self):
"""Basic operation test for DStream.reduce."""
input = [range(1, 5), range(5, 9), range(9, 13)]
def func(dstream):
return dstream.reduce(operator.add)
expected = map(lambda x: [reduce(operator.add, x)], input)
self._test_func(input, func, expected)
def test_reduceByKey(self):
"""Basic operation test for DStream.reduceByKey."""
input = [[("a", 1), ("a", 1), ("b", 1), ("b", 1)],
[("", 1), ("", 1), ("", 1), ("", 1)],
[(1, 1), (1, 1), (2, 1), (2, 1), (3, 1)]]
def func(dstream):
return dstream.reduceByKey(operator.add)
expected = [[("a", 2), ("b", 2)], [("", 4)], [(1, 2), (2, 2), (3, 1)]]
self._test_func(input, func, expected, sort=True)
def test_mapValues(self):
"""Basic operation test for DStream.mapValues."""
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
[("", 4), (1, 1), (2, 2), (3, 3)],
[(1, 1), (2, 1), (3, 1), (4, 1)]]
def func(dstream):
return dstream.mapValues(lambda x: x + 10)
expected = [[("a", 12), ("b", 12), ("c", 11), ("d", 11)],
[("", 14), (1, 11), (2, 12), (3, 13)],
[(1, 11), (2, 11), (3, 11), (4, 11)]]
self._test_func(input, func, expected, sort=True)
def test_flatMapValues(self):
"""Basic operation test for DStream.flatMapValues."""
input = [[("a", 2), ("b", 2), ("c", 1), ("d", 1)],
[("", 4), (1, 1), (2, 1), (3, 1)],
[(1, 1), (2, 1), (3, 1), (4, 1)]]
def func(dstream):
return dstream.flatMapValues(lambda x: (x, x + 10))
expected = [[("a", 2), ("a", 12), ("b", 2), ("b", 12),
("c", 1), ("c", 11), ("d", 1), ("d", 11)],
[("", 4), ("", 14), (1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11)],
[(1, 1), (1, 11), (2, 1), (2, 11), (3, 1), (3, 11), (4, 1), (4, 11)]]
self._test_func(input, func, expected)
def test_glom(self):
"""Basic operation test for DStream.glom."""
input = [range(1, 5), range(5, 9), range(9, 13)]
rdds = [self.sc.parallelize(r, 2) for r in input]
def func(dstream):
return dstream.glom()
expected = [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]
self._test_func(rdds, func, expected)
def test_mapPartitions(self):
"""Basic operation test for DStream.mapPartitions."""
input = [range(1, 5), range(5, 9), range(9, 13)]
rdds = [self.sc.parallelize(r, 2) for r in input]
def func(dstream):
def f(iterator):
yield sum(iterator)
return dstream.mapPartitions(f)
expected = [[3, 7], [11, 15], [19, 23]]
self._test_func(rdds, func, expected)
def test_countByValue(self):
"""Basic operation test for DStream.countByValue."""
input = [range(1, 5) * 2, range(5, 7) + range(5, 9), ["a", "a", "b", ""]]
def func(dstream):
return dstream.countByValue()
expected = [[4], [4], [3]]
self._test_func(input, func, expected)
def test_groupByKey(self):
"""Basic operation test for DStream.groupByKey."""
input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
[(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
def func(dstream):
return dstream.groupByKey().mapValues(list)
expected = [[(1, [1]), (2, [1]), (3, [1]), (4, [1])],
[(1, [1, 1, 1]), (2, [1, 1]), (3, [1])],
[("a", [1, 1]), ("b", [1]), ("", [1, 1, 1])]]
self._test_func(input, func, expected, sort=True)
def test_combineByKey(self):
"""Basic operation test for DStream.combineByKey."""
input = [[(1, 1), (2, 1), (3, 1), (4, 1)],
[(1, 1), (1, 1), (1, 1), (2, 1), (2, 1), (3, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1), ("", 1)]]
def func(dstream):
def add(a, b):
return a + str(b)
return dstream.combineByKey(str, add, add)
expected = [[(1, "1"), (2, "1"), (3, "1"), (4, "1")],
[(1, "111"), (2, "11"), (3, "1")],
[("a", "11"), ("b", "1"), ("", "111")]]
self._test_func(input, func, expected, sort=True)
def test_repartition(self):
input = [range(1, 5), range(5, 9)]
rdds = [self.sc.parallelize(r, 2) for r in input]
def func(dstream):
return dstream.repartition(1).glom()
expected = [[[1, 2, 3, 4]], [[5, 6, 7, 8]]]
self._test_func(rdds, func, expected)
def test_union(self):
input1 = [range(3), range(5), range(6)]
input2 = [range(3, 6), range(5, 6)]
def func(d1, d2):
return d1.union(d2)
expected = [range(6), range(6), range(6)]
self._test_func(input1, func, expected, input2=input2)
def test_cogroup(self):
input = [[(1, 1), (2, 1), (3, 1)],
[(1, 1), (1, 1), (1, 1), (2, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 1)]]
input2 = [[(1, 2)],
[(4, 1)],
[("a", 1), ("a", 1), ("b", 1), ("", 1), ("", 2)]]
def func(d1, d2):
return d1.cogroup(d2).mapValues(lambda vs: tuple(map(list, vs)))
expected = [[(1, ([1], [2])), (2, ([1], [])), (3, ([1], []))],
[(1, ([1, 1, 1], [])), (2, ([1], [])), (4, ([], [1]))],
[("a", ([1, 1], [1, 1])), ("b", ([1], [1])), ("", ([1, 1], [1, 2]))]]
self._test_func(input, func, expected, sort=True, input2=input2)
def test_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.join(b)
expected = [[('b', (2, 3))]]
self._test_func(input, func, expected, True, input2)
def test_left_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.leftOuterJoin(b)
expected = [[('a', (1, None)), ('b', (2, 3))]]
self._test_func(input, func, expected, True, input2)
def test_right_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.rightOuterJoin(b)
expected = [[('b', (2, 3)), ('c', (None, 4))]]
self._test_func(input, func, expected, True, input2)
def test_full_outer_join(self):
input = [[('a', 1), ('b', 2)]]
input2 = [[('b', 3), ('c', 4)]]
def func(a, b):
return a.fullOuterJoin(b)
expected = [[('a', (1, None)), ('b', (2, 3)), ('c', (None, 4))]]
self._test_func(input, func, expected, True, input2)
def test_update_state_by_key(self):
def updater(vs, s):
if not s:
s = []
s.extend(vs)
return s
input = [[('k', i)] for i in range(5)]
def func(dstream):
return dstream.updateStateByKey(updater)
expected = [[0], [0, 1], [0, 1, 2], [0, 1, 2, 3], [0, 1, 2, 3, 4]]
expected = [[('k', v)] for v in expected]
self._test_func(input, func, expected)
class WindowFunctionTests(PySparkStreamingTestCase):
timeout = 20
def test_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
return dstream.window(3, 1).count()
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
def test_count_by_window(self):
input = [range(1), range(2), range(3), range(4), range(5)]
def func(dstream):
return dstream.countByWindow(3, 1)
expected = [[1], [3], [6], [9], [12], [9], [5]]
self._test_func(input, func, expected)
def test_count_by_window_large(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
return dstream.countByWindow(5, 1)
expected = [[1], [3], [6], [10], [15], [20], [18], [15], [11], [6]]
self._test_func(input, func, expected)
def test_count_by_value_and_window(self):
input = [range(1), range(2), range(3), range(4), range(5), range(6)]
def func(dstream):
return dstream.countByValueAndWindow(5, 1)
expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]]
self._test_func(input, func, expected)
def test_group_by_key_and_window(self):
input = [[('a', i)] for i in range(5)]
def func(dstream):
return dstream.groupByKeyAndWindow(3, 1).mapValues(list)
expected = [[('a', [0])], [('a', [0, 1])], [('a', [0, 1, 2])], [('a', [1, 2, 3])],
[('a', [2, 3, 4])], [('a', [3, 4])], [('a', [4])]]
self._test_func(input, func, expected)
def test_reduce_by_invalid_window(self):
input1 = [range(3), range(5), range(1), range(6)]
d1 = self.ssc.queueStream(input1)
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 0.1, 0.1))
self.assertRaises(ValueError, lambda: d1.reduceByKeyAndWindow(None, None, 1, 0.1))
class StreamingContextTests(PySparkStreamingTestCase):
duration = 0.1
def _add_input_stream(self):
inputs = map(lambda x: range(1, x), range(101))
stream = self.ssc.queueStream(inputs)
self._collect(stream, 1, block=False)
def test_stop_only_streaming_context(self):
self._add_input_stream()
self.ssc.start()
self.ssc.stop(False)
self.assertEqual(len(self.sc.parallelize(range(5), 5).glom().collect()), 5)
def test_stop_multiple_times(self):
self._add_input_stream()
self.ssc.start()
self.ssc.stop()
self.ssc.stop()
def test_queue_stream(self):
input = [range(i + 1) for i in range(3)]
dstream = self.ssc.queueStream(input)
result = self._collect(dstream, 3)
self.assertEqual(input, result)
def test_text_file_stream(self):
d = tempfile.mkdtemp()
self.ssc = StreamingContext(self.sc, self.duration)
dstream2 = self.ssc.textFileStream(d).map(int)
result = self._collect(dstream2, 2, block=False)
self.ssc.start()
for name in ('a', 'b'):
time.sleep(1)
with open(os.path.join(d, name), "w") as f:
f.writelines(["%d\n" % i for i in range(10)])
self.wait_for(result, 2)
self.assertEqual([range(10), range(10)], result)
def test_union(self):
input = [range(i + 1) for i in range(3)]
dstream = self.ssc.queueStream(input)
dstream2 = self.ssc.queueStream(input)
dstream3 = self.ssc.union(dstream, dstream2)
result = self._collect(dstream3, 3)
expected = [i * 2 for i in input]
self.assertEqual(expected, result)
def test_transform(self):
dstream1 = self.ssc.queueStream([[1]])
dstream2 = self.ssc.queueStream([[2]])
dstream3 = self.ssc.queueStream([[3]])
def func(rdds):
rdd1, rdd2, rdd3 = rdds
return rdd2.union(rdd3).union(rdd1)
dstream = self.ssc.transform([dstream1, dstream2, dstream3], func)
self.assertEqual([2, 3, 1], self._take(dstream, 3))
class CheckpointTests(PySparkStreamingTestCase):
def setUp(self):
pass
def test_get_or_create(self):
inputd = tempfile.mkdtemp()
outputd = tempfile.mkdtemp() + "/"
def updater(vs, s):
return sum(vs, s or 0)
def setup():
conf = SparkConf().set("spark.default.parallelism", 1)
sc = SparkContext(conf=conf)
ssc = StreamingContext(sc, 0.5)
dstream = ssc.textFileStream(inputd).map(lambda x: (x, 1))
wc = dstream.updateStateByKey(updater)
wc.map(lambda x: "%s,%d" % x).saveAsTextFiles(outputd + "test")
wc.checkpoint(.5)
return ssc
cpd = tempfile.mkdtemp("test_streaming_cps")
self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start()
def check_output(n):
while not os.listdir(outputd):
time.sleep(0.1)
time.sleep(1) # make sure mtime is larger than the previous one
with open(os.path.join(inputd, str(n)), 'w') as f:
f.writelines(["%d\n" % i for i in range(10)])
while True:
p = os.path.join(outputd, max(os.listdir(outputd)))
if '_SUCCESS' not in os.listdir(p):
# not finished
time.sleep(0.01)
continue
ordd = ssc.sparkContext.textFile(p).map(lambda line: line.split(","))
d = ordd.values().map(int).collect()
if not d:
time.sleep(0.01)
continue
self.assertEqual(10, len(d))
s = set(d)
self.assertEqual(1, len(s))
m = s.pop()
if n > m:
continue
self.assertEqual(n, m)
break
check_output(1)
check_output(2)
ssc.stop(True, True)
time.sleep(1)
self.ssc = ssc = StreamingContext.getOrCreate(cpd, setup)
ssc.start()
check_output(3)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,128 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import time
from datetime import datetime
import traceback
from pyspark import SparkContext, RDD
class TransformFunction(object):
"""
This class wraps a function RDD[X] -> RDD[Y] that was passed to
DStream.transform(), allowing it to be called from Java via Py4J's
callback server.
Java calls this function with a sequence of JavaRDDs and this function
returns a single JavaRDD pointer back to Java.
"""
_emptyRDD = None
def __init__(self, ctx, func, *deserializers):
self.ctx = ctx
self.func = func
self.deserializers = deserializers
def call(self, milliseconds, jrdds):
try:
if self.ctx is None:
self.ctx = SparkContext._active_spark_context
if not self.ctx or not self.ctx._jsc:
# stopped
return
# extend deserializers with the first one
sers = self.deserializers
if len(sers) < len(jrdds):
sers += (sers[0],) * (len(jrdds) - len(sers))
rdds = [RDD(jrdd, self.ctx, ser) if jrdd else None
for jrdd, ser in zip(jrdds, sers)]
t = datetime.fromtimestamp(milliseconds / 1000.0)
r = self.func(t, *rdds)
if r:
return r._jrdd
except Exception:
traceback.print_exc()
def __repr__(self):
return "TransformFunction(%s)" % self.func
class Java:
implements = ['org.apache.spark.streaming.api.python.PythonTransformFunction']
class TransformFunctionSerializer(object):
"""
This class implements a serializer for PythonTransformFunction Java
objects.
This is necessary because the Java PythonTransformFunction objects are
actually Py4J references to Python objects and thus are not directly
serializable. When Java needs to serialize a PythonTransformFunction,
it uses this class to invoke Python, which returns the serialized function
as a byte array.
"""
def __init__(self, ctx, serializer, gateway=None):
self.ctx = ctx
self.serializer = serializer
self.gateway = gateway or self.ctx._gateway
self.gateway.jvm.PythonDStream.registerSerializer(self)
def dumps(self, id):
try:
func = self.gateway.gateway_property.pool[id]
return bytearray(self.serializer.dumps((func.func, func.deserializers)))
except Exception:
traceback.print_exc()
def loads(self, bytes):
try:
f, deserializers = self.serializer.loads(str(bytes))
return TransformFunction(self.ctx, f, *deserializers)
except Exception:
traceback.print_exc()
def __repr__(self):
return "TransformFunctionSerializer(%s)" % self.serializer
class Java:
implements = ['org.apache.spark.streaming.api.python.PythonTransformFunctionSerializer']
def rddToFileName(prefix, suffix, timestamp):
"""
Return string prefix-time(.suffix)
>>> rddToFileName("spark", None, 12345678910)
'spark-12345678910'
>>> rddToFileName("spark", "tmp", 12345678910)
'spark-12345678910.tmp'
"""
if isinstance(timestamp, datetime):
seconds = time.mktime(timestamp.timetuple())
timestamp = long(seconds * 1000) + timestamp.microsecond / 1000
if suffix is None:
return prefix + "-" + str(timestamp)
else:
return prefix + "-" + str(timestamp) + "." + suffix
if __name__ == "__main__":
import doctest
doctest.testmod()

View file

@ -81,6 +81,11 @@ function run_mllib_tests() {
run_test "pyspark/mllib/tests.py"
}
function run_streaming_tests() {
run_test "pyspark/streaming/util.py"
run_test "pyspark/streaming/tests.py"
}
echo "Running PySpark tests. Output is in python/unit-tests.log."
export PYSPARK_PYTHON="python"
@ -96,6 +101,7 @@ $PYSPARK_PYTHON --version
run_core_tests
run_sql_tests
run_mllib_tests
run_streaming_tests
# Try to test with PyPy
if [ $(which pypy) ]; then
@ -105,6 +111,7 @@ if [ $(which pypy) ]; then
run_core_tests
run_sql_tests
run_streaming_tests
fi
if [[ $FAILED == 0 ]]; then

View file

@ -167,7 +167,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
new JavaPairDStream(dstream.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2])
}
/**
/**
* Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
* of the RDD.

View file

@ -0,0 +1,316 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.api.python
import java.io.{ObjectInputStream, ObjectOutputStream}
import java.lang.reflect.Proxy
import java.util.{ArrayList => JArrayList, List => JList}
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.language.existentials
import py4j.GatewayServer
import org.apache.spark.api.java._
import org.apache.spark.api.python._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Interval, Duration, Time}
import org.apache.spark.streaming.dstream._
import org.apache.spark.streaming.api.java._
/**
* Interface for Python callback function which is used to transform RDDs
*/
private[python] trait PythonTransformFunction {
def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
}
/**
* Interface for Python Serializer to serialize PythonTransformFunction
*/
private[python] trait PythonTransformFunctionSerializer {
def dumps(id: String): Array[Byte]
def loads(bytes: Array[Byte]): PythonTransformFunction
}
/**
* Wraps a PythonTransformFunction (which is a Python object accessed through Py4J)
* so that it looks like a Scala function and can be transparently serialized and
* deserialized by Java.
*/
private[python] class TransformFunction(@transient var pfunc: PythonTransformFunction)
extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] {
def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
Option(pfunc.call(time.milliseconds, List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava))
.map(_.rdd)
}
def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava
Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd)
}
// for function.Function2
def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
pfunc.call(time.milliseconds, rdds)
}
private def writeObject(out: ObjectOutputStream): Unit = {
val bytes = PythonTransformFunctionSerializer.serialize(pfunc)
out.writeInt(bytes.length)
out.write(bytes)
}
private def readObject(in: ObjectInputStream): Unit = {
val length = in.readInt()
val bytes = new Array[Byte](length)
in.readFully(bytes)
pfunc = PythonTransformFunctionSerializer.deserialize(bytes)
}
}
/**
* Helpers for PythonTransformFunctionSerializer
*
* PythonTransformFunctionSerializer is logically a singleton that's happens to be
* implemented as a Python object.
*/
private[python] object PythonTransformFunctionSerializer {
/**
* A serializer in Python, used to serialize PythonTransformFunction
*/
private var serializer: PythonTransformFunctionSerializer = _
/*
* Register a serializer from Python, should be called during initialization
*/
def register(ser: PythonTransformFunctionSerializer): Unit = {
serializer = ser
}
def serialize(func: PythonTransformFunction): Array[Byte] = {
assert(serializer != null, "Serializer has not been registered!")
// get the id of PythonTransformFunction in py4j
val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
val f = h.getClass().getDeclaredField("id")
f.setAccessible(true)
val id = f.get(h).asInstanceOf[String]
serializer.dumps(id)
}
def deserialize(bytes: Array[Byte]): PythonTransformFunction = {
assert(serializer != null, "Serializer has not been registered!")
serializer.loads(bytes)
}
}
/**
* Helper functions, which are called from Python via Py4J.
*/
private[python] object PythonDStream {
/**
* can not access PythonTransformFunctionSerializer.register() via Py4j
* Py4JError: PythonTransformFunctionSerializerregister does not exist in the JVM
*/
def registerSerializer(ser: PythonTransformFunctionSerializer): Unit = {
PythonTransformFunctionSerializer.register(ser)
}
/**
* Update the port of callback client to `port`
*/
def updatePythonGatewayPort(gws: GatewayServer, port: Int): Unit = {
val cl = gws.getCallbackClient
val f = cl.getClass.getDeclaredField("port")
f.setAccessible(true)
f.setInt(cl, port)
}
/**
* helper function for DStream.foreachRDD(),
* cannot be `foreachRDD`, it will confusing py4j
*/
def callForeachRDD(jdstream: JavaDStream[Array[Byte]], pfunc: PythonTransformFunction) {
val func = new TransformFunction((pfunc))
jdstream.dstream.foreachRDD((rdd, time) => func(Some(rdd), time))
}
/**
* convert list of RDD into queue of RDDs, for ssc.queueStream()
*/
def toRDDQueue(rdds: JArrayList[JavaRDD[Array[Byte]]]): java.util.Queue[JavaRDD[Array[Byte]]] = {
val queue = new java.util.LinkedList[JavaRDD[Array[Byte]]]
rdds.forall(queue.add(_))
queue
}
}
/**
* Base class for PythonDStream with some common methods
*/
private[python] abstract class PythonDStream(
parent: DStream[_],
@transient pfunc: PythonTransformFunction)
extends DStream[Array[Byte]] (parent.ssc) {
val func = new TransformFunction(pfunc)
override def dependencies = List(parent)
override def slideDuration: Duration = parent.slideDuration
val asJavaDStream = JavaDStream.fromDStream(this)
}
/**
* Transformed DStream in Python.
*/
private[python] class PythonTransformedDStream (
parent: DStream[_],
@transient pfunc: PythonTransformFunction)
extends PythonDStream(parent, pfunc) {
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
val rdd = parent.getOrCompute(validTime)
if (rdd.isDefined) {
func(rdd, validTime)
} else {
None
}
}
}
/**
* Transformed from two DStreams in Python.
*/
private[python] class PythonTransformed2DStream(
parent: DStream[_],
parent2: DStream[_],
@transient pfunc: PythonTransformFunction)
extends DStream[Array[Byte]] (parent.ssc) {
val func = new TransformFunction(pfunc)
override def dependencies = List(parent, parent2)
override def slideDuration: Duration = parent.slideDuration
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
val empty: RDD[_] = ssc.sparkContext.emptyRDD
val rdd1 = parent.getOrCompute(validTime).getOrElse(empty)
val rdd2 = parent2.getOrCompute(validTime).getOrElse(empty)
func(Some(rdd1), Some(rdd2), validTime)
}
val asJavaDStream = JavaDStream.fromDStream(this)
}
/**
* similar to StateDStream
*/
private[python] class PythonStateDStream(
parent: DStream[Array[Byte]],
@transient reduceFunc: PythonTransformFunction)
extends PythonDStream(parent, reduceFunc) {
super.persist(StorageLevel.MEMORY_ONLY)
override val mustCheckpoint = true
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
val lastState = getOrCompute(validTime - slideDuration)
val rdd = parent.getOrCompute(validTime)
if (rdd.isDefined) {
func(lastState, rdd, validTime)
} else {
lastState
}
}
}
/**
* similar to ReducedWindowedDStream
*/
private[python] class PythonReducedWindowedDStream(
parent: DStream[Array[Byte]],
@transient preduceFunc: PythonTransformFunction,
@transient pinvReduceFunc: PythonTransformFunction,
_windowDuration: Duration,
_slideDuration: Duration)
extends PythonDStream(parent, preduceFunc) {
super.persist(StorageLevel.MEMORY_ONLY)
override val mustCheckpoint = true
val invReduceFunc = new TransformFunction(pinvReduceFunc)
def windowDuration: Duration = _windowDuration
override def slideDuration: Duration = _slideDuration
override def parentRememberDuration: Duration = rememberDuration + windowDuration
override def compute(validTime: Time): Option[RDD[Array[Byte]]] = {
val currentTime = validTime
val current = new Interval(currentTime - windowDuration, currentTime)
val previous = current - slideDuration
// _____________________________
// | previous window _________|___________________
// |___________________| current window | --------------> Time
// |_____________________________|
//
// |________ _________| |________ _________|
// | |
// V V
// old RDDs new RDDs
//
val previousRDD = getOrCompute(previous.endTime)
// for small window, reduce once will be better than twice
if (pinvReduceFunc != null && previousRDD.isDefined
&& windowDuration >= slideDuration * 5) {
// subtract the values from old RDDs
val oldRDDs = parent.slice(previous.beginTime + parent.slideDuration, current.beginTime)
val subtracted = if (oldRDDs.size > 0) {
invReduceFunc(previousRDD, Some(ssc.sc.union(oldRDDs)), validTime)
} else {
previousRDD
}
// add the RDDs of the reduced values in "new time steps"
val newRDDs = parent.slice(previous.endTime + parent.slideDuration, current.endTime)
if (newRDDs.size > 0) {
func(subtracted, Some(ssc.sc.union(newRDDs)), validTime)
} else {
subtracted
}
} else {
// Get the RDDs of the reduced values in current window
val currentRDDs = parent.slice(current.beginTime + parent.slideDuration, current.endTime)
if (currentRDDs.size > 0) {
func(None, Some(ssc.sc.union(currentRDDs)), validTime)
} else {
None
}
}
}
}