Revert "[SPARK-30667][CORE] Add allGather method to BarrierTaskContext"
This reverts commit 57254c9719
.
This commit is contained in:
parent
57254c9719
commit
fa3517cdb1
|
@ -17,17 +17,12 @@
|
|||
|
||||
package org.apache.spark
|
||||
|
||||
import java.nio.charset.StandardCharsets.UTF_8
|
||||
import java.util.{Timer, TimerTask}
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.function.Consumer
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.json4s.JsonAST._
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods.{compact, render}
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint}
|
||||
import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted}
|
||||
|
@ -104,15 +99,10 @@ private[spark] class BarrierCoordinator(
|
|||
// reset when a barrier() call fails due to timeout.
|
||||
private var barrierEpoch: Int = 0
|
||||
|
||||
// An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
|
||||
// An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier()
|
||||
// call.
|
||||
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)
|
||||
|
||||
// An Array of allGather messages for barrier tasks that have made a blocking runBarrier() call
|
||||
private val allGatherMessages: ArrayBuffer[String] = new Array[String](numTasks).to[ArrayBuffer]
|
||||
|
||||
// The blocking requestMethod called by tasks to sync up for this stage attempt
|
||||
private var requestMethodToSync: RequestMethod.Value = RequestMethod.BARRIER
|
||||
|
||||
// A timer task that ensures we may timeout for a barrier() call.
|
||||
private var timerTask: TimerTask = null
|
||||
|
||||
|
@ -140,32 +130,9 @@ private[spark] class BarrierCoordinator(
|
|||
|
||||
// Process the global sync request. The barrier() call succeed if collected enough requests
|
||||
// within a configured time, otherwise fail all the pending requests.
|
||||
def handleRequest(
|
||||
requester: RpcCallContext,
|
||||
request: RequestToSync
|
||||
): Unit = synchronized {
|
||||
def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized {
|
||||
val taskId = request.taskAttemptId
|
||||
val epoch = request.barrierEpoch
|
||||
val requestMethod = request.requestMethod
|
||||
val partitionId = request.partitionId
|
||||
val allGatherMessage = request match {
|
||||
case ag: AllGatherRequestToSync => ag.allGatherMessage
|
||||
case _ => ""
|
||||
}
|
||||
|
||||
if (requesters.size == 0) {
|
||||
requestMethodToSync = requestMethod
|
||||
}
|
||||
|
||||
if (requestMethodToSync != requestMethod) {
|
||||
requesters.foreach(
|
||||
_.sendFailure(new SparkException(s"$barrierId tried to use requestMethod " +
|
||||
s"`$requestMethod` during barrier epoch $barrierEpoch, which does not match " +
|
||||
s"the current synchronized requestMethod `$requestMethodToSync`"
|
||||
))
|
||||
)
|
||||
cleanupBarrierStage(barrierId)
|
||||
}
|
||||
|
||||
// Require the number of tasks is correctly set from the BarrierTaskContext.
|
||||
require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " +
|
||||
|
@ -186,7 +153,6 @@ private[spark] class BarrierCoordinator(
|
|||
}
|
||||
// Add the requester to array of RPCCallContexts pending for reply.
|
||||
requesters += requester
|
||||
allGatherMessages(partitionId) = allGatherMessage
|
||||
logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " +
|
||||
s"$taskId, current progress: ${requesters.size}/$numTasks.")
|
||||
if (maybeFinishAllRequesters(requesters, numTasks)) {
|
||||
|
@ -196,7 +162,6 @@ private[spark] class BarrierCoordinator(
|
|||
s"tasks, finished successfully.")
|
||||
barrierEpoch += 1
|
||||
requesters.clear()
|
||||
allGatherMessages.clear()
|
||||
cancelTimerTask()
|
||||
}
|
||||
}
|
||||
|
@ -208,13 +173,7 @@ private[spark] class BarrierCoordinator(
|
|||
requesters: ArrayBuffer[RpcCallContext],
|
||||
numTasks: Int): Boolean = {
|
||||
if (requesters.size == numTasks) {
|
||||
requestMethodToSync match {
|
||||
case RequestMethod.BARRIER =>
|
||||
requesters.foreach(_.reply(""))
|
||||
case RequestMethod.ALL_GATHER =>
|
||||
val json: String = compact(render(allGatherMessages))
|
||||
requesters.foreach(_.reply(json))
|
||||
}
|
||||
requesters.foreach(_.reply(()))
|
||||
true
|
||||
} else {
|
||||
false
|
||||
|
@ -227,7 +186,6 @@ private[spark] class BarrierCoordinator(
|
|||
// messages come from current stage attempt shall fail.
|
||||
barrierEpoch = -1
|
||||
requesters.clear()
|
||||
allGatherMessages.clear()
|
||||
cancelTimerTask()
|
||||
}
|
||||
}
|
||||
|
@ -241,11 +199,11 @@ private[spark] class BarrierCoordinator(
|
|||
}
|
||||
|
||||
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
|
||||
case request: RequestToSync =>
|
||||
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) =>
|
||||
// Get or init the ContextBarrierState correspond to the stage attempt.
|
||||
val barrierId = ContextBarrierId(request.stageId, request.stageAttemptId)
|
||||
val barrierId = ContextBarrierId(stageId, stageAttemptId)
|
||||
states.computeIfAbsent(barrierId,
|
||||
(key: ContextBarrierId) => new ContextBarrierState(key, request.numTasks))
|
||||
(key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
|
||||
val barrierState = states.get(barrierId)
|
||||
|
||||
barrierState.handleRequest(context, request)
|
||||
|
@ -258,16 +216,6 @@ private[spark] class BarrierCoordinator(
|
|||
|
||||
private[spark] sealed trait BarrierCoordinatorMessage extends Serializable
|
||||
|
||||
private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
|
||||
def numTasks: Int
|
||||
def stageId: Int
|
||||
def stageAttemptId: Int
|
||||
def taskAttemptId: Long
|
||||
def barrierEpoch: Int
|
||||
def partitionId: Int
|
||||
def requestMethod: RequestMethod.Value
|
||||
}
|
||||
|
||||
/**
|
||||
* A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is
|
||||
* identified by stageId + stageAttemptId + barrierEpoch.
|
||||
|
@ -276,44 +224,11 @@ private[spark] sealed trait RequestToSync extends BarrierCoordinatorMessage {
|
|||
* @param stageId ID of current stage
|
||||
* @param stageAttemptId ID of current stage attempt
|
||||
* @param taskAttemptId Unique ID of current task
|
||||
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
|
||||
* @param partitionId ID of the current partition the task is assigned to
|
||||
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
|
||||
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls.
|
||||
*/
|
||||
private[spark] case class BarrierRequestToSync(
|
||||
numTasks: Int,
|
||||
stageId: Int,
|
||||
stageAttemptId: Int,
|
||||
taskAttemptId: Long,
|
||||
barrierEpoch: Int,
|
||||
partitionId: Int,
|
||||
requestMethod: RequestMethod.Value
|
||||
) extends RequestToSync
|
||||
|
||||
/**
|
||||
* A global sync request message from BarrierTaskContext, by `allGather()` call. Each request is
|
||||
* identified by stageId + stageAttemptId + barrierEpoch.
|
||||
*
|
||||
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
|
||||
* @param stageId ID of current stage
|
||||
* @param stageAttemptId ID of current stage attempt
|
||||
* @param taskAttemptId Unique ID of current task
|
||||
* @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls
|
||||
* @param partitionId ID of the current partition the task is assigned to
|
||||
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
|
||||
* @param allGatherMessage Message sent from the BarrierTaskContext if requestMethod is ALL_GATHER
|
||||
*/
|
||||
private[spark] case class AllGatherRequestToSync(
|
||||
numTasks: Int,
|
||||
stageId: Int,
|
||||
stageAttemptId: Int,
|
||||
taskAttemptId: Long,
|
||||
barrierEpoch: Int,
|
||||
partitionId: Int,
|
||||
requestMethod: RequestMethod.Value,
|
||||
allGatherMessage: String
|
||||
) extends RequestToSync
|
||||
|
||||
private[spark] object RequestMethod extends Enumeration {
|
||||
val BARRIER, ALL_GATHER = Value
|
||||
}
|
||||
private[spark] case class RequestToSync(
|
||||
numTasks: Int,
|
||||
stageId: Int,
|
||||
stageAttemptId: Int,
|
||||
taskAttemptId: Long,
|
||||
barrierEpoch: Int) extends BarrierCoordinatorMessage
|
||||
|
|
|
@ -17,19 +17,11 @@
|
|||
|
||||
package org.apache.spark
|
||||
|
||||
import java.nio.charset.StandardCharsets.UTF_8
|
||||
import java.util.{Properties, Timer, TimerTask}
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.concurrent.TimeoutException
|
||||
import scala.concurrent.duration._
|
||||
import scala.language.postfixOps
|
||||
|
||||
import org.json4s.DefaultFormats
|
||||
import org.json4s.JsonAST._
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods.parse
|
||||
|
||||
import org.apache.spark.annotation.{Experimental, Since}
|
||||
import org.apache.spark.executor.TaskMetrics
|
||||
|
@ -67,97 +59,6 @@ class BarrierTaskContext private[spark] (
|
|||
// from different tasks within the same barrier stage attempt to succeed.
|
||||
private lazy val numTasks = getTaskInfos().size
|
||||
|
||||
private def getRequestToSync(
|
||||
numTasks: Int,
|
||||
stageId: Int,
|
||||
stageAttemptNumber: Int,
|
||||
taskAttemptId: Long,
|
||||
barrierEpoch: Int,
|
||||
partitionId: Int,
|
||||
requestMethod: RequestMethod.Value,
|
||||
allGatherMessage: String
|
||||
): RequestToSync = {
|
||||
requestMethod match {
|
||||
case RequestMethod.BARRIER =>
|
||||
BarrierRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
|
||||
barrierEpoch, partitionId, requestMethod)
|
||||
case RequestMethod.ALL_GATHER =>
|
||||
AllGatherRequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
|
||||
barrierEpoch, partitionId, requestMethod, allGatherMessage)
|
||||
}
|
||||
}
|
||||
|
||||
private def runBarrier(
|
||||
requestMethod: RequestMethod.Value,
|
||||
allGatherMessage: String = ""
|
||||
): String = {
|
||||
|
||||
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
|
||||
s"the global sync, current barrier epoch is $barrierEpoch.")
|
||||
logTrace("Current callSite: " + Utils.getCallSite())
|
||||
|
||||
val startTime = System.currentTimeMillis()
|
||||
val timerTask = new TimerTask {
|
||||
override def run(): Unit = {
|
||||
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " +
|
||||
s"under the global sync since $startTime, has been waiting for " +
|
||||
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
|
||||
s"current barrier epoch is $barrierEpoch.")
|
||||
}
|
||||
}
|
||||
// Log the update of global sync every 60 seconds.
|
||||
timer.schedule(timerTask, 60000, 60000)
|
||||
|
||||
var json: String = ""
|
||||
|
||||
try {
|
||||
val abortableRpcFuture = barrierCoordinator.askAbortable[String](
|
||||
message = getRequestToSync(numTasks, stageId, stageAttemptNumber,
|
||||
taskAttemptId, barrierEpoch, partitionId, requestMethod, allGatherMessage),
|
||||
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
|
||||
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
|
||||
timeout = new RpcTimeout(365.days, "barrierTimeout"))
|
||||
|
||||
// Wait the RPC future to be completed, but every 1 second it will jump out waiting
|
||||
// and check whether current spark task is killed. If killed, then throw
|
||||
// a `TaskKilledException`, otherwise continue wait RPC until it completes.
|
||||
try {
|
||||
while (!abortableRpcFuture.toFuture.isCompleted) {
|
||||
// wait RPC future for at most 1 second
|
||||
try {
|
||||
json = ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
|
||||
} catch {
|
||||
case _: TimeoutException | _: InterruptedException =>
|
||||
// If `TimeoutException` thrown, waiting RPC future reach 1 second.
|
||||
// If `InterruptedException` thrown, it is possible this task is killed.
|
||||
// So in this two cases, we should check whether task is killed and then
|
||||
// throw `TaskKilledException`
|
||||
taskContext.killTaskIfInterrupted()
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
abortableRpcFuture.abort(taskContext.getKillReason().getOrElse("Unknown reason."))
|
||||
}
|
||||
|
||||
barrierEpoch += 1
|
||||
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " +
|
||||
"global sync successfully, waited for " +
|
||||
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
|
||||
s"current barrier epoch is $barrierEpoch.")
|
||||
} catch {
|
||||
case e: SparkException =>
|
||||
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
|
||||
"to perform global sync, waited for " +
|
||||
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
|
||||
s"current barrier epoch is $barrierEpoch.")
|
||||
throw e
|
||||
} finally {
|
||||
timerTask.cancel()
|
||||
timer.purge()
|
||||
}
|
||||
json
|
||||
}
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
|
||||
|
@ -201,27 +102,67 @@ class BarrierTaskContext private[spark] (
|
|||
@Experimental
|
||||
@Since("2.4.0")
|
||||
def barrier(): Unit = {
|
||||
runBarrier(RequestMethod.BARRIER)
|
||||
()
|
||||
}
|
||||
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " +
|
||||
s"the global sync, current barrier epoch is $barrierEpoch.")
|
||||
logTrace("Current callSite: " + Utils.getCallSite())
|
||||
|
||||
/**
|
||||
* :: Experimental ::
|
||||
* Blocks until all tasks in the same stage have reached this routine. Each task passes in
|
||||
* a message and returns with a list of all the messages passed in by each of those tasks.
|
||||
*
|
||||
* CAUTION! The allGather method requires the same precautions as the barrier method
|
||||
*
|
||||
* The message is type String rather than Array[Byte] because it is more convenient for
|
||||
* the user at the cost of worse performance.
|
||||
*/
|
||||
@Experimental
|
||||
@Since("3.0.0")
|
||||
def allGather(message: String): ArrayBuffer[String] = {
|
||||
val json = runBarrier(RequestMethod.ALL_GATHER, message)
|
||||
val jsonArray = parse(json)
|
||||
implicit val formats = DefaultFormats
|
||||
ArrayBuffer(jsonArray.extract[Array[String]]: _*)
|
||||
val startTime = System.currentTimeMillis()
|
||||
val timerTask = new TimerTask {
|
||||
override def run(): Unit = {
|
||||
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " +
|
||||
s"under the global sync since $startTime, has been waiting for " +
|
||||
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
|
||||
s"current barrier epoch is $barrierEpoch.")
|
||||
}
|
||||
}
|
||||
// Log the update of global sync every 60 seconds.
|
||||
timer.schedule(timerTask, 60000, 60000)
|
||||
|
||||
try {
|
||||
val abortableRpcFuture = barrierCoordinator.askAbortable[Unit](
|
||||
message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId,
|
||||
barrierEpoch),
|
||||
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
|
||||
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
|
||||
timeout = new RpcTimeout(365.days, "barrierTimeout"))
|
||||
|
||||
// Wait the RPC future to be completed, but every 1 second it will jump out waiting
|
||||
// and check whether current spark task is killed. If killed, then throw
|
||||
// a `TaskKilledException`, otherwise continue wait RPC until it completes.
|
||||
try {
|
||||
while (!abortableRpcFuture.toFuture.isCompleted) {
|
||||
// wait RPC future for at most 1 second
|
||||
try {
|
||||
ThreadUtils.awaitResult(abortableRpcFuture.toFuture, 1.second)
|
||||
} catch {
|
||||
case _: TimeoutException | _: InterruptedException =>
|
||||
// If `TimeoutException` thrown, waiting RPC future reach 1 second.
|
||||
// If `InterruptedException` thrown, it is possible this task is killed.
|
||||
// So in this two cases, we should check whether task is killed and then
|
||||
// throw `TaskKilledException`
|
||||
taskContext.killTaskIfInterrupted()
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
abortableRpcFuture.abort(taskContext.getKillReason().getOrElse("Unknown reason."))
|
||||
}
|
||||
|
||||
barrierEpoch += 1
|
||||
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " +
|
||||
"global sync successfully, waited for " +
|
||||
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
|
||||
s"current barrier epoch is $barrierEpoch.")
|
||||
} catch {
|
||||
case e: SparkException =>
|
||||
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " +
|
||||
"to perform global sync, waited for " +
|
||||
s"${MILLISECONDS.toSeconds(System.currentTimeMillis() - startTime)} seconds, " +
|
||||
s"current barrier epoch is $barrierEpoch.")
|
||||
throw e
|
||||
} finally {
|
||||
timerTask.cancel()
|
||||
timer.purge()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -24,13 +24,8 @@ import java.nio.charset.StandardCharsets.UTF_8
|
|||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.util.control.NonFatal
|
||||
|
||||
import org.json4s.JsonAST._
|
||||
import org.json4s.JsonDSL._
|
||||
import org.json4s.jackson.JsonMethods.{compact, render}
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
|
||||
|
@ -243,18 +238,13 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
|||
sock.setSoTimeout(10000)
|
||||
authHelper.authClient(sock)
|
||||
val input = new DataInputStream(sock.getInputStream())
|
||||
val requestMethod = input.readInt()
|
||||
// The BarrierTaskContext function may wait infinitely, socket shall not timeout
|
||||
// before the function finishes.
|
||||
sock.setSoTimeout(0)
|
||||
requestMethod match {
|
||||
input.readInt() match {
|
||||
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
|
||||
barrierAndServe(requestMethod, sock)
|
||||
case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
|
||||
val length = input.readInt()
|
||||
val message = new Array[Byte](length)
|
||||
input.readFully(message)
|
||||
barrierAndServe(requestMethod, sock, new String(message, UTF_8))
|
||||
// The barrier() function may wait infinitely, socket shall not timeout
|
||||
// before the function finishes.
|
||||
sock.setSoTimeout(0)
|
||||
barrierAndServe(sock)
|
||||
|
||||
case _ =>
|
||||
val out = new DataOutputStream(new BufferedOutputStream(
|
||||
sock.getOutputStream))
|
||||
|
@ -405,31 +395,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
|
|||
}
|
||||
|
||||
/**
|
||||
* Gateway to call BarrierTaskContext methods.
|
||||
* Gateway to call BarrierTaskContext.barrier().
|
||||
*/
|
||||
def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = {
|
||||
require(
|
||||
serverSocket.isDefined,
|
||||
"No available ServerSocket to redirect the BarrierTaskContext method call."
|
||||
)
|
||||
def barrierAndServe(sock: Socket): Unit = {
|
||||
require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.")
|
||||
|
||||
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
|
||||
try {
|
||||
var result: String = ""
|
||||
requestMethod match {
|
||||
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
|
||||
context.asInstanceOf[BarrierTaskContext].barrier()
|
||||
result = BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS
|
||||
case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
|
||||
val messages: ArrayBuffer[String] = context.asInstanceOf[BarrierTaskContext].allGather(
|
||||
message
|
||||
)
|
||||
result = compact(render(JArray(
|
||||
messages.map(
|
||||
(message) => JString(message)
|
||||
).toList
|
||||
)))
|
||||
}
|
||||
writeUTF(result, out)
|
||||
context.asInstanceOf[BarrierTaskContext].barrier()
|
||||
writeUTF(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS, out)
|
||||
} catch {
|
||||
case e: SparkException =>
|
||||
writeUTF(e.getMessage, out)
|
||||
|
@ -664,7 +638,6 @@ private[spark] object SpecialLengths {
|
|||
|
||||
private[spark] object BarrierTaskContextMessageProtocol {
|
||||
val BARRIER_FUNCTION = 1
|
||||
val ALL_GATHER_FUNCTION = 2
|
||||
val BARRIER_RESULT_SUCCESS = "success"
|
||||
val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side."
|
||||
}
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.apache.spark.scheduler
|
|||
|
||||
import java.io.File
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark._
|
||||
|
@ -53,79 +52,6 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext {
|
|||
assert(times.max - times.min <= 1000)
|
||||
}
|
||||
|
||||
test("share messages with allGather() call") {
|
||||
val conf = new SparkConf()
|
||||
.setMaster("local-cluster[4, 1, 1024]")
|
||||
.setAppName("test-cluster")
|
||||
sc = new SparkContext(conf)
|
||||
val rdd = sc.makeRDD(1 to 10, 4)
|
||||
val rdd2 = rdd.barrier().mapPartitions { it =>
|
||||
val context = BarrierTaskContext.get()
|
||||
// Sleep for a random time before global sync.
|
||||
Thread.sleep(Random.nextInt(1000))
|
||||
// Pass partitionId message in
|
||||
val message = context.partitionId().toString
|
||||
val messages = context.allGather(message)
|
||||
messages.toList.iterator
|
||||
}
|
||||
// Take a sorted list of all the partitionId messages
|
||||
val messages = rdd2.collect().head
|
||||
// All the task partitionIds are shared
|
||||
for((x, i) <- messages.view.zipWithIndex) assert(x == i.toString)
|
||||
}
|
||||
|
||||
test("throw exception if we attempt to synchronize with different blocking calls") {
|
||||
val conf = new SparkConf()
|
||||
.setMaster("local-cluster[4, 1, 1024]")
|
||||
.setAppName("test-cluster")
|
||||
sc = new SparkContext(conf)
|
||||
val rdd = sc.makeRDD(1 to 10, 4)
|
||||
val rdd2 = rdd.barrier().mapPartitions { it =>
|
||||
val context = BarrierTaskContext.get()
|
||||
val partitionId = context.partitionId
|
||||
if (partitionId == 0) {
|
||||
context.barrier()
|
||||
} else {
|
||||
context.allGather(partitionId.toString)
|
||||
}
|
||||
Seq(null).iterator
|
||||
}
|
||||
val error = intercept[SparkException] {
|
||||
rdd2.collect()
|
||||
}.getMessage
|
||||
assert(error.contains("does not match the current synchronized requestMethod"))
|
||||
}
|
||||
|
||||
test("successively sync with allGather and barrier") {
|
||||
val conf = new SparkConf()
|
||||
.setMaster("local-cluster[4, 1, 1024]")
|
||||
.setAppName("test-cluster")
|
||||
sc = new SparkContext(conf)
|
||||
val rdd = sc.makeRDD(1 to 10, 4)
|
||||
val rdd2 = rdd.barrier().mapPartitions { it =>
|
||||
val context = BarrierTaskContext.get()
|
||||
// Sleep for a random time before global sync.
|
||||
Thread.sleep(Random.nextInt(1000))
|
||||
context.barrier()
|
||||
val time1 = System.currentTimeMillis()
|
||||
// Sleep for a random time before global sync.
|
||||
Thread.sleep(Random.nextInt(1000))
|
||||
// Pass partitionId message in
|
||||
val message = context.partitionId().toString
|
||||
val messages = context.allGather(message)
|
||||
val time2 = System.currentTimeMillis()
|
||||
Seq((time1, time2)).iterator
|
||||
}
|
||||
val times = rdd2.collect()
|
||||
// All the tasks shall finish the first round of global sync within a short time slot.
|
||||
val times1 = times.map(_._1)
|
||||
assert(times1.max - times1.min <= 1000)
|
||||
|
||||
// All the tasks shall finish the second round of global sync within a short time slot.
|
||||
val times2 = times.map(_._2)
|
||||
assert(times2.max - times2.min <= 1000)
|
||||
}
|
||||
|
||||
test("support multiple barrier() call within a single task") {
|
||||
initLocalClusterSparkContext()
|
||||
val rdd = sc.makeRDD(1 to 10, 4)
|
||||
|
|
|
@ -16,10 +16,9 @@
|
|||
#
|
||||
|
||||
from __future__ import print_function
|
||||
import json
|
||||
|
||||
from pyspark.java_gateway import local_connect_and_auth
|
||||
from pyspark.serializers import write_int, write_with_length, UTF8Deserializer
|
||||
from pyspark.serializers import write_int, UTF8Deserializer
|
||||
|
||||
|
||||
class TaskContext(object):
|
||||
|
@ -108,28 +107,18 @@ class TaskContext(object):
|
|||
|
||||
|
||||
BARRIER_FUNCTION = 1
|
||||
ALL_GATHER_FUNCTION = 2
|
||||
|
||||
|
||||
def _load_from_socket(port, auth_secret, function, all_gather_message=None):
|
||||
def _load_from_socket(port, auth_secret):
|
||||
"""
|
||||
Load data from a given socket, this is a blocking method thus only return when the socket
|
||||
connection has been closed.
|
||||
"""
|
||||
(sockfile, sock) = local_connect_and_auth(port, auth_secret)
|
||||
|
||||
# The call may block forever, so no timeout
|
||||
# The barrier() call may block forever, so no timeout
|
||||
sock.settimeout(None)
|
||||
|
||||
if function == BARRIER_FUNCTION:
|
||||
# Make a barrier() function call.
|
||||
write_int(function, sockfile)
|
||||
elif function == ALL_GATHER_FUNCTION:
|
||||
# Make a all_gather() function call.
|
||||
write_int(function, sockfile)
|
||||
write_with_length(all_gather_message.encode("utf-8"), sockfile)
|
||||
else:
|
||||
raise ValueError("Unrecognized function type")
|
||||
# Make a barrier() function call.
|
||||
write_int(BARRIER_FUNCTION, sockfile)
|
||||
sockfile.flush()
|
||||
|
||||
# Collect result.
|
||||
|
@ -210,33 +199,7 @@ class BarrierTaskContext(TaskContext):
|
|||
raise Exception("Not supported to call barrier() before initialize " +
|
||||
"BarrierTaskContext.")
|
||||
else:
|
||||
_load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
|
||||
|
||||
def allGather(self, message=""):
|
||||
"""
|
||||
.. note:: Experimental
|
||||
|
||||
This function blocks until all tasks in the same stage have reached this routine.
|
||||
Each task passes in a message and returns with a list of all the messages passed in
|
||||
by each of those tasks.
|
||||
|
||||
.. warning:: In a barrier stage, each task much have the same number of `allGather()`
|
||||
calls, in all possible code branches.
|
||||
Otherwise, you may get the job hanging or a SparkException after timeout.
|
||||
"""
|
||||
if not isinstance(message, str):
|
||||
raise ValueError("Argument `message` must be of type `str`")
|
||||
elif self._port is None or self._secret is None:
|
||||
raise Exception("Not supported to call barrier() before initialize " +
|
||||
"BarrierTaskContext.")
|
||||
else:
|
||||
gathered_items = _load_from_socket(
|
||||
self._port,
|
||||
self._secret,
|
||||
ALL_GATHER_FUNCTION,
|
||||
message,
|
||||
)
|
||||
return [e for e in json.loads(gathered_items)]
|
||||
_load_from_socket(self._port, self._secret)
|
||||
|
||||
def getTaskInfos(self):
|
||||
"""
|
||||
|
|
|
@ -135,26 +135,6 @@ class TaskContextTests(PySparkTestCase):
|
|||
times = rdd.barrier().mapPartitions(f).map(context_barrier).collect()
|
||||
self.assertTrue(max(times) - min(times) < 1)
|
||||
|
||||
def test_all_gather(self):
|
||||
"""
|
||||
Verify that BarrierTaskContext.allGather() performs global sync among all barrier tasks
|
||||
within a stage and passes messages properly.
|
||||
"""
|
||||
rdd = self.sc.parallelize(range(10), 4)
|
||||
|
||||
def f(iterator):
|
||||
yield sum(iterator)
|
||||
|
||||
def context_barrier(x):
|
||||
tc = BarrierTaskContext.get()
|
||||
time.sleep(random.randint(1, 10))
|
||||
out = tc.allGather(str(context.partitionId()))
|
||||
pids = [int(e) for e in out]
|
||||
return [pids]
|
||||
|
||||
pids = rdd.barrier().mapPartitions(f).map(context_barrier).collect()[0]
|
||||
self.assertTrue(pids == [0, 1, 2, 3])
|
||||
|
||||
def test_barrier_infos(self):
|
||||
"""
|
||||
Verify that BarrierTaskContext.getTaskInfos() returns a list of all task infos in the
|
||||
|
|
Loading…
Reference in a new issue