diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 042a2664a0..4e417679ca 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -17,17 +17,12 @@ package org.apache.spark -import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap import java.util.function.Consumer import scala.collection.mutable.ArrayBuffer -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render} - import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} @@ -104,15 +99,10 @@ private[spark] class BarrierCoordinator( // reset when a barrier() call fails due to timeout. private var barrierEpoch: Int = 0 - // An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call + // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() + // call. private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) - // An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call - private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer] - - // The blocking requestMethod called by tasks to sync up for this stage attempt - private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER - // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null @@ -140,32 +130,9 @@ private[spark] class BarrierCoordinator( // Process the global sync request. The barrier() call succeed if collected enough requests // within a configured time, otherwise fail all the pending requests. - def handleRequest( - requester: RpcCallContext, - request: RequestToSync - ): Unit = synchronized { + def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { val taskId = request.taskAttemptId val epoch = request.barrierEpoch - val requestMethod = request.requestMethod - val partitionId = request.partitionId - val allGatherMessage = request match { - case ag: AllGatherRequestToSync => ag.allGatherMessage - case _ => "" - } - - if (requesters.size == 0) { - requestMethodToSync = requestMethod - } - - if (requestMethodToSync != requestMethod) { - requesters.foreach( - _.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " + - s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " + - s"the current synchronized requestMethod `$requestMethodToSync`" - )) - ) - cleanupBarrierStage(barrierId) - } // Require the number of tasks is correctly set from the BarrierTaskContext. require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + @@ -186,7 +153,6 @@ private[spark] class BarrierCoordinator( } // Add the requester to array of RPCCallContexts pending for reply. requesters += requester - allGatherMessages(partitionId) = allGatherMessage logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + s"$taskId, current progress: ${requesters.size}/$numTasks.") if (maybeFinishAllRequesters(requesters, numTasks)) { @@ -196,7 +162,6 @@ private[spark] class BarrierCoordinator( s"tasks, finished successfully.") barrierEpoch += 1 requesters.clear() - allGatherMessages.clear() cancelTimerTask() } } @@ -208,13 +173,7 @@ private[spark] class BarrierCoordinator( requesters: ArrayBuffer[RpcCallContext], numTasks: Int): Boolean = { if (requesters.size == numTasks) { - requestMethodToSync match { - case RequestMethod.BARRIER => - requesters.foreach(_.reply("")) - case RequestMethod.ALL_GATHER => - val json: String = compact(render(allGatherMessages)) - requesters.foreach(_.reply(json)) - } + requesters.foreach(_.reply(())) true } else { false @@ -227,7 +186,6 @@ private[spark] class BarrierCoordinator( // messages come from current stage attempt shall fail. barrierEpoch = -1 requesters.clear() - allGatherMessages.clear() cancelTimerTask() } } @@ -241,11 +199,11 @@ private[spark] class BarrierCoordinator( } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case request: RequestToSync => + case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => // Get or init the ContextBarrierState correspond to the stage attempt. - val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId) + val barrierId = ContextBarrierId(stageId, stageAttemptId) states.computeIfAbsent(barrierId, - (key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks)) + (key: ContextBarrierId) => new ContextBarrierState(key, numTasks)) val barrierState = states.get(barrierId) barrierState.handleRequest(context, request) @@ -258,16 +216,6 @@ private[spark] class BarrierCoordinator( private[spark] sealed trait BarrierCoordinatorMessage extends Serializable -private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { - def numTasks: Int - def stageId: Int - def stageAttemptId: Int - def taskAttemptId: Long - def barrierEpoch: Int - def partitionId: Int - def requestMethod: RequestMethod.Value -} - /** * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is * identified by stageId + stageAttemptId + barrierEpoch. @@ -276,44 +224,11 @@ private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage { * @param stageId ID of current stage * @param stageAttemptId ID of current stage attempt * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls - * @param partitionId ID of the current partition the task is assigned to - * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. */ -private[spark] case class BarrierRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value -) extends RequestToSync - -/** - * A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is - * identified by stageId + stageAttemptId + barrierEpoch. - * - * @param numTasks The number of global sync requests the BarrierCoordinator shall receive - * @param stageId ID of current stage - * @param stageAttemptId ID of current stage attempt - * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls - * @param partitionId ID of the current partition the task is assigned to - * @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator - * @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER - */ -private[spark] case class AllGatherRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptId: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value, - allGatherMessage: String -) extends RequestToSync - -private[spark] object RequestMethod extends Enumeration { - val BARRIER, ALL_GATHER = Value -} +private[spark] case class RequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int) extends BarrierCoordinatorMessage diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 2263538a11..3d369802f3 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,19 +17,11 @@ package org.apache.spark -import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Properties, Timer, TimerTask} import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.concurrent.TimeoutException import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.json4s.DefaultFormats -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.parse import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics @@ -67,97 +59,6 @@ class BarrierTaskContext private[spark] ( // from different tasks within the same barrier stage attempt to succeed. private lazy val numTasks = getTaskInfos().size - private def getRequestToSync( - numTasks: Int, - stageId: Int, - stageAttemptNumber: Int, - taskAttemptId: Long, - barrierEpoch: Int, - partitionId: Int, - requestMethod: RequestMethod.Value, - allGatherMessage: String - ): RequestToSync = { - requestMethod match { - case RequestMethod.BARRIER => - BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch, partitionId, requestMethod) - case RequestMethod.ALL_GATHER => - AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch, partitionId, requestMethod, allGatherMessage) - } - } - - private def runBarrier( - requestMethod: RequestMethod.Value, - allGatherMessage: String = "" - ): String = { - - logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + - s"the global sync, current barrier epoch is $barrierEpoch.") - logTrace("Current callSite: " + Utils.getCallSite()) - - val startTime = System.currentTimeMillis() - val timerTask = new TimerTask { - override def run(): Unit = { - logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " + - s"under the global sync since $startTime, has been waiting for " + - s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " + - s"current barrier epoch is $barrierEpoch.") - } - } - // Log the update of global sync every 60 seconds. - timer.schedule(timerTask, 60000, 60000) - - var json: String = "" - - try { - val abortableRpcFuture = barrierCoordinator.askAbortable[String]( - message = getRequestToSync(numTasks, stageId, stageAttemptNumber, - taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage), - // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by - // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. - timeout = new RpcTimeout(365.days, "barrierTimeout")) - - // Wait the RPC future to be completed, but every 1 second it will jump out waiting - // and check whether current spark task is killed. If killed, then throw - // a `TaskKilledException`, otherwise continue wait RPC until it completes. - try { - while (!abortableRpcFuture.toFuture.isCompleted) { - // wait RPC future for at most 1 second - try { - json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) - } catch { - case _: TimeoutException | _: InterruptedException => - // If `TimeoutException` thrown, waiting RPC future reach 1 second. - // If `InterruptedException` thrown, it is possible this task is killed. - // So in this two cases, we should check whether task is killed and then - // throw `TaskKilledException` - taskContext.killTaskIfInterrupted() - } - } - } finally { - abortableRpcFuture.abort(taskContext.getKillReason().getOrElse("Unknown reason.")) - } - - barrierEpoch += 1 - logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + - "global sync successfully, waited for " + - s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " + - s"current barrier epoch is $barrierEpoch.") - } catch { - case e: SparkException => - logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + - "to perform global sync, waited for " + - s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " + - s"current barrier epoch is $barrierEpoch.") - throw e - } finally { - timerTask.cancel() - timer.purge() - } - json - } - /** * :: Experimental :: * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to @@ -201,27 +102,67 @@ class BarrierTaskContext private[spark] ( @Experimental @Since("2.4.0") def barrier(): Unit = { - runBarrier(RequestMethod.BARRIER) - () - } + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + + s"the global sync, current barrier epoch is $barrierEpoch.") + logTrace("Current callSite: " + Utils.getCallSite()) - /** - * :: Experimental :: - * Blocks until all tasks in the same stage have reached this routine. Each task passes in - * a message and returns with a list of all the messages passed in by each of those tasks. - * - * CAUTION! The allGather method requires the same precautions as the barrier method - * - * The message is type String rather than Array[Byte] because it is more convenient for - * the user at the cost of worse performance. - */ - @Experimental - @Since("3.0.0") - def allGather(message: String): ArrayBuffer[String] = { - val json = runBarrier(RequestMethod.ALL_GATHER, message) - val jsonArray = parse(json) - implicit val formats = DefaultFormats - ArrayBuffer(jsonArray.extract[Array[String]]: _*) + val startTime = System.currentTimeMillis() + val timerTask = new TimerTask { + override def run(): Unit = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " + + s"under the global sync since $startTime, has been waiting for " + + s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " + + s"current barrier epoch is $barrierEpoch.") + } + } + // Log the update of global sync every 60 seconds. + timer.schedule(timerTask, 60000, 60000) + + try { + val abortableRpcFuture = barrierCoordinator.askAbortable[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch), + // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by + // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. + timeout = new RpcTimeout(365.days, "barrierTimeout")) + + // Wait the RPC future to be completed, but every 1 second it will jump out waiting + // and check whether current spark task is killed. If killed, then throw + // a `TaskKilledException`, otherwise continue wait RPC until it completes. + try { + while (!abortableRpcFuture.toFuture.isCompleted) { + // wait RPC future for at most 1 second + try { + ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second) + } catch { + case _: TimeoutException | _: InterruptedException => + // If `TimeoutException` thrown, waiting RPC future reach 1 second. + // If `InterruptedException` thrown, it is possible this task is killed. + // So in this two cases, we should check whether task is killed and then + // throw `TaskKilledException` + taskContext.killTaskIfInterrupted() + } + } + } finally { + abortableRpcFuture.abort(taskContext.getKillReason().getOrElse("Unknown reason.")) + } + + barrierEpoch += 1 + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + + "global sync successfully, waited for " + + s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " + + s"current barrier epoch is $barrierEpoch.") + } catch { + case e: SparkException => + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + + "to perform global sync, waited for " + + s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " + + s"current barrier epoch is $barrierEpoch.") + throw e + } finally { + timerTask.cancel() + timer.purge() + } } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index fa8bf0fc06..658e0d593a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -24,13 +24,8 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal -import org.json4s.JsonAST._ -import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render} - import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} @@ -243,18 +238,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( sock.setSoTimeout(10000) authHelper.authClient(sock) val input = new DataInputStream(sock.getInputStream()) - val requestMethod = input.readInt() - // The BarrierTaskContext function may wait infinitely, socket shall not timeout - // before the function finishes. - sock.setSoTimeout(0) - requestMethod match { + input.readInt() match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - barrierAndServe(requestMethod, sock) - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val length = input.readInt() - val message = new Array[Byte](length) - input.readFully(message) - barrierAndServe(requestMethod, sock, new String(message, UTF_8)) + // The barrier() function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(0) + barrierAndServe(sock) + case _ => val out = new DataOutputStream(new BufferedOutputStream( sock.getOutputStream)) @@ -405,31 +395,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( } /** - * Gateway to call BarrierTaskContext methods. + * Gateway to call BarrierTaskContext.barrier(). */ - def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { - require( - serverSocket.isDefined, - "No available ServerSocket to redirect the BarrierTaskContext method call." - ) + def barrierAndServe(sock: Socket): Unit = { + require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { - var result: String = "" - requestMethod match { - case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - context.asInstanceOf[BarrierTaskContext].barrier() - result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather( - message - ) - result = compact(render(JArray( - messages.map( - (message) => JString(message) - ).toList - ))) - } - writeUTF(result, out) + context.asInstanceOf[BarrierTaskContext].barrier() + writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out) } catch { case e: SparkException => writeUTF(e.getMessage, out) @@ -664,7 +638,6 @@ private[spark] object SpecialLengths { private[spark] object BarrierTaskContextMessageProtocol { val BARRIER_FUNCTION = 1 - val ALL_GATHER_FUNCTION = 2 val BARRIER_RESULT_SUCCESS = "success" val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index ed38b7f0ec..fc8ac38479 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.scheduler import java.io.File -import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ @@ -53,79 +52,6 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { assert(times.max - times.min <= 1000) } - test("share messages with allGather() call") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) - val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { it => - val context = BarrierTaskContext.get() - // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) - // Pass partitionId message in - val message = context.partitionId().toString - val messages = context.allGather(message) - messages.toList.iterator - } - // Take a sorted list of all the partitionId messages - val messages = rdd2.collect().head - // All the task partitionIds are shared - for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString) - } - - test("throw exception if we attempt to synchronize with different blocking calls") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) - val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { it => - val context = BarrierTaskContext.get() - val partitionId = context.partitionId - if (partitionId == 0) { - context.barrier() - } else { - context.allGather(partitionId.toString) - } - Seq(null).iterator - } - val error = intercept[SparkException] { - rdd2.collect() - }.getMessage - assert(error.contains("does not match the current synchronized requestMethod")) - } - - test("successively sync with allGather and barrier") { - val conf = new SparkConf() - .setMaster("local-cluster[4, 1, 1024]") - .setAppName("test-cluster") - sc = new SparkContext(conf) - val rdd = sc.makeRDD(1 to 10, 4) - val rdd2 = rdd.barrier().mapPartitions { it => - val context = BarrierTaskContext.get() - // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) - context.barrier() - val time1 = System.currentTimeMillis() - // Sleep for a random time before global sync. - Thread.sleep(Random.nextInt(1000)) - // Pass partitionId message in - val message = context.partitionId().toString - val messages = context.allGather(message) - val time2 = System.currentTimeMillis() - Seq((time1, time2)).iterator - } - val times = rdd2.collect() - // All the tasks shall finish the first round of global sync within a short time slot. - val times1 = times.map(_._1) - assert(times1.max - times1.min <= 1000) - - // All the tasks shall finish the second round of global sync within a short time slot. - val times2 = times.map(_._2) - assert(times2.max - times2.min <= 1000) - } - test("support multiple barrier() call within a single task") { initLocalClusterSparkContext() val rdd = sc.makeRDD(1 to 10, 4) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index 90bd2345ac..d648f63338 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -16,10 +16,9 @@ # from __future__ import print_function -import json from pyspark.java_gateway import local_connect_and_auth -from pyspark.serializers import write_int, write_with_length, UTF8Deserializer +from pyspark.serializers import write_int, UTF8Deserializer class TaskContext(object): @@ -108,28 +107,18 @@ class TaskContext(object): BARRIER_FUNCTION = 1 -ALL_GATHER_FUNCTION = 2 -def _load_from_socket(port, auth_secret, function, all_gather_message=None): +def _load_from_socket(port, auth_secret): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. """ (sockfile, sock) = local_connect_and_auth(port, auth_secret) - - # The call may block forever, so no timeout + # The barrier() call may block forever, so no timeout sock.settimeout(None) - - if function == BARRIER_FUNCTION: - # Make a barrier() function call. - write_int(function, sockfile) - elif function == ALL_GATHER_FUNCTION: - # Make a all_gather() function call. - write_int(function, sockfile) - write_with_length(all_gather_message.encode("utf-8"), sockfile) - else: - raise ValueError("Unrecognized function type") + # Make a barrier() function call. + write_int(BARRIER_FUNCTION, sockfile) sockfile.flush() # Collect result. @@ -210,33 +199,7 @@ class BarrierTaskContext(TaskContext): raise Exception("Not supported to call barrier() before initialize " + "BarrierTaskContext.") else: - _load_from_socket(self._port, self._secret, BARRIER_FUNCTION) - - def allGather(self, message=""): - """ - .. note:: Experimental - - This function blocks until all tasks in the same stage have reached this routine. - Each task passes in a message and returns with a list of all the messages passed in - by each of those tasks. - - .. warning:: In a barrier stage, each task much have the same number of `allGather()` - calls, in all possible code branches. - Otherwise, you may get the job hanging or a SparkException after timeout. - """ - if not isinstance(message, str): - raise ValueError("Argument `message` must be of type `str`") - elif self._port is None or self._secret is None: - raise Exception("Not supported to call barrier() before initialize " + - "BarrierTaskContext.") - else: - gathered_items = _load_from_socket( - self._port, - self._secret, - ALL_GATHER_FUNCTION, - message, - ) - return [e for e in json.loads(gathered_items)] + _load_from_socket(self._port, self._secret) def getTaskInfos(self): """ diff --git a/python/pyspark/tests/test_taskcontext.py b/python/pyspark/tests/test_taskcontext.py index 0053aadd9c..68cfe81476 100644 --- a/python/pyspark/tests/test_taskcontext.py +++ b/python/pyspark/tests/test_taskcontext.py @@ -135,26 +135,6 @@ class TaskContextTests(PySparkTestCase): times = rdd.barrier().mapPartitions(f).map(context_barrier).collect() self.assertTrue(max(times) - min(times) < 1) - def test_all_gather(self): - """ - Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks - within a stage and passes messages properly. - """ - rdd = self.sc.parallelize(range(10), 4) - - def f(iterator): - yield sum(iterator) - - def context_barrier(x): - tc = BarrierTaskContext.get() - time.sleep(random.randint(1, 10)) - out = tc.allGather(str(context.partitionId())) - pids = [int(e) for e in out] - return [pids] - - pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0] - self.assertTrue(pids == [0, 1, 2, 3]) - def test_barrier_infos(self): """ Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the