[SPARK-3632] ConnectionManager can run out of receive threads with authentication on

If you turn authentication on and you are using a lot of executors. There is a chance that all the of the threads in the handleMessageExecutor could be waiting to send a message because they are blocked waiting on authentication to happen. This can cause a temporary deadlock until the connection times out.

To fix it, I got rid of the wait/notify and use a single outbox but only send security messages from it until authentication has completed.

Author: Thomas Graves <tgraves@apache.org>

Closes #2484 from tgravescs/cm_threads_auth and squashes the following commits:

a0a961d [Thomas Graves] give it a type
b6bc80b [Thomas Graves] Rework comments
d6d4175 [Thomas Graves] update from comments
081b765 [Thomas Graves] cleanup
4d7f8f5 [Thomas Graves] Change to not use wait/notify while waiting for authentication
This commit is contained in:
Thomas Graves 2014-10-02 13:52:54 -07:00 committed by Reynold Xin
parent 5db78e6b87
commit 127e97bee1
3 changed files with 63 additions and 79 deletions

View file

@ -103,10 +103,9 @@ import org.apache.spark.deploy.SparkHadoopUtil
* and a Server, so for a particular connection is has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
* match up incoming messages with connections waiting for authentication.
* If its acting as a client and trying to send a message to another ConnectionManager,
* it blocks the thread calling sendMessage until the SASL negotiation has occurred.
* The ConnectionManager tracks all the sendingConnections using the ConnectionId
* and waits for the response from the server and does the handshake.
* and waits for the response from the server and does the handshake before sending
* the real message.
*
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work

View file

@ -20,23 +20,27 @@ package org.apache.spark.network.nio
import java.net._
import java.nio._
import java.nio.channels._
import java.util.LinkedList
import org.apache.spark._
import scala.collection.mutable.{ArrayBuffer, HashMap, Queue}
import scala.collection.mutable.{ArrayBuffer, HashMap}
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId)
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
val securityMgr: SecurityManager)
extends Logging {
var sparkSaslServer: SparkSaslServer = null
var sparkSaslClient: SparkSaslClient = null
def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId,
securityMgr_ : SecurityManager) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]),
id_, securityMgr_)
}
channel.configureBlocking(false)
@ -52,14 +56,6 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
/**
* Used to synchronize client requests: client's work-related requests must
* wait until SASL authentication completes.
*/
private val authenticated = new Object()
def getAuthenticated(): Object = authenticated
def isSaslComplete(): Boolean
def resetForceReregister(): Boolean
@ -192,22 +188,22 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
private[nio]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
remoteId_ : ConnectionManagerId, id_ : ConnectionId)
extends Connection(SocketChannel.open, selector_, remoteId_, id_) {
remoteId_ : ConnectionManagerId, id_ : ConnectionId,
securityMgr_ : SecurityManager)
extends Connection(SocketChannel.open, selector_, remoteId_, id_, securityMgr_) {
def isSaslComplete(): Boolean = {
if (sparkSaslClient != null) sparkSaslClient.isComplete() else false
}
private class Outbox {
val messages = new Queue[Message]()
val messages = new LinkedList[Message]()
val defaultChunkSize = 65536
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
messages.synchronized {
/* messages += message */
messages.enqueue(message)
messages.add(message)
logDebug("Added [" + message + "] to outbox for sending to " +
"[" + getRemoteConnectionManagerId() + "]")
}
@ -218,10 +214,27 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
while (!messages.isEmpty) {
/* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
/* val message = messages(nextMessageToBeUsed) */
val message = messages.dequeue()
val message = if (securityMgr.isAuthenticationEnabled() && !isSaslComplete()) {
// only allow sending of security messages until sasl is complete
var pos = 0
var securityMsg: Message = null
while (pos < messages.size() && securityMsg == null) {
if (messages.get(pos).isSecurityNeg) {
securityMsg = messages.remove(pos)
}
pos = pos + 1
}
// didn't find any security messages and auth isn't completed so return
if (securityMsg == null) return None
securityMsg
} else {
messages.removeFirst()
}
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages.enqueue(message)
messages.add(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
logDebug(
@ -273,6 +286,15 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
changeConnectionKeyInterest(DEFAULT_INTEREST)
}
def registerAfterAuth(): Unit = {
outbox.synchronized {
needForceReregister = true
}
if (channel.isConnected) {
registerInterest()
}
}
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
@ -415,8 +437,9 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
private[spark] class ReceivingConnection(
channel_ : SocketChannel,
selector_ : Selector,
id_ : ConnectionId)
extends Connection(channel_, selector_, id_) {
id_ : ConnectionId,
securityMgr_ : SecurityManager)
extends Connection(channel_, selector_, id_, securityMgr_) {
def isSaslComplete(): Boolean = {
if (sparkSaslServer != null) sparkSaslServer.isComplete() else false

View file

@ -32,7 +32,7 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps
import org.apache.spark._
import org.apache.spark.util.{SystemClock, Utils}
import org.apache.spark.util.Utils
private[nio] class ConnectionManager(
@ -65,8 +65,6 @@ private[nio] class ConnectionManager(
private val selector = SelectorProvider.provider.openSelector()
private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
// default to 30 second timeout waiting for authentication
private val authTimeout = conf.getInt("spark.core.connection.auth.wait.timeout", 30)
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
private val handleMessageExecutor = new ThreadPoolExecutor(
@ -409,7 +407,8 @@ private[nio] class ConnectionManager(
while (newChannel != null) {
try {
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId)
val newConnection = new ReceivingConnection(newChannel, selector, newConnectionId,
securityManager)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
@ -527,9 +526,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
waitingConn.getAuthenticated().synchronized {
waitingConn.getAuthenticated().notifyAll()
}
waitingConn.registerAfterAuth()
wakeupSelector()
return
} else {
var replyToken : Array[Byte] = null
@ -538,9 +536,8 @@ private[nio] class ConnectionManager(
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
waitingConn.getAuthenticated().synchronized {
waitingConn.getAuthenticated().notifyAll()
}
waitingConn.registerAfterAuth()
wakeupSelector()
return
}
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
@ -574,9 +571,11 @@ private[nio] class ConnectionManager(
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
if (connection.isSaslComplete()) {
logDebug("Server sasl completed: " + connection.connectionId)
logDebug("Server sasl completed: " + connection.connectionId +
" for: " + connectionId)
} else {
logDebug("Server sasl not completed: " + connection.connectionId)
logDebug("Server sasl not completed: " + connection.connectionId +
" for: " + connectionId)
}
if (replyToken != null) {
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
@ -723,7 +722,8 @@ private[nio] class ConnectionManager(
if (message == null) throw new Exception("Error creating security message")
connectionsAwaitingSasl += ((conn.connectionId, conn))
sendSecurityMessage(connManagerId, message)
logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId)
logDebug("adding connectionsAwaitingSasl id: " + conn.connectionId +
" to: " + connManagerId)
} catch {
case e: Exception => {
logError("Error getting first response from the SaslClient.", e)
@ -744,7 +744,7 @@ private[nio] class ConnectionManager(
val inetSocketAddress = new InetSocketAddress(connManagerId.host, connManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connManagerId,
newConnectionId)
newConnectionId, securityManager)
logInfo("creating new sending connection for security! " + newConnectionId )
registerRequests.enqueue(newConnection)
@ -769,61 +769,23 @@ private[nio] class ConnectionManager(
connectionManagerId.port)
val newConnectionId = new ConnectionId(id, idCount.getAndIncrement.intValue)
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId,
newConnectionId)
newConnectionId, securityManager)
logTrace("creating new sending connection: " + newConnectionId)
registerRequests.enqueue(newConnection)
newConnection
}
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
if (authEnabled) {
checkSendAuthFirst(connectionManagerId, connection)
}
message.senderAddress = id.toSocketAddress()
logDebug("Before Sending [" + message + "] to [" + connectionManagerId + "]" + " " +
"connectionid: " + connection.connectionId)
if (authEnabled) {
// if we aren't authenticated yet lets block the senders until authentication completes
try {
connection.getAuthenticated().synchronized {
val clock = SystemClock
val startTime = clock.getTime()
while (!connection.isSaslComplete()) {
logDebug("getAuthenticated wait connectionid: " + connection.connectionId)
// have timeout in case remote side never responds
connection.getAuthenticated().wait(500)
if (((clock.getTime() - startTime) >= (authTimeout * 1000))
&& (!connection.isSaslComplete())) {
// took to long to authenticate the connection, something probably went wrong
throw new Exception("Took to long for authentication to " + connectionManagerId +
", waited " + authTimeout + "seconds, failing.")
}
}
}
} catch {
case e: Exception => logError("Exception while waiting for authentication.", e)
// need to tell sender it failed
messageStatuses.synchronized {
val s = messageStatuses.get(message.id)
s match {
case Some(msgStatus) => {
messageStatuses -= message.id
logInfo("Notifying " + msgStatus.connectionManagerId)
msgStatus.markDone(None)
}
case None => {
logError("no messageStatus for failed message id: " + message.id)
}
}
}
}
checkSendAuthFirst(connectionManagerId, connection)
}
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
connection.send(message)
wakeupSelector()
}