[SPARK-3632] ConnectionManager can run out of receive threads with authentication on
If you turn authentication on and you are using a lot of executors. There is a chance that all the of the threads in the handleMessageExecutor could be waiting to send a message because they are blocked waiting on authentication to happen. This can cause a temporary deadlock until the connection times out. To fix it, I got rid of the wait/notify and use a single outbox but only send security messages from it until authentication has completed. Author: Thomas Graves <tgraves@apache.org> Closes #2484 from tgravescs/cm_threads_auth and squashes the following commits: a0a961d [Thomas Graves] give it a type b6bc80b [Thomas Graves] Rework comments d6d4175 [Thomas Graves] update from comments 081b765 [Thomas Graves] cleanup 4d7f8f5 [Thomas Graves] Change to not use wait/notify while waiting for authentication
This commit is contained in:
parent
5db78e6b87
commit
127e97bee1
|
@ -103,10 +103,9 @@ import org.apache.spark.deploy.SparkHadoopUtil
|
|||
* and a Server, so for a particular connection is has to determine what to do.
|
||||
* A ConnectionId was added to be able to track connections and is used to
|
||||
* match up incoming messages with connections waiting for authentication.
|
||||
* If its acting as a client and trying to send a message to another ConnectionManager,
|
||||
* it blocks the thread calling sendMessage until the SASL negotiation has occurred.
|
||||
* The ConnectionManager tracks all the sendingConnections using the ConnectionId
|
||||
* and waits for the response from the server and does the handshake.
|
||||
* and waits for the response from the server and does the handshake before sending
|
||||
* the real message.
|
||||
*
|
||||
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
|
||||
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
|
||||
|
|
|
@ -20,23 +20,27 @@ package org.apache.spark.network.nio
|
|||
import java.net._
|
||||
import java.nio._
|
||||
import java.nio.channels._
|
||||
import java.util.LinkedList
|
||||
|
||||
import org.apache.spark._
|
||||
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap}
|
||||
|
||||
private[nio]
|
||||
abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
||||
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
|
||||
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
|
||||
val securityMgr: SecurityManager)
|
||||
extends Logging {
|
||||
|
||||
var sparkSaslServer: SparkSaslServer = null
|
||||
var sparkSaslClient: SparkSaslClient = null
|
||||
|
||||
def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
|
||||
def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId,
|
||||
securityMgr_ : SecurityManager) = {
|
||||
this(channel_, selector_,
|
||||
ConnectionManagerId.fromSocketAddress(
|
||||
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
|
||||
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]),
|
||||
id_, securityMgr_)
|
||||
}
|
||||
|
||||
channel.configureBlocking(false)
|
||||
|
@ -52,14 +56,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
|||
|
||||
val remoteAddress = getRemoteAddress()
|
||||
|
||||
/**
|
||||
* Used to synchronize client requests: client's work-related requests must
|
||||
* wait until SASL authentication completes.
|
||||
*/
|
||||
private val authenticated = new Object()
|
||||
|
||||
def getAuthenticated(): Object = authenticated
|
||||
|
||||
def isSaslComplete(): Boolean
|
||||
|
||||
def resetForceReregister(): Boolean
|
||||
|
@ -192,22 +188,22 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
|||
|
||||
private[nio]
|
||||
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
|
||||
remoteId_ : ConnectionManagerId, id_ : ConnectionId)
|
||||
extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
|
||||
remoteId_ : ConnectionManagerId, id_ : ConnectionId,
|
||||
securityMgr_ : SecurityManager)
|
||||
extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) {
|
||||
|
||||
def isSaslComplete(): Boolean = {
|
||||
if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
|
||||
}
|
||||
|
||||
private class Outbox {
|
||||
val messages = new Queue[Message]()
|
||||
val messages = new LinkedList[Message]()
|
||||
val defaultChunkSize = 65536
|
||||
var nextMessageToBeUsed = 0
|
||||
|
||||
def addMessage(message: Message) {
|
||||
messages.synchronized {
|
||||
/* messages += message */
|
||||
messages.enqueue(message)
|
||||
messages.add(message)
|
||||
logDebug("Added [" + message + "] to outbox for sending to " +
|
||||
"[" + getRemoteConnectionManagerId() + "]")
|
||||
}
|
||||
|
@ -218,10 +214,27 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
|
|||
while (!messages.isEmpty) {
|
||||
/* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
|
||||
/* val message = messages(nextMessageToBeUsed) */
|
||||
val message = messages.dequeue()
|
||||
|
||||
val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) {
|
||||
// only allow sending of security messages until sasl is complete
|
||||
var pos = 0
|
||||
var securityMsg: Message = null
|
||||
while (pos < messages.size() && securityMsg == null) {
|
||||
if (messages.get(pos).isSecurityNeg) {
|
||||
securityMsg = messages.remove(pos)
|
||||
}
|
||||
pos = pos + 1
|
||||
}
|
||||
// didn't find any security messages and auth isn't completed so return
|
||||
if (securityMsg == null) return None
|
||||
securityMsg
|
||||
} else {
|
||||
messages.removeFirst()
|
||||
}
|
||||
|
||||
val chunk = message.getChunkForSending(defaultChunkSize)
|
||||
if (chunk.isDefined) {
|
||||
messages.enqueue(message)
|
||||
messages.add(message)
|
||||
nextMessageToBeUsed = nextMessageToBeUsed + 1
|
||||
if (!message.started) {
|
||||
logDebug(
|
||||
|
@ -273,6 +286,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
|
|||
changeConnectionKeyInterest(DEFAULT_INTEREST)
|
||||
}
|
||||
|
||||
def registerAfterAuth(): Unit = {
|
||||
outbox.synchronized {
|
||||
needForceReregister = true
|
||||
}
|
||||
if (channel.isConnected) {
|
||||
registerInterest()
|
||||
}
|
||||
}
|
||||
|
||||
def send(message: Message) {
|
||||
outbox.synchronized {
|
||||
outbox.addMessage(message)
|
||||
|
@ -415,8 +437,9 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
|
|||
private[spark] class ReceivingConnection(
|
||||
channel_ : SocketChannel,
|
||||
selector_ : Selector,
|
||||
id_ : ConnectionId)
|
||||
extends Connection(channel_, selector_, id_) {
|
||||
id_ : ConnectionId,
|
||||
securityMgr_ : SecurityManager)
|
||||
extends Connection(channel_, selector_, id_, securityMgr_) {
|
||||
|
||||
def isSaslComplete(): Boolean = {
|
||||
if (sparkSaslServer != null) sparkSaslServer.isComplete() else false
|
||||
|
|
|
@ -32,7 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
|
|||
import scala.language.postfixOps
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.util.{SystemClock, Utils}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
||||
private[nio] class ConnectionManager(
|
||||
|
@ -65,8 +65,6 @@ private[nio] class ConnectionManager(
|
|||
private val selector = SelectorProvider.provider.openSelector()
|
||||
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
|
||||
|
||||
// default to 30 second timeout waiting for authentication
|
||||
private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
|
||||
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
|
||||
|
||||
private val handleMessageExecutor = new ThreadPoolExecutor(
|
||||
|
@ -409,7 +407,8 @@ private[nio] class ConnectionManager(
|
|||
while (newChannel != null) {
|
||||
try {
|
||||
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
|
||||
val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
|
||||
val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId,
|
||||
securityManager)
|
||||
newConnection.onReceive(receiveMessage)
|
||||
addListeners(newConnection)
|
||||
addConnection(newConnection)
|
||||
|
@ -527,9 +526,8 @@ private[nio] class ConnectionManager(
|
|||
if (waitingConn.isSaslComplete()) {
|
||||
logDebug("Client sasl completed for id: " + waitingConn.connectionId)
|
||||
connectionsAwaitingSasl -= waitingConn.connectionId
|
||||
waitingConn.getAuthenticated().synchronized {
|
||||
waitingConn.getAuthenticated().notifyAll()
|
||||
}
|
||||
waitingConn.registerAfterAuth()
|
||||
wakeupSelector()
|
||||
return
|
||||
} else {
|
||||
var replyToken : Array[Byte] = null
|
||||
|
@ -538,9 +536,8 @@ private[nio] class ConnectionManager(
|
|||
if (waitingConn.isSaslComplete()) {
|
||||
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
|
||||
connectionsAwaitingSasl -= waitingConn.connectionId
|
||||
waitingConn.getAuthenticated().synchronized {
|
||||
waitingConn.getAuthenticated().notifyAll()
|
||||
}
|
||||
waitingConn.registerAfterAuth()
|
||||
wakeupSelector()
|
||||
return
|
||||
}
|
||||
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
|
||||
|
@ -574,9 +571,11 @@ private[nio] class ConnectionManager(
|
|||
}
|
||||
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
|
||||
if (connection.isSaslComplete()) {
|
||||
logDebug("Server sasl completed: " + connection.connectionId)
|
||||
logDebug("Server sasl completed: " + connection.connectionId +
|
||||
" for: " + connectionId)
|
||||
} else {
|
||||
logDebug("Server sasl not completed: " + connection.connectionId)
|
||||
logDebug("Server sasl not completed: " + connection.connectionId +
|
||||
" for: " + connectionId)
|
||||
}
|
||||
if (replyToken != null) {
|
||||
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
|
||||
|
@ -723,7 +722,8 @@ private[nio] class ConnectionManager(
|
|||
if (message == null) throw new Exception("Error creating security message")
|
||||
connectionsAwaitingSasl += ((conn.connectionId, conn))
|
||||
sendSecurityMessage(connManagerId, message)
|
||||
logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
|
||||
logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId +
|
||||
" to: " + connManagerId)
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logError("Error getting first response from the SaslClient.", e)
|
||||
|
@ -744,7 +744,7 @@ private[nio] class ConnectionManager(
|
|||
val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
|
||||
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
|
||||
val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
|
||||
newConnectionId)
|
||||
newConnectionId, securityManager)
|
||||
logInfo("creating new sending connection for security! " + newConnectionId )
|
||||
registerRequests.enqueue(newConnection)
|
||||
|
||||
|
@ -769,61 +769,23 @@ private[nio] class ConnectionManager(
|
|||
connectionManagerId.port)
|
||||
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
|
||||
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
|
||||
newConnectionId)
|
||||
newConnectionId, securityManager)
|
||||
logTrace("creating new sending connection: " + newConnectionId)
|
||||
registerRequests.enqueue(newConnection)
|
||||
|
||||
newConnection
|
||||
}
|
||||
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
|
||||
if (authEnabled) {
|
||||
checkSendAuthFirst(connectionManagerId, connection)
|
||||
}
|
||||
|
||||
message.senderAddress = id.toSocketAddress()
|
||||
logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
|
||||
"connectionid: " + connection.connectionId)
|
||||
|
||||
if (authEnabled) {
|
||||
// if we aren't authenticated yet lets block the senders until authentication completes
|
||||
try {
|
||||
connection.getAuthenticated().synchronized {
|
||||
val clock = SystemClock
|
||||
val startTime = clock.getTime()
|
||||
|
||||
while (!connection.isSaslComplete()) {
|
||||
logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
|
||||
// have timeout in case remote side never responds
|
||||
connection.getAuthenticated().wait(500)
|
||||
if (((clock.getTime() - startTime) >= (authTimeout * 1000))
|
||||
&& (!connection.isSaslComplete())) {
|
||||
// took to long to authenticate the connection, something probably went wrong
|
||||
throw new Exception("Took to long for authentication to " + connectionManagerId +
|
||||
", waited " + authTimeout + "seconds, failing.")
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => logError("Exception while waiting for authentication.", e)
|
||||
|
||||
// need to tell sender it failed
|
||||
messageStatuses.synchronized {
|
||||
val s = messageStatuses.get(message.id)
|
||||
s match {
|
||||
case Some(msgStatus) => {
|
||||
messageStatuses -= message.id
|
||||
logInfo("Notifying " + msgStatus.connectionManagerId)
|
||||
msgStatus.markDone(None)
|
||||
}
|
||||
case None => {
|
||||
logError("no messageStatus for failed message id: " + message.id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
checkSendAuthFirst(connectionManagerId, connection)
|
||||
}
|
||||
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
|
||||
connection.send(message)
|
||||
|
||||
wakeupSelector()
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue