Revert "[SPARK-30667][CORE] Add allGather method to BarrierTaskContext"

This reverts commit 57254c9719.
This commit is contained in:
Xingbo Jiang 2020-02-13 17:43:55 -08:00
parent 57254c9719
commit fa3517cdb1
6 changed files with 92 additions and 394 deletions

View file

@ -17,17 +17,12 @@
package org.apache.spark
import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Timer, TimerTask}
import java.util.concurrent.ConcurrentHashMap
import java.util.function.Consumer
import scala.collection.mutable.ArrayBuffer
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark.internal.Logging
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
@ -104,15 +99,10 @@ private[spark] class BarrierCoordinator(
// reset when a barrier() call fails due to timeout.
private var barrierEpoch: Int = 0
// An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
// An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier()
// call.
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)
// An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call
private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer]
// The blocking requestMethod called by tasks to sync up for this stage attempt
private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER
// A timer task that ensures we may timeout for a barrier() call.
private var timerTask: TimerTask = null
@ -140,32 +130,9 @@ private[spark] class BarrierCoordinator(
// Process the global sync request. The barrier() call succeed if collected enough requests
// within a configured time, otherwise fail all the pending requests.
def handleRequest(
requester: RpcCallContext,
request: RequestToSync
): Unit = synchronized {
def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
val taskId = request.taskAttemptId
val epoch = request.barrierEpoch
val requestMethod = request.requestMethod
val partitionId = request.partitionId
val allGatherMessage = request match {
case ag: AllGatherRequestToSync => ag.allGatherMessage
case _ => ""
}
if (requesters.size == 0) {
requestMethodToSync = requestMethod
}
if (requestMethodToSync != requestMethod) {
requesters.foreach(
_.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " +
s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " +
s"the current synchronized requestMethod `$requestMethodToSync`"
))
)
cleanupBarrierStage(barrierId)
}
// Require the number of tasks is correctly set from the BarrierTaskContext.
require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " +
@ -186,7 +153,6 @@ private[spark] class BarrierCoordinator(
}
// Add the requester to array of RPCCallContexts pending for reply.
requesters += requester
allGatherMessages(partitionId) = allGatherMessage
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
s"$taskId, current progress: ${requesters.size}/$numTasks.")
if (maybeFinishAllRequesters(requesters, numTasks)) {
@ -196,7 +162,6 @@ private[spark] class BarrierCoordinator(
s"tasks, finished successfully.")
barrierEpoch += 1
requesters.clear()
allGatherMessages.clear()
cancelTimerTask()
}
}
@ -208,13 +173,7 @@ private[spark] class BarrierCoordinator(
requesters: ArrayBuffer[RpcCallContext],
numTasks: Int): Boolean = {
if (requesters.size == numTasks) {
requestMethodToSync match {
case RequestMethod.BARRIER =>
requesters.foreach(_.reply(""))
case RequestMethod.ALL_GATHER =>
val json: String = compact(render(allGatherMessages))
requesters.foreach(_.reply(json))
}
requesters.foreach(_.reply(()))
true
} else {
false
@ -227,7 +186,6 @@ private[spark] class BarrierCoordinator(
// messages come from current stage attempt shall fail.
barrierEpoch = -1
requesters.clear()
allGatherMessages.clear()
cancelTimerTask()
}
}
@ -241,11 +199,11 @@ private[spark] class BarrierCoordinator(
}
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
case request: RequestToSync =>
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
// Get or init the ContextBarrierState correspond to the stage attempt.
val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId)
val barrierId = ContextBarrierId(stageId, stageAttemptId)
states.computeIfAbsent(barrierId,
(key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks))
(key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
val barrierState = states.get(barrierId)
barrierState.handleRequest(context, request)
@ -258,16 +216,6 @@ private[spark] class BarrierCoordinator(
private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
def numTasks: Int
def stageId: Int
def stageAttemptId: Int
def taskAttemptId: Long
def barrierEpoch: Int
def partitionId: Int
def requestMethod: RequestMethod.Value
}
/**
* A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
@ -276,44 +224,11 @@ private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
* @param partitionId ID of the current partition the task is assigned to
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls.
*/
private[spark] case class BarrierRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value
) extends RequestToSync
/**
* A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
*
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
* @param partitionId ID of the current partition the task is assigned to
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
* @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER
*/
private[spark] case class AllGatherRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value,
allGatherMessage: String
) extends RequestToSync
private[spark] object RequestMethod extends Enumeration {
val BARRIER, ALL_GATHER = Value
}
private[spark] case class RequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int) extends BarrierCoordinatorMessage

View file

@ -17,19 +17,11 @@
package org.apache.spark
import java.nio.charset.StandardCharsets.UTF_8
import java.util.{Properties, Timer, TimerTask}
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.TimeoutException
import scala.concurrent.duration._
import scala.language.postfixOps
import org.json4s.DefaultFormats
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.parse
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
@ -67,97 +59,6 @@ class BarrierTaskContext private[spark] (
// from different tasks within the same barrier stage attempt to succeed.
private lazy val numTasks = getTaskInfos().size
private def getRequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptNumber: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
requestMethod: RequestMethod.Value,
allGatherMessage: String
): RequestToSync = {
requestMethod match {
case RequestMethod.BARRIER =>
BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, requestMethod)
case RequestMethod.ALL_GATHER =>
AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch, partitionId, requestMethod, allGatherMessage)
}
}
private def runBarrier(
requestMethod: RequestMethod.Value,
allGatherMessage: String = ""
): String = {
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
logTrace("Current callSite: " + Utils.getCallSite())
val startTime = System.currentTimeMillis()
val timerTask = new TimerTask {
override def run(): Unit = {
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " +
s"under the global sync since $startTime, has been waiting for " +
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
s"current barrier epoch is $barrierEpoch.")
}
}
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)
var json: String = ""
try {
val abortableRpcFuture = barrierCoordinator.askAbortable[String](
message = getRequestToSync(numTasks, stageId, stageAttemptNumber,
taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(365.days, "barrierTimeout"))
// Wait the RPC future to be completed, but every 1 second it will jump out waiting
// and check whether current spark task is killed. If killed, then throw
// a `TaskKilledException`, otherwise continue wait RPC until it completes.
try {
while (!abortableRpcFuture.toFuture.isCompleted) {
// wait RPC future for at most 1 second
try {
json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
} catch {
case _: TimeoutException | _: InterruptedException =>
// If `TimeoutException` thrown, waiting RPC future reach 1 second.
// If `InterruptedException` thrown, it is possible this task is killed.
// So in this two cases, we should check whether task is killed and then
// throw `TaskKilledException`
taskContext.killTaskIfInterrupted()
}
}
} finally {
abortableRpcFuture.abort(taskContext.getKillReason().getOrElse("Unknown reason."))
}
barrierEpoch += 1
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " +
"global sync successfully, waited for " +
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
s"current barrier epoch is $barrierEpoch.")
} catch {
case e: SparkException =>
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
"to perform global sync, waited for " +
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
s"current barrier epoch is $barrierEpoch.")
throw e
} finally {
timerTask.cancel()
timer.purge()
}
json
}
/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
@ -201,27 +102,67 @@ class BarrierTaskContext private[spark] (
@Experimental
@Since("2.4.0")
def barrier(): Unit = {
runBarrier(RequestMethod.BARRIER)
()
}
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
s"the global sync, current barrier epoch is $barrierEpoch.")
logTrace("Current callSite: " + Utils.getCallSite())
/**
* :: Experimental ::
* Blocks until all tasks in the same stage have reached this routine. Each task passes in
* a message and returns with a list of all the messages passed in by each of those tasks.
*
* CAUTION! The allGather method requires the same precautions as the barrier method
*
* The message is type String rather than Array[Byte] because it is more convenient for
* the user at the cost of worse performance.
*/
@Experimental
@Since("3.0.0")
def allGather(message: String): ArrayBuffer[String] = {
val json = runBarrier(RequestMethod.ALL_GATHER, message)
val jsonArray = parse(json)
implicit val formats = DefaultFormats
ArrayBuffer(jsonArray.extract[Array[String]]: _*)
val startTime = System.currentTimeMillis()
val timerTask = new TimerTask {
override def run(): Unit = {
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " +
s"under the global sync since $startTime, has been waiting for " +
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
s"current barrier epoch is $barrierEpoch.")
}
}
// Log the update of global sync every 60 seconds.
timer.schedule(timerTask, 60000, 60000)
try {
val abortableRpcFuture = barrierCoordinator.askAbortable[Unit](
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
barrierEpoch),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(365.days, "barrierTimeout"))
// Wait the RPC future to be completed, but every 1 second it will jump out waiting
// and check whether current spark task is killed. If killed, then throw
// a `TaskKilledException`, otherwise continue wait RPC until it completes.
try {
while (!abortableRpcFuture.toFuture.isCompleted) {
// wait RPC future for at most 1 second
try {
ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
} catch {
case _: TimeoutException | _: InterruptedException =>
// If `TimeoutException` thrown, waiting RPC future reach 1 second.
// If `InterruptedException` thrown, it is possible this task is killed.
// So in this two cases, we should check whether task is killed and then
// throw `TaskKilledException`
taskContext.killTaskIfInterrupted()
}
}
} finally {
abortableRpcFuture.abort(taskContext.getKillReason().getOrElse("Unknown reason."))
}
barrierEpoch += 1
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " +
"global sync successfully, waited for " +
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
s"current barrier epoch is $barrierEpoch.")
} catch {
case e: SparkException =>
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
"to perform global sync, waited for " +
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
s"current barrier epoch is $barrierEpoch.")
throw e
} finally {
timerTask.cancel()
timer.purge()
}
}
/**

View file

@ -24,13 +24,8 @@ import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import org.json4s.JsonAST._
import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
@ -243,18 +238,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
sock.setSoTimeout(10000)
authHelper.authClient(sock)
val input = new DataInputStream(sock.getInputStream())
val requestMethod = input.readInt()
// The BarrierTaskContext function may wait infinitely, socket shall not timeout
// before the function finishes.
sock.setSoTimeout(0)
requestMethod match {
input.readInt() match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
barrierAndServe(requestMethod, sock)
case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
val length = input.readInt()
val message = new Array[Byte](length)
input.readFully(message)
barrierAndServe(requestMethod, sock, new String(message, UTF_8))
// The barrier() function may wait infinitely, socket shall not timeout
// before the function finishes.
sock.setSoTimeout(0)
barrierAndServe(sock)
case _ =>
val out = new DataOutputStream(new BufferedOutputStream(
sock.getOutputStream))
@ -405,31 +395,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
/**
* Gateway to call BarrierTaskContext methods.
* Gateway to call BarrierTaskContext.barrier().
*/
def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = {
require(
serverSocket.isDefined,
"No available ServerSocket to redirect the BarrierTaskContext method call."
)
def barrierAndServe(sock: Socket): Unit = {
require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.")
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
var result: String = ""
requestMethod match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
context.asInstanceOf[BarrierTaskContext].barrier()
result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS
case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather(
message
)
result = compact(render(JArray(
messages.map(
(message) => JString(message)
).toList
)))
}
writeUTF(result, out)
context.asInstanceOf[BarrierTaskContext].barrier()
writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out)
} catch {
case e: SparkException =>
writeUTF(e.getMessage, out)
@ -664,7 +638,6 @@ private[spark] object SpecialLengths {
private[spark] object BarrierTaskContextMessageProtocol {
val BARRIER_FUNCTION = 1
val ALL_GATHER_FUNCTION = 2
val BARRIER_RESULT_SUCCESS = "success"
val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side."
}

View file

@ -19,7 +19,6 @@ package org.apache.spark.scheduler
import java.io.File
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import org.apache.spark._
@ -53,79 +52,6 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
assert(times.max - times.min <= 1000)
}
test("share messages with allGather() call") {
val conf = new SparkConf()
.setMaster("local-cluster[4, 1, 1024]")
.setAppName("test-cluster")
sc = new SparkContext(conf)
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
// Sleep for a random time before global sync.
Thread.sleep(Random.nextInt(1000))
// Pass partitionId message in
val message = context.partitionId().toString
val messages = context.allGather(message)
messages.toList.iterator
}
// Take a sorted list of all the partitionId messages
val messages = rdd2.collect().head
// All the task partitionIds are shared
for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString)
}
test("throw exception if we attempt to synchronize with different blocking calls") {
val conf = new SparkConf()
.setMaster("local-cluster[4, 1, 1024]")
.setAppName("test-cluster")
sc = new SparkContext(conf)
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
val partitionId = context.partitionId
if (partitionId == 0) {
context.barrier()
} else {
context.allGather(partitionId.toString)
}
Seq(null).iterator
}
val error = intercept[SparkException] {
rdd2.collect()
}.getMessage
assert(error.contains("does not match the current synchronized requestMethod"))
}
test("successively sync with allGather and barrier") {
val conf = new SparkConf()
.setMaster("local-cluster[4, 1, 1024]")
.setAppName("test-cluster")
sc = new SparkContext(conf)
val rdd = sc.makeRDD(1 to 10, 4)
val rdd2 = rdd.barrier().mapPartitions { it =>
val context = BarrierTaskContext.get()
// Sleep for a random time before global sync.
Thread.sleep(Random.nextInt(1000))
context.barrier()
val time1 = System.currentTimeMillis()
// Sleep for a random time before global sync.
Thread.sleep(Random.nextInt(1000))
// Pass partitionId message in
val message = context.partitionId().toString
val messages = context.allGather(message)
val time2 = System.currentTimeMillis()
Seq((time1, time2)).iterator
}
val times = rdd2.collect()
// All the tasks shall finish the first round of global sync within a short time slot.
val times1 = times.map(_._1)
assert(times1.max - times1.min <= 1000)
// All the tasks shall finish the second round of global sync within a short time slot.
val times2 = times.map(_._2)
assert(times2.max - times2.min <= 1000)
}
test("support multiple barrier() call within a single task") {
initLocalClusterSparkContext()
val rdd = sc.makeRDD(1 to 10, 4)

View file

@ -16,10 +16,9 @@
#
from __future__ import print_function
import json
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import write_int, write_with_length, UTF8Deserializer
from pyspark.serializers import write_int, UTF8Deserializer
class TaskContext(object):
@ -108,28 +107,18 @@ class TaskContext(object):
BARRIER_FUNCTION = 1
ALL_GATHER_FUNCTION = 2
def _load_from_socket(port, auth_secret, function, all_gather_message=None):
def _load_from_socket(port, auth_secret):
"""
Load data from a given socket, this is a blocking method thus only return when the socket
connection has been closed.
"""
(sockfile, sock) = local_connect_and_auth(port, auth_secret)
# The call may block forever, so no timeout
# The barrier() call may block forever, so no timeout
sock.settimeout(None)
if function == BARRIER_FUNCTION:
# Make a barrier() function call.
write_int(function, sockfile)
elif function == ALL_GATHER_FUNCTION:
# Make a all_gather() function call.
write_int(function, sockfile)
write_with_length(all_gather_message.encode("utf-8"), sockfile)
else:
raise ValueError("Unrecognized function type")
# Make a barrier() function call.
write_int(BARRIER_FUNCTION, sockfile)
sockfile.flush()
# Collect result.
@ -210,33 +199,7 @@ class BarrierTaskContext(TaskContext):
raise Exception("Not supported to call barrier() before initialize " +
"BarrierTaskContext.")
else:
_load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
def allGather(self, message=""):
"""
.. note:: Experimental
This function blocks until all tasks in the same stage have reached this routine.
Each task passes in a message and returns with a list of all the messages passed in
by each of those tasks.
.. warning:: In a barrier stage, each task much have the same number of `allGather()`
calls, in all possible code branches.
Otherwise, you may get the job hanging or a SparkException after timeout.
"""
if not isinstance(message, str):
raise ValueError("Argument `message` must be of type `str`")
elif self._port is None or self._secret is None:
raise Exception("Not supported to call barrier() before initialize " +
"BarrierTaskContext.")
else:
gathered_items = _load_from_socket(
self._port,
self._secret,
ALL_GATHER_FUNCTION,
message,
)
return [e for e in json.loads(gathered_items)]
_load_from_socket(self._port, self._secret)
def getTaskInfos(self):
"""

View file

@ -135,26 +135,6 @@ class TaskContextTests(PySparkTestCase):
times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
self.assertTrue(max(times) - min(times) < 1)
def test_all_gather(self):
"""
Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks
within a stage and passes messages properly.
"""
rdd = self.sc.parallelize(range(10), 4)
def f(iterator):
yield sum(iterator)
def context_barrier(x):
tc = BarrierTaskContext.get()
time.sleep(random.randint(1, 10))
out = tc.allGather(str(context.partitionId()))
pids = [int(e) for e in out]
return [pids]
pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0]
self.assertTrue(pids == [0, 1, 2, 3])
def test_barrier_infos(self):
"""
Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the