[SPARK-16861][PYSPARK][CORE] Refactor PySpark accumulator API on top of Accumulator V2
## What changes were proposed in this pull request? Move the internals of the PySpark accumulator API from the old deprecated API on top of the new accumulator API. ## How was this patch tested? The existing PySpark accumulator tests (both unit tests and doc tests at the start of accumulator.py). Author: Holden Karau <holden@us.ibm.com> Closes #14467 from holdenk/SPARK-16861-refactor-pyspark-accumulator-api.
This commit is contained in:
parent
5c5396cb47
commit
90d5754212
|
@ -20,7 +20,7 @@ package org.apache.spark.api.python
|
|||
import java.io._
|
||||
import java.net._
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap}
|
||||
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable
|
||||
|
@ -38,7 +38,7 @@ import org.apache.spark.broadcast.Broadcast
|
|||
import org.apache.spark.input.PortableDataStream
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.util.{SerializableConfiguration, Utils}
|
||||
import org.apache.spark.util._
|
||||
|
||||
|
||||
private[spark] class PythonRDD(
|
||||
|
@ -75,7 +75,7 @@ private[spark] case class PythonFunction(
|
|||
pythonExec: String,
|
||||
pythonVer: String,
|
||||
broadcastVars: JList[Broadcast[PythonBroadcast]],
|
||||
accumulator: Accumulator[JList[Array[Byte]]])
|
||||
accumulator: PythonAccumulatorV2)
|
||||
|
||||
/**
|
||||
* A wrapper for chained Python functions (from bottom to top).
|
||||
|
@ -200,7 +200,7 @@ private[spark] class PythonRunner(
|
|||
val updateLen = stream.readInt()
|
||||
val update = new Array[Byte](updateLen)
|
||||
stream.readFully(update)
|
||||
accumulator += Collections.singletonList(update)
|
||||
accumulator.add(update)
|
||||
}
|
||||
// Check whether the worker is ready to be re-used.
|
||||
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
|
||||
|
@ -461,7 +461,7 @@ private[spark] object PythonRDD extends Logging {
|
|||
JavaRDD[Array[Byte]] = {
|
||||
val file = new DataInputStream(new FileInputStream(filename))
|
||||
try {
|
||||
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
|
||||
val objs = new mutable.ArrayBuffer[Array[Byte]]
|
||||
try {
|
||||
while (true) {
|
||||
val length = file.readInt()
|
||||
|
@ -866,11 +866,13 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By
|
|||
}
|
||||
|
||||
/**
|
||||
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
|
||||
* Internal class that acts as an `AccumulatorV2` for Python accumulators. Inside, it
|
||||
* collects a list of pickled strings that we pass to Python through a socket.
|
||||
*/
|
||||
private class PythonAccumulatorParam(@transient private val serverHost: String, serverPort: Int)
|
||||
extends AccumulatorParam[JList[Array[Byte]]] {
|
||||
private[spark] class PythonAccumulatorV2(
|
||||
@transient private val serverHost: String,
|
||||
private val serverPort: Int)
|
||||
extends CollectionAccumulator[Array[Byte]] {
|
||||
|
||||
Utils.checkHost(serverHost, "Expected hostname")
|
||||
|
||||
|
@ -880,30 +882,33 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
|
|||
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
|
||||
* by the DAGScheduler's single-threaded RpcEndpoint anyway.
|
||||
*/
|
||||
@transient var socket: Socket = _
|
||||
@transient private var socket: Socket = _
|
||||
|
||||
def openSocket(): Socket = synchronized {
|
||||
private def openSocket(): Socket = synchronized {
|
||||
if (socket == null || socket.isClosed) {
|
||||
socket = new Socket(serverHost, serverPort)
|
||||
}
|
||||
socket
|
||||
}
|
||||
|
||||
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
|
||||
// Need to override so the types match with PythonFunction
|
||||
override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort)
|
||||
|
||||
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
|
||||
: JList[Array[Byte]] = synchronized {
|
||||
override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized {
|
||||
val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2]
|
||||
// This conditional isn't strictly speaking needed - merging only currently happens on the
|
||||
// driver program - but that isn't gauranteed so incase this changes.
|
||||
if (serverHost == null) {
|
||||
// This happens on the worker node, where we just want to remember all the updates
|
||||
val1.addAll(val2)
|
||||
val1
|
||||
// We are on the worker
|
||||
super.merge(otherPythonAccumulator)
|
||||
} else {
|
||||
// This happens on the master, where we pass the updates to Python through a socket
|
||||
val socket = openSocket()
|
||||
val in = socket.getInputStream
|
||||
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
|
||||
out.writeInt(val2.size)
|
||||
for (array <- val2.asScala) {
|
||||
val values = other.value
|
||||
out.writeInt(values.size)
|
||||
for (array <- values.asScala) {
|
||||
out.writeInt(array.length)
|
||||
out.write(array)
|
||||
}
|
||||
|
@ -913,7 +918,6 @@ private class PythonAccumulatorParam(@transient private val serverHost: String,
|
|||
if (byteRead == -1) {
|
||||
throw new SparkException("EOF reached before Python server acknowledged")
|
||||
}
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -173,9 +173,8 @@ class SparkContext(object):
|
|||
# they will be passed back to us through a TCP server
|
||||
self._accumulatorServer = accumulators._start_update_server()
|
||||
(host, port) = self._accumulatorServer.server_address
|
||||
self._javaAccumulator = self._jsc.accumulator(
|
||||
self._jvm.java.util.ArrayList(),
|
||||
self._jvm.PythonAccumulatorParam(host, port))
|
||||
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port)
|
||||
self._jsc.sc().register(self._javaAccumulator)
|
||||
|
||||
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')
|
||||
self.pythonVer = "%d.%d" % sys.version_info[:2]
|
||||
|
|
Loading…
Reference in a new issue