6cf507685e
Re-implement the Python broadcast using file: 1) serialize the python object using cPickle, write into disks. 2) Create a wrapper in JVM (for the dumped file), it read data from during serialization 3) Using TorrentBroadcast or HttpBroadcast to transfer the data (compressed) into executors 4) During deserialization, writing the data into disk. 5) Passing the path into Python worker, read data from disk and unpickle it into python object, until the first access. It fixes the performance regression introduced in #2659, has similar performance as 1.1, but support object larger than 2G, also improve the memory efficiency (only one compressed copy in driver and executor). Testing with a 500M broadcast and 4 tasks (excluding the benefit from reused worker in 1.2): name | 1.1 | 1.2 with this patch | improvement ---------|--------|---------|-------- python-broadcast-w-bytes | 25.20 | 9.33 | 170.13% | python-broadcast-w-set | 4.13 | 4.50 | -8.35% | Testing with 100 tasks (16 CPUs): name | 1.1 | 1.2 with this patch | improvement ---------|--------|---------|-------- python-broadcast-w-bytes | 38.16 | 8.40 | 353.98% python-broadcast-w-set | 23.29 | 9.59 | 142.80% Author: Davies Liu <davies@databricks.com> Closes #3417 from davies/pybroadcast and squashes the following commits: 50a58e0 [Davies Liu] address comments b98de1d [Davies Liu] disable gc while unpickle e5ee6b9 [Davies Liu] support large string 09303b8 [Davies Liu] read all data into memory dde02dd [Davies Liu] improve performance of python broadcast
147 lines
5.2 KiB
Python
147 lines
5.2 KiB
Python
#
|
|
# Licensed to the Apache Software Foundation (ASF) under one or more
|
|
# contributor license agreements. See the NOTICE file distributed with
|
|
# this work for additional information regarding copyright ownership.
|
|
# The ASF licenses this file to You under the Apache License, Version 2.0
|
|
# (the "License"); you may not use this file except in compliance with
|
|
# the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
"""
|
|
Worker that receives input from Piped RDD.
|
|
"""
|
|
import os
|
|
import sys
|
|
import time
|
|
import socket
|
|
import traceback
|
|
import cProfile
|
|
import pstats
|
|
|
|
from pyspark.accumulators import _accumulatorRegistry
|
|
from pyspark.broadcast import Broadcast, _broadcastRegistry
|
|
from pyspark.files import SparkFiles
|
|
from pyspark.serializers import write_with_length, write_int, read_long, \
|
|
write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
|
|
from pyspark import shuffle
|
|
|
|
pickleSer = PickleSerializer()
|
|
utf8_deserializer = UTF8Deserializer()
|
|
|
|
|
|
def report_times(outfile, boot, init, finish):
|
|
write_int(SpecialLengths.TIMING_DATA, outfile)
|
|
write_long(1000 * boot, outfile)
|
|
write_long(1000 * init, outfile)
|
|
write_long(1000 * finish, outfile)
|
|
|
|
|
|
def add_path(path):
|
|
# worker can be used, so donot add path multiple times
|
|
if path not in sys.path:
|
|
# overwrite system packages
|
|
sys.path.insert(1, path)
|
|
|
|
|
|
def main(infile, outfile):
|
|
try:
|
|
boot_time = time.time()
|
|
split_index = read_int(infile)
|
|
if split_index == -1: # for unit tests
|
|
exit(-1)
|
|
|
|
# initialize global state
|
|
shuffle.MemoryBytesSpilled = 0
|
|
shuffle.DiskBytesSpilled = 0
|
|
_accumulatorRegistry.clear()
|
|
|
|
# fetch name of workdir
|
|
spark_files_dir = utf8_deserializer.loads(infile)
|
|
SparkFiles._root_directory = spark_files_dir
|
|
SparkFiles._is_running_on_worker = True
|
|
|
|
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
|
|
add_path(spark_files_dir) # *.py files that were added will be copied here
|
|
num_python_includes = read_int(infile)
|
|
for _ in range(num_python_includes):
|
|
filename = utf8_deserializer.loads(infile)
|
|
add_path(os.path.join(spark_files_dir, filename))
|
|
|
|
# fetch names and values of broadcast variables
|
|
num_broadcast_variables = read_int(infile)
|
|
for _ in range(num_broadcast_variables):
|
|
bid = read_long(infile)
|
|
if bid >= 0:
|
|
path = utf8_deserializer.loads(infile)
|
|
_broadcastRegistry[bid] = Broadcast(path=path)
|
|
else:
|
|
bid = - bid - 1
|
|
_broadcastRegistry.pop(bid)
|
|
|
|
_accumulatorRegistry.clear()
|
|
command = pickleSer._read_with_length(infile)
|
|
if isinstance(command, Broadcast):
|
|
command = pickleSer.loads(command.value)
|
|
(func, stats, deserializer, serializer) = command
|
|
init_time = time.time()
|
|
|
|
def process():
|
|
iterator = deserializer.load_stream(infile)
|
|
serializer.dump_stream(func(split_index, iterator), outfile)
|
|
|
|
if stats:
|
|
p = cProfile.Profile()
|
|
p.runcall(process)
|
|
st = pstats.Stats(p)
|
|
st.stream = None # make it picklable
|
|
stats.add(st.strip_dirs())
|
|
else:
|
|
process()
|
|
except Exception:
|
|
try:
|
|
write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
|
|
write_with_length(traceback.format_exc(), outfile)
|
|
except IOError:
|
|
# JVM close the socket
|
|
pass
|
|
except Exception:
|
|
# Write the error to stderr if it happened while serializing
|
|
print >> sys.stderr, "PySpark worker failed with exception:"
|
|
print >> sys.stderr, traceback.format_exc()
|
|
exit(-1)
|
|
finish_time = time.time()
|
|
report_times(outfile, boot_time, init_time, finish_time)
|
|
write_long(shuffle.MemoryBytesSpilled, outfile)
|
|
write_long(shuffle.DiskBytesSpilled, outfile)
|
|
|
|
# Mark the beginning of the accumulators section of the output
|
|
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
|
|
write_int(len(_accumulatorRegistry), outfile)
|
|
for (aid, accum) in _accumulatorRegistry.items():
|
|
pickleSer._write_with_length((aid, accum._value), outfile)
|
|
|
|
# check end of stream
|
|
if read_int(infile) == SpecialLengths.END_OF_STREAM:
|
|
write_int(SpecialLengths.END_OF_STREAM, outfile)
|
|
else:
|
|
# write a different value to tell JVM to not reuse this worker
|
|
write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
|
|
exit(-1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# Read a local port to connect to from stdin
|
|
java_port = int(sys.stdin.readline())
|
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
sock.connect(("127.0.0.1", java_port))
|
|
sock_file = sock.makefile("a+", 65536)
|
|
main(sock_file, sock_file)
|