Do not launch JavaGateways on workers (SPARK-674).
The problem was that the gateway was being initialized whenever the pyspark.context module was loaded. The fix uses lazy initialization that occurs only when SparkContext instances are actually constructed. I also made the gateway and jvm variables private. This change results in ~3-4x performance improvement when running the PySpark unit tests.
This commit is contained in:
parent
571af31304
commit
9cc6ff9c4e
|
@ -24,11 +24,10 @@ class SparkContext(object):
|
|||
broadcast variables on that cluster.
|
||||
"""
|
||||
|
||||
gateway = launch_gateway()
|
||||
jvm = gateway.jvm
|
||||
_readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
|
||||
_writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
|
||||
_takePartition = jvm.PythonRDD.takePartition
|
||||
_gateway = None
|
||||
_jvm = None
|
||||
_writeIteratorToPickleFile = None
|
||||
_takePartition = None
|
||||
_next_accum_id = 0
|
||||
_active_spark_context = None
|
||||
_lock = Lock()
|
||||
|
@ -56,6 +55,13 @@ class SparkContext(object):
|
|||
raise ValueError("Cannot run multiple SparkContexts at once")
|
||||
else:
|
||||
SparkContext._active_spark_context = self
|
||||
if not SparkContext._gateway:
|
||||
SparkContext._gateway = launch_gateway()
|
||||
SparkContext._jvm = SparkContext._gateway.jvm
|
||||
SparkContext._writeIteratorToPickleFile = \
|
||||
SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
|
||||
SparkContext._takePartition = \
|
||||
SparkContext._jvm.PythonRDD.takePartition
|
||||
self.master = master
|
||||
self.jobName = jobName
|
||||
self.sparkHome = sparkHome or None # None becomes null in Py4J
|
||||
|
@ -63,8 +69,8 @@ class SparkContext(object):
|
|||
self.batchSize = batchSize # -1 represents a unlimited batch size
|
||||
|
||||
# Create the Java SparkContext through Py4J
|
||||
empty_string_array = self.gateway.new_array(self.jvm.String, 0)
|
||||
self._jsc = self.jvm.JavaSparkContext(master, jobName, sparkHome,
|
||||
empty_string_array = self._gateway.new_array(self._jvm.String, 0)
|
||||
self._jsc = self._jvm.JavaSparkContext(master, jobName, sparkHome,
|
||||
empty_string_array)
|
||||
|
||||
# Create a single Accumulator in Java that we'll send all our updates through;
|
||||
|
@ -72,8 +78,8 @@ class SparkContext(object):
|
|||
self._accumulatorServer = accumulators._start_update_server()
|
||||
(host, port) = self._accumulatorServer.server_address
|
||||
self._javaAccumulator = self._jsc.accumulator(
|
||||
self.jvm.java.util.ArrayList(),
|
||||
self.jvm.PythonAccumulatorParam(host, port))
|
||||
self._jvm.java.util.ArrayList(),
|
||||
self._jvm.PythonAccumulatorParam(host, port))
|
||||
|
||||
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
|
||||
# Broadcast's __reduce__ method stores Broadcast instances here.
|
||||
|
@ -127,7 +133,8 @@ class SparkContext(object):
|
|||
for x in c:
|
||||
write_with_length(dump_pickle(x), tempFile)
|
||||
tempFile.close()
|
||||
jrdd = self._readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
|
||||
readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
|
||||
jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
|
||||
return RDD(jrdd, self)
|
||||
|
||||
def textFile(self, name, minSplits=None):
|
||||
|
|
|
@ -35,4 +35,4 @@ class SparkFiles(object):
|
|||
return cls._root_directory
|
||||
else:
|
||||
# This will have to change if we support multiple SparkContexts:
|
||||
return cls._sc.jvm.spark.SparkFiles.getRootDirectory()
|
||||
return cls._sc._jvm.spark.SparkFiles.getRootDirectory()
|
||||
|
|
|
@ -407,7 +407,7 @@ class RDD(object):
|
|||
return (str(x).encode("utf-8") for x in iterator)
|
||||
keyed = PipelinedRDD(self, func)
|
||||
keyed._bypass_serializer = True
|
||||
keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
|
||||
keyed._jrdd.map(self.ctx._jvm.BytesToString()).saveAsTextFile(path)
|
||||
|
||||
# Pair functions
|
||||
|
||||
|
@ -550,8 +550,8 @@ class RDD(object):
|
|||
yield dump_pickle(Batch(items))
|
||||
keyed = PipelinedRDD(self, add_shuffle_key)
|
||||
keyed._bypass_serializer = True
|
||||
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
|
||||
partitioner = self.ctx.jvm.PythonPartitioner(numSplits,
|
||||
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
|
||||
partitioner = self.ctx._jvm.PythonPartitioner(numSplits,
|
||||
id(partitionFunc))
|
||||
jrdd = pairRDD.partitionBy(partitioner).values()
|
||||
rdd = RDD(jrdd, self.ctx)
|
||||
|
@ -730,13 +730,13 @@ class PipelinedRDD(RDD):
|
|||
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
|
||||
broadcast_vars = ListConverter().convert(
|
||||
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
|
||||
self.ctx.gateway._gateway_client)
|
||||
self.ctx._gateway._gateway_client)
|
||||
self.ctx._pickled_broadcast_vars.clear()
|
||||
class_manifest = self._prev_jrdd.classManifest()
|
||||
env = copy.copy(self.ctx.environment)
|
||||
env['PYTHONPATH'] = os.environ.get("PYTHONPATH", "")
|
||||
env = MapConverter().convert(env, self.ctx.gateway._gateway_client)
|
||||
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
|
||||
env = MapConverter().convert(env, self.ctx._gateway._gateway_client)
|
||||
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
|
||||
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
|
||||
broadcast_vars, self.ctx._javaAccumulator, class_manifest)
|
||||
self._jrdd_val = python_rdd.asJavaRDD()
|
||||
|
|
|
@ -26,7 +26,7 @@ class PySparkTestCase(unittest.TestCase):
|
|||
sys.path = self._old_sys_path
|
||||
# To avoid Akka rebinding to the same port, since it doesn't unbind
|
||||
# immediately on shutdown
|
||||
self.sc.jvm.System.clearProperty("spark.driver.port")
|
||||
self.sc._jvm.System.clearProperty("spark.driver.port")
|
||||
|
||||
|
||||
class TestCheckpoint(PySparkTestCase):
|
||||
|
|
Loading…
Reference in a new issue