1e6648d62f
Move Py4jCallbackConnectionCleaner to Streaming because the callback server starts only in StreamingContext. Author: Shixiong Zhu <shixiong@databricks.com> Closes #10621 from zsxwing/SPARK-12617-2.
454 lines
19 KiB
Python
454 lines
19 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import sys
|
|
from threading import RLock, Timer
|
|
|
|
from py4j.java_gateway import java_import, JavaObject
|
|
|
|
from pyspark import RDD, SparkConf
|
|
from pyspark.serializers import NoOpSerializer, UTF8Deserializer, CloudPickleSerializer
|
|
from pyspark.context import SparkContext
|
|
from pyspark.storagelevel import StorageLevel
|
|
from pyspark.streaming.dstream import DStream
|
|
from pyspark.streaming.util import TransformFunction, TransformFunctionSerializer
|
|
|
|
__all__ = ["StreamingContext"]
|
|
|
|
|
|
class Py4jCallbackConnectionCleaner(object):
|
|
|
|
"""
|
|
A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617.
|
|
It will scan all callback connections every 30 seconds and close the dead connections.
|
|
"""
|
|
|
|
def __init__(self, gateway):
|
|
self._gateway = gateway
|
|
self._stopped = False
|
|
self._timer = None
|
|
self._lock = RLock()
|
|
|
|
def start(self):
|
|
if self._stopped:
|
|
return
|
|
|
|
def clean_closed_connections():
|
|
from py4j.java_gateway import quiet_close, quiet_shutdown
|
|
|
|
callback_server = self._gateway._callback_server
|
|
if callback_server:
|
|
with callback_server.lock:
|
|
try:
|
|
closed_connections = []
|
|
for connection in callback_server.connections:
|
|
if not connection.isAlive():
|
|
quiet_close(connection.input)
|
|
quiet_shutdown(connection.socket)
|
|
quiet_close(connection.socket)
|
|
closed_connections.append(connection)
|
|
|
|
for closed_connection in closed_connections:
|
|
callback_server.connections.remove(closed_connection)
|
|
except Exception:
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
self._start_timer(clean_closed_connections)
|
|
|
|
self._start_timer(clean_closed_connections)
|
|
|
|
def _start_timer(self, f):
|
|
with self._lock:
|
|
if not self._stopped:
|
|
self._timer = Timer(30.0, f)
|
|
self._timer.daemon = True
|
|
self._timer.start()
|
|
|
|
def stop(self):
|
|
with self._lock:
|
|
self._stopped = True
|
|
if self._timer:
|
|
self._timer.cancel()
|
|
self._timer = None
|
|
|
|
|
|
class StreamingContext(object):
|
|
"""
|
|
Main entry point for Spark Streaming functionality. A StreamingContext
|
|
represents the connection to a Spark cluster, and can be used to create
|
|
L{DStream} various input sources. It can be from an existing L{SparkContext}.
|
|
After creating and transforming DStreams, the streaming computation can
|
|
be started and stopped using `context.start()` and `context.stop()`,
|
|
respectively. `context.awaitTermination()` allows the current thread
|
|
to wait for the termination of the context by `stop()` or by an exception.
|
|
"""
|
|
_transformerSerializer = None
|
|
|
|
# Reference to a currently active StreamingContext
|
|
_activeContext = None
|
|
|
|
# A cleaner to clean leak sockets of callback server every 30 seconds
|
|
_py4j_cleaner = None
|
|
|
|
def __init__(self, sparkContext, batchDuration=None, jssc=None):
|
|
"""
|
|
Create a new StreamingContext.
|
|
|
|
@param sparkContext: L{SparkContext} object.
|
|
@param batchDuration: the time interval (in seconds) at which streaming
|
|
data will be divided into batches
|
|
"""
|
|
|
|
self._sc = sparkContext
|
|
self._jvm = self._sc._jvm
|
|
self._jssc = jssc or self._initialize_context(self._sc, batchDuration)
|
|
|
|
def _initialize_context(self, sc, duration):
|
|
self._ensure_initialized()
|
|
return self._jvm.JavaStreamingContext(sc._jsc, self._jduration(duration))
|
|
|
|
def _jduration(self, seconds):
|
|
"""
|
|
Create Duration object given number of seconds
|
|
"""
|
|
return self._jvm.Duration(int(seconds * 1000))
|
|
|
|
@classmethod
|
|
def _ensure_initialized(cls):
|
|
SparkContext._ensure_initialized()
|
|
gw = SparkContext._gateway
|
|
|
|
java_import(gw.jvm, "org.apache.spark.streaming.*")
|
|
java_import(gw.jvm, "org.apache.spark.streaming.api.java.*")
|
|
java_import(gw.jvm, "org.apache.spark.streaming.api.python.*")
|
|
|
|
# start callback server
|
|
# getattr will fallback to JVM, so we cannot test by hasattr()
|
|
if "_callback_server" not in gw.__dict__ or gw._callback_server is None:
|
|
gw.callback_server_parameters.eager_load = True
|
|
gw.callback_server_parameters.daemonize = True
|
|
gw.callback_server_parameters.daemonize_connections = True
|
|
gw.callback_server_parameters.port = 0
|
|
gw.start_callback_server(gw.callback_server_parameters)
|
|
cbport = gw._callback_server.server_socket.getsockname()[1]
|
|
gw._callback_server.port = cbport
|
|
# gateway with real port
|
|
gw._python_proxy_port = gw._callback_server.port
|
|
# get the GatewayServer object in JVM by ID
|
|
jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client)
|
|
# update the port of CallbackClient with real port
|
|
gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port)
|
|
_py4j_cleaner = Py4jCallbackConnectionCleaner(gw)
|
|
_py4j_cleaner.start()
|
|
|
|
# register serializer for TransformFunction
|
|
# it happens before creating SparkContext when loading from checkpointing
|
|
if cls._transformerSerializer is None:
|
|
transformer_serializer = TransformFunctionSerializer()
|
|
transformer_serializer.init(
|
|
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
|
|
# SPARK-12511 streaming driver with checkpointing unable to finalize leading to OOM
|
|
# There is an issue that Py4J's PythonProxyHandler.finalize blocks forever.
|
|
# (https://github.com/bartdag/py4j/pull/184)
|
|
#
|
|
# Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when
|
|
# calling "registerSerializer". If we call "registerSerializer" twice, the second
|
|
# PythonProxyHandler will override the first one, then the first one will be GCed and
|
|
# trigger "PythonProxyHandler.finalize". To avoid that, we should not call
|
|
# "registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't
|
|
# be GCed.
|
|
#
|
|
# TODO Once Py4J fixes this issue, we should upgrade Py4j to the latest version.
|
|
transformer_serializer.gateway.jvm.PythonDStream.registerSerializer(
|
|
transformer_serializer)
|
|
cls._transformerSerializer = transformer_serializer
|
|
else:
|
|
cls._transformerSerializer.init(
|
|
SparkContext._active_spark_context, CloudPickleSerializer(), gw)
|
|
|
|
@classmethod
|
|
def getOrCreate(cls, checkpointPath, setupFunc):
|
|
"""
|
|
Either recreate a StreamingContext from checkpoint data or create a new StreamingContext.
|
|
If checkpoint data exists in the provided `checkpointPath`, then StreamingContext will be
|
|
recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
|
|
will be used to create a new context.
|
|
|
|
@param checkpointPath: Checkpoint directory used in an earlier streaming program
|
|
@param setupFunc: Function to create a new context and setup DStreams
|
|
"""
|
|
cls._ensure_initialized()
|
|
gw = SparkContext._gateway
|
|
|
|
# Check whether valid checkpoint information exists in the given path
|
|
ssc_option = gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath)
|
|
if ssc_option.isEmpty():
|
|
ssc = setupFunc()
|
|
ssc.checkpoint(checkpointPath)
|
|
return ssc
|
|
|
|
jssc = gw.jvm.JavaStreamingContext(ssc_option.get())
|
|
|
|
# If there is already an active instance of Python SparkContext use it, or create a new one
|
|
if not SparkContext._active_spark_context:
|
|
jsc = jssc.sparkContext()
|
|
conf = SparkConf(_jconf=jsc.getConf())
|
|
SparkContext(conf=conf, gateway=gw, jsc=jsc)
|
|
|
|
sc = SparkContext._active_spark_context
|
|
|
|
# update ctx in serializer
|
|
cls._transformerSerializer.ctx = sc
|
|
return StreamingContext(sc, None, jssc)
|
|
|
|
@classmethod
|
|
def getActive(cls):
|
|
"""
|
|
Return either the currently active StreamingContext (i.e., if there is a context started
|
|
but not stopped) or None.
|
|
"""
|
|
activePythonContext = cls._activeContext
|
|
if activePythonContext is not None:
|
|
# Verify that the current running Java StreamingContext is active and is the same one
|
|
# backing the supposedly active Python context
|
|
activePythonContextJavaId = activePythonContext._jssc.ssc().hashCode()
|
|
activeJvmContextOption = activePythonContext._jvm.StreamingContext.getActive()
|
|
|
|
if activeJvmContextOption.isEmpty():
|
|
cls._activeContext = None
|
|
elif activeJvmContextOption.get().hashCode() != activePythonContextJavaId:
|
|
cls._activeContext = None
|
|
raise Exception("JVM's active JavaStreamingContext is not the JavaStreamingContext "
|
|
"backing the action Python StreamingContext. This is unexpected.")
|
|
return cls._activeContext
|
|
|
|
@classmethod
|
|
def getActiveOrCreate(cls, checkpointPath, setupFunc):
|
|
"""
|
|
Either return the active StreamingContext (i.e. currently started but not stopped),
|
|
or recreate a StreamingContext from checkpoint data or create a new StreamingContext
|
|
using the provided setupFunc function. If the checkpointPath is None or does not contain
|
|
valid checkpoint data, then setupFunc will be called to create a new context and setup
|
|
DStreams.
|
|
|
|
@param checkpointPath: Checkpoint directory used in an earlier streaming program. Can be
|
|
None if the intention is to always create a new context when there
|
|
is no active context.
|
|
@param setupFunc: Function to create a new JavaStreamingContext and setup DStreams
|
|
"""
|
|
|
|
if setupFunc is None:
|
|
raise Exception("setupFunc cannot be None")
|
|
activeContext = cls.getActive()
|
|
if activeContext is not None:
|
|
return activeContext
|
|
elif checkpointPath is not None:
|
|
return cls.getOrCreate(checkpointPath, setupFunc)
|
|
else:
|
|
return setupFunc()
|
|
|
|
@property
|
|
def sparkContext(self):
|
|
"""
|
|
Return SparkContext which is associated with this StreamingContext.
|
|
"""
|
|
return self._sc
|
|
|
|
def start(self):
|
|
"""
|
|
Start the execution of the streams.
|
|
"""
|
|
self._jssc.start()
|
|
StreamingContext._activeContext = self
|
|
|
|
def awaitTermination(self, timeout=None):
|
|
"""
|
|
Wait for the execution to stop.
|
|
|
|
@param timeout: time to wait in seconds
|
|
"""
|
|
if timeout is None:
|
|
self._jssc.awaitTermination()
|
|
else:
|
|
self._jssc.awaitTerminationOrTimeout(int(timeout * 1000))
|
|
|
|
def awaitTerminationOrTimeout(self, timeout):
|
|
"""
|
|
Wait for the execution to stop. Return `true` if it's stopped; or
|
|
throw the reported error during the execution; or `false` if the
|
|
waiting time elapsed before returning from the method.
|
|
|
|
@param timeout: time to wait in seconds
|
|
"""
|
|
return self._jssc.awaitTerminationOrTimeout(int(timeout * 1000))
|
|
|
|
def stop(self, stopSparkContext=True, stopGraceFully=False):
|
|
"""
|
|
Stop the execution of the streams, with option of ensuring all
|
|
received data has been processed.
|
|
|
|
@param stopSparkContext: Stop the associated SparkContext or not
|
|
@param stopGracefully: Stop gracefully by waiting for the processing
|
|
of all received data to be completed
|
|
"""
|
|
self._jssc.stop(stopSparkContext, stopGraceFully)
|
|
StreamingContext._activeContext = None
|
|
if stopSparkContext:
|
|
self._sc.stop()
|
|
|
|
def remember(self, duration):
|
|
"""
|
|
Set each DStreams in this context to remember RDDs it generated
|
|
in the last given duration. DStreams remember RDDs only for a
|
|
limited duration of time and releases them for garbage collection.
|
|
This method allows the developer to specify how to long to remember
|
|
the RDDs (if the developer wishes to query old data outside the
|
|
DStream computation).
|
|
|
|
@param duration: Minimum duration (in seconds) that each DStream
|
|
should remember its RDDs
|
|
"""
|
|
self._jssc.remember(self._jduration(duration))
|
|
|
|
def checkpoint(self, directory):
|
|
"""
|
|
Sets the context to periodically checkpoint the DStream operations for master
|
|
fault-tolerance. The graph will be checkpointed every batch interval.
|
|
|
|
@param directory: HDFS-compatible directory where the checkpoint data
|
|
will be reliably stored
|
|
"""
|
|
self._jssc.checkpoint(directory)
|
|
|
|
def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_2):
|
|
"""
|
|
Create an input from TCP source hostname:port. Data is received using
|
|
a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited
|
|
lines.
|
|
|
|
@param hostname: Hostname to connect to for receiving data
|
|
@param port: Port to connect to for receiving data
|
|
@param storageLevel: Storage level to use for storing the received objects
|
|
"""
|
|
jlevel = self._sc._getJavaStorageLevel(storageLevel)
|
|
return DStream(self._jssc.socketTextStream(hostname, port, jlevel), self,
|
|
UTF8Deserializer())
|
|
|
|
def textFileStream(self, directory):
|
|
"""
|
|
Create an input stream that monitors a Hadoop-compatible file system
|
|
for new files and reads them as text files. Files must be wrriten to the
|
|
monitored directory by "moving" them from another location within the same
|
|
file system. File names starting with . are ignored.
|
|
"""
|
|
return DStream(self._jssc.textFileStream(directory), self, UTF8Deserializer())
|
|
|
|
def binaryRecordsStream(self, directory, recordLength):
|
|
"""
|
|
Create an input stream that monitors a Hadoop-compatible file system
|
|
for new files and reads them as flat binary files with records of
|
|
fixed length. Files must be written to the monitored directory by "moving"
|
|
them from another location within the same file system.
|
|
File names starting with . are ignored.
|
|
|
|
@param directory: Directory to load data from
|
|
@param recordLength: Length of each record in bytes
|
|
"""
|
|
return DStream(self._jssc.binaryRecordsStream(directory, recordLength), self,
|
|
NoOpSerializer())
|
|
|
|
def _check_serializers(self, rdds):
|
|
# make sure they have same serializer
|
|
if len(set(rdd._jrdd_deserializer for rdd in rdds)) > 1:
|
|
for i in range(len(rdds)):
|
|
# reset them to sc.serializer
|
|
rdds[i] = rdds[i]._reserialize()
|
|
|
|
def queueStream(self, rdds, oneAtATime=True, default=None):
|
|
"""
|
|
Create an input stream from an queue of RDDs or list. In each batch,
|
|
it will process either one or all of the RDDs returned by the queue.
|
|
|
|
NOTE: changes to the queue after the stream is created will not be recognized.
|
|
|
|
@param rdds: Queue of RDDs
|
|
@param oneAtATime: pick one rdd each time or pick all of them once.
|
|
@param default: The default rdd if no more in rdds
|
|
"""
|
|
if default and not isinstance(default, RDD):
|
|
default = self._sc.parallelize(default)
|
|
|
|
if not rdds and default:
|
|
rdds = [rdds]
|
|
|
|
if rdds and not isinstance(rdds[0], RDD):
|
|
rdds = [self._sc.parallelize(input) for input in rdds]
|
|
self._check_serializers(rdds)
|
|
|
|
queue = self._jvm.PythonDStream.toRDDQueue([r._jrdd for r in rdds])
|
|
if default:
|
|
default = default._reserialize(rdds[0]._jrdd_deserializer)
|
|
jdstream = self._jssc.queueStream(queue, oneAtATime, default._jrdd)
|
|
else:
|
|
jdstream = self._jssc.queueStream(queue, oneAtATime)
|
|
return DStream(jdstream, self, rdds[0]._jrdd_deserializer)
|
|
|
|
def transform(self, dstreams, transformFunc):
|
|
"""
|
|
Create a new DStream in which each RDD is generated by applying
|
|
a function on RDDs of the DStreams. The order of the JavaRDDs in
|
|
the transform function parameter will be the same as the order
|
|
of corresponding DStreams in the list.
|
|
"""
|
|
jdstreams = [d._jdstream for d in dstreams]
|
|
# change the final serializer to sc.serializer
|
|
func = TransformFunction(self._sc,
|
|
lambda t, *rdds: transformFunc(rdds).map(lambda x: x),
|
|
*[d._jrdd_deserializer for d in dstreams])
|
|
jfunc = self._jvm.TransformFunction(func)
|
|
jdstream = self._jssc.transform(jdstreams, jfunc)
|
|
return DStream(jdstream, self, self._sc.serializer)
|
|
|
|
def union(self, *dstreams):
|
|
"""
|
|
Create a unified DStream from multiple DStreams of the same
|
|
type and same slide duration.
|
|
"""
|
|
if not dstreams:
|
|
raise ValueError("should have at least one DStream to union")
|
|
if len(dstreams) == 1:
|
|
return dstreams[0]
|
|
if len(set(s._jrdd_deserializer for s in dstreams)) > 1:
|
|
raise ValueError("All DStreams should have same serializer")
|
|
if len(set(s._slideDuration for s in dstreams)) > 1:
|
|
raise ValueError("All DStreams should have same slide duration")
|
|
first = dstreams[0]
|
|
jrest = [d._jdstream for d in dstreams[1:]]
|
|
return DStream(self._jssc.union(first._jdstream, jrest), self, first._jrdd_deserializer)
|
|
|
|
def addStreamingListener(self, streamingListener):
|
|
"""
|
|
Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for
|
|
receiving system events related to streaming.
|
|
"""
|
|
self._jssc.addStreamingListener(self._jvm.JavaStreamingListenerWrapper(
|
|
self._jvm.PythonStreamingListenerWrapper(streamingListener)))
|