Checkpoint commit - compiles and passes a lot of tests - not all though, looking into FileSuite issues
This commit is contained in:
parent
6798a09df8
commit
d90d2af103
18
core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala
Normal file
18
core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala
Normal file
|
@ -0,0 +1,18 @@
|
|||
package spark.deploy
|
||||
|
||||
/**
|
||||
* Contains util methods to interact with Hadoop from spark.
|
||||
*/
|
||||
object SparkHadoopUtil {
|
||||
|
||||
def getUserNameFromEnvironment(): String = {
|
||||
// defaulting to -D ...
|
||||
System.getProperty("user.name")
|
||||
}
|
||||
|
||||
def runAsUser(func: (Product) => Unit, args: Product) {
|
||||
|
||||
// Add support, if exists - for now, simply run func !
|
||||
func(args)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
package spark.deploy
|
||||
|
||||
import collection.mutable.HashMap
|
||||
import org.apache.hadoop.security.UserGroupInformation
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
|
||||
import java.security.PrivilegedExceptionAction
|
||||
|
||||
/**
|
||||
* Contains util methods to interact with Hadoop from spark.
|
||||
*/
|
||||
object SparkHadoopUtil {
|
||||
|
||||
val yarnConf = new YarnConfiguration(new Configuration())
|
||||
|
||||
def getUserNameFromEnvironment(): String = {
|
||||
// defaulting to env if -D is not present ...
|
||||
val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name))
|
||||
|
||||
// If nothing found, default to user we are running as
|
||||
if (retval == null) System.getProperty("user.name") else retval
|
||||
}
|
||||
|
||||
def runAsUser(func: (Product) => Unit, args: Product) {
|
||||
runAsUser(func, args, getUserNameFromEnvironment())
|
||||
}
|
||||
|
||||
def runAsUser(func: (Product) => Unit, args: Product, user: String) {
|
||||
|
||||
// println("running as user " + jobUserName)
|
||||
|
||||
UserGroupInformation.setConfiguration(yarnConf)
|
||||
val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user)
|
||||
appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
|
||||
def run: AnyRef = {
|
||||
func(args)
|
||||
// no return value ...
|
||||
null
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true.
|
||||
def isYarnMode(): Boolean = {
|
||||
val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))
|
||||
java.lang.Boolean.valueOf(yarnMode)
|
||||
}
|
||||
|
||||
// Set an env variable indicating we are running in YARN mode.
|
||||
// Note that anything with SPARK prefix gets propagated to all (remote) processes
|
||||
def setYarnMode() {
|
||||
System.setProperty("SPARK_YARN_MODE", "true")
|
||||
}
|
||||
|
||||
def setYarnMode(env: HashMap[String, String]) {
|
||||
env("SPARK_YARN_MODE") = "true"
|
||||
}
|
||||
}
|
|
@ -0,0 +1,342 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import java.net.Socket
|
||||
import java.util.concurrent.CopyOnWriteArrayList
|
||||
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.net.NetUtils
|
||||
import org.apache.hadoop.yarn.api._
|
||||
import org.apache.hadoop.yarn.api.records._
|
||||
import org.apache.hadoop.yarn.api.protocolrecords._
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration
|
||||
import org.apache.hadoop.yarn.ipc.YarnRPC
|
||||
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
|
||||
import scala.collection.JavaConversions._
|
||||
import spark.{SparkContext, Logging, Utils}
|
||||
import org.apache.hadoop.security.UserGroupInformation
|
||||
import java.security.PrivilegedExceptionAction
|
||||
|
||||
class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
|
||||
|
||||
def this(args: ApplicationMasterArguments) = this(args, new Configuration())
|
||||
|
||||
private var rpc: YarnRPC = YarnRPC.create(conf)
|
||||
private var resourceManager: AMRMProtocol = null
|
||||
private var appAttemptId: ApplicationAttemptId = null
|
||||
private var userThread: Thread = null
|
||||
private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
|
||||
|
||||
private var yarnAllocator: YarnAllocationHandler = null
|
||||
|
||||
def run() {
|
||||
|
||||
// Initialization
|
||||
val jobUserName = Utils.getUserNameFromEnvironment()
|
||||
logInfo("running as user " + jobUserName)
|
||||
|
||||
// run as user ...
|
||||
UserGroupInformation.setConfiguration(yarnConf)
|
||||
val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName)
|
||||
appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
|
||||
def run: AnyRef = {
|
||||
runImpl()
|
||||
return null
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private def runImpl() {
|
||||
|
||||
appAttemptId = getApplicationAttemptId()
|
||||
resourceManager = registerWithResourceManager()
|
||||
val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
|
||||
|
||||
// Compute number of threads for akka
|
||||
val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
|
||||
|
||||
if (minimumMemory > 0) {
|
||||
val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
|
||||
val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
|
||||
|
||||
if (numCore > 0) {
|
||||
// do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
|
||||
// TODO: Uncomment when hadoop is on a version which has this fixed.
|
||||
// args.workerCores = numCore
|
||||
}
|
||||
}
|
||||
|
||||
// Workaround until hadoop moves to something which has
|
||||
// https://issues.apache.org/jira/browse/HADOOP-8406
|
||||
// ignore result
|
||||
// This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
|
||||
// Hence args.workerCores = numCore disabled above. Any better option ?
|
||||
// org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
|
||||
|
||||
ApplicationMaster.register(this)
|
||||
// Start the user's JAR
|
||||
userThread = startUserClass()
|
||||
|
||||
// This a bit hacky, but we need to wait until the spark.master.port property has
|
||||
// been set by the Thread executing the user class.
|
||||
waitForSparkMaster()
|
||||
|
||||
// Allocate all containers
|
||||
allocateWorkers()
|
||||
|
||||
// Wait for the user class to Finish
|
||||
userThread.join()
|
||||
|
||||
// Finish the ApplicationMaster
|
||||
finishApplicationMaster()
|
||||
// TODO: Exit based on success/failure
|
||||
System.exit(0)
|
||||
}
|
||||
|
||||
private def getApplicationAttemptId(): ApplicationAttemptId = {
|
||||
val envs = System.getenv()
|
||||
val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
|
||||
val containerId = ConverterUtils.toContainerId(containerIdString)
|
||||
val appAttemptId = containerId.getApplicationAttemptId()
|
||||
logInfo("ApplicationAttemptId: " + appAttemptId)
|
||||
return appAttemptId
|
||||
}
|
||||
|
||||
private def registerWithResourceManager(): AMRMProtocol = {
|
||||
val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
|
||||
YarnConfiguration.RM_SCHEDULER_ADDRESS,
|
||||
YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
|
||||
logInfo("Connecting to ResourceManager at " + rmAddress)
|
||||
return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
|
||||
}
|
||||
|
||||
private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
|
||||
logInfo("Registering the ApplicationMaster")
|
||||
val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
|
||||
.asInstanceOf[RegisterApplicationMasterRequest]
|
||||
appMasterRequest.setApplicationAttemptId(appAttemptId)
|
||||
// Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
|
||||
// Users can then monitor stderr/stdout on that node if required.
|
||||
appMasterRequest.setHost(Utils.localHostName())
|
||||
appMasterRequest.setRpcPort(0)
|
||||
// What do we provide here ? Might make sense to expose something sensible later ?
|
||||
appMasterRequest.setTrackingUrl("")
|
||||
return resourceManager.registerApplicationMaster(appMasterRequest)
|
||||
}
|
||||
|
||||
private def waitForSparkMaster() {
|
||||
logInfo("Waiting for spark master to be reachable.")
|
||||
var masterUp = false
|
||||
while(!masterUp) {
|
||||
val masterHost = System.getProperty("spark.master.host")
|
||||
val masterPort = System.getProperty("spark.master.port")
|
||||
try {
|
||||
val socket = new Socket(masterHost, masterPort.toInt)
|
||||
socket.close()
|
||||
logInfo("Master now available: " + masterHost + ":" + masterPort)
|
||||
masterUp = true
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("Failed to connect to master at " + masterHost + ":" + masterPort)
|
||||
Thread.sleep(100)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def startUserClass(): Thread = {
|
||||
logInfo("Starting the user JAR in a separate Thread")
|
||||
val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader)
|
||||
.getMethod("main", classOf[Array[String]])
|
||||
val t = new Thread {
|
||||
override def run() {
|
||||
var mainArgs: Array[String] = null
|
||||
var startIndex = 0
|
||||
|
||||
// I am sure there is a better 'scala' way to do this .... but I am just trying to get things to work right now !
|
||||
if (args.userArgs.isEmpty || args.userArgs.get(0) != "yarn-standalone") {
|
||||
// ensure that first param is ALWAYS "yarn-standalone"
|
||||
mainArgs = new Array[String](args.userArgs.size() + 1)
|
||||
mainArgs.update(0, "yarn-standalone")
|
||||
startIndex = 1
|
||||
}
|
||||
else {
|
||||
mainArgs = new Array[String](args.userArgs.size())
|
||||
}
|
||||
|
||||
args.userArgs.copyToArray(mainArgs, startIndex, args.userArgs.size())
|
||||
|
||||
mainMethod.invoke(null, mainArgs)
|
||||
}
|
||||
}
|
||||
t.start()
|
||||
return t
|
||||
}
|
||||
|
||||
private def allocateWorkers() {
|
||||
logInfo("Waiting for spark context initialization")
|
||||
|
||||
try {
|
||||
var sparkContext: SparkContext = null
|
||||
ApplicationMaster.sparkContextRef.synchronized {
|
||||
var count = 0
|
||||
while (ApplicationMaster.sparkContextRef.get() == null) {
|
||||
logInfo("Waiting for spark context initialization ... " + count)
|
||||
count = count + 1
|
||||
ApplicationMaster.sparkContextRef.wait(10000L)
|
||||
}
|
||||
sparkContext = ApplicationMaster.sparkContextRef.get()
|
||||
assert(sparkContext != null)
|
||||
this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData)
|
||||
}
|
||||
|
||||
|
||||
logInfo("Allocating " + args.numWorkers + " workers.")
|
||||
// Wait until all containers have finished
|
||||
// TODO: This is a bit ugly. Can we make it nicer?
|
||||
// TODO: Handle container failure
|
||||
while(yarnAllocator.getNumWorkersRunning < args.numWorkers &&
|
||||
// If user thread exists, then quit !
|
||||
userThread.isAlive) {
|
||||
|
||||
this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
|
||||
ApplicationMaster.incrementAllocatorLoop(1)
|
||||
Thread.sleep(100)
|
||||
}
|
||||
} finally {
|
||||
// in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT :
|
||||
// so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
|
||||
ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
|
||||
}
|
||||
logInfo("All workers have launched.")
|
||||
|
||||
// Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
|
||||
if (userThread.isAlive){
|
||||
// ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
|
||||
|
||||
val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
|
||||
// must be <= timeoutInterval/ 2.
|
||||
// On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
|
||||
// so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
|
||||
val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
|
||||
launchReporterThread(interval)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: We might want to extend this to allocate more containers in case they die !
|
||||
private def launchReporterThread(_sleepTime: Long): Thread = {
|
||||
val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
|
||||
|
||||
val t = new Thread {
|
||||
override def run() {
|
||||
while (userThread.isAlive){
|
||||
val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
|
||||
if (missingWorkerCount > 0) {
|
||||
logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
|
||||
yarnAllocator.allocateContainers(missingWorkerCount)
|
||||
}
|
||||
else sendProgress()
|
||||
Thread.sleep(sleepTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
// setting to daemon status, though this is usually not a good idea.
|
||||
t.setDaemon(true)
|
||||
t.start()
|
||||
logInfo("Started progress reporter thread - sleep time : " + sleepTime)
|
||||
return t
|
||||
}
|
||||
|
||||
private def sendProgress() {
|
||||
logDebug("Sending progress")
|
||||
// simulated with an allocate request with no nodes requested ...
|
||||
yarnAllocator.allocateContainers(0)
|
||||
}
|
||||
|
||||
/*
|
||||
def printContainers(containers: List[Container]) = {
|
||||
for (container <- containers) {
|
||||
logInfo("Launching shell command on a new container."
|
||||
+ ", containerId=" + container.getId()
|
||||
+ ", containerNode=" + container.getNodeId().getHost()
|
||||
+ ":" + container.getNodeId().getPort()
|
||||
+ ", containerNodeURI=" + container.getNodeHttpAddress()
|
||||
+ ", containerState" + container.getState()
|
||||
+ ", containerResourceMemory"
|
||||
+ container.getResource().getMemory())
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
def finishApplicationMaster() {
|
||||
val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
|
||||
.asInstanceOf[FinishApplicationMasterRequest]
|
||||
finishReq.setAppAttemptId(appAttemptId)
|
||||
// TODO: Check if the application has failed or succeeded
|
||||
finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED)
|
||||
resourceManager.finishApplicationMaster(finishReq)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object ApplicationMaster {
|
||||
// number of times to wait for the allocator loop to complete.
|
||||
// each loop iteration waits for 100ms, so maximum of 3 seconds.
|
||||
// This is to ensure that we have reasonable number of containers before we start
|
||||
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more
|
||||
// containers are available. Might need to handle this better.
|
||||
private val ALLOCATOR_LOOP_WAIT_COUNT = 30
|
||||
def incrementAllocatorLoop(by: Int) {
|
||||
val count = yarnAllocatorLoop.getAndAdd(by)
|
||||
if (count >= ALLOCATOR_LOOP_WAIT_COUNT){
|
||||
yarnAllocatorLoop.synchronized {
|
||||
// to wake threads off wait ...
|
||||
yarnAllocatorLoop.notifyAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
|
||||
|
||||
def register(master: ApplicationMaster) {
|
||||
applicationMasters.add(master)
|
||||
}
|
||||
|
||||
val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null)
|
||||
val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
|
||||
|
||||
def sparkContextInitialized(sc: SparkContext): Boolean = {
|
||||
var modified = false
|
||||
sparkContextRef.synchronized {
|
||||
modified = sparkContextRef.compareAndSet(null, sc)
|
||||
sparkContextRef.notifyAll()
|
||||
}
|
||||
|
||||
// Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
|
||||
// Should not really have to do this, but it helps yarn to evict resources earlier.
|
||||
// not to mention, prevent Client declaring failure even though we exit'ed properly.
|
||||
if (modified) {
|
||||
Runtime.getRuntime().addShutdownHook(new Thread with Logging {
|
||||
// This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
|
||||
logInfo("Adding shutdown hook for context " + sc)
|
||||
override def run() {
|
||||
logInfo("Invoking sc stop from shutdown hook")
|
||||
sc.stop()
|
||||
// best case ...
|
||||
for (master <- applicationMasters) master.finishApplicationMaster
|
||||
}
|
||||
} )
|
||||
}
|
||||
|
||||
// Wait for initialization to complete and atleast 'some' nodes can get allocated
|
||||
yarnAllocatorLoop.synchronized {
|
||||
while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){
|
||||
yarnAllocatorLoop.wait(1000L)
|
||||
}
|
||||
}
|
||||
modified
|
||||
}
|
||||
|
||||
def main(argStrings: Array[String]) {
|
||||
val args = new ApplicationMasterArguments(argStrings)
|
||||
new ApplicationMaster(args).run()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,78 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import spark.util.IntParam
|
||||
import collection.mutable.ArrayBuffer
|
||||
|
||||
class ApplicationMasterArguments(val args: Array[String]) {
|
||||
var userJar: String = null
|
||||
var userClass: String = null
|
||||
var userArgs: Seq[String] = Seq[String]()
|
||||
var workerMemory = 1024
|
||||
var workerCores = 1
|
||||
var numWorkers = 2
|
||||
|
||||
parseArgs(args.toList)
|
||||
|
||||
private def parseArgs(inputArgs: List[String]): Unit = {
|
||||
val userArgsBuffer = new ArrayBuffer[String]()
|
||||
|
||||
var args = inputArgs
|
||||
|
||||
while (! args.isEmpty) {
|
||||
|
||||
args match {
|
||||
case ("--jar") :: value :: tail =>
|
||||
userJar = value
|
||||
args = tail
|
||||
|
||||
case ("--class") :: value :: tail =>
|
||||
userClass = value
|
||||
args = tail
|
||||
|
||||
case ("--args") :: value :: tail =>
|
||||
userArgsBuffer += value
|
||||
args = tail
|
||||
|
||||
case ("--num-workers") :: IntParam(value) :: tail =>
|
||||
numWorkers = value
|
||||
args = tail
|
||||
|
||||
case ("--worker-memory") :: IntParam(value) :: tail =>
|
||||
workerMemory = value
|
||||
args = tail
|
||||
|
||||
case ("--worker-cores") :: IntParam(value) :: tail =>
|
||||
workerCores = value
|
||||
args = tail
|
||||
|
||||
case Nil =>
|
||||
if (userJar == null || userClass == null) {
|
||||
printUsageAndExit(1)
|
||||
}
|
||||
|
||||
case _ =>
|
||||
printUsageAndExit(1, args)
|
||||
}
|
||||
}
|
||||
|
||||
userArgs = userArgsBuffer.readOnly
|
||||
}
|
||||
|
||||
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
|
||||
if (unknownParam != null) {
|
||||
System.err.println("Unknown/unsupported param " + unknownParam)
|
||||
}
|
||||
System.err.println(
|
||||
"Usage: spark.deploy.yarn.ApplicationMaster [options] \n" +
|
||||
"Options:\n" +
|
||||
" --jar JAR_PATH Path to your application's JAR file (required)\n" +
|
||||
" --class CLASS_NAME Name of your application's main class (required)\n" +
|
||||
" --args ARGS Arguments to be passed to your application's main class.\n" +
|
||||
" Mutliple invocations are possible, each will be passed in order.\n" +
|
||||
" Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" +
|
||||
" --num-workers NUM Number of workers to start (Default: 2)\n" +
|
||||
" --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
|
||||
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n")
|
||||
System.exit(exitCode)
|
||||
}
|
||||
}
|
326
core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala
Normal file
326
core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala
Normal file
|
@ -0,0 +1,326 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import java.net.{InetSocketAddress, URI}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
|
||||
import org.apache.hadoop.net.NetUtils
|
||||
import org.apache.hadoop.yarn.api._
|
||||
import org.apache.hadoop.yarn.api.records._
|
||||
import org.apache.hadoop.yarn.api.protocolrecords._
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration
|
||||
import org.apache.hadoop.yarn.ipc.YarnRPC
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.JavaConversions._
|
||||
import spark.{Logging, Utils}
|
||||
import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils}
|
||||
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
|
||||
import spark.deploy.SparkHadoopUtil
|
||||
|
||||
class Client(conf: Configuration, args: ClientArguments) extends Logging {
|
||||
|
||||
def this(args: ClientArguments) = this(new Configuration(), args)
|
||||
|
||||
var applicationsManager: ClientRMProtocol = null
|
||||
var rpc: YarnRPC = YarnRPC.create(conf)
|
||||
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
|
||||
|
||||
def run() {
|
||||
connectToASM()
|
||||
logClusterResourceDetails()
|
||||
|
||||
val newApp = getNewApplication()
|
||||
val appId = newApp.getApplicationId()
|
||||
|
||||
verifyClusterResources(newApp)
|
||||
val appContext = createApplicationSubmissionContext(appId)
|
||||
val localResources = prepareLocalResources(appId, "spark")
|
||||
val env = setupLaunchEnv(localResources)
|
||||
val amContainer = createContainerLaunchContext(newApp, localResources, env)
|
||||
|
||||
appContext.setQueue(args.amQueue)
|
||||
appContext.setAMContainerSpec(amContainer)
|
||||
appContext.setUser(args.amUser)
|
||||
|
||||
submitApp(appContext)
|
||||
|
||||
monitorApplication(appId)
|
||||
System.exit(0)
|
||||
}
|
||||
|
||||
|
||||
def connectToASM() {
|
||||
val rmAddress: InetSocketAddress = NetUtils.createSocketAddr(
|
||||
yarnConf.get(YarnConfiguration.RM_ADDRESS, YarnConfiguration.DEFAULT_RM_ADDRESS)
|
||||
)
|
||||
logInfo("Connecting to ResourceManager at" + rmAddress)
|
||||
applicationsManager = rpc.getProxy(classOf[ClientRMProtocol], rmAddress, conf)
|
||||
.asInstanceOf[ClientRMProtocol]
|
||||
}
|
||||
|
||||
def logClusterResourceDetails() {
|
||||
val clusterMetrics: YarnClusterMetrics = getYarnClusterMetrics
|
||||
logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers)
|
||||
|
||||
/*
|
||||
val clusterNodeReports: List[NodeReport] = getNodeReports
|
||||
logDebug("Got Cluster node info from ASM")
|
||||
for (node <- clusterNodeReports) {
|
||||
logDebug("Got node report from ASM for, nodeId=" + node.getNodeId + ", nodeAddress=" + node.getHttpAddress +
|
||||
", nodeRackName=" + node.getRackName + ", nodeNumContainers=" + node.getNumContainers + ", nodeHealthStatus=" + node.getNodeHealthStatus)
|
||||
}
|
||||
*/
|
||||
|
||||
val queueInfo: QueueInfo = getQueueInfo(args.amQueue)
|
||||
logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity +
|
||||
", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size +
|
||||
", queueChildQueueCount=" + queueInfo.getChildQueues.size)
|
||||
}
|
||||
|
||||
def getYarnClusterMetrics: YarnClusterMetrics = {
|
||||
val request: GetClusterMetricsRequest = Records.newRecord(classOf[GetClusterMetricsRequest])
|
||||
val response: GetClusterMetricsResponse = applicationsManager.getClusterMetrics(request)
|
||||
return response.getClusterMetrics
|
||||
}
|
||||
|
||||
def getNodeReports: List[NodeReport] = {
|
||||
val request: GetClusterNodesRequest = Records.newRecord(classOf[GetClusterNodesRequest])
|
||||
val response: GetClusterNodesResponse = applicationsManager.getClusterNodes(request)
|
||||
return response.getNodeReports.toList
|
||||
}
|
||||
|
||||
def getQueueInfo(queueName: String): QueueInfo = {
|
||||
val request: GetQueueInfoRequest = Records.newRecord(classOf[GetQueueInfoRequest])
|
||||
request.setQueueName(queueName)
|
||||
request.setIncludeApplications(true)
|
||||
request.setIncludeChildQueues(false)
|
||||
request.setRecursive(false)
|
||||
Records.newRecord(classOf[GetQueueInfoRequest])
|
||||
return applicationsManager.getQueueInfo(request).getQueueInfo
|
||||
}
|
||||
|
||||
def getNewApplication(): GetNewApplicationResponse = {
|
||||
logInfo("Requesting new Application")
|
||||
val request = Records.newRecord(classOf[GetNewApplicationRequest])
|
||||
val response = applicationsManager.getNewApplication(request)
|
||||
logInfo("Got new ApplicationId: " + response.getApplicationId())
|
||||
return response
|
||||
}
|
||||
|
||||
def verifyClusterResources(app: GetNewApplicationResponse) = {
|
||||
val maxMem = app.getMaximumResourceCapability().getMemory()
|
||||
logInfo("Max mem capabililty of resources in this cluster " + maxMem)
|
||||
|
||||
// If the cluster does not have enough memory resources, exit.
|
||||
val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory
|
||||
if (requestedMem > maxMem) {
|
||||
logError("Cluster cannot satisfy memory resource request of " + requestedMem)
|
||||
System.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = {
|
||||
logInfo("Setting up application submission context for ASM")
|
||||
val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
|
||||
appContext.setApplicationId(appId)
|
||||
appContext.setApplicationName("Spark")
|
||||
return appContext
|
||||
}
|
||||
|
||||
def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = {
|
||||
logInfo("Preparing Local resources")
|
||||
val locaResources = HashMap[String, LocalResource]()
|
||||
// Upload Spark and the application JAR to the remote file system
|
||||
// Add them as local resources to the AM
|
||||
val fs = FileSystem.get(conf)
|
||||
Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
|
||||
.foreach { case(destName, _localPath) =>
|
||||
val localPath: String = if (_localPath != null) _localPath.trim() else ""
|
||||
if (! localPath.isEmpty()) {
|
||||
val src = new Path(localPath)
|
||||
val pathSuffix = appName + "/" + appId.getId() + destName
|
||||
val dst = new Path(fs.getHomeDirectory(), pathSuffix)
|
||||
logInfo("Uploading " + src + " to " + dst)
|
||||
fs.copyFromLocalFile(false, true, src, dst)
|
||||
val destStatus = fs.getFileStatus(dst)
|
||||
|
||||
val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
|
||||
amJarRsrc.setType(LocalResourceType.FILE)
|
||||
amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
|
||||
amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst))
|
||||
amJarRsrc.setTimestamp(destStatus.getModificationTime())
|
||||
amJarRsrc.setSize(destStatus.getLen())
|
||||
locaResources(destName) = amJarRsrc
|
||||
}
|
||||
}
|
||||
return locaResources
|
||||
}
|
||||
|
||||
def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
|
||||
logInfo("Setting up the launch environment")
|
||||
val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null)
|
||||
|
||||
val env = new HashMap[String, String]()
|
||||
Apps.addToEnvironment(env, Environment.USER.name, args.amUser)
|
||||
|
||||
// If log4j present, ensure ours overrides all others
|
||||
if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
|
||||
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
|
||||
Client.populateHadoopClasspath(yarnConf, env)
|
||||
SparkHadoopUtil.setYarnMode(env)
|
||||
env("SPARK_YARN_JAR_PATH") =
|
||||
localResources("spark.jar").getResource().getScheme.toString() + "://" +
|
||||
localResources("spark.jar").getResource().getFile().toString()
|
||||
env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString()
|
||||
env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString()
|
||||
|
||||
env("SPARK_YARN_USERJAR_PATH") =
|
||||
localResources("app.jar").getResource().getScheme.toString() + "://" +
|
||||
localResources("app.jar").getResource().getFile().toString()
|
||||
env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString()
|
||||
env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString()
|
||||
|
||||
if (log4jConfLocalRes != null) {
|
||||
env("SPARK_YARN_LOG4J_PATH") =
|
||||
log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString()
|
||||
env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString()
|
||||
env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
|
||||
}
|
||||
|
||||
// Add each SPARK-* key to the environment
|
||||
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
|
||||
return env
|
||||
}
|
||||
|
||||
def userArgsToString(clientArgs: ClientArguments): String = {
|
||||
val prefix = " --args "
|
||||
val args = clientArgs.userArgs
|
||||
val retval = new StringBuilder()
|
||||
for (arg <- args){
|
||||
retval.append(prefix).append(" '").append(arg).append("' ")
|
||||
}
|
||||
|
||||
retval.toString
|
||||
}
|
||||
|
||||
def createContainerLaunchContext(newApp: GetNewApplicationResponse,
|
||||
localResources: HashMap[String, LocalResource],
|
||||
env: HashMap[String, String]): ContainerLaunchContext = {
|
||||
logInfo("Setting up container launch context")
|
||||
val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
|
||||
amContainer.setLocalResources(localResources)
|
||||
amContainer.setEnvironment(env)
|
||||
|
||||
val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
|
||||
|
||||
var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
|
||||
(if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD
|
||||
|
||||
// Extra options for the JVM
|
||||
var JAVA_OPTS = ""
|
||||
|
||||
// Add Xmx for am memory
|
||||
JAVA_OPTS += "-Xmx" + amMemory + "m "
|
||||
|
||||
// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
|
||||
// The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
|
||||
// node, spark gc effects all other containers performance (which can also be other spark containers)
|
||||
// Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
|
||||
// limited to subset of cores on a node.
|
||||
if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) {
|
||||
// In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines
|
||||
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
|
||||
JAVA_OPTS += " -XX:+CMSIncrementalMode "
|
||||
JAVA_OPTS += " -XX:+CMSIncrementalPacing "
|
||||
JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
|
||||
JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
|
||||
}
|
||||
if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
|
||||
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
|
||||
}
|
||||
|
||||
// Command for the ApplicationMaster
|
||||
val commands = List[String]("java " +
|
||||
" -server " +
|
||||
JAVA_OPTS +
|
||||
" spark.deploy.yarn.ApplicationMaster" +
|
||||
" --class " + args.userClass +
|
||||
" --jar " + args.userJar +
|
||||
userArgsToString(args) +
|
||||
" --worker-memory " + args.workerMemory +
|
||||
" --worker-cores " + args.workerCores +
|
||||
" --num-workers " + args.numWorkers +
|
||||
" 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
|
||||
" 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
|
||||
logInfo("Command for the ApplicationMaster: " + commands(0))
|
||||
amContainer.setCommands(commands)
|
||||
|
||||
val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
|
||||
// Memory for the ApplicationMaster
|
||||
capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
|
||||
amContainer.setResource(capability)
|
||||
|
||||
return amContainer
|
||||
}
|
||||
|
||||
def submitApp(appContext: ApplicationSubmissionContext) = {
|
||||
// Create the request to send to the applications manager
|
||||
val appRequest = Records.newRecord(classOf[SubmitApplicationRequest])
|
||||
.asInstanceOf[SubmitApplicationRequest]
|
||||
appRequest.setApplicationSubmissionContext(appContext)
|
||||
// Submit the application to the applications manager
|
||||
logInfo("Submitting application to ASM")
|
||||
applicationsManager.submitApplication(appRequest)
|
||||
}
|
||||
|
||||
def monitorApplication(appId: ApplicationId): Boolean = {
|
||||
while(true) {
|
||||
Thread.sleep(1000)
|
||||
val reportRequest = Records.newRecord(classOf[GetApplicationReportRequest])
|
||||
.asInstanceOf[GetApplicationReportRequest]
|
||||
reportRequest.setApplicationId(appId)
|
||||
val reportResponse = applicationsManager.getApplicationReport(reportRequest)
|
||||
val report = reportResponse.getApplicationReport()
|
||||
|
||||
logInfo("Application report from ASM: \n" +
|
||||
"\t application identifier: " + appId.toString() + "\n" +
|
||||
"\t appId: " + appId.getId() + "\n" +
|
||||
"\t clientToken: " + report.getClientToken() + "\n" +
|
||||
"\t appDiagnostics: " + report.getDiagnostics() + "\n" +
|
||||
"\t appMasterHost: " + report.getHost() + "\n" +
|
||||
"\t appQueue: " + report.getQueue() + "\n" +
|
||||
"\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
|
||||
"\t appStartTime: " + report.getStartTime() + "\n" +
|
||||
"\t yarnAppState: " + report.getYarnApplicationState() + "\n" +
|
||||
"\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" +
|
||||
"\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
|
||||
"\t appUser: " + report.getUser()
|
||||
)
|
||||
|
||||
val state = report.getYarnApplicationState()
|
||||
val dsStatus = report.getFinalApplicationStatus()
|
||||
if (state == YarnApplicationState.FINISHED ||
|
||||
state == YarnApplicationState.FAILED ||
|
||||
state == YarnApplicationState.KILLED) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
object Client {
|
||||
def main(argStrings: Array[String]) {
|
||||
val args = new ClientArguments(argStrings)
|
||||
SparkHadoopUtil.setYarnMode()
|
||||
new Client(args).run
|
||||
}
|
||||
|
||||
// Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
|
||||
def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
|
||||
for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) {
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,104 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import spark.util.MemoryParam
|
||||
import spark.util.IntParam
|
||||
import collection.mutable.{ArrayBuffer, HashMap}
|
||||
import spark.scheduler.{InputFormatInfo, SplitInfo}
|
||||
|
||||
// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
|
||||
class ClientArguments(val args: Array[String]) {
|
||||
var userJar: String = null
|
||||
var userClass: String = null
|
||||
var userArgs: Seq[String] = Seq[String]()
|
||||
var workerMemory = 1024
|
||||
var workerCores = 1
|
||||
var numWorkers = 2
|
||||
var amUser = System.getProperty("user.name")
|
||||
var amQueue = System.getProperty("QUEUE", "default")
|
||||
var amMemory: Int = 512
|
||||
// TODO
|
||||
var inputFormatInfo: List[InputFormatInfo] = null
|
||||
|
||||
parseArgs(args.toList)
|
||||
|
||||
private def parseArgs(inputArgs: List[String]): Unit = {
|
||||
val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]()
|
||||
val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]()
|
||||
|
||||
var args = inputArgs
|
||||
|
||||
while (! args.isEmpty) {
|
||||
|
||||
args match {
|
||||
case ("--jar") :: value :: tail =>
|
||||
userJar = value
|
||||
args = tail
|
||||
|
||||
case ("--class") :: value :: tail =>
|
||||
userClass = value
|
||||
args = tail
|
||||
|
||||
case ("--args") :: value :: tail =>
|
||||
userArgsBuffer += value
|
||||
args = tail
|
||||
|
||||
case ("--master-memory") :: MemoryParam(value) :: tail =>
|
||||
amMemory = value
|
||||
args = tail
|
||||
|
||||
case ("--num-workers") :: IntParam(value) :: tail =>
|
||||
numWorkers = value
|
||||
args = tail
|
||||
|
||||
case ("--worker-memory") :: MemoryParam(value) :: tail =>
|
||||
workerMemory = value
|
||||
args = tail
|
||||
|
||||
case ("--worker-cores") :: IntParam(value) :: tail =>
|
||||
workerCores = value
|
||||
args = tail
|
||||
|
||||
case ("--user") :: value :: tail =>
|
||||
amUser = value
|
||||
args = tail
|
||||
|
||||
case ("--queue") :: value :: tail =>
|
||||
amQueue = value
|
||||
args = tail
|
||||
|
||||
case Nil =>
|
||||
if (userJar == null || userClass == null) {
|
||||
printUsageAndExit(1)
|
||||
}
|
||||
|
||||
case _ =>
|
||||
printUsageAndExit(1, args)
|
||||
}
|
||||
}
|
||||
|
||||
userArgs = userArgsBuffer.readOnly
|
||||
inputFormatInfo = inputFormatMap.values.toList
|
||||
}
|
||||
|
||||
|
||||
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
|
||||
if (unknownParam != null) {
|
||||
System.err.println("Unknown/unsupported param " + unknownParam)
|
||||
}
|
||||
System.err.println(
|
||||
"Usage: spark.deploy.yarn.Client [options] \n" +
|
||||
"Options:\n" +
|
||||
" --jar JAR_PATH Path to your application's JAR file (required)\n" +
|
||||
" --class CLASS_NAME Name of your application's main class (required)\n" +
|
||||
" --args ARGS Arguments to be passed to your application's main class.\n" +
|
||||
" Mutliple invocations are possible, each will be passed in order.\n" +
|
||||
" Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" +
|
||||
" --num-workers NUM Number of workers to start (Default: 2)\n" +
|
||||
" --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
|
||||
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
|
||||
" --user USERNAME Run the ApplicationMaster as a different user\n"
|
||||
)
|
||||
System.exit(exitCode)
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import java.net.URI
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
|
||||
import org.apache.hadoop.net.NetUtils
|
||||
import org.apache.hadoop.security.UserGroupInformation
|
||||
import org.apache.hadoop.yarn.api._
|
||||
import org.apache.hadoop.yarn.api.records._
|
||||
import org.apache.hadoop.yarn.api.protocolrecords._
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration
|
||||
import org.apache.hadoop.yarn.ipc.YarnRPC
|
||||
import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
|
||||
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
import spark.{Logging, Utils}
|
||||
|
||||
class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String,
|
||||
slaveId: String, hostname: String, workerMemory: Int, workerCores: Int)
|
||||
extends Runnable with Logging {
|
||||
|
||||
var rpc: YarnRPC = YarnRPC.create(conf)
|
||||
var cm: ContainerManager = null
|
||||
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
|
||||
|
||||
def run = {
|
||||
logInfo("Starting Worker Container")
|
||||
cm = connectToCM
|
||||
startContainer
|
||||
}
|
||||
|
||||
def startContainer = {
|
||||
logInfo("Setting up ContainerLaunchContext")
|
||||
|
||||
val ctx = Records.newRecord(classOf[ContainerLaunchContext])
|
||||
.asInstanceOf[ContainerLaunchContext]
|
||||
|
||||
ctx.setContainerId(container.getId())
|
||||
ctx.setResource(container.getResource())
|
||||
val localResources = prepareLocalResources
|
||||
ctx.setLocalResources(localResources)
|
||||
|
||||
val env = prepareEnvironment
|
||||
ctx.setEnvironment(env)
|
||||
|
||||
// Extra options for the JVM
|
||||
var JAVA_OPTS = ""
|
||||
// Set the JVM memory
|
||||
val workerMemoryString = workerMemory + "m"
|
||||
JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " "
|
||||
if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
|
||||
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
|
||||
}
|
||||
// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
|
||||
// The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
|
||||
// node, spark gc effects all other containers performance (which can also be other spark containers)
|
||||
// Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
|
||||
// limited to subset of cores on a node.
|
||||
/*
|
||||
else {
|
||||
// If no java_opts specified, default to using -XX:+CMSIncrementalMode
|
||||
// It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it.
|
||||
// In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines
|
||||
// The options are based on
|
||||
// http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
|
||||
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
|
||||
JAVA_OPTS += " -XX:+CMSIncrementalMode "
|
||||
JAVA_OPTS += " -XX:+CMSIncrementalPacing "
|
||||
JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
|
||||
JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
|
||||
}
|
||||
*/
|
||||
|
||||
ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
|
||||
val commands = List[String]("java " +
|
||||
" -server " +
|
||||
// Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
|
||||
// Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state.
|
||||
// TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
|
||||
" -XX:OnOutOfMemoryError='kill %p' " +
|
||||
JAVA_OPTS +
|
||||
" spark.executor.StandaloneExecutorBackend " +
|
||||
masterAddress + " " +
|
||||
slaveId + " " +
|
||||
hostname + " " +
|
||||
workerCores +
|
||||
" 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
|
||||
" 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
|
||||
logInfo("Setting up worker with commands: " + commands)
|
||||
ctx.setCommands(commands)
|
||||
|
||||
// Send the start request to the ContainerManager
|
||||
val startReq = Records.newRecord(classOf[StartContainerRequest])
|
||||
.asInstanceOf[StartContainerRequest]
|
||||
startReq.setContainerLaunchContext(ctx)
|
||||
cm.startContainer(startReq)
|
||||
}
|
||||
|
||||
|
||||
def prepareLocalResources: HashMap[String, LocalResource] = {
|
||||
logInfo("Preparing Local resources")
|
||||
val locaResources = HashMap[String, LocalResource]()
|
||||
|
||||
// Spark JAR
|
||||
val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
|
||||
sparkJarResource.setType(LocalResourceType.FILE)
|
||||
sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
|
||||
sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
|
||||
new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
|
||||
sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
|
||||
sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
|
||||
locaResources("spark.jar") = sparkJarResource
|
||||
// User JAR
|
||||
val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
|
||||
userJarResource.setType(LocalResourceType.FILE)
|
||||
userJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
|
||||
userJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
|
||||
new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
|
||||
userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
|
||||
userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
|
||||
locaResources("app.jar") = userJarResource
|
||||
|
||||
// Log4j conf - if available
|
||||
if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
|
||||
val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
|
||||
log4jConfResource.setType(LocalResourceType.FILE)
|
||||
log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION)
|
||||
log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI(
|
||||
new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
|
||||
log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
|
||||
log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
|
||||
locaResources("log4j.properties") = log4jConfResource
|
||||
}
|
||||
|
||||
|
||||
logInfo("Prepared Local resources " + locaResources)
|
||||
return locaResources
|
||||
}
|
||||
|
||||
def prepareEnvironment: HashMap[String, String] = {
|
||||
val env = new HashMap[String, String]()
|
||||
// should we add this ?
|
||||
Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment())
|
||||
|
||||
// If log4j present, ensure ours overrides all others
|
||||
if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
|
||||
// Which is correct ?
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties")
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
|
||||
}
|
||||
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
|
||||
Client.populateHadoopClasspath(yarnConf, env)
|
||||
|
||||
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
|
||||
return env
|
||||
}
|
||||
|
||||
def connectToCM: ContainerManager = {
|
||||
val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort()
|
||||
val cmAddress = NetUtils.createSocketAddr(cmHostPortStr)
|
||||
logInfo("Connecting to ContainerManager at " + cmHostPortStr)
|
||||
return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager]
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,547 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import spark.{Logging, Utils}
|
||||
import spark.scheduler.SplitInfo
|
||||
import scala.collection
|
||||
import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
|
||||
import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
|
||||
import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
|
||||
import org.apache.hadoop.yarn.util.{RackResolver, Records}
|
||||
import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import org.apache.hadoop.yarn.api.AMRMProtocol
|
||||
import collection.JavaConversions._
|
||||
import collection.mutable.{ArrayBuffer, HashMap, HashSet}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import java.util.{Collections, Set => JSet}
|
||||
import java.lang.{Boolean => JBoolean}
|
||||
|
||||
object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
|
||||
type AllocationType = Value
|
||||
val HOST, RACK, ANY = Value
|
||||
}
|
||||
|
||||
// too many params ? refactor it 'somehow' ?
|
||||
// needs to be mt-safe
|
||||
// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it
|
||||
// more proactive and decoupled.
|
||||
// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
|
||||
// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info
|
||||
// on how we are requesting for containers.
|
||||
private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol,
|
||||
val appAttemptId: ApplicationAttemptId,
|
||||
val maxWorkers: Int, val workerMemory: Int, val workerCores: Int,
|
||||
val preferredHostToCount: Map[String, Int],
|
||||
val preferredRackToCount: Map[String, Int])
|
||||
extends Logging {
|
||||
|
||||
|
||||
// These three are locked on allocatedHostToContainersMap. Complementary data structures
|
||||
// allocatedHostToContainersMap : containers which are running : host, Set<containerid>
|
||||
// allocatedContainerToHostMap: container to host mapping
|
||||
private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]()
|
||||
private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
|
||||
// allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node)
|
||||
// As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap
|
||||
private val allocatedRackCount = new HashMap[String, Int]()
|
||||
|
||||
// containers which have been released.
|
||||
private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
|
||||
// containers to be released in next request to RM
|
||||
private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
|
||||
|
||||
private val numWorkersRunning = new AtomicInteger()
|
||||
// Used to generate a unique id per worker
|
||||
private val workerIdCounter = new AtomicInteger()
|
||||
private val lastResponseId = new AtomicInteger()
|
||||
|
||||
def getNumWorkersRunning: Int = numWorkersRunning.intValue
|
||||
|
||||
|
||||
def isResourceConstraintSatisfied(container: Container): Boolean = {
|
||||
container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
|
||||
}
|
||||
|
||||
def allocateContainers(workersToRequest: Int) {
|
||||
// We need to send the request only once from what I understand ... but for now, not modifying this much.
|
||||
|
||||
// Keep polling the Resource Manager for containers
|
||||
val amResp = allocateWorkerResources(workersToRequest).getAMResponse
|
||||
|
||||
val _allocatedContainers = amResp.getAllocatedContainers()
|
||||
if (_allocatedContainers.size > 0) {
|
||||
|
||||
|
||||
logDebug("Allocated " + _allocatedContainers.size + " containers, current count " +
|
||||
numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
|
||||
", pendingReleaseContainers : " + pendingReleaseContainers)
|
||||
logDebug("Cluster Resources: " + amResp.getAvailableResources)
|
||||
|
||||
val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
|
||||
|
||||
// ignore if not satisfying constraints {
|
||||
for (container <- _allocatedContainers) {
|
||||
if (isResourceConstraintSatisfied(container)) {
|
||||
// allocatedContainers += container
|
||||
|
||||
val host = container.getNodeId.getHost
|
||||
val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]())
|
||||
|
||||
containers += container
|
||||
}
|
||||
// Add all ignored containers to released list
|
||||
else releasedContainerList.add(container.getId())
|
||||
}
|
||||
|
||||
// Find the appropriate containers to use
|
||||
// Slightly non trivial groupBy I guess ...
|
||||
val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
|
||||
val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
|
||||
val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
|
||||
|
||||
for (candidateHost <- hostToContainers.keySet)
|
||||
{
|
||||
val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
|
||||
val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
|
||||
|
||||
var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null)
|
||||
assert(remainingContainers != null)
|
||||
|
||||
if (requiredHostCount >= remainingContainers.size){
|
||||
// Since we got <= required containers, add all to dataLocalContainers
|
||||
dataLocalContainers.put(candidateHost, remainingContainers)
|
||||
// all consumed
|
||||
remainingContainers = null
|
||||
}
|
||||
else if (requiredHostCount > 0) {
|
||||
// container list has more containers than we need for data locality.
|
||||
// Split into two : data local container count of (remainingContainers.size - requiredHostCount)
|
||||
// and rest as remainingContainer
|
||||
val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount)
|
||||
dataLocalContainers.put(candidateHost, dataLocal)
|
||||
// remainingContainers = remaining
|
||||
|
||||
// yarn has nasty habit of allocating a tonne of containers on a host - discourage this :
|
||||
// add remaining to release list. If we have insufficient containers, next allocation cycle
|
||||
// will reallocate (but wont treat it as data local)
|
||||
for (container <- remaining) releasedContainerList.add(container.getId())
|
||||
remainingContainers = null
|
||||
}
|
||||
|
||||
// now rack local
|
||||
if (remainingContainers != null){
|
||||
val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
|
||||
|
||||
if (rack != null){
|
||||
val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
|
||||
val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
|
||||
rackLocalContainers.get(rack).getOrElse(List()).size
|
||||
|
||||
|
||||
if (requiredRackCount >= remainingContainers.size){
|
||||
// Add all to dataLocalContainers
|
||||
dataLocalContainers.put(rack, remainingContainers)
|
||||
// all consumed
|
||||
remainingContainers = null
|
||||
}
|
||||
else if (requiredRackCount > 0) {
|
||||
// container list has more containers than we need for data locality.
|
||||
// Split into two : data local container count of (remainingContainers.size - requiredRackCount)
|
||||
// and rest as remainingContainer
|
||||
val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount)
|
||||
val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]())
|
||||
|
||||
existingRackLocal ++= rackLocal
|
||||
remainingContainers = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If still not consumed, then it is off rack host - add to that list.
|
||||
if (remainingContainers != null){
|
||||
offRackContainers.put(candidateHost, remainingContainers)
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we have split the containers into various groups, go through them in order :
|
||||
// first host local, then rack local and then off rack (everything else).
|
||||
// Note that the list we create below tries to ensure that not all containers end up within a host
|
||||
// if there are sufficiently large number of hosts/containers.
|
||||
|
||||
val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
|
||||
allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
|
||||
allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
|
||||
allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
|
||||
|
||||
// Run each of the allocated containers
|
||||
for (container <- allocatedContainers) {
|
||||
val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
|
||||
val workerHostname = container.getNodeId.getHost
|
||||
val containerId = container.getId
|
||||
|
||||
assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
|
||||
|
||||
if (numWorkersRunningNow > maxWorkers) {
|
||||
logInfo("Ignoring container " + containerId + " at host " + workerHostname +
|
||||
" .. we already have required number of containers")
|
||||
releasedContainerList.add(containerId)
|
||||
// reset counter back to old value.
|
||||
numWorkersRunning.decrementAndGet()
|
||||
}
|
||||
else {
|
||||
// deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter)
|
||||
val workerId = workerIdCounter.incrementAndGet().toString
|
||||
val masterUrl = "akka://spark@%s:%s/user/%s".format(
|
||||
System.getProperty("spark.master.host"), System.getProperty("spark.master.port"),
|
||||
StandaloneSchedulerBackend.ACTOR_NAME)
|
||||
|
||||
logInfo("launching container on " + containerId + " host " + workerHostname)
|
||||
// just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..
|
||||
pendingReleaseContainers.remove(containerId)
|
||||
|
||||
val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
|
||||
allocatedHostToContainersMap.synchronized {
|
||||
val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]())
|
||||
|
||||
containerSet += containerId
|
||||
allocatedContainerToHostMap.put(containerId, workerHostname)
|
||||
if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
|
||||
}
|
||||
|
||||
new Thread(
|
||||
new WorkerRunnable(container, conf, masterUrl, workerId,
|
||||
workerHostname, workerMemory, workerCores)
|
||||
).start()
|
||||
}
|
||||
}
|
||||
logDebug("After allocated " + allocatedContainers.size + " containers (orig : " +
|
||||
_allocatedContainers.size + "), current count " + numWorkersRunning.get() +
|
||||
", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
|
||||
}
|
||||
|
||||
|
||||
val completedContainers = amResp.getCompletedContainersStatuses()
|
||||
if (completedContainers.size > 0){
|
||||
logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() +
|
||||
", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
|
||||
|
||||
for (completedContainer <- completedContainers){
|
||||
val containerId = completedContainer.getContainerId
|
||||
|
||||
// Was this released by us ? If yes, then simply remove from containerSet and move on.
|
||||
if (pendingReleaseContainers.containsKey(containerId)) {
|
||||
pendingReleaseContainers.remove(containerId)
|
||||
}
|
||||
else {
|
||||
// simply decrement count - next iteration of ReporterThread will take care of allocating !
|
||||
numWorkersRunning.decrementAndGet()
|
||||
logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState +
|
||||
" httpaddress: " + completedContainer.getDiagnostics)
|
||||
}
|
||||
|
||||
allocatedHostToContainersMap.synchronized {
|
||||
if (allocatedContainerToHostMap.containsKey(containerId)) {
|
||||
val host = allocatedContainerToHostMap.get(containerId).getOrElse(null)
|
||||
assert (host != null)
|
||||
|
||||
val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null)
|
||||
assert (containerSet != null)
|
||||
|
||||
containerSet -= containerId
|
||||
if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host)
|
||||
else allocatedHostToContainersMap.update(host, containerSet)
|
||||
|
||||
allocatedContainerToHostMap -= containerId
|
||||
|
||||
// doing this within locked context, sigh ... move to outside ?
|
||||
val rack = YarnAllocationHandler.lookupRack(conf, host)
|
||||
if (rack != null) {
|
||||
val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
|
||||
if (rackCount > 0) allocatedRackCount.put(rack, rackCount)
|
||||
else allocatedRackCount.remove(rack)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
logDebug("After completed " + completedContainers.size + " containers, current count " +
|
||||
numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
|
||||
", pendingReleaseContainers : " + pendingReleaseContainers)
|
||||
}
|
||||
}
|
||||
|
||||
def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = {
|
||||
// First generate modified racks and new set of hosts under it : then issue requests
|
||||
val rackToCounts = new HashMap[String, Int]()
|
||||
|
||||
// Within this lock - used to read/write to the rack related maps too.
|
||||
for (container <- hostContainers) {
|
||||
val candidateHost = container.getHostName
|
||||
val candidateNumContainers = container.getNumContainers
|
||||
assert(YarnAllocationHandler.ANY_HOST != candidateHost)
|
||||
|
||||
val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
|
||||
if (rack != null) {
|
||||
var count = rackToCounts.getOrElse(rack, 0)
|
||||
count += candidateNumContainers
|
||||
rackToCounts.put(rack, count)
|
||||
}
|
||||
}
|
||||
|
||||
val requestedContainers: ArrayBuffer[ResourceRequest] =
|
||||
new ArrayBuffer[ResourceRequest](rackToCounts.size)
|
||||
for ((rack, count) <- rackToCounts){
|
||||
requestedContainers +=
|
||||
createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY)
|
||||
}
|
||||
|
||||
requestedContainers.toList
|
||||
}
|
||||
|
||||
def allocatedContainersOnHost(host: String): Int = {
|
||||
var retval = 0
|
||||
allocatedHostToContainersMap.synchronized {
|
||||
retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
|
||||
}
|
||||
retval
|
||||
}
|
||||
|
||||
def allocatedContainersOnRack(rack: String): Int = {
|
||||
var retval = 0
|
||||
allocatedHostToContainersMap.synchronized {
|
||||
retval = allocatedRackCount.getOrElse(rack, 0)
|
||||
}
|
||||
retval
|
||||
}
|
||||
|
||||
private def allocateWorkerResources(numWorkers: Int): AllocateResponse = {
|
||||
|
||||
var resourceRequests: List[ResourceRequest] = null
|
||||
|
||||
// default.
|
||||
if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
|
||||
logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty)
|
||||
resourceRequests = List(
|
||||
createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY))
|
||||
}
|
||||
else {
|
||||
// request for all hosts in preferred nodes and for numWorkers -
|
||||
// candidates.size, request by default allocation policy.
|
||||
val hostContainerRequests: ArrayBuffer[ResourceRequest] =
|
||||
new ArrayBuffer[ResourceRequest](preferredHostToCount.size)
|
||||
for ((candidateHost, candidateCount) <- preferredHostToCount) {
|
||||
val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
|
||||
|
||||
if (requiredCount > 0) {
|
||||
hostContainerRequests +=
|
||||
createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY)
|
||||
}
|
||||
}
|
||||
val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList)
|
||||
|
||||
val anyContainerRequests: ResourceRequest =
|
||||
createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)
|
||||
|
||||
val containerRequests: ArrayBuffer[ResourceRequest] =
|
||||
new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1)
|
||||
|
||||
containerRequests ++= hostContainerRequests
|
||||
containerRequests ++= rackContainerRequests
|
||||
containerRequests += anyContainerRequests
|
||||
|
||||
resourceRequests = containerRequests.toList
|
||||
}
|
||||
|
||||
val req = Records.newRecord(classOf[AllocateRequest])
|
||||
req.setResponseId(lastResponseId.incrementAndGet)
|
||||
req.setApplicationAttemptId(appAttemptId)
|
||||
|
||||
req.addAllAsks(resourceRequests)
|
||||
|
||||
val releasedContainerList = createReleasedContainerList()
|
||||
req.addAllReleases(releasedContainerList)
|
||||
|
||||
|
||||
|
||||
if (numWorkers > 0) {
|
||||
logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.")
|
||||
}
|
||||
else {
|
||||
logDebug("Empty allocation req .. release : " + releasedContainerList)
|
||||
}
|
||||
|
||||
for (req <- resourceRequests) {
|
||||
logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers +
|
||||
", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability)
|
||||
}
|
||||
resourceManager.allocate(req)
|
||||
}
|
||||
|
||||
|
||||
private def createResourceRequest(requestType: AllocationType.AllocationType,
|
||||
resource:String, numWorkers: Int, priority: Int): ResourceRequest = {
|
||||
|
||||
// If hostname specified, we need atleast two requests - node local and rack local.
|
||||
// There must be a third request - which is ANY : that will be specially handled.
|
||||
requestType match {
|
||||
case AllocationType.HOST => {
|
||||
assert (YarnAllocationHandler.ANY_HOST != resource)
|
||||
|
||||
val hostname = resource
|
||||
val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority)
|
||||
|
||||
// add to host->rack mapping
|
||||
YarnAllocationHandler.populateRackInfo(conf, hostname)
|
||||
|
||||
nodeLocal
|
||||
}
|
||||
|
||||
case AllocationType.RACK => {
|
||||
val rack = resource
|
||||
createResourceRequestImpl(rack, numWorkers, priority)
|
||||
}
|
||||
|
||||
case AllocationType.ANY => {
|
||||
createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority)
|
||||
}
|
||||
|
||||
case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType)
|
||||
}
|
||||
}
|
||||
|
||||
private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = {
|
||||
|
||||
val rsrcRequest = Records.newRecord(classOf[ResourceRequest])
|
||||
val memCapability = Records.newRecord(classOf[Resource])
|
||||
// There probably is some overhead here, let's reserve a bit more memory.
|
||||
memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
|
||||
rsrcRequest.setCapability(memCapability)
|
||||
|
||||
val pri = Records.newRecord(classOf[Priority])
|
||||
pri.setPriority(priority)
|
||||
rsrcRequest.setPriority(pri)
|
||||
|
||||
rsrcRequest.setHostName(hostname)
|
||||
|
||||
rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0))
|
||||
rsrcRequest
|
||||
}
|
||||
|
||||
def createReleasedContainerList(): ArrayBuffer[ContainerId] = {
|
||||
|
||||
val retval = new ArrayBuffer[ContainerId](1)
|
||||
// iterator on COW list ...
|
||||
for (container <- releasedContainerList.iterator()){
|
||||
retval += container
|
||||
}
|
||||
// remove from the original list.
|
||||
if (! retval.isEmpty) {
|
||||
releasedContainerList.removeAll(retval)
|
||||
for (v <- retval) pendingReleaseContainers.put(v, true)
|
||||
logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " +
|
||||
pendingReleaseContainers)
|
||||
}
|
||||
|
||||
retval
|
||||
}
|
||||
}
|
||||
|
||||
object YarnAllocationHandler {
|
||||
|
||||
val ANY_HOST = "*"
|
||||
// all requests are issued with same priority : we do not (yet) have any distinction between
|
||||
// request types (like map/reduce in hadoop for example)
|
||||
val PRIORITY = 1
|
||||
|
||||
// Additional memory overhead - in mb
|
||||
val MEMORY_OVERHEAD = 384
|
||||
|
||||
// host to rack map - saved from allocation requests
|
||||
// We are expecting this not to change.
|
||||
// Note that it is possible for this to change : and RM will indicate that to us via update
|
||||
// response to allocate. But we are punting on handling that for now.
|
||||
private val hostToRack = new ConcurrentHashMap[String, String]()
|
||||
private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
|
||||
|
||||
def newAllocator(conf: Configuration,
|
||||
resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
|
||||
args: ApplicationMasterArguments,
|
||||
map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
|
||||
|
||||
val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
|
||||
|
||||
|
||||
new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers,
|
||||
args.workerMemory, args.workerCores, hostToCount, rackToCount)
|
||||
}
|
||||
|
||||
def newAllocator(conf: Configuration,
|
||||
resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
|
||||
maxWorkers: Int, workerMemory: Int, workerCores: Int,
|
||||
map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
|
||||
|
||||
val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
|
||||
|
||||
new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers,
|
||||
workerMemory, workerCores, hostToCount, rackToCount)
|
||||
}
|
||||
|
||||
// A simple method to copy the split info map.
|
||||
private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) :
|
||||
// host to count, rack to count
|
||||
(Map[String, Int], Map[String, Int]) = {
|
||||
|
||||
if (input == null) return (Map[String, Int](), Map[String, Int]())
|
||||
|
||||
val hostToCount = new HashMap[String, Int]
|
||||
val rackToCount = new HashMap[String, Int]
|
||||
|
||||
for ((host, splits) <- input) {
|
||||
val hostCount = hostToCount.getOrElse(host, 0)
|
||||
hostToCount.put(host, hostCount + splits.size)
|
||||
|
||||
val rack = lookupRack(conf, host)
|
||||
if (rack != null){
|
||||
val rackCount = rackToCount.getOrElse(host, 0)
|
||||
rackToCount.put(host, rackCount + splits.size)
|
||||
}
|
||||
}
|
||||
|
||||
(hostToCount.toMap, rackToCount.toMap)
|
||||
}
|
||||
|
||||
def lookupRack(conf: Configuration, host: String): String = {
|
||||
if (! hostToRack.contains(host)) populateRackInfo(conf, host)
|
||||
hostToRack.get(host)
|
||||
}
|
||||
|
||||
def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
|
||||
val set = rackToHostSet.get(rack)
|
||||
if (set == null) return None
|
||||
|
||||
// No better way to get a Set[String] from JSet ?
|
||||
val convertedSet: collection.mutable.Set[String] = set
|
||||
Some(convertedSet.toSet)
|
||||
}
|
||||
|
||||
def populateRackInfo(conf: Configuration, hostname: String) {
|
||||
Utils.checkHost(hostname)
|
||||
|
||||
if (!hostToRack.containsKey(hostname)) {
|
||||
// If there are repeated failures to resolve, all to an ignore list ?
|
||||
val rackInfo = RackResolver.resolve(conf, hostname)
|
||||
if (rackInfo != null && rackInfo.getNetworkLocation != null) {
|
||||
val rack = rackInfo.getNetworkLocation
|
||||
hostToRack.put(hostname, rack)
|
||||
if (! rackToHostSet.containsKey(rack)) {
|
||||
rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
|
||||
}
|
||||
rackToHostSet.get(rack).add(hostname)
|
||||
|
||||
// Since RackResolver caches, we are disabling this for now ...
|
||||
} /* else {
|
||||
// right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
|
||||
hostToRack.put(hostname, null)
|
||||
} */
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package spark.scheduler.cluster
|
||||
|
||||
import spark._
|
||||
import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
||||
/**
|
||||
*
|
||||
* This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
|
||||
*/
|
||||
private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
|
||||
|
||||
def this(sc: SparkContext) = this(sc, new Configuration())
|
||||
|
||||
// Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
|
||||
// Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
|
||||
// Subsequent creations are ignored - since nodes are already allocated by then.
|
||||
|
||||
|
||||
// By default, rack is unknown
|
||||
override def getRackForHost(hostPort: String): Option[String] = {
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
val retval = YarnAllocationHandler.lookupRack(conf, host)
|
||||
if (retval != null) Some(retval) else None
|
||||
}
|
||||
|
||||
// By default, if rack is unknown, return nothing
|
||||
override def getCachedHostsForRack(rack: String): Option[Set[String]] = {
|
||||
if (rack == None || rack == null) return None
|
||||
|
||||
YarnAllocationHandler.fetchCachedHostsForRack(rack)
|
||||
}
|
||||
|
||||
override def postStartHook() {
|
||||
val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
|
||||
if (sparkContextInitialized){
|
||||
// Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
|
||||
Thread.sleep(3000L)
|
||||
}
|
||||
logInfo("YarnClusterScheduler.postStartHook done")
|
||||
}
|
||||
}
|
18
core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala
Normal file
18
core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala
Normal file
|
@ -0,0 +1,18 @@
|
|||
package spark.deploy
|
||||
|
||||
/**
|
||||
* Contains util methods to interact with Hadoop from spark.
|
||||
*/
|
||||
object SparkHadoopUtil {
|
||||
|
||||
def getUserNameFromEnvironment(): String = {
|
||||
// defaulting to -D ...
|
||||
System.getProperty("user.name")
|
||||
}
|
||||
|
||||
def runAsUser(func: (Product) => Unit, args: Product) {
|
||||
|
||||
// Add support, if exists - for now, simply run func !
|
||||
func(args)
|
||||
}
|
||||
}
|
|
@ -8,12 +8,20 @@ import scala.collection.mutable.Set
|
|||
import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
|
||||
import org.objectweb.asm.commons.EmptyVisitor
|
||||
import org.objectweb.asm.Opcodes._
|
||||
import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
|
||||
|
||||
private[spark] object ClosureCleaner extends Logging {
|
||||
// Get an ASM class reader for a given class from the JAR that loaded it
|
||||
private def getClassReader(cls: Class[_]): ClassReader = {
|
||||
new ClassReader(cls.getResourceAsStream(
|
||||
cls.getName.replaceFirst("^.*\\.", "") + ".class"))
|
||||
// Copy data over, before delegating to ClassReader - else we can run out of open file handles.
|
||||
val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
|
||||
val resourceStream = cls.getResourceAsStream(className)
|
||||
// todo: Fixme - continuing with earlier behavior ...
|
||||
if (resourceStream == null) return new ClassReader(resourceStream)
|
||||
|
||||
val baos = new ByteArrayOutputStream(128)
|
||||
Utils.copyStream(resourceStream, baos, true)
|
||||
new ClassReader(new ByteArrayInputStream(baos.toByteArray))
|
||||
}
|
||||
|
||||
// Check whether a class represents a Scala closure
|
||||
|
|
|
@ -3,18 +3,25 @@ package spark
|
|||
import spark.storage.BlockManagerId
|
||||
|
||||
private[spark] class FetchFailedException(
|
||||
val bmAddress: BlockManagerId,
|
||||
val shuffleId: Int,
|
||||
val mapId: Int,
|
||||
val reduceId: Int,
|
||||
taskEndReason: TaskEndReason,
|
||||
message: String,
|
||||
cause: Throwable)
|
||||
extends Exception {
|
||||
|
||||
override def getMessage(): String =
|
||||
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
|
||||
def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) =
|
||||
this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
|
||||
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
|
||||
cause)
|
||||
|
||||
def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
|
||||
this(FetchFailed(null, shuffleId, -1, reduceId),
|
||||
"Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)
|
||||
|
||||
override def getMessage(): String = message
|
||||
|
||||
|
||||
override def getCause(): Throwable = cause
|
||||
|
||||
def toTaskEndReason: TaskEndReason =
|
||||
FetchFailed(bmAddress, shuffleId, mapId, reduceId)
|
||||
def toTaskEndReason: TaskEndReason = taskEndReason
|
||||
|
||||
}
|
||||
|
|
|
@ -68,6 +68,10 @@ trait Logging {
|
|||
if (log.isErrorEnabled) log.error(msg, throwable)
|
||||
}
|
||||
|
||||
protected def isTraceEnabled(): Boolean = {
|
||||
log.isTraceEnabled
|
||||
}
|
||||
|
||||
// Method for ensuring that logging is initialized, to avoid having multiple
|
||||
// threads do it concurrently (as SLF4J initialization is not thread safe).
|
||||
protected def initLogging() { log }
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package spark
|
||||
|
||||
import java.io._
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
|
@ -12,8 +11,7 @@ import akka.dispatch._
|
|||
import akka.pattern.ask
|
||||
import akka.remote._
|
||||
import akka.util.Duration
|
||||
import akka.util.Timeout
|
||||
import akka.util.duration._
|
||||
|
||||
|
||||
import spark.scheduler.MapStatus
|
||||
import spark.storage.BlockManagerId
|
||||
|
@ -40,10 +38,12 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
|
|||
|
||||
private[spark] class MapOutputTracker extends Logging {
|
||||
|
||||
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
|
||||
|
||||
// Set to the MapOutputTrackerActor living on the driver
|
||||
var trackerActor: ActorRef = _
|
||||
|
||||
var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
|
||||
// Incremented every time a fetch fails so that client nodes know to clear
|
||||
// their cache of map output locations if this happens.
|
||||
|
@ -52,7 +52,7 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
|
||||
// Cache a serialized version of the output statuses for each shuffle to send them out faster
|
||||
var cacheGeneration = generation
|
||||
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
|
||||
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
|
||||
|
||||
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
|
||||
|
||||
|
@ -60,7 +60,6 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
// throw a SparkException if this fails.
|
||||
def askTracker(message: Any): Any = {
|
||||
try {
|
||||
val timeout = 10.seconds
|
||||
val future = trackerActor.ask(message)(timeout)
|
||||
return Await.result(future, timeout)
|
||||
} catch {
|
||||
|
@ -77,10 +76,9 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
}
|
||||
|
||||
def registerShuffle(shuffleId: Int, numMaps: Int) {
|
||||
if (mapStatuses.get(shuffleId) != None) {
|
||||
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
|
||||
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
|
||||
}
|
||||
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
|
||||
}
|
||||
|
||||
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
|
||||
|
@ -101,8 +99,9 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
}
|
||||
|
||||
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
|
||||
var array = mapStatuses(shuffleId)
|
||||
if (array != null) {
|
||||
var arrayOpt = mapStatuses.get(shuffleId)
|
||||
if (arrayOpt.isDefined && arrayOpt.get != null) {
|
||||
var array = arrayOpt.get
|
||||
array.synchronized {
|
||||
if (array(mapId) != null && array(mapId).location == bmAddress) {
|
||||
array(mapId) = null
|
||||
|
@ -115,13 +114,14 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
}
|
||||
|
||||
// Remembers which map output locations are currently being fetched on a worker
|
||||
val fetching = new HashSet[Int]
|
||||
private val fetching = new HashSet[Int]
|
||||
|
||||
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
|
||||
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
|
||||
val statuses = mapStatuses.get(shuffleId).orNull
|
||||
if (statuses == null) {
|
||||
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
|
||||
var fetchedStatuses: Array[MapStatus] = null
|
||||
fetching.synchronized {
|
||||
if (fetching.contains(shuffleId)) {
|
||||
// Someone else is fetching it; wait for them to be done
|
||||
|
@ -132,31 +132,49 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
case e: InterruptedException =>
|
||||
}
|
||||
}
|
||||
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
|
||||
} else {
|
||||
}
|
||||
|
||||
// Either while we waited the fetch happened successfully, or
|
||||
// someone fetched it in between the get and the fetching.synchronized.
|
||||
fetchedStatuses = mapStatuses.get(shuffleId).orNull
|
||||
if (fetchedStatuses == null) {
|
||||
// We have to do the fetch, get others to wait for us.
|
||||
fetching += shuffleId
|
||||
}
|
||||
}
|
||||
// We won the race to fetch the output locs; do so
|
||||
logInfo("Doing the fetch; tracker actor = " + trackerActor)
|
||||
val host = System.getProperty("spark.hostname", Utils.localHostName)
|
||||
// This try-finally prevents hangs due to timeouts:
|
||||
var fetchedStatuses: Array[MapStatus] = null
|
||||
try {
|
||||
val fetchedBytes =
|
||||
askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
|
||||
fetchedStatuses = deserializeStatuses(fetchedBytes)
|
||||
logInfo("Got the output locations")
|
||||
mapStatuses.put(shuffleId, fetchedStatuses)
|
||||
} finally {
|
||||
fetching.synchronized {
|
||||
fetching -= shuffleId
|
||||
fetching.notifyAll()
|
||||
|
||||
if (fetchedStatuses == null) {
|
||||
// We won the race to fetch the output locs; do so
|
||||
logInfo("Doing the fetch; tracker actor = " + trackerActor)
|
||||
val hostPort = Utils.localHostPort()
|
||||
// This try-finally prevents hangs due to timeouts:
|
||||
var fetchedStatuses: Array[MapStatus] = null
|
||||
try {
|
||||
val fetchedBytes =
|
||||
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
|
||||
fetchedStatuses = deserializeStatuses(fetchedBytes)
|
||||
logInfo("Got the output locations")
|
||||
mapStatuses.put(shuffleId, fetchedStatuses)
|
||||
} finally {
|
||||
fetching.synchronized {
|
||||
fetching -= shuffleId
|
||||
fetching.notifyAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
|
||||
if (fetchedStatuses != null) {
|
||||
fetchedStatuses.synchronized {
|
||||
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
|
||||
}
|
||||
}
|
||||
else{
|
||||
throw new FetchFailedException(null, shuffleId, -1, reduceId,
|
||||
new Exception("Missing all output locations for shuffle " + shuffleId))
|
||||
}
|
||||
} else {
|
||||
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
|
||||
statuses.synchronized {
|
||||
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -194,7 +212,8 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
generationLock.synchronized {
|
||||
if (newGen > generation) {
|
||||
logInfo("Updating generation to " + newGen + " and clearing cache")
|
||||
mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
mapStatuses.clear()
|
||||
generation = newGen
|
||||
}
|
||||
}
|
||||
|
@ -232,10 +251,13 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
// Serialize an array of map output locations into an efficient byte format so that we can send
|
||||
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
|
||||
// generally be pretty compressible because many map outputs will be on the same hostname.
|
||||
def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
|
||||
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
|
||||
val out = new ByteArrayOutputStream
|
||||
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
|
||||
objOut.writeObject(statuses)
|
||||
// Since statuses can be modified in parallel, sync on it
|
||||
statuses.synchronized {
|
||||
objOut.writeObject(statuses)
|
||||
}
|
||||
objOut.close()
|
||||
out.toByteArray
|
||||
}
|
||||
|
@ -243,7 +265,10 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
// Opposite of serializeStatuses.
|
||||
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
|
||||
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
|
||||
objIn.readObject().asInstanceOf[Array[MapStatus]]
|
||||
objIn.readObject().
|
||||
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
|
||||
// comment this out - nulls could be due to missing location ?
|
||||
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -253,14 +278,11 @@ private[spark] object MapOutputTracker {
|
|||
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
|
||||
// any of the statuses is null (indicating a missing location due to a failed mapper),
|
||||
// throw a FetchFailedException.
|
||||
def convertMapStatuses(
|
||||
private def convertMapStatuses(
|
||||
shuffleId: Int,
|
||||
reduceId: Int,
|
||||
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
|
||||
if (statuses == null) {
|
||||
throw new FetchFailedException(null, shuffleId, -1, reduceId,
|
||||
new Exception("Missing all output locations for shuffle " + shuffleId))
|
||||
}
|
||||
assert (statuses != null)
|
||||
statuses.map {
|
||||
status =>
|
||||
if (status == null) {
|
||||
|
|
|
@ -37,7 +37,7 @@ import spark.partial.PartialResult
|
|||
import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
|
||||
import spark.scheduler._
|
||||
import spark.scheduler.local.LocalScheduler
|
||||
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
|
||||
import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
|
||||
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
|
||||
import spark.storage.BlockManagerUI
|
||||
import spark.util.{MetadataCleaner, TimeStampedHashMap}
|
||||
|
@ -59,7 +59,10 @@ class SparkContext(
|
|||
val appName: String,
|
||||
val sparkHome: String = null,
|
||||
val jars: Seq[String] = Nil,
|
||||
val environment: Map[String, String] = Map())
|
||||
val environment: Map[String, String] = Map(),
|
||||
// This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too.
|
||||
// This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host
|
||||
val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map())
|
||||
extends Logging {
|
||||
|
||||
// Ensure logging is initialized before we spawn any threads
|
||||
|
@ -67,7 +70,7 @@ class SparkContext(
|
|||
|
||||
// Set Spark driver host and port system properties
|
||||
if (System.getProperty("spark.driver.host") == null) {
|
||||
System.setProperty("spark.driver.host", Utils.localIpAddress)
|
||||
System.setProperty("spark.driver.host", Utils.localHostName())
|
||||
}
|
||||
if (System.getProperty("spark.driver.port") == null) {
|
||||
System.setProperty("spark.driver.port", "0")
|
||||
|
@ -99,7 +102,7 @@ class SparkContext(
|
|||
|
||||
|
||||
// Add each JAR given through the constructor
|
||||
jars.foreach { addJar(_) }
|
||||
if (jars != null) jars.foreach { addJar(_) }
|
||||
|
||||
// Environment variables to pass to our executors
|
||||
private[spark] val executorEnvs = HashMap[String, String]()
|
||||
|
@ -111,7 +114,7 @@ class SparkContext(
|
|||
executorEnvs(key) = value
|
||||
}
|
||||
}
|
||||
executorEnvs ++= environment
|
||||
if (environment != null) executorEnvs ++= environment
|
||||
|
||||
// Create and start the scheduler
|
||||
private var taskScheduler: TaskScheduler = {
|
||||
|
@ -164,6 +167,22 @@ class SparkContext(
|
|||
}
|
||||
scheduler
|
||||
|
||||
case "yarn-standalone" =>
|
||||
val scheduler = try {
|
||||
val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler")
|
||||
val cons = clazz.getConstructor(classOf[SparkContext])
|
||||
cons.newInstance(this).asInstanceOf[ClusterScheduler]
|
||||
} catch {
|
||||
// TODO: Enumerate the exact reasons why it can fail
|
||||
// But irrespective of it, it means we cannot proceed !
|
||||
case th: Throwable => {
|
||||
throw new SparkException("YARN mode not available ?", th)
|
||||
}
|
||||
}
|
||||
val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
|
||||
scheduler.initialize(backend)
|
||||
scheduler
|
||||
|
||||
case _ =>
|
||||
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
|
||||
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
|
||||
|
@ -183,7 +202,7 @@ class SparkContext(
|
|||
}
|
||||
taskScheduler.start()
|
||||
|
||||
private var dagScheduler = new DAGScheduler(taskScheduler)
|
||||
@volatile private var dagScheduler = new DAGScheduler(taskScheduler)
|
||||
dagScheduler.start()
|
||||
|
||||
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
|
||||
|
@ -207,6 +226,9 @@ class SparkContext(
|
|||
|
||||
private[spark] var checkpointDir: Option[String] = None
|
||||
|
||||
// Post init
|
||||
taskScheduler.postStartHook()
|
||||
|
||||
// Methods for creating RDDs
|
||||
|
||||
/** Distribute a local Scala collection to form an RDD. */
|
||||
|
@ -471,7 +493,7 @@ class SparkContext(
|
|||
*/
|
||||
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
|
||||
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
|
||||
(blockManagerId.ip + ":" + blockManagerId.port, mem)
|
||||
(blockManagerId.host + ":" + blockManagerId.port, mem)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -527,10 +549,13 @@ class SparkContext(
|
|||
|
||||
/** Shut down the SparkContext. */
|
||||
def stop() {
|
||||
if (dagScheduler != null) {
|
||||
// Do this only if not stopped already - best case effort.
|
||||
// prevent NPE if stopped more than once.
|
||||
val dagSchedulerCopy = dagScheduler
|
||||
dagScheduler = null
|
||||
if (dagSchedulerCopy != null) {
|
||||
metadataCleaner.cancel()
|
||||
dagScheduler.stop()
|
||||
dagScheduler = null
|
||||
dagSchedulerCopy.stop()
|
||||
taskScheduler = null
|
||||
// TODO: Cache.stop()?
|
||||
env.stop()
|
||||
|
@ -546,6 +571,7 @@ class SparkContext(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get Spark's home location from either a value set through the constructor,
|
||||
* or the spark.home Java property, or the SPARK_HOME environment variable
|
||||
|
|
|
@ -72,6 +72,16 @@ object SparkEnv extends Logging {
|
|||
System.setProperty("spark.driver.port", boundPort.toString)
|
||||
}
|
||||
|
||||
// set only if unset until now.
|
||||
if (System.getProperty("spark.hostPort", null) == null) {
|
||||
if (!isDriver){
|
||||
// unexpected
|
||||
Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
|
||||
}
|
||||
Utils.checkHost(hostname)
|
||||
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
|
||||
}
|
||||
|
||||
val classLoader = Thread.currentThread.getContextClassLoader
|
||||
|
||||
// Create an instance of the class named by the given Java system property, or by
|
||||
|
@ -88,9 +98,10 @@ object SparkEnv extends Logging {
|
|||
logInfo("Registering " + name)
|
||||
actorSystem.actorOf(Props(newActor), name = name)
|
||||
} else {
|
||||
val driverIp: String = System.getProperty("spark.driver.host", "localhost")
|
||||
val driverHost: String = System.getProperty("spark.driver.host", "localhost")
|
||||
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
|
||||
val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name)
|
||||
Utils.checkHost(driverHost, "Expected hostname")
|
||||
val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
|
||||
logInfo("Connecting to " + name + ": " + url)
|
||||
actorSystem.actorFor(url)
|
||||
}
|
||||
|
|
|
@ -1,18 +1,18 @@
|
|||
package spark
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
|
||||
import java.util.{Locale, Random, UUID}
|
||||
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
|
||||
import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap}
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.io.Source
|
||||
import com.google.common.io.Files
|
||||
import com.google.common.util.concurrent.ThreadFactoryBuilder
|
||||
import scala.Some
|
||||
import spark.serializer.SerializerInstance
|
||||
import spark.deploy.SparkHadoopUtil
|
||||
|
||||
/**
|
||||
* Various utility methods used by Spark.
|
||||
|
@ -68,6 +68,41 @@ private object Utils extends Logging {
|
|||
return buf
|
||||
}
|
||||
|
||||
|
||||
private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
|
||||
|
||||
// Register the path to be deleted via shutdown hook
|
||||
def registerShutdownDeleteDir(file: File) {
|
||||
val absolutePath = file.getAbsolutePath()
|
||||
shutdownDeletePaths.synchronized {
|
||||
shutdownDeletePaths += absolutePath
|
||||
}
|
||||
}
|
||||
|
||||
// Is the path already registered to be deleted via a shutdown hook ?
|
||||
def hasShutdownDeleteDir(file: File): Boolean = {
|
||||
val absolutePath = file.getAbsolutePath()
|
||||
shutdownDeletePaths.synchronized {
|
||||
shutdownDeletePaths.contains(absolutePath)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: if file is child of some registered path, while not equal to it, then return true; else false
|
||||
// This is to ensure that two shutdown hooks do not try to delete each others paths - resulting in IOException
|
||||
// and incomplete cleanup
|
||||
def hasRootAsShutdownDeleteDir(file: File): Boolean = {
|
||||
|
||||
val absolutePath = file.getAbsolutePath()
|
||||
|
||||
val retval = shutdownDeletePaths.synchronized {
|
||||
shutdownDeletePaths.find(path => ! absolutePath.equals(path) && absolutePath.startsWith(path) ).isDefined
|
||||
}
|
||||
|
||||
if (retval) logInfo("path = " + file + ", already present as root for deletion.")
|
||||
|
||||
retval
|
||||
}
|
||||
|
||||
/** Create a temporary directory inside the given parent directory */
|
||||
def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
|
||||
var attempts = 0
|
||||
|
@ -86,10 +121,14 @@ private object Utils extends Logging {
|
|||
}
|
||||
} catch { case e: IOException => ; }
|
||||
}
|
||||
|
||||
registerShutdownDeleteDir(dir)
|
||||
|
||||
// Add a shutdown hook to delete the temp dir when the JVM exits
|
||||
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
|
||||
override def run() {
|
||||
Utils.deleteRecursively(dir)
|
||||
// Attempt to delete if some patch which is parent of this is not already registered.
|
||||
if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
|
||||
}
|
||||
})
|
||||
return dir
|
||||
|
@ -227,8 +266,10 @@ private object Utils extends Logging {
|
|||
|
||||
/**
|
||||
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
|
||||
* Note, this is typically not used from within core spark.
|
||||
*/
|
||||
lazy val localIpAddress: String = findLocalIpAddress()
|
||||
lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress)
|
||||
|
||||
private def findLocalIpAddress(): String = {
|
||||
val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
|
||||
|
@ -266,6 +307,8 @@ private object Utils extends Logging {
|
|||
* hostname it reports to the master.
|
||||
*/
|
||||
def setCustomHostname(hostname: String) {
|
||||
// DEBUG code
|
||||
Utils.checkHost(hostname)
|
||||
customHostname = Some(hostname)
|
||||
}
|
||||
|
||||
|
@ -273,7 +316,90 @@ private object Utils extends Logging {
|
|||
* Get the local machine's hostname.
|
||||
*/
|
||||
def localHostName(): String = {
|
||||
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
|
||||
// customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
|
||||
customHostname.getOrElse(localIpAddressHostname)
|
||||
}
|
||||
|
||||
def getAddressHostName(address: String): String = {
|
||||
InetAddress.getByName(address).getHostName
|
||||
}
|
||||
|
||||
|
||||
|
||||
def localHostPort(): String = {
|
||||
val retval = System.getProperty("spark.hostPort", null)
|
||||
if (retval == null) {
|
||||
logErrorWithStack("spark.hostPort not set but invoking localHostPort")
|
||||
return localHostName()
|
||||
}
|
||||
|
||||
retval
|
||||
}
|
||||
|
||||
// Used by DEBUG code : remove when all testing done
|
||||
def checkHost(host: String, message: String = "") {
|
||||
// Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous !
|
||||
if (host.matches("^[0-9]+(\\.[0-9]+)*$")) {
|
||||
Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message)
|
||||
}
|
||||
if (Utils.parseHostPort(host)._2 != 0){
|
||||
Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message)
|
||||
}
|
||||
}
|
||||
|
||||
// Used by DEBUG code : remove when all testing done
|
||||
def checkHostPort(hostPort: String, message: String = "") {
|
||||
val (host, port) = Utils.parseHostPort(hostPort)
|
||||
checkHost(host)
|
||||
if (port <= 0){
|
||||
Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message)
|
||||
}
|
||||
}
|
||||
|
||||
def getUserNameFromEnvironment(): String = {
|
||||
SparkHadoopUtil.getUserNameFromEnvironment
|
||||
}
|
||||
|
||||
// Used by DEBUG code : remove when all testing done
|
||||
def logErrorWithStack(msg: String) {
|
||||
try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
|
||||
// temp code for debug
|
||||
System.exit(-1)
|
||||
}
|
||||
|
||||
// Typically, this will be of order of number of nodes in cluster
|
||||
// If not, we should change it to LRUCache or something.
|
||||
private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
|
||||
def parseHostPort(hostPort: String): (String, Int) = {
|
||||
{
|
||||
// Check cache first.
|
||||
var cached = hostPortParseResults.get(hostPort)
|
||||
if (cached != null) return cached
|
||||
}
|
||||
|
||||
val indx: Int = hostPort.lastIndexOf(':')
|
||||
// This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now.
|
||||
// For now, we assume that if port exists, then it is valid - not check if it is an int > 0
|
||||
if (-1 == indx) {
|
||||
val retval = (hostPort, 0)
|
||||
hostPortParseResults.put(hostPort, retval)
|
||||
return retval
|
||||
}
|
||||
|
||||
val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt)
|
||||
hostPortParseResults.putIfAbsent(hostPort, retval)
|
||||
hostPortParseResults.get(hostPort)
|
||||
}
|
||||
|
||||
def addIfNoPort(hostPort: String, port: Int): String = {
|
||||
if (port <= 0) throw new IllegalArgumentException("Invalid port specified " + port)
|
||||
|
||||
// This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now.
|
||||
// For now, we assume that if port exists, then it is valid - not check if it is an int > 0
|
||||
val indx: Int = hostPort.lastIndexOf(':')
|
||||
if (-1 != indx) return hostPort
|
||||
|
||||
hostPort + ":" + port
|
||||
}
|
||||
|
||||
private[spark] val daemonThreadFactory: ThreadFactory =
|
||||
|
|
|
@ -278,6 +278,8 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte]
|
|||
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
|
||||
extends AccumulatorParam[JList[Array[Byte]]] {
|
||||
|
||||
Utils.checkHost(serverHost, "Expected hostname")
|
||||
|
||||
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
|
||||
|
||||
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
|
||||
|
|
|
@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState
|
|||
import spark.deploy.master.{WorkerInfo, ApplicationInfo}
|
||||
import spark.deploy.worker.ExecutorRunner
|
||||
import scala.collection.immutable.List
|
||||
import spark.Utils
|
||||
|
||||
|
||||
private[spark] sealed trait DeployMessage extends Serializable
|
||||
|
@ -19,7 +20,10 @@ case class RegisterWorker(
|
|||
memory: Int,
|
||||
webUiPort: Int,
|
||||
publicAddress: String)
|
||||
extends DeployMessage
|
||||
extends DeployMessage {
|
||||
Utils.checkHost(host, "Required hostname")
|
||||
assert (port > 0)
|
||||
}
|
||||
|
||||
private[spark]
|
||||
case class ExecutorStateChanged(
|
||||
|
@ -58,7 +62,9 @@ private[spark]
|
|||
case class RegisteredApplication(appId: String) extends DeployMessage
|
||||
|
||||
private[spark]
|
||||
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
|
||||
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
|
||||
Utils.checkHostPort(hostPort, "Required hostport")
|
||||
}
|
||||
|
||||
private[spark]
|
||||
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String],
|
||||
|
@ -81,6 +87,9 @@ private[spark]
|
|||
case class MasterState(host: String, port: Int, workers: Array[WorkerInfo],
|
||||
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
|
||||
|
||||
Utils.checkHost(host, "Required hostname")
|
||||
assert (port > 0)
|
||||
|
||||
def uri = "spark://" + host + ":" + port
|
||||
}
|
||||
|
||||
|
@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState
|
|||
private[spark]
|
||||
case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner],
|
||||
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
|
||||
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
|
||||
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) {
|
||||
|
||||
Utils.checkHost(host, "Required hostname")
|
||||
assert (port > 0)
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
|
|||
def write(obj: WorkerInfo) = JsObject(
|
||||
"id" -> JsString(obj.id),
|
||||
"host" -> JsString(obj.host),
|
||||
"port" -> JsNumber(obj.port),
|
||||
"webuiaddress" -> JsString(obj.webUiAddress),
|
||||
"cores" -> JsNumber(obj.cores),
|
||||
"coresused" -> JsNumber(obj.coresUsed),
|
||||
|
|
|
@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer
|
|||
private[spark]
|
||||
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
|
||||
|
||||
private val localIpAddress = Utils.localIpAddress
|
||||
private val localHostname = Utils.localHostName()
|
||||
private val masterActorSystems = ArrayBuffer[ActorSystem]()
|
||||
private val workerActorSystems = ArrayBuffer[ActorSystem]()
|
||||
|
||||
|
@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
|
|||
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
|
||||
|
||||
/* Start the Master */
|
||||
val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
|
||||
val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0)
|
||||
masterActorSystems += masterSystem
|
||||
val masterUrl = "spark://" + localIpAddress + ":" + masterPort
|
||||
val masterUrl = "spark://" + localHostname + ":" + masterPort
|
||||
|
||||
/* Start the Workers */
|
||||
for (workerNum <- 1 to numWorkers) {
|
||||
val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
|
||||
val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
|
||||
memoryPerWorker, masterUrl, null, Some(workerNum))
|
||||
workerActorSystems += workerSystem
|
||||
}
|
||||
|
|
|
@ -59,10 +59,10 @@ private[spark] class Client(
|
|||
markDisconnected()
|
||||
context.stop(self)
|
||||
|
||||
case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) =>
|
||||
case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) =>
|
||||
val fullId = appId + "/" + id
|
||||
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores))
|
||||
listener.executorAdded(fullId, workerId, host, cores, memory)
|
||||
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores))
|
||||
listener.executorAdded(fullId, workerId, hostPort, cores, memory)
|
||||
|
||||
case ExecutorUpdated(id, state, message, exitStatus) =>
|
||||
val fullId = appId + "/" + id
|
||||
|
|
|
@ -12,7 +12,7 @@ private[spark] trait ClientListener {
|
|||
|
||||
def disconnected(): Unit
|
||||
|
||||
def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
|
||||
def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
|
||||
|
||||
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ private[spark] object TestClient {
|
|||
System.exit(0)
|
||||
}
|
||||
|
||||
def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {}
|
||||
def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
|
||||
|
||||
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils}
|
|||
import spark.util.AkkaUtils
|
||||
|
||||
|
||||
private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
|
||||
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
|
||||
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
|
||||
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
|
||||
|
||||
|
@ -35,9 +35,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
|
||||
var firstApp: Option[ApplicationInfo] = None
|
||||
|
||||
Utils.checkHost(host, "Expected hostname")
|
||||
|
||||
val masterPublicAddress = {
|
||||
val envVar = System.getenv("SPARK_PUBLIC_DNS")
|
||||
if (envVar != null) envVar else ip
|
||||
if (envVar != null) envVar else host
|
||||
}
|
||||
|
||||
// As a temporary workaround before better ways of configuring memory, we allow users to set
|
||||
|
@ -46,7 +48,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
|
||||
|
||||
override def preStart() {
|
||||
logInfo("Starting Spark master at spark://" + ip + ":" + port)
|
||||
logInfo("Starting Spark master at spark://" + host + ":" + port)
|
||||
// Listen for remote client disconnection events, since they don't go through Akka's watch()
|
||||
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
|
||||
startWebUi()
|
||||
|
@ -145,7 +147,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
}
|
||||
|
||||
case RequestMasterState => {
|
||||
sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray)
|
||||
sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,13 +213,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
|
||||
worker.addExecutor(exec)
|
||||
worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
|
||||
exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
|
||||
exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
|
||||
}
|
||||
|
||||
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
|
||||
publicAddress: String): WorkerInfo = {
|
||||
// There may be one or more refs to dead workers on this same node (w/ different ID's), remove them.
|
||||
workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
|
||||
workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
|
||||
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
|
||||
workers += worker
|
||||
idToWorker(worker.id) = worker
|
||||
|
@ -307,7 +309,7 @@ private[spark] object Master {
|
|||
|
||||
def main(argStrings: Array[String]) {
|
||||
val args = new MasterArguments(argStrings)
|
||||
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
|
||||
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort)
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
||||
|
|
|
@ -7,13 +7,13 @@ import spark.Utils
|
|||
* Command-line parser for the master.
|
||||
*/
|
||||
private[spark] class MasterArguments(args: Array[String]) {
|
||||
var ip = Utils.localHostName()
|
||||
var host = Utils.localHostName()
|
||||
var port = 7077
|
||||
var webUiPort = 8080
|
||||
|
||||
// Check for settings in environment variables
|
||||
if (System.getenv("SPARK_MASTER_IP") != null) {
|
||||
ip = System.getenv("SPARK_MASTER_IP")
|
||||
if (System.getenv("SPARK_MASTER_HOST") != null) {
|
||||
host = System.getenv("SPARK_MASTER_HOST")
|
||||
}
|
||||
if (System.getenv("SPARK_MASTER_PORT") != null) {
|
||||
port = System.getenv("SPARK_MASTER_PORT").toInt
|
||||
|
@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) {
|
|||
|
||||
def parse(args: List[String]): Unit = args match {
|
||||
case ("--ip" | "-i") :: value :: tail =>
|
||||
ip = value
|
||||
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
|
||||
host = value
|
||||
parse(tail)
|
||||
|
||||
case ("--host" | "-h") :: value :: tail =>
|
||||
Utils.checkHost(value, "Please use hostname " + value)
|
||||
host = value
|
||||
parse(tail)
|
||||
|
||||
case ("--port" | "-p") :: IntParam(value) :: tail =>
|
||||
|
@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) {
|
|||
"Usage: Master [options]\n" +
|
||||
"\n" +
|
||||
"Options:\n" +
|
||||
" -i IP, --ip IP IP address or DNS name to listen on\n" +
|
||||
" -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
|
||||
" -h HOST, --host HOST Hostname to listen on\n" +
|
||||
" -p PORT, --port PORT Port to listen on (default: 7077)\n" +
|
||||
" --webui-port PORT Port for web UI (default: 8080)")
|
||||
System.exit(exitCode)
|
||||
|
|
|
@ -2,6 +2,7 @@ package spark.deploy.master
|
|||
|
||||
import akka.actor.ActorRef
|
||||
import scala.collection.mutable
|
||||
import spark.Utils
|
||||
|
||||
private[spark] class WorkerInfo(
|
||||
val id: String,
|
||||
|
@ -13,6 +14,9 @@ private[spark] class WorkerInfo(
|
|||
val webUiPort: Int,
|
||||
val publicAddress: String) {
|
||||
|
||||
Utils.checkHost(host, "Expected hostname")
|
||||
assert (port > 0)
|
||||
|
||||
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
|
||||
var state: WorkerState.Value = WorkerState.ALIVE
|
||||
var coresUsed = 0
|
||||
|
@ -23,6 +27,11 @@ private[spark] class WorkerInfo(
|
|||
def coresFree: Int = cores - coresUsed
|
||||
def memoryFree: Int = memory - memoryUsed
|
||||
|
||||
def hostPort: String = {
|
||||
assert (port > 0)
|
||||
host + ":" + port
|
||||
}
|
||||
|
||||
def addExecutor(exec: ExecutorInfo) {
|
||||
executors(exec.fullId) = exec
|
||||
coresUsed += exec.cores
|
||||
|
|
|
@ -21,11 +21,13 @@ private[spark] class ExecutorRunner(
|
|||
val memory: Int,
|
||||
val worker: ActorRef,
|
||||
val workerId: String,
|
||||
val hostname: String,
|
||||
val hostPort: String,
|
||||
val sparkHome: File,
|
||||
val workDir: File)
|
||||
extends Logging {
|
||||
|
||||
Utils.checkHostPort(hostPort, "Expected hostport")
|
||||
|
||||
val fullId = appId + "/" + execId
|
||||
var workerThread: Thread = null
|
||||
var process: Process = null
|
||||
|
@ -68,7 +70,7 @@ private[spark] class ExecutorRunner(
|
|||
/** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
|
||||
def substituteVariables(argument: String): String = argument match {
|
||||
case "{{EXECUTOR_ID}}" => execId.toString
|
||||
case "{{HOSTNAME}}" => hostname
|
||||
case "{{HOSTPORT}}" => hostPort
|
||||
case "{{CORES}}" => cores.toString
|
||||
case other => other
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ import spark.deploy.master.Master
|
|||
import java.io.File
|
||||
|
||||
private[spark] class Worker(
|
||||
ip: String,
|
||||
host: String,
|
||||
port: Int,
|
||||
webUiPort: Int,
|
||||
cores: Int,
|
||||
|
@ -25,6 +25,9 @@ private[spark] class Worker(
|
|||
workDirPath: String = null)
|
||||
extends Actor with Logging {
|
||||
|
||||
Utils.checkHost(host, "Expected hostname")
|
||||
assert (port > 0)
|
||||
|
||||
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
|
||||
|
||||
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
|
||||
|
@ -39,7 +42,7 @@ private[spark] class Worker(
|
|||
val finishedExecutors = new HashMap[String, ExecutorRunner]
|
||||
val publicAddress = {
|
||||
val envVar = System.getenv("SPARK_PUBLIC_DNS")
|
||||
if (envVar != null) envVar else ip
|
||||
if (envVar != null) envVar else host
|
||||
}
|
||||
|
||||
var coresUsed = 0
|
||||
|
@ -64,7 +67,7 @@ private[spark] class Worker(
|
|||
|
||||
override def preStart() {
|
||||
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
|
||||
ip, port, cores, Utils.memoryMegabytesToString(memory)))
|
||||
host, port, cores, Utils.memoryMegabytesToString(memory)))
|
||||
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
|
||||
logInfo("Spark home: " + sparkHome)
|
||||
createWorkDir()
|
||||
|
@ -75,7 +78,7 @@ private[spark] class Worker(
|
|||
def connectToMaster() {
|
||||
logInfo("Connecting to master " + masterUrl)
|
||||
master = context.actorFor(Master.toAkkaUrl(masterUrl))
|
||||
master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
|
||||
master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress)
|
||||
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
|
||||
context.watch(master) // Doesn't work with remote actors, but useful for testing
|
||||
}
|
||||
|
@ -106,7 +109,7 @@ private[spark] class Worker(
|
|||
case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
|
||||
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
|
||||
val manager = new ExecutorRunner(
|
||||
appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
|
||||
appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir)
|
||||
executors(appId + "/" + execId) = manager
|
||||
manager.start()
|
||||
coresUsed += cores_
|
||||
|
@ -141,7 +144,7 @@ private[spark] class Worker(
|
|||
masterDisconnected()
|
||||
|
||||
case RequestWorkerState => {
|
||||
sender ! WorkerState(ip, port, workerId, executors.values.toList,
|
||||
sender ! WorkerState(host, port, workerId, executors.values.toList,
|
||||
finishedExecutors.values.toList, masterUrl, cores, memory,
|
||||
coresUsed, memoryUsed, masterWebUiUrl)
|
||||
}
|
||||
|
@ -156,7 +159,7 @@ private[spark] class Worker(
|
|||
}
|
||||
|
||||
def generateWorkerId(): String = {
|
||||
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
|
||||
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port)
|
||||
}
|
||||
|
||||
override def postStop() {
|
||||
|
@ -167,7 +170,7 @@ private[spark] class Worker(
|
|||
private[spark] object Worker {
|
||||
def main(argStrings: Array[String]) {
|
||||
val args = new WorkerArguments(argStrings)
|
||||
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
|
||||
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
|
||||
args.memory, args.master, args.workDir)
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory
|
|||
* Command-line parser for the master.
|
||||
*/
|
||||
private[spark] class WorkerArguments(args: Array[String]) {
|
||||
var ip = Utils.localHostName()
|
||||
var host = Utils.localHostName()
|
||||
var port = 0
|
||||
var webUiPort = 8081
|
||||
var cores = inferDefaultCores()
|
||||
|
@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) {
|
|||
|
||||
def parse(args: List[String]): Unit = args match {
|
||||
case ("--ip" | "-i") :: value :: tail =>
|
||||
ip = value
|
||||
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
|
||||
host = value
|
||||
parse(tail)
|
||||
|
||||
case ("--host" | "-h") :: value :: tail =>
|
||||
Utils.checkHost(value, "Please use hostname " + value)
|
||||
host = value
|
||||
parse(tail)
|
||||
|
||||
case ("--port" | "-p") :: IntParam(value) :: tail =>
|
||||
|
@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) {
|
|||
" -c CORES, --cores CORES Number of cores to use\n" +
|
||||
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
|
||||
" -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" +
|
||||
" -i IP, --ip IP IP address or DNS name to listen on\n" +
|
||||
" -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
|
||||
" -h HOST, --host HOST Hostname to listen on\n" +
|
||||
" -p PORT, --port PORT Port to listen on (default: random)\n" +
|
||||
" --webui-port PORT Port for web UI (default: 8081)")
|
||||
System.exit(exitCode)
|
||||
|
|
|
@ -27,6 +27,11 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
|
|||
|
||||
initLogging()
|
||||
|
||||
// No ip or host:port - just hostname
|
||||
Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
|
||||
// must not have port specified.
|
||||
assert (0 == Utils.parseHostPort(slaveHostname)._2)
|
||||
|
||||
// Make sure the local hostname we report matches the cluster scheduler's name for this host
|
||||
Utils.setCustomHostname(slaveHostname)
|
||||
|
||||
|
|
|
@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor
|
|||
import spark.scheduler.cluster.LaunchTask
|
||||
import spark.scheduler.cluster.RegisterExecutorFailed
|
||||
import spark.scheduler.cluster.RegisterExecutor
|
||||
import spark.Utils
|
||||
import spark.deploy.SparkHadoopUtil
|
||||
|
||||
private[spark] class StandaloneExecutorBackend(
|
||||
driverUrl: String,
|
||||
executorId: String,
|
||||
hostname: String,
|
||||
hostPort: String,
|
||||
cores: Int)
|
||||
extends Actor
|
||||
with ExecutorBackend
|
||||
with Logging {
|
||||
|
||||
Utils.checkHostPort(hostPort, "Expected hostport")
|
||||
|
||||
var executor: Executor = null
|
||||
var driver: ActorRef = null
|
||||
|
||||
override def preStart() {
|
||||
logInfo("Connecting to driver: " + driverUrl)
|
||||
driver = context.actorFor(driverUrl)
|
||||
driver ! RegisterExecutor(executorId, hostname, cores)
|
||||
driver ! RegisterExecutor(executorId, hostPort, cores)
|
||||
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
|
||||
context.watch(driver) // Doesn't work with remote actors, but useful for testing
|
||||
}
|
||||
|
@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend(
|
|||
override def receive = {
|
||||
case RegisteredExecutor(sparkProperties) =>
|
||||
logInfo("Successfully registered with driver")
|
||||
executor = new Executor(executorId, hostname, sparkProperties)
|
||||
// Make this host instead of hostPort ?
|
||||
executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
|
||||
|
||||
case RegisterExecutorFailed(message) =>
|
||||
logError("Slave registration failed: " + message)
|
||||
|
@ -63,11 +68,29 @@ private[spark] class StandaloneExecutorBackend(
|
|||
|
||||
private[spark] object StandaloneExecutorBackend {
|
||||
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
|
||||
SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores))
|
||||
}
|
||||
|
||||
// This will be run 'as' the user
|
||||
def run0(args: Product) {
|
||||
assert(4 == args.productArity)
|
||||
runImpl(args.productElement(0).asInstanceOf[String],
|
||||
args.productElement(0).asInstanceOf[String],
|
||||
args.productElement(0).asInstanceOf[String],
|
||||
args.productElement(0).asInstanceOf[Int])
|
||||
}
|
||||
|
||||
private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) {
|
||||
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
|
||||
// before getting started with all our system properties, etc
|
||||
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
|
||||
// Debug code
|
||||
Utils.checkHost(hostname)
|
||||
// set it
|
||||
val sparkHostPort = hostname + ":" + boundPort
|
||||
System.setProperty("spark.hostPort", sparkHostPort)
|
||||
val actor = actorSystem.actorOf(
|
||||
Props(new StandaloneExecutorBackend(driverUrl, executorId, hostname, cores)),
|
||||
Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
|
||||
name = "Executor")
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ import java.net._
|
|||
|
||||
private[spark]
|
||||
abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
||||
val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
|
||||
val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging {
|
||||
def this(channel_ : SocketChannel, selector_ : Selector) = {
|
||||
this(channel_, selector_,
|
||||
ConnectionManagerId.fromSocketAddress(
|
||||
|
@ -33,15 +33,42 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
|||
|
||||
val remoteAddress = getRemoteAddress()
|
||||
|
||||
// Read channels typically do not register for write and write does not for read
|
||||
// Now, we do have write registering for read too (temporarily), but this is to detect
|
||||
// channel close NOT to actually read/consume data on it !
|
||||
// How does this work if/when we move to SSL ?
|
||||
|
||||
// What is the interest to register with selector for when we want this connection to be selected
|
||||
def registerInterest()
|
||||
// What is the interest to register with selector for when we want this connection to be de-selected
|
||||
// Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be
|
||||
// SelectionKey.OP_READ (until we fix it properly)
|
||||
def unregisterInterest()
|
||||
|
||||
// On receiving a read event, should we change the interest for this channel or not ?
|
||||
// Will be true for ReceivingConnection, false for SendingConnection.
|
||||
def changeInterestForRead(): Boolean
|
||||
|
||||
// On receiving a write event, should we change the interest for this channel or not ?
|
||||
// Will be false for ReceivingConnection, true for SendingConnection.
|
||||
// Actually, for now, should not get triggered for ReceivingConnection
|
||||
def changeInterestForWrite(): Boolean
|
||||
|
||||
def getRemoteConnectionManagerId(): ConnectionManagerId = {
|
||||
socketRemoteConnectionManagerId
|
||||
}
|
||||
|
||||
def key() = channel.keyFor(selector)
|
||||
|
||||
def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
|
||||
|
||||
def read() {
|
||||
// Returns whether we have to register for further reads or not.
|
||||
def read(): Boolean = {
|
||||
throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
|
||||
}
|
||||
|
||||
def write() {
|
||||
// Returns whether we have to register for further writes or not.
|
||||
def write(): Boolean = {
|
||||
throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
|
||||
}
|
||||
|
||||
|
@ -64,7 +91,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
|||
if (onExceptionCallback != null) {
|
||||
onExceptionCallback(this, e)
|
||||
} else {
|
||||
logError("Error in connection to " + remoteConnectionManagerId +
|
||||
logError("Error in connection to " + getRemoteConnectionManagerId() +
|
||||
" and OnExceptionCallback not registered", e)
|
||||
}
|
||||
}
|
||||
|
@ -73,7 +100,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
|||
if (onCloseCallback != null) {
|
||||
onCloseCallback(this)
|
||||
} else {
|
||||
logWarning("Connection to " + remoteConnectionManagerId +
|
||||
logWarning("Connection to " + getRemoteConnectionManagerId() +
|
||||
" closed and OnExceptionCallback not registered")
|
||||
}
|
||||
|
||||
|
@ -122,7 +149,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
messages.synchronized{
|
||||
/*messages += message*/
|
||||
messages.enqueue(message)
|
||||
logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
|
||||
logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -149,9 +176,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
}
|
||||
return chunk
|
||||
} else {
|
||||
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
|
||||
/*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
|
||||
message.finishTime = System.currentTimeMillis
|
||||
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
|
||||
logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
|
||||
"] in " + message.timeTaken )
|
||||
}
|
||||
}
|
||||
|
@ -170,15 +197,15 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
messages.enqueue(message)
|
||||
nextMessageToBeUsed = nextMessageToBeUsed + 1
|
||||
if (!message.started) {
|
||||
logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
|
||||
logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
|
||||
message.started = true
|
||||
message.startTime = System.currentTimeMillis
|
||||
}
|
||||
logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
|
||||
logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
|
||||
return chunk
|
||||
} else {
|
||||
message.finishTime = System.currentTimeMillis
|
||||
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
|
||||
logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
|
||||
"] in " + message.timeTaken )
|
||||
}
|
||||
}
|
||||
|
@ -187,26 +214,39 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
}
|
||||
}
|
||||
|
||||
val outbox = new Outbox(1)
|
||||
private val outbox = new Outbox(1)
|
||||
val currentBuffers = new ArrayBuffer[ByteBuffer]()
|
||||
|
||||
/*channel.socket.setSendBufferSize(256 * 1024)*/
|
||||
|
||||
override def getRemoteAddress() = address
|
||||
|
||||
val DEFAULT_INTEREST = SelectionKey.OP_READ
|
||||
|
||||
override def registerInterest() {
|
||||
// Registering read too - does not really help in most cases, but for some
|
||||
// it does - so let us keep it for now.
|
||||
changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
|
||||
}
|
||||
|
||||
override def unregisterInterest() {
|
||||
changeConnectionKeyInterest(DEFAULT_INTEREST)
|
||||
}
|
||||
|
||||
def send(message: Message) {
|
||||
outbox.synchronized {
|
||||
outbox.addMessage(message)
|
||||
if (channel.isConnected) {
|
||||
changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
|
||||
registerInterest()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MUST be called within the selector loop
|
||||
def connect() {
|
||||
try{
|
||||
channel.connect(address)
|
||||
channel.register(selector, SelectionKey.OP_CONNECT)
|
||||
channel.connect(address)
|
||||
logInfo("Initiating connection to [" + address + "]")
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
|
@ -216,20 +256,33 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
}
|
||||
}
|
||||
|
||||
def finishConnect() {
|
||||
def finishConnect(force: Boolean): Boolean = {
|
||||
try {
|
||||
channel.finishConnect
|
||||
changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
|
||||
// Typically, this should finish immediately since it was triggered by a connect
|
||||
// selection - though need not necessarily always complete successfully.
|
||||
val connected = channel.finishConnect
|
||||
if (!force && !connected) {
|
||||
logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
|
||||
return false
|
||||
}
|
||||
|
||||
// Fallback to previous behavior - assume finishConnect completed
|
||||
// This will happen only when finishConnect failed for some repeated number of times (10 or so)
|
||||
// Is highly unlikely unless there was an unclean close of socket, etc
|
||||
registerInterest()
|
||||
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
|
||||
return true
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logWarning("Error finishing connection to " + address, e)
|
||||
callOnExceptionCallback(e)
|
||||
// ignore
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def write() {
|
||||
override def write(): Boolean = {
|
||||
try{
|
||||
while(true) {
|
||||
if (currentBuffers.size == 0) {
|
||||
|
@ -239,8 +292,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
currentBuffers ++= chunk.buffers
|
||||
}
|
||||
case None => {
|
||||
changeConnectionKeyInterest(SelectionKey.OP_READ)
|
||||
return
|
||||
// changeConnectionKeyInterest(0)
|
||||
/*key.interestOps(0)*/
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -254,38 +308,53 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
currentBuffers -= buffer
|
||||
}
|
||||
if (writtenBytes < remainingBytes) {
|
||||
return
|
||||
// re-register for write.
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
|
||||
logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
|
||||
callOnExceptionCallback(e)
|
||||
close()
|
||||
return false
|
||||
}
|
||||
}
|
||||
// should not happen - to keep scala compiler happy
|
||||
return true
|
||||
}
|
||||
|
||||
override def read() {
|
||||
// This is a hack to determine if remote socket was closed or not.
|
||||
// SendingConnection DOES NOT expect to receive any data - if it does, it is an error
|
||||
// For a bunch of cases, read will return -1 in case remote socket is closed : hence we
|
||||
// register for reads to determine that.
|
||||
override def read(): Boolean = {
|
||||
// We don't expect the other side to send anything; so, we just read to detect an error or EOF.
|
||||
try {
|
||||
val length = channel.read(ByteBuffer.allocate(1))
|
||||
if (length == -1) { // EOF
|
||||
close()
|
||||
} else if (length > 0) {
|
||||
logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId)
|
||||
logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
|
||||
}
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e)
|
||||
logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e)
|
||||
callOnExceptionCallback(e)
|
||||
close()
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
override def changeInterestForRead(): Boolean = false
|
||||
|
||||
override def changeInterestForWrite(): Boolean = true
|
||||
}
|
||||
|
||||
|
||||
// Must be created within selector loop - else deadlock
|
||||
private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
|
||||
extends Connection(channel_, selector_) {
|
||||
|
||||
|
@ -298,13 +367,13 @@ extends Connection(channel_, selector_) {
|
|||
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
|
||||
newMessage.started = true
|
||||
newMessage.startTime = System.currentTimeMillis
|
||||
logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
|
||||
logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
|
||||
messages += ((newMessage.id, newMessage))
|
||||
newMessage
|
||||
}
|
||||
|
||||
val message = messages.getOrElseUpdate(header.id, createNewMessage)
|
||||
logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
|
||||
logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
|
||||
message.getChunkForReceiving(header.chunkSize)
|
||||
}
|
||||
|
||||
|
@ -317,6 +386,26 @@ extends Connection(channel_, selector_) {
|
|||
}
|
||||
}
|
||||
|
||||
@volatile private var inferredRemoteManagerId: ConnectionManagerId = null
|
||||
override def getRemoteConnectionManagerId(): ConnectionManagerId = {
|
||||
val currId = inferredRemoteManagerId
|
||||
if (currId != null) currId else super.getRemoteConnectionManagerId()
|
||||
}
|
||||
|
||||
// The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver.
|
||||
// We infer that from the messages we receive on the receiver socket.
|
||||
private def processConnectionManagerId(header: MessageChunkHeader) {
|
||||
val currId = inferredRemoteManagerId
|
||||
if (header.address == null || currId != null) return
|
||||
|
||||
val managerId = ConnectionManagerId.fromSocketAddress(header.address)
|
||||
|
||||
if (managerId != null) {
|
||||
inferredRemoteManagerId = managerId
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
val inbox = new Inbox()
|
||||
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
|
||||
var onReceiveCallback: (Connection , Message) => Unit = null
|
||||
|
@ -324,17 +413,18 @@ extends Connection(channel_, selector_) {
|
|||
|
||||
channel.register(selector, SelectionKey.OP_READ)
|
||||
|
||||
override def read() {
|
||||
override def read(): Boolean = {
|
||||
try {
|
||||
while (true) {
|
||||
if (currentChunk == null) {
|
||||
val headerBytesRead = channel.read(headerBuffer)
|
||||
if (headerBytesRead == -1) {
|
||||
close()
|
||||
return
|
||||
return false
|
||||
}
|
||||
if (headerBuffer.remaining > 0) {
|
||||
return
|
||||
// re-register for read event ...
|
||||
return true
|
||||
}
|
||||
headerBuffer.flip
|
||||
if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
|
||||
|
@ -342,6 +432,9 @@ extends Connection(channel_, selector_) {
|
|||
}
|
||||
val header = MessageChunkHeader.create(headerBuffer)
|
||||
headerBuffer.clear()
|
||||
|
||||
processConnectionManagerId(header)
|
||||
|
||||
header.typ match {
|
||||
case Message.BUFFER_MESSAGE => {
|
||||
if (header.totalSize == 0) {
|
||||
|
@ -349,7 +442,8 @@ extends Connection(channel_, selector_) {
|
|||
onReceiveCallback(this, Message.create(header))
|
||||
}
|
||||
currentChunk = null
|
||||
return
|
||||
// re-register for read event ...
|
||||
return true
|
||||
} else {
|
||||
currentChunk = inbox.getChunk(header).orNull
|
||||
}
|
||||
|
@ -362,10 +456,11 @@ extends Connection(channel_, selector_) {
|
|||
|
||||
val bytesRead = channel.read(currentChunk.buffer)
|
||||
if (bytesRead == 0) {
|
||||
return
|
||||
// re-register for read event ...
|
||||
return true
|
||||
} else if (bytesRead == -1) {
|
||||
close()
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
/*logDebug("Read " + bytesRead + " bytes for the buffer")*/
|
||||
|
@ -376,7 +471,7 @@ extends Connection(channel_, selector_) {
|
|||
if (bufferMessage.isCompletelyReceived) {
|
||||
bufferMessage.flip
|
||||
bufferMessage.finishTime = System.currentTimeMillis
|
||||
logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
|
||||
logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
|
||||
if (onReceiveCallback != null) {
|
||||
onReceiveCallback(this, bufferMessage)
|
||||
}
|
||||
|
@ -387,12 +482,31 @@ extends Connection(channel_, selector_) {
|
|||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
|
||||
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
|
||||
callOnExceptionCallback(e)
|
||||
close()
|
||||
return false
|
||||
}
|
||||
}
|
||||
// should not happen - to keep scala compiler happy
|
||||
return true
|
||||
}
|
||||
|
||||
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
|
||||
|
||||
override def changeInterestForRead(): Boolean = true
|
||||
|
||||
override def changeInterestForWrite(): Boolean = {
|
||||
throw new IllegalStateException("Unexpected invocation right now")
|
||||
}
|
||||
|
||||
override def registerInterest() {
|
||||
// Registering read too - does not really help in most cases, but for some
|
||||
// it does - so let us keep it for now.
|
||||
changeConnectionKeyInterest(SelectionKey.OP_READ)
|
||||
}
|
||||
|
||||
override def unregisterInterest() {
|
||||
changeConnectionKeyInterest(0)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,12 +6,12 @@ import java.nio._
|
|||
import java.nio.channels._
|
||||
import java.nio.channels.spi._
|
||||
import java.net._
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
|
||||
|
||||
import scala.collection.mutable.HashSet
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.mutable.SynchronizedMap
|
||||
import scala.collection.mutable.SynchronizedQueue
|
||||
import scala.collection.mutable.Queue
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import akka.dispatch.{Await, Promise, ExecutionContext, Future}
|
||||
|
@ -19,6 +19,10 @@ import akka.util.Duration
|
|||
import akka.util.duration._
|
||||
|
||||
private[spark] case class ConnectionManagerId(host: String, port: Int) {
|
||||
// DEBUG code
|
||||
Utils.checkHost(host)
|
||||
assert (port > 0)
|
||||
|
||||
def toSocketAddress() = new InetSocketAddress(host, port)
|
||||
}
|
||||
|
||||
|
@ -42,19 +46,37 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
def markDone() { completionHandler(this) }
|
||||
}
|
||||
|
||||
val selector = SelectorProvider.provider.openSelector()
|
||||
val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
|
||||
val serverChannel = ServerSocketChannel.open()
|
||||
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
|
||||
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
|
||||
val messageStatuses = new HashMap[Int, MessageStatus]
|
||||
val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
|
||||
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
|
||||
val sendMessageRequests = new Queue[(Message, SendingConnection)]
|
||||
private val selector = SelectorProvider.provider.openSelector()
|
||||
|
||||
private val handleMessageExecutor = new ThreadPoolExecutor(
|
||||
System.getProperty("spark.core.connection.handler.threads.min","20").toInt,
|
||||
System.getProperty("spark.core.connection.handler.threads.max","60").toInt,
|
||||
System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS,
|
||||
new LinkedBlockingDeque[Runnable]())
|
||||
|
||||
private val handleReadWriteExecutor = new ThreadPoolExecutor(
|
||||
System.getProperty("spark.core.connection.io.threads.min","4").toInt,
|
||||
System.getProperty("spark.core.connection.io.threads.max","32").toInt,
|
||||
System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS,
|
||||
new LinkedBlockingDeque[Runnable]())
|
||||
|
||||
// Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
|
||||
private val handleConnectExecutor = new ThreadPoolExecutor(
|
||||
System.getProperty("spark.core.connection.connect.threads.min","1").toInt,
|
||||
System.getProperty("spark.core.connection.connect.threads.max","8").toInt,
|
||||
System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS,
|
||||
new LinkedBlockingDeque[Runnable]())
|
||||
|
||||
private val serverChannel = ServerSocketChannel.open()
|
||||
private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
|
||||
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
|
||||
private val messageStatuses = new HashMap[Int, MessageStatus]
|
||||
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
|
||||
private val registerRequests = new SynchronizedQueue[SendingConnection]
|
||||
|
||||
implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
|
||||
|
||||
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
|
||||
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
|
||||
|
||||
serverChannel.configureBlocking(false)
|
||||
serverChannel.socket.setReuseAddress(true)
|
||||
|
@ -66,45 +88,138 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
|
||||
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
|
||||
|
||||
val selectorThread = new Thread("connection-manager-thread") {
|
||||
private val selectorThread = new Thread("connection-manager-thread") {
|
||||
override def run() = ConnectionManager.this.run()
|
||||
}
|
||||
selectorThread.setDaemon(true)
|
||||
selectorThread.start()
|
||||
|
||||
private def run() {
|
||||
private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
|
||||
|
||||
private def triggerWrite(key: SelectionKey) {
|
||||
val conn = connectionsByKey.getOrElse(key, null)
|
||||
if (conn == null) return
|
||||
|
||||
writeRunnableStarted.synchronized {
|
||||
// So that we do not trigger more write events while processing this one.
|
||||
// The write method will re-register when done.
|
||||
if (conn.changeInterestForWrite()) conn.unregisterInterest()
|
||||
if (writeRunnableStarted.contains(key)) {
|
||||
// key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
|
||||
return
|
||||
}
|
||||
|
||||
writeRunnableStarted += key
|
||||
}
|
||||
handleReadWriteExecutor.execute(new Runnable {
|
||||
override def run() {
|
||||
var register: Boolean = false
|
||||
try {
|
||||
register = conn.write()
|
||||
} finally {
|
||||
writeRunnableStarted.synchronized {
|
||||
writeRunnableStarted -= key
|
||||
if (register && conn.changeInterestForWrite()) {
|
||||
conn.registerInterest()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} )
|
||||
}
|
||||
|
||||
private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
|
||||
|
||||
private def triggerRead(key: SelectionKey) {
|
||||
val conn = connectionsByKey.getOrElse(key, null)
|
||||
if (conn == null) return
|
||||
|
||||
readRunnableStarted.synchronized {
|
||||
// So that we do not trigger more read events while processing this one.
|
||||
// The read method will re-register when done.
|
||||
if (conn.changeInterestForRead())conn.unregisterInterest()
|
||||
if (readRunnableStarted.contains(key)) {
|
||||
return
|
||||
}
|
||||
|
||||
readRunnableStarted += key
|
||||
}
|
||||
handleReadWriteExecutor.execute(new Runnable {
|
||||
override def run() {
|
||||
var register: Boolean = false
|
||||
try {
|
||||
register = conn.read()
|
||||
} finally {
|
||||
readRunnableStarted.synchronized {
|
||||
readRunnableStarted -= key
|
||||
if (register && conn.changeInterestForRead()) {
|
||||
conn.registerInterest()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} )
|
||||
}
|
||||
|
||||
private def triggerConnect(key: SelectionKey) {
|
||||
val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
|
||||
if (conn == null) return
|
||||
|
||||
// prevent other events from being triggered
|
||||
// Since we are still trying to connect, we do not need to do the additional steps in triggerWrite
|
||||
conn.changeConnectionKeyInterest(0)
|
||||
|
||||
handleConnectExecutor.execute(new Runnable {
|
||||
override def run() {
|
||||
|
||||
var tries: Int = 10
|
||||
while (tries >= 0) {
|
||||
if (conn.finishConnect(false)) return
|
||||
// Sleep ?
|
||||
Thread.sleep(1)
|
||||
tries -= 1
|
||||
}
|
||||
|
||||
// fallback to previous behavior : we should not really come here since this method was
|
||||
// triggered since channel became connectable : but at times, the first finishConnect need not
|
||||
// succeed : hence the loop to retry a few 'times'.
|
||||
conn.finishConnect(true)
|
||||
}
|
||||
} )
|
||||
}
|
||||
|
||||
def run() {
|
||||
try {
|
||||
while(!selectorThread.isInterrupted) {
|
||||
for ((connectionManagerId, sendingConnection) <- connectionRequests) {
|
||||
sendingConnection.connect()
|
||||
addConnection(sendingConnection)
|
||||
connectionRequests -= connectionManagerId
|
||||
}
|
||||
sendMessageRequests.synchronized {
|
||||
while (!sendMessageRequests.isEmpty) {
|
||||
val (message, connection) = sendMessageRequests.dequeue
|
||||
connection.send(message)
|
||||
}
|
||||
while (! registerRequests.isEmpty) {
|
||||
val conn: SendingConnection = registerRequests.dequeue
|
||||
addListeners(conn)
|
||||
conn.connect()
|
||||
addConnection(conn)
|
||||
}
|
||||
|
||||
while (!keyInterestChangeRequests.isEmpty) {
|
||||
while(!keyInterestChangeRequests.isEmpty) {
|
||||
val (key, ops) = keyInterestChangeRequests.dequeue
|
||||
val connection = connectionsByKey(key)
|
||||
val lastOps = key.interestOps()
|
||||
key.interestOps(ops)
|
||||
val connection = connectionsByKey.getOrElse(key, null)
|
||||
if (connection != null) {
|
||||
val lastOps = key.interestOps()
|
||||
key.interestOps(ops)
|
||||
|
||||
def intToOpStr(op: Int): String = {
|
||||
val opStrs = ArrayBuffer[String]()
|
||||
if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
|
||||
if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
|
||||
if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
|
||||
if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
|
||||
if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
|
||||
// hot loop - prevent materialization of string if trace not enabled.
|
||||
if (isTraceEnabled()) {
|
||||
def intToOpStr(op: Int): String = {
|
||||
val opStrs = ArrayBuffer[String]()
|
||||
if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
|
||||
if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
|
||||
if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
|
||||
if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
|
||||
if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
|
||||
}
|
||||
|
||||
logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
|
||||
"] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
|
||||
}
|
||||
}
|
||||
|
||||
logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
|
||||
"] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
|
||||
|
||||
}
|
||||
|
||||
val selectedKeysCount = selector.select()
|
||||
|
@ -123,12 +238,15 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
if (key.isValid) {
|
||||
if (key.isAcceptable) {
|
||||
acceptConnection(key)
|
||||
} else if (key.isConnectable) {
|
||||
connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
|
||||
} else if (key.isReadable) {
|
||||
connectionsByKey(key).read()
|
||||
} else if (key.isWritable) {
|
||||
connectionsByKey(key).write()
|
||||
} else
|
||||
if (key.isConnectable) {
|
||||
triggerConnect(key)
|
||||
} else
|
||||
if (key.isReadable) {
|
||||
triggerRead(key)
|
||||
} else
|
||||
if (key.isWritable) {
|
||||
triggerWrite(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -138,94 +256,116 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
private def acceptConnection(key: SelectionKey) {
|
||||
def acceptConnection(key: SelectionKey) {
|
||||
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
|
||||
val newChannel = serverChannel.accept()
|
||||
val newConnection = new ReceivingConnection(newChannel, selector)
|
||||
newConnection.onReceive(receiveMessage)
|
||||
newConnection.onClose(removeConnection)
|
||||
addConnection(newConnection)
|
||||
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
|
||||
|
||||
var newChannel = serverChannel.accept()
|
||||
|
||||
// accept them all in a tight loop. non blocking accept with no processing, should be fine
|
||||
while (newChannel != null) {
|
||||
try {
|
||||
val newConnection = new ReceivingConnection(newChannel, selector)
|
||||
newConnection.onReceive(receiveMessage)
|
||||
addListeners(newConnection)
|
||||
addConnection(newConnection)
|
||||
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
|
||||
} catch {
|
||||
// might happen in case of issues with registering with selector
|
||||
case e: Exception => logError("Error in accept loop", e)
|
||||
}
|
||||
|
||||
newChannel = serverChannel.accept()
|
||||
}
|
||||
}
|
||||
|
||||
private def addConnection(connection: Connection) {
|
||||
connectionsByKey += ((connection.key, connection))
|
||||
if (connection.isInstanceOf[SendingConnection]) {
|
||||
val sendingConnection = connection.asInstanceOf[SendingConnection]
|
||||
connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
|
||||
}
|
||||
private def addListeners(connection: Connection) {
|
||||
connection.onKeyInterestChange(changeConnectionKeyInterest)
|
||||
connection.onException(handleConnectionError)
|
||||
connection.onClose(removeConnection)
|
||||
}
|
||||
|
||||
private def removeConnection(connection: Connection) {
|
||||
def addConnection(connection: Connection) {
|
||||
connectionsByKey += ((connection.key, connection))
|
||||
}
|
||||
|
||||
def removeConnection(connection: Connection) {
|
||||
connectionsByKey -= connection.key
|
||||
if (connection.isInstanceOf[SendingConnection]) {
|
||||
val sendingConnection = connection.asInstanceOf[SendingConnection]
|
||||
val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
|
||||
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
|
||||
|
||||
connectionsById -= sendingConnectionManagerId
|
||||
try {
|
||||
if (connection.isInstanceOf[SendingConnection]) {
|
||||
val sendingConnection = connection.asInstanceOf[SendingConnection]
|
||||
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
|
||||
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
|
||||
|
||||
messageStatuses.synchronized {
|
||||
messageStatuses
|
||||
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
|
||||
logInfo("Notifying " + status)
|
||||
status.synchronized {
|
||||
status.attempted = true
|
||||
status.acked = false
|
||||
status.markDone()
|
||||
}
|
||||
connectionsById -= sendingConnectionManagerId
|
||||
|
||||
messageStatuses.synchronized {
|
||||
messageStatuses
|
||||
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
|
||||
logInfo("Notifying " + status)
|
||||
status.synchronized {
|
||||
status.attempted = true
|
||||
status.acked = false
|
||||
status.markDone()
|
||||
}
|
||||
})
|
||||
|
||||
messageStatuses.retain((i, status) => {
|
||||
status.connectionManagerId != sendingConnectionManagerId
|
||||
})
|
||||
}
|
||||
} else if (connection.isInstanceOf[ReceivingConnection]) {
|
||||
val receivingConnection = connection.asInstanceOf[ReceivingConnection]
|
||||
val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
|
||||
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
|
||||
|
||||
messageStatuses.retain((i, status) => {
|
||||
status.connectionManagerId != sendingConnectionManagerId
|
||||
})
|
||||
}
|
||||
} else if (connection.isInstanceOf[ReceivingConnection]) {
|
||||
val receivingConnection = connection.asInstanceOf[ReceivingConnection]
|
||||
val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
|
||||
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
|
||||
|
||||
val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
|
||||
if (sendingConnectionManagerId == null) {
|
||||
logError("Corresponding SendingConnectionManagerId not found")
|
||||
return
|
||||
}
|
||||
logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
|
||||
|
||||
val sendingConnection = connectionsById(sendingConnectionManagerId)
|
||||
sendingConnection.close()
|
||||
connectionsById -= sendingConnectionManagerId
|
||||
|
||||
messageStatuses.synchronized {
|
||||
for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
|
||||
logInfo("Notifying " + s)
|
||||
s.synchronized {
|
||||
s.attempted = true
|
||||
s.acked = false
|
||||
s.markDone()
|
||||
}
|
||||
val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
|
||||
if (! sendingConnectionOpt.isDefined) {
|
||||
logError("Corresponding SendingConnectionManagerId not found")
|
||||
return
|
||||
}
|
||||
|
||||
messageStatuses.retain((i, status) => {
|
||||
status.connectionManagerId != sendingConnectionManagerId
|
||||
})
|
||||
val sendingConnection = sendingConnectionOpt.get
|
||||
connectionsById -= remoteConnectionManagerId
|
||||
sendingConnection.close()
|
||||
|
||||
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
|
||||
|
||||
assert (sendingConnectionManagerId == remoteConnectionManagerId)
|
||||
|
||||
messageStatuses.synchronized {
|
||||
for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
|
||||
logInfo("Notifying " + s)
|
||||
s.synchronized {
|
||||
s.attempted = true
|
||||
s.acked = false
|
||||
s.markDone()
|
||||
}
|
||||
}
|
||||
|
||||
messageStatuses.retain((i, status) => {
|
||||
status.connectionManagerId != sendingConnectionManagerId
|
||||
})
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
// So that the selection keys can be removed.
|
||||
wakeupSelector()
|
||||
}
|
||||
}
|
||||
|
||||
private def handleConnectionError(connection: Connection, e: Exception) {
|
||||
logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
|
||||
def handleConnectionError(connection: Connection, e: Exception) {
|
||||
logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId())
|
||||
removeConnection(connection)
|
||||
}
|
||||
|
||||
private def changeConnectionKeyInterest(connection: Connection, ops: Int) {
|
||||
def changeConnectionKeyInterest(connection: Connection, ops: Int) {
|
||||
keyInterestChangeRequests += ((connection.key, ops))
|
||||
// so that registerations happen !
|
||||
wakeupSelector()
|
||||
}
|
||||
|
||||
private def receiveMessage(connection: Connection, message: Message) {
|
||||
def receiveMessage(connection: Connection, message: Message) {
|
||||
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
|
||||
logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
|
||||
val runnable = new Runnable() {
|
||||
|
@ -293,18 +433,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
|
||||
def startNewConnection(): SendingConnection = {
|
||||
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
|
||||
val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
|
||||
new SendingConnection(inetSocketAddress, selector, connectionManagerId))
|
||||
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
|
||||
registerRequests.enqueue(newConnection)
|
||||
|
||||
newConnection
|
||||
}
|
||||
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
|
||||
val connection = connectionsById.getOrElse(lookupKey, startNewConnection())
|
||||
// I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ...
|
||||
// If we do re-add it, we should consistently use it everywhere I guess ?
|
||||
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
|
||||
message.senderAddress = id.toSocketAddress()
|
||||
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
|
||||
/*connection.send(message)*/
|
||||
sendMessageRequests.synchronized {
|
||||
sendMessageRequests += ((message, connection))
|
||||
}
|
||||
connection.send(message)
|
||||
|
||||
wakeupSelector()
|
||||
}
|
||||
|
||||
private def wakeupSelector() {
|
||||
selector.wakeup()
|
||||
}
|
||||
|
||||
|
@ -337,6 +481,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
logWarning("All connections not cleaned up")
|
||||
}
|
||||
handleMessageExecutor.shutdown()
|
||||
handleReadWriteExecutor.shutdown()
|
||||
handleConnectExecutor.shutdown()
|
||||
logInfo("ConnectionManager stopped")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ private[spark] class MessageChunkHeader(
|
|||
val other: Int,
|
||||
val address: InetSocketAddress) {
|
||||
lazy val buffer = {
|
||||
// No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection
|
||||
val ip = address.getAddress.getAddress()
|
||||
val port = address.getPort()
|
||||
ByteBuffer.
|
||||
|
|
|
@ -50,6 +50,11 @@ class DAGScheduler(
|
|||
eventQueue.put(ExecutorLost(execId))
|
||||
}
|
||||
|
||||
// Called by TaskScheduler when a host is added
|
||||
override def executorGained(execId: String, hostPort: String) {
|
||||
eventQueue.put(ExecutorGained(execId, hostPort))
|
||||
}
|
||||
|
||||
// Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
|
||||
override def taskSetFailed(taskSet: TaskSet, reason: String) {
|
||||
eventQueue.put(TaskSetFailed(taskSet, reason))
|
||||
|
@ -113,7 +118,7 @@ class DAGScheduler(
|
|||
if (!cacheLocs.contains(rdd.id)) {
|
||||
val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
|
||||
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
|
||||
locations => locations.map(_.ip).toList
|
||||
locations => locations.map(_.hostPort).toList
|
||||
}.toArray
|
||||
}
|
||||
cacheLocs(rdd.id)
|
||||
|
@ -293,6 +298,9 @@ class DAGScheduler(
|
|||
submitStage(finalStage)
|
||||
}
|
||||
|
||||
case ExecutorGained(execId, hostPort) =>
|
||||
handleExecutorGained(execId, hostPort)
|
||||
|
||||
case ExecutorLost(execId) =>
|
||||
handleExecutorLost(execId)
|
||||
|
||||
|
@ -631,6 +639,14 @@ class DAGScheduler(
|
|||
}
|
||||
}
|
||||
|
||||
private def handleExecutorGained(execId: String, hostPort: String) {
|
||||
// remove from failedGeneration(execId) ?
|
||||
if (failedGeneration.contains(execId)) {
|
||||
logInfo("Host gained which was in lost list earlier: " + hostPort)
|
||||
failedGeneration -= execId
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
|
||||
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
|
||||
|
|
|
@ -32,6 +32,10 @@ private[spark] case class CompletionEvent(
|
|||
taskMetrics: TaskMetrics)
|
||||
extends DAGSchedulerEvent
|
||||
|
||||
private[spark] case class ExecutorGained(execId: String, hostPort: String) extends DAGSchedulerEvent {
|
||||
Utils.checkHostPort(hostPort, "Required hostport")
|
||||
}
|
||||
|
||||
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
|
||||
|
||||
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
|
||||
|
|
156
core/src/main/scala/spark/scheduler/InputFormatInfo.scala
Normal file
156
core/src/main/scala/spark/scheduler/InputFormatInfo.scala
Normal file
|
@ -0,0 +1,156 @@
|
|||
package spark.scheduler
|
||||
|
||||
import spark.Logging
|
||||
import scala.collection.immutable.Set
|
||||
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
|
||||
import org.apache.hadoop.util.ReflectionUtils
|
||||
import org.apache.hadoop.mapreduce.Job
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
|
||||
/**
|
||||
* Parses and holds information about inputFormat (and files) specified as a parameter.
|
||||
*/
|
||||
class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_],
|
||||
val path: String) extends Logging {
|
||||
|
||||
var mapreduceInputFormat: Boolean = false
|
||||
var mapredInputFormat: Boolean = false
|
||||
|
||||
validate()
|
||||
|
||||
override def toString(): String = {
|
||||
"InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path
|
||||
}
|
||||
|
||||
override def hashCode(): Int = {
|
||||
var hashCode = inputFormatClazz.hashCode
|
||||
hashCode = hashCode * 31 + path.hashCode
|
||||
hashCode
|
||||
}
|
||||
|
||||
// Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path
|
||||
// .. which is fine, this is best case effort to remove duplicates - right ?
|
||||
override def equals(other: Any): Boolean = other match {
|
||||
case that: InputFormatInfo => {
|
||||
// not checking config - that should be fine, right ?
|
||||
this.inputFormatClazz == that.inputFormatClazz &&
|
||||
this.path == that.path
|
||||
}
|
||||
case _ => false
|
||||
}
|
||||
|
||||
private def validate() {
|
||||
logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path)
|
||||
|
||||
try {
|
||||
if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
|
||||
logDebug("inputformat is from mapreduce package")
|
||||
mapreduceInputFormat = true
|
||||
}
|
||||
else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
|
||||
logDebug("inputformat is from mapred package")
|
||||
mapredInputFormat = true
|
||||
}
|
||||
else {
|
||||
throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz +
|
||||
" is NOT a supported input format ? does not implement either of the supported hadoop api's")
|
||||
}
|
||||
}
|
||||
catch {
|
||||
case e: ClassNotFoundException => {
|
||||
throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// This method does not expect failures, since validate has already passed ...
|
||||
private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
|
||||
val conf = new JobConf(configuration)
|
||||
FileInputFormat.setInputPaths(conf, path)
|
||||
|
||||
val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
|
||||
ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[
|
||||
org.apache.hadoop.mapreduce.InputFormat[_, _]]
|
||||
val job = new Job(conf)
|
||||
|
||||
val retval = new ArrayBuffer[SplitInfo]()
|
||||
val list = instance.getSplits(job)
|
||||
for (split <- list) {
|
||||
retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split)
|
||||
}
|
||||
|
||||
return retval.toSet
|
||||
}
|
||||
|
||||
// This method does not expect failures, since validate has already passed ...
|
||||
private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
|
||||
val jobConf = new JobConf(configuration)
|
||||
FileInputFormat.setInputPaths(jobConf, path)
|
||||
|
||||
val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
|
||||
ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[
|
||||
org.apache.hadoop.mapred.InputFormat[_, _]]
|
||||
|
||||
val retval = new ArrayBuffer[SplitInfo]()
|
||||
instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach(
|
||||
elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem)
|
||||
)
|
||||
|
||||
return retval.toSet
|
||||
}
|
||||
|
||||
private def findPreferredLocations(): Set[SplitInfo] = {
|
||||
logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat +
|
||||
", inputFormatClazz : " + inputFormatClazz)
|
||||
if (mapreduceInputFormat) {
|
||||
return prefLocsFromMapreduceInputFormat()
|
||||
}
|
||||
else {
|
||||
assert(mapredInputFormat)
|
||||
return prefLocsFromMapredInputFormat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
object InputFormatInfo {
|
||||
/**
|
||||
Computes the preferred locations based on input(s) and returned a location to block map.
|
||||
Typical use of this method for allocation would follow some algo like this
|
||||
(which is what we currently do in YARN branch) :
|
||||
a) For each host, count number of splits hosted on that host.
|
||||
b) Decrement the currently allocated containers on that host.
|
||||
c) Compute rack info for each host and update rack -> count map based on (b).
|
||||
d) Allocate nodes based on (c)
|
||||
e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node
|
||||
(even if data locality on that is very high) : this is to prevent fragility of job if a single
|
||||
(or small set of) hosts go down.
|
||||
|
||||
go to (a) until required nodes are allocated.
|
||||
|
||||
If a node 'dies', follow same procedure.
|
||||
|
||||
PS: I know the wording here is weird, hopefully it makes some sense !
|
||||
*/
|
||||
def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = {
|
||||
|
||||
val nodeToSplit = new HashMap[String, HashSet[SplitInfo]]
|
||||
for (inputSplit <- formats) {
|
||||
val splits = inputSplit.findPreferredLocations()
|
||||
|
||||
for (split <- splits){
|
||||
val location = split.hostLocation
|
||||
val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo])
|
||||
set += split
|
||||
}
|
||||
}
|
||||
|
||||
nodeToSplit
|
||||
}
|
||||
}
|
|
@ -70,6 +70,14 @@ private[spark] class ResultTask[T, U](
|
|||
rdd.partitions(partition)
|
||||
}
|
||||
|
||||
// data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts.
|
||||
val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq
|
||||
|
||||
{
|
||||
// DEBUG code
|
||||
preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs))
|
||||
}
|
||||
|
||||
override def run(attemptId: Long): U = {
|
||||
val context = new TaskContext(stageId, partition, attemptId)
|
||||
metrics = Some(context.taskMetrics)
|
||||
|
@ -80,7 +88,7 @@ private[spark] class ResultTask[T, U](
|
|||
}
|
||||
}
|
||||
|
||||
override def preferredLocations: Seq[String] = locs
|
||||
override def preferredLocations: Seq[String] = preferredLocs
|
||||
|
||||
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
|
||||
|
||||
|
|
|
@ -77,13 +77,21 @@ private[spark] class ShuffleMapTask(
|
|||
var rdd: RDD[_],
|
||||
var dep: ShuffleDependency[_,_],
|
||||
var partition: Int,
|
||||
@transient var locs: Seq[String])
|
||||
@transient private var locs: Seq[String])
|
||||
extends Task[MapStatus](stageId)
|
||||
with Externalizable
|
||||
with Logging {
|
||||
|
||||
protected def this() = this(0, null, null, 0, null)
|
||||
|
||||
// data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts.
|
||||
private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq
|
||||
|
||||
{
|
||||
// DEBUG code
|
||||
preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs))
|
||||
}
|
||||
|
||||
var split = if (rdd == null) {
|
||||
null
|
||||
} else {
|
||||
|
@ -154,7 +162,7 @@ private[spark] class ShuffleMapTask(
|
|||
}
|
||||
}
|
||||
|
||||
override def preferredLocations: Seq[String] = locs
|
||||
override def preferredLocations: Seq[String] = preferredLocs
|
||||
|
||||
override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
|
||||
}
|
||||
|
|
61
core/src/main/scala/spark/scheduler/SplitInfo.scala
Normal file
61
core/src/main/scala/spark/scheduler/SplitInfo.scala
Normal file
|
@ -0,0 +1,61 @@
|
|||
package spark.scheduler
|
||||
|
||||
import collection.mutable.ArrayBuffer
|
||||
|
||||
// information about a specific split instance : handles both split instances.
|
||||
// So that we do not need to worry about the differences.
|
||||
class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String,
|
||||
val length: Long, val underlyingSplit: Any) {
|
||||
override def toString(): String = {
|
||||
"SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz +
|
||||
", hostLocation : " + hostLocation + ", path : " + path +
|
||||
", length : " + length + ", underlyingSplit " + underlyingSplit
|
||||
}
|
||||
|
||||
override def hashCode(): Int = {
|
||||
var hashCode = inputFormatClazz.hashCode
|
||||
hashCode = hashCode * 31 + hostLocation.hashCode
|
||||
hashCode = hashCode * 31 + path.hashCode
|
||||
// ignore overflow ? It is hashcode anyway !
|
||||
hashCode = hashCode * 31 + (length & 0x7fffffff).toInt
|
||||
hashCode
|
||||
}
|
||||
|
||||
// This is practically useless since most of the Split impl's dont seem to implement equals :-(
|
||||
// So unless there is identity equality between underlyingSplits, it will always fail even if it
|
||||
// is pointing to same block.
|
||||
override def equals(other: Any): Boolean = other match {
|
||||
case that: SplitInfo => {
|
||||
this.hostLocation == that.hostLocation &&
|
||||
this.inputFormatClazz == that.inputFormatClazz &&
|
||||
this.path == that.path &&
|
||||
this.length == that.length &&
|
||||
// other split specific checks (like start for FileSplit)
|
||||
this.underlyingSplit == that.underlyingSplit
|
||||
}
|
||||
case _ => false
|
||||
}
|
||||
}
|
||||
|
||||
object SplitInfo {
|
||||
|
||||
def toSplitInfo(inputFormatClazz: Class[_], path: String,
|
||||
mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = {
|
||||
val retval = new ArrayBuffer[SplitInfo]()
|
||||
val length = mapredSplit.getLength
|
||||
for (host <- mapredSplit.getLocations) {
|
||||
retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit)
|
||||
}
|
||||
retval
|
||||
}
|
||||
|
||||
def toSplitInfo(inputFormatClazz: Class[_], path: String,
|
||||
mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = {
|
||||
val retval = new ArrayBuffer[SplitInfo]()
|
||||
val length = mapreduceSplit.getLength
|
||||
for (host <- mapreduceSplit.getLocations) {
|
||||
retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit)
|
||||
}
|
||||
retval
|
||||
}
|
||||
}
|
|
@ -10,6 +10,10 @@ package spark.scheduler
|
|||
private[spark] trait TaskScheduler {
|
||||
def start(): Unit
|
||||
|
||||
// Invoked after system has successfully initialized (typically in spark context).
|
||||
// Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc.
|
||||
def postStartHook() { }
|
||||
|
||||
// Disconnect from the cluster.
|
||||
def stop(): Unit
|
||||
|
||||
|
|
|
@ -14,6 +14,9 @@ private[spark] trait TaskSchedulerListener {
|
|||
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
|
||||
taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
|
||||
|
||||
// A node was added to the cluster.
|
||||
def executorGained(execId: String, hostPort: String): Unit
|
||||
|
||||
// A node was lost from the cluster.
|
||||
def executorLost(execId: String): Unit
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
package spark.scheduler.cluster
|
||||
|
||||
import java.io.{File, FileInputStream, FileOutputStream}
|
||||
import java.lang.{Boolean => JBoolean}
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
|
@ -25,6 +25,30 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
|
||||
// Threshold above which we warn user initial TaskSet may be starved
|
||||
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
|
||||
// How often to revive offers in case there are pending tasks - that is how often to try to get
|
||||
// tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior
|
||||
// Note that this is required due to delayed scheduling due to data locality waits, etc.
|
||||
// TODO: rename property ?
|
||||
val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong
|
||||
|
||||
/*
|
||||
This property controls how aggressive we should be to modulate waiting for host local task scheduling.
|
||||
To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for host locality of tasks before
|
||||
scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order :
|
||||
host-local, rack-local and then others
|
||||
But once all available host local (and no pref) tasks are scheduled, instead of waiting for 3 sec before
|
||||
scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can
|
||||
modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is
|
||||
maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap.
|
||||
|
||||
TODO: rename property ? The value is one of
|
||||
- HOST_LOCAL (default, no change w.r.t current behavior),
|
||||
- RACK_LOCAL and
|
||||
- ANY
|
||||
|
||||
Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective.
|
||||
*/
|
||||
val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "HOST_LOCAL"))
|
||||
|
||||
val activeTaskSets = new HashMap[String, TaskSetManager]
|
||||
var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
|
||||
|
@ -33,9 +57,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
val taskIdToExecutorId = new HashMap[Long, String]
|
||||
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
|
||||
|
||||
var hasReceivedTask = false
|
||||
var hasLaunchedTask = false
|
||||
val starvationTimer = new Timer(true)
|
||||
@volatile private var hasReceivedTask = false
|
||||
@volatile private var hasLaunchedTask = false
|
||||
private val starvationTimer = new Timer(true)
|
||||
|
||||
// Incrementing Mesos task IDs
|
||||
val nextTaskId = new AtomicLong(0)
|
||||
|
@ -43,11 +67,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
// Which executor IDs we have executors on
|
||||
val activeExecutorIds = new HashSet[String]
|
||||
|
||||
// TODO: We might want to remove this and merge it with execId datastructures - but later.
|
||||
// Which hosts in the cluster are alive (contains hostPort's)
|
||||
private val hostPortsAlive = new HashSet[String]
|
||||
private val hostToAliveHostPorts = new HashMap[String, HashSet[String]]
|
||||
|
||||
// The set of executors we have on each host; this is used to compute hostsAlive, which
|
||||
// in turn is used to decide when we can attain data locality on a given host
|
||||
val executorsByHost = new HashMap[String, HashSet[String]]
|
||||
val executorsByHostPort = new HashMap[String, HashSet[String]]
|
||||
|
||||
val executorIdToHost = new HashMap[String, String]
|
||||
val executorIdToHostPort = new HashMap[String, String]
|
||||
|
||||
// JAR server, if any JARs were added by the user to the SparkContext
|
||||
var jarServer: HttpServer = null
|
||||
|
@ -75,11 +104,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
override def start() {
|
||||
backend.start()
|
||||
|
||||
if (System.getProperty("spark.speculation", "false") == "true") {
|
||||
if (JBoolean.getBoolean("spark.speculation")) {
|
||||
new Thread("ClusterScheduler speculation check") {
|
||||
setDaemon(true)
|
||||
|
||||
override def run() {
|
||||
logInfo("Starting speculative execution thread")
|
||||
while (true) {
|
||||
try {
|
||||
Thread.sleep(SPECULATION_INTERVAL)
|
||||
|
@ -91,6 +121,27 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
}
|
||||
}.start()
|
||||
}
|
||||
|
||||
|
||||
// Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ?
|
||||
if (TASK_REVIVAL_INTERVAL > 0) {
|
||||
new Thread("ClusterScheduler task offer revival check") {
|
||||
setDaemon(true)
|
||||
|
||||
override def run() {
|
||||
logInfo("Starting speculative task offer revival thread")
|
||||
while (true) {
|
||||
try {
|
||||
Thread.sleep(TASK_REVIVAL_INTERVAL)
|
||||
} catch {
|
||||
case e: InterruptedException => {}
|
||||
}
|
||||
|
||||
if (hasPendingTasks()) backend.reviveOffers()
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
}
|
||||
|
||||
override def submitTasks(taskSet: TaskSet) {
|
||||
|
@ -139,22 +190,92 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
SparkEnv.set(sc.env)
|
||||
// Mark each slave as alive and remember its hostname
|
||||
for (o <- offers) {
|
||||
executorIdToHost(o.executorId) = o.hostname
|
||||
if (!executorsByHost.contains(o.hostname)) {
|
||||
executorsByHost(o.hostname) = new HashSet()
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(o.hostPort)
|
||||
|
||||
executorIdToHostPort(o.executorId) = o.hostPort
|
||||
if (! executorsByHostPort.contains(o.hostPort)) {
|
||||
executorsByHostPort(o.hostPort) = new HashSet[String]()
|
||||
}
|
||||
|
||||
hostPortsAlive += o.hostPort
|
||||
hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(o.hostPort)._1, new HashSet[String]).add(o.hostPort)
|
||||
executorGained(o.executorId, o.hostPort)
|
||||
}
|
||||
// Build a list of tasks to assign to each slave
|
||||
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
|
||||
val availableCpus = offers.map(o => o.cores).toArray
|
||||
var launchedTask = false
|
||||
|
||||
|
||||
for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
|
||||
|
||||
// Split offers based on host local, rack local and off-rack tasks.
|
||||
val hostLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
|
||||
val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
|
||||
val otherOffers = new HashMap[String, ArrayBuffer[Int]]()
|
||||
|
||||
for (i <- 0 until offers.size) {
|
||||
val hostPort = offers(i).hostPort
|
||||
// DEBUG code
|
||||
Utils.checkHostPort(hostPort)
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
val numHostLocalTasks = math.max(0, math.min(manager.numPendingTasksForHost(hostPort), availableCpus(i)))
|
||||
if (numHostLocalTasks > 0){
|
||||
val list = hostLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
|
||||
for (j <- 0 until numHostLocalTasks) list += i
|
||||
}
|
||||
|
||||
val numRackLocalTasks = math.max(0,
|
||||
// Remove host local tasks (which are also rack local btw !) from this
|
||||
math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numHostLocalTasks, availableCpus(i)))
|
||||
if (numRackLocalTasks > 0){
|
||||
val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
|
||||
for (j <- 0 until numRackLocalTasks) list += i
|
||||
}
|
||||
if (numHostLocalTasks <= 0 && numRackLocalTasks <= 0){
|
||||
// add to others list - spread even this across cluster.
|
||||
val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
|
||||
list += i
|
||||
}
|
||||
}
|
||||
|
||||
val offersPriorityList = new ArrayBuffer[Int](
|
||||
hostLocalOffers.size + rackLocalOffers.size + otherOffers.size)
|
||||
// First host local, then rack, then others
|
||||
val numHostLocalOffers = {
|
||||
val hostLocalPriorityList = ClusterScheduler.prioritizeContainers(hostLocalOffers)
|
||||
offersPriorityList ++= hostLocalPriorityList
|
||||
hostLocalPriorityList.size
|
||||
}
|
||||
val numRackLocalOffers = {
|
||||
val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers)
|
||||
offersPriorityList ++= rackLocalPriorityList
|
||||
rackLocalPriorityList.size
|
||||
}
|
||||
offersPriorityList ++= ClusterScheduler.prioritizeContainers(otherOffers)
|
||||
|
||||
var lastLoop = false
|
||||
val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match {
|
||||
case TaskLocality.HOST_LOCAL => numHostLocalOffers
|
||||
case TaskLocality.RACK_LOCAL => numRackLocalOffers + numHostLocalOffers
|
||||
case TaskLocality.ANY => offersPriorityList.size
|
||||
}
|
||||
|
||||
do {
|
||||
launchedTask = false
|
||||
for (i <- 0 until offers.size) {
|
||||
var loopCount = 0
|
||||
for (i <- offersPriorityList) {
|
||||
val execId = offers(i).executorId
|
||||
val host = offers(i).hostname
|
||||
manager.slaveOffer(execId, host, availableCpus(i)) match {
|
||||
val hostPort = offers(i).hostPort
|
||||
|
||||
// If last loop and within the lastLoopIndex, expand scope - else use null (which will use default/existing)
|
||||
val overrideLocality = if (lastLoop && loopCount < lastLoopIndex) TASK_SCHEDULING_AGGRESSION else null
|
||||
|
||||
// If last loop, override waiting for host locality - we scheduled all local tasks already and there might be more available ...
|
||||
loopCount += 1
|
||||
|
||||
manager.slaveOffer(execId, hostPort, availableCpus(i), overrideLocality) match {
|
||||
case Some(task) =>
|
||||
tasks(i) += task
|
||||
val tid = task.taskId
|
||||
|
@ -162,15 +283,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
taskSetTaskIds(manager.taskSet.id) += tid
|
||||
taskIdToExecutorId(tid) = execId
|
||||
activeExecutorIds += execId
|
||||
executorsByHost(host) += execId
|
||||
executorsByHostPort(hostPort) += execId
|
||||
availableCpus(i) -= 1
|
||||
launchedTask = true
|
||||
|
||||
case None => {}
|
||||
}
|
||||
}
|
||||
// Loop once more - when lastLoop = true, then we try to schedule task on all nodes irrespective of
|
||||
// data locality (we still go in order of priority : but that would not change anything since
|
||||
// if data local tasks had been available, we would have scheduled them already)
|
||||
if (lastLoop) {
|
||||
// prevent more looping
|
||||
launchedTask = false
|
||||
} else if (!lastLoop && !launchedTask) {
|
||||
// Do this only if TASK_SCHEDULING_AGGRESSION != HOST_LOCAL
|
||||
if (TASK_SCHEDULING_AGGRESSION != TaskLocality.HOST_LOCAL) {
|
||||
// fudge launchedTask to ensure we loop once more
|
||||
launchedTask = true
|
||||
// dont loop anymore
|
||||
lastLoop = true
|
||||
}
|
||||
}
|
||||
} while (launchedTask)
|
||||
}
|
||||
|
||||
if (tasks.size > 0) {
|
||||
hasLaunchedTask = true
|
||||
}
|
||||
|
@ -256,10 +393,15 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
if (jarServer != null) {
|
||||
jarServer.stop()
|
||||
}
|
||||
|
||||
// sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
|
||||
// TODO: Do something better !
|
||||
Thread.sleep(5000L)
|
||||
}
|
||||
|
||||
override def defaultParallelism() = backend.defaultParallelism()
|
||||
|
||||
|
||||
// Check for speculatable tasks in all our active jobs.
|
||||
def checkSpeculatableTasks() {
|
||||
var shouldRevive = false
|
||||
|
@ -273,12 +415,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
}
|
||||
}
|
||||
|
||||
// Check for pending tasks in all our active jobs.
|
||||
def hasPendingTasks(): Boolean = {
|
||||
synchronized {
|
||||
activeTaskSetsQueue.exists( _.hasPendingTasks() )
|
||||
}
|
||||
}
|
||||
|
||||
def executorLost(executorId: String, reason: ExecutorLossReason) {
|
||||
var failedExecutor: Option[String] = None
|
||||
|
||||
synchronized {
|
||||
if (activeExecutorIds.contains(executorId)) {
|
||||
val host = executorIdToHost(executorId)
|
||||
logError("Lost executor %s on %s: %s".format(executorId, host, reason))
|
||||
val hostPort = executorIdToHostPort(executorId)
|
||||
logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
|
||||
removeExecutor(executorId)
|
||||
failedExecutor = Some(executorId)
|
||||
} else {
|
||||
|
@ -296,19 +446,95 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
}
|
||||
}
|
||||
|
||||
/** Get a list of hosts that currently have executors */
|
||||
def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet
|
||||
|
||||
/** Remove an executor from all our data structures and mark it as lost */
|
||||
private def removeExecutor(executorId: String) {
|
||||
activeExecutorIds -= executorId
|
||||
val host = executorIdToHost(executorId)
|
||||
val execs = executorsByHost.getOrElse(host, new HashSet)
|
||||
val hostPort = executorIdToHostPort(executorId)
|
||||
if (hostPortsAlive.contains(hostPort)) {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(hostPort)
|
||||
|
||||
hostPortsAlive -= hostPort
|
||||
hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort)
|
||||
}
|
||||
|
||||
val execs = executorsByHostPort.getOrElse(hostPort, new HashSet)
|
||||
execs -= executorId
|
||||
if (execs.isEmpty) {
|
||||
executorsByHost -= host
|
||||
executorsByHostPort -= hostPort
|
||||
}
|
||||
executorIdToHost -= executorId
|
||||
activeTaskSetsQueue.foreach(_.executorLost(executorId, host))
|
||||
executorIdToHostPort -= executorId
|
||||
activeTaskSetsQueue.foreach(_.executorLost(executorId, hostPort))
|
||||
}
|
||||
|
||||
def executorGained(execId: String, hostPort: String) {
|
||||
listener.executorGained(execId, hostPort)
|
||||
}
|
||||
|
||||
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = {
|
||||
val retval = hostToAliveHostPorts.get(host)
|
||||
if (retval.isDefined) {
|
||||
return Some(retval.get.toSet)
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
// By default, rack is unknown
|
||||
def getRackForHost(value: String): Option[String] = None
|
||||
|
||||
// By default, (cached) hosts for rack is unknown
|
||||
def getCachedHostsForRack(rack: String): Option[Set[String]] = None
|
||||
}
|
||||
|
||||
|
||||
object ClusterScheduler {
|
||||
|
||||
// Used to 'spray' available containers across the available set to ensure too many containers on same host
|
||||
// are not used up. Used in yarn mode and in task scheduling (when there are multiple containers available
|
||||
// to execute a task)
|
||||
// For example: yarn can returns more containers than we would have requested under ANY, this method
|
||||
// prioritizes how to use the allocated containers.
|
||||
// flatten the map such that the array buffer entries are spread out across the returned value.
|
||||
// given <host, list[container]> == <h1, [c1 .. c5]>, <h2, [c1 .. c3]>, <h3, [c1, c2]>, <h4, c1>, <h5, c1>, i
|
||||
// the return value would be something like : h1c1, h2c1, h3c1, h4c1, h5c1, h1c2, h2c2, h3c2, h1c3, h2c3, h1c4, h1c5
|
||||
// We then 'use' the containers in this order (consuming only the top K from this list where
|
||||
// K = number to be user). This is to ensure that if we have multiple eligible allocations,
|
||||
// they dont end up allocating all containers on a small number of hosts - increasing probability of
|
||||
// multiple container failure when a host goes down.
|
||||
// Note, there is bias for keys with higher number of entries in value to be picked first (by design)
|
||||
// Also note that invocation of this method is expected to have containers of same 'type'
|
||||
// (host-local, rack-local, off-rack) and not across types : so that reordering is simply better from
|
||||
// the available list - everything else being same.
|
||||
// That is, we we first consume data local, then rack local and finally off rack nodes. So the
|
||||
// prioritization from this method applies to within each category
|
||||
def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
|
||||
val _keyList = new ArrayBuffer[K](map.size)
|
||||
_keyList ++= map.keys
|
||||
|
||||
// order keyList based on population of value in map
|
||||
val keyList = _keyList.sortWith(
|
||||
(left, right) => map.get(left).getOrElse(Set()).size > map.get(right).getOrElse(Set()).size
|
||||
)
|
||||
|
||||
val retval = new ArrayBuffer[T](keyList.size * 2)
|
||||
var index = 0
|
||||
var found = true
|
||||
|
||||
while (found){
|
||||
found = false
|
||||
for (key <- keyList) {
|
||||
val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
|
||||
assert(containerList != null)
|
||||
// Get the index'th entry for this host - if present
|
||||
if (index < containerList.size){
|
||||
retval += containerList.apply(index)
|
||||
found = true
|
||||
}
|
||||
}
|
||||
index += 1
|
||||
}
|
||||
|
||||
retval.toList
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ private[spark] class SparkDeploySchedulerBackend(
|
|||
val driverUrl = "akka://spark@%s:%s/user/%s".format(
|
||||
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
|
||||
StandaloneSchedulerBackend.ACTOR_NAME)
|
||||
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
|
||||
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTPORT}}", "{{CORES}}")
|
||||
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
|
||||
val sparkHome = sc.getSparkHome().getOrElse(
|
||||
throw new IllegalArgumentException("must supply spark home for spark standalone"))
|
||||
|
@ -57,9 +57,9 @@ private[spark] class SparkDeploySchedulerBackend(
|
|||
}
|
||||
}
|
||||
|
||||
override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) {
|
||||
logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format(
|
||||
executorId, host, cores, Utils.memoryMegabytesToString(memory)))
|
||||
override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) {
|
||||
logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format(
|
||||
executorId, hostPort, cores, Utils.memoryMegabytesToString(memory)))
|
||||
}
|
||||
|
||||
override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) {
|
||||
|
|
|
@ -3,6 +3,7 @@ package spark.scheduler.cluster
|
|||
import spark.TaskState.TaskState
|
||||
import java.nio.ByteBuffer
|
||||
import spark.util.SerializableBuffer
|
||||
import spark.Utils
|
||||
|
||||
private[spark] sealed trait StandaloneClusterMessage extends Serializable
|
||||
|
||||
|
@ -19,8 +20,10 @@ case class RegisterExecutorFailed(message: String) extends StandaloneClusterMess
|
|||
|
||||
// Executors to driver
|
||||
private[spark]
|
||||
case class RegisterExecutor(executorId: String, host: String, cores: Int)
|
||||
extends StandaloneClusterMessage
|
||||
case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
|
||||
extends StandaloneClusterMessage {
|
||||
Utils.checkHostPort(hostPort, "Expected host port")
|
||||
}
|
||||
|
||||
private[spark]
|
||||
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer)
|
||||
|
|
|
@ -5,8 +5,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
|
|||
import akka.actor._
|
||||
import akka.util.duration._
|
||||
import akka.pattern.ask
|
||||
import akka.util.Duration
|
||||
|
||||
import spark.{SparkException, Logging, TaskState}
|
||||
import spark.{Utils, SparkException, Logging, TaskState}
|
||||
import akka.dispatch.Await
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
|
||||
|
@ -24,12 +25,12 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
|||
var totalCoreCount = new AtomicInteger(0)
|
||||
|
||||
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
|
||||
val executorActor = new HashMap[String, ActorRef]
|
||||
val executorAddress = new HashMap[String, Address]
|
||||
val executorHost = new HashMap[String, String]
|
||||
val freeCores = new HashMap[String, Int]
|
||||
val actorToExecutorId = new HashMap[ActorRef, String]
|
||||
val addressToExecutorId = new HashMap[Address, String]
|
||||
private val executorActor = new HashMap[String, ActorRef]
|
||||
private val executorAddress = new HashMap[String, Address]
|
||||
private val executorHostPort = new HashMap[String, String]
|
||||
private val freeCores = new HashMap[String, Int]
|
||||
private val actorToExecutorId = new HashMap[ActorRef, String]
|
||||
private val addressToExecutorId = new HashMap[Address, String]
|
||||
|
||||
override def preStart() {
|
||||
// Listen for remote client disconnection events, since they don't go through Akka's watch()
|
||||
|
@ -37,7 +38,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
|||
}
|
||||
|
||||
def receive = {
|
||||
case RegisterExecutor(executorId, host, cores) =>
|
||||
case RegisterExecutor(executorId, hostPort, cores) =>
|
||||
Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
|
||||
if (executorActor.contains(executorId)) {
|
||||
sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
|
||||
} else {
|
||||
|
@ -45,7 +47,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
|||
sender ! RegisteredExecutor(sparkProperties)
|
||||
context.watch(sender)
|
||||
executorActor(executorId) = sender
|
||||
executorHost(executorId) = host
|
||||
executorHostPort(executorId) = hostPort
|
||||
freeCores(executorId) = cores
|
||||
executorAddress(executorId) = sender.path.address
|
||||
actorToExecutorId(sender) = executorId
|
||||
|
@ -85,13 +87,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
|||
// Make fake resource offers on all executors
|
||||
def makeOffers() {
|
||||
launchTasks(scheduler.resourceOffers(
|
||||
executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
|
||||
executorHostPort.toArray.map {case (id, hostPort) => new WorkerOffer(id, hostPort, freeCores(id))}))
|
||||
}
|
||||
|
||||
// Make fake resource offers on just one executor
|
||||
def makeOffers(executorId: String) {
|
||||
launchTasks(scheduler.resourceOffers(
|
||||
Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
|
||||
Seq(new WorkerOffer(executorId, executorHostPort(executorId), freeCores(executorId)))))
|
||||
}
|
||||
|
||||
// Launch tasks returned by a set of resource offers
|
||||
|
@ -110,9 +112,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
|||
actorToExecutorId -= executorActor(executorId)
|
||||
addressToExecutorId -= executorAddress(executorId)
|
||||
executorActor -= executorId
|
||||
executorHost -= executorId
|
||||
executorHostPort -= executorId
|
||||
freeCores -= executorId
|
||||
executorHost -= executorId
|
||||
executorHostPort -= executorId
|
||||
totalCoreCount.addAndGet(-numCores)
|
||||
scheduler.executorLost(executorId, SlaveLost(reason))
|
||||
}
|
||||
|
@ -128,7 +130,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
|||
while (iterator.hasNext) {
|
||||
val entry = iterator.next
|
||||
val (key, value) = (entry.getKey.toString, entry.getValue.toString)
|
||||
if (key.startsWith("spark.")) {
|
||||
if (key.startsWith("spark.") && !key.equals("spark.hostPort")) {
|
||||
properties += ((key, value))
|
||||
}
|
||||
}
|
||||
|
@ -136,10 +138,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
|
|||
Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
|
||||
}
|
||||
|
||||
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
|
||||
|
||||
override def stop() {
|
||||
try {
|
||||
if (driverActor != null) {
|
||||
val timeout = 5.seconds
|
||||
val future = driverActor.ask(StopDriver)(timeout)
|
||||
Await.result(future, timeout)
|
||||
}
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
package spark.scheduler.cluster
|
||||
|
||||
import spark.Utils
|
||||
|
||||
/**
|
||||
* Information about a running task attempt inside a TaskSet.
|
||||
*/
|
||||
|
@ -9,8 +11,11 @@ class TaskInfo(
|
|||
val index: Int,
|
||||
val launchTime: Long,
|
||||
val executorId: String,
|
||||
val host: String,
|
||||
val preferred: Boolean) {
|
||||
val hostPort: String,
|
||||
val taskLocality: TaskLocality.TaskLocality) {
|
||||
|
||||
Utils.checkHostPort(hostPort, "Expected hostport")
|
||||
|
||||
var finishTime: Long = 0
|
||||
var failed = false
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package spark.scheduler.cluster
|
||||
|
||||
import java.util.Arrays
|
||||
import java.util.{HashMap => JHashMap}
|
||||
import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
|
@ -14,6 +13,36 @@ import spark.scheduler._
|
|||
import spark.TaskState.TaskState
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
private[spark] object TaskLocality extends Enumeration("HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging {
|
||||
|
||||
val HOST_LOCAL, RACK_LOCAL, ANY = Value
|
||||
|
||||
type TaskLocality = Value
|
||||
|
||||
def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
|
||||
|
||||
constraint match {
|
||||
case TaskLocality.HOST_LOCAL => condition == TaskLocality.HOST_LOCAL
|
||||
case TaskLocality.RACK_LOCAL => condition == TaskLocality.HOST_LOCAL || condition == TaskLocality.RACK_LOCAL
|
||||
// For anything else, allow
|
||||
case _ => true
|
||||
}
|
||||
}
|
||||
|
||||
def parse(str: String): TaskLocality = {
|
||||
// better way to do this ?
|
||||
try {
|
||||
TaskLocality.withName(str)
|
||||
} catch {
|
||||
case nEx: NoSuchElementException => {
|
||||
logWarning("Invalid task locality specified '" + str + "', defaulting to HOST_LOCAL");
|
||||
// default to preserve earlier behavior
|
||||
HOST_LOCAL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
|
||||
*/
|
||||
|
@ -47,14 +76,22 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
|
|||
// Last time when we launched a preferred task (for delay scheduling)
|
||||
var lastPreferredLaunchTime = System.currentTimeMillis
|
||||
|
||||
// List of pending tasks for each node. These collections are actually
|
||||
// List of pending tasks for each node (hyper local to container). These collections are actually
|
||||
// treated as stacks, in which new tasks are added to the end of the
|
||||
// ArrayBuffer and removed from the end. This makes it faster to detect
|
||||
// tasks that repeatedly fail because whenever a task failed, it is put
|
||||
// back at the head of the stack. They are also only cleaned up lazily;
|
||||
// when a task is launched, it remains in all the pending lists except
|
||||
// the one that it was launched from, but gets removed from them later.
|
||||
val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
|
||||
private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
|
||||
|
||||
// List of pending tasks for each node.
|
||||
// Essentially, similar to pendingTasksForHostPort, except at host level
|
||||
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
|
||||
|
||||
// List of pending tasks for each node based on rack locality.
|
||||
// Essentially, similar to pendingTasksForHost, except at rack level
|
||||
private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
|
||||
|
||||
// List containing pending tasks with no locality preferences
|
||||
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
|
||||
|
@ -96,26 +133,117 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
|
|||
addPendingTask(i)
|
||||
}
|
||||
|
||||
private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, rackLocal: Boolean = false): ArrayBuffer[String] = {
|
||||
// DEBUG code
|
||||
_taskPreferredLocations.foreach(h => Utils.checkHost(h, "taskPreferredLocation " + _taskPreferredLocations))
|
||||
|
||||
val taskPreferredLocations = if (! rackLocal) _taskPreferredLocations else {
|
||||
// Expand set to include all 'seen' rack local hosts.
|
||||
// This works since container allocation/management happens within master - so any rack locality information is updated in msater.
|
||||
// Best case effort, and maybe sort of kludge for now ... rework it later ?
|
||||
val hosts = new HashSet[String]
|
||||
_taskPreferredLocations.foreach(h => {
|
||||
val rackOpt = scheduler.getRackForHost(h)
|
||||
if (rackOpt.isDefined) {
|
||||
val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
|
||||
if (hostsOpt.isDefined) {
|
||||
hosts ++= hostsOpt.get
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that irrespective of what scheduler says, host is always added !
|
||||
hosts += h
|
||||
})
|
||||
|
||||
hosts
|
||||
}
|
||||
|
||||
val retval = new ArrayBuffer[String]
|
||||
scheduler.synchronized {
|
||||
for (prefLocation <- taskPreferredLocations) {
|
||||
val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(prefLocation)
|
||||
if (aliveLocationsOpt.isDefined) {
|
||||
retval ++= aliveLocationsOpt.get
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
retval
|
||||
}
|
||||
|
||||
// Add a task to all the pending-task lists that it should be on.
|
||||
private def addPendingTask(index: Int) {
|
||||
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
|
||||
if (locations.size == 0) {
|
||||
// We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
|
||||
// hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
|
||||
val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched)
|
||||
val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, true)
|
||||
|
||||
if (rackLocalLocations.size == 0) {
|
||||
// Current impl ensures this.
|
||||
assert (hostLocalLocations.size == 0)
|
||||
pendingTasksWithNoPrefs += index
|
||||
} else {
|
||||
for (host <- locations) {
|
||||
val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
|
||||
|
||||
// host locality
|
||||
for (hostPort <- hostLocalLocations) {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(hostPort)
|
||||
|
||||
val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
|
||||
hostPortList += index
|
||||
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
|
||||
hostList += index
|
||||
}
|
||||
|
||||
// rack locality
|
||||
for (rackLocalHostPort <- rackLocalLocations) {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(rackLocalHostPort)
|
||||
|
||||
val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
|
||||
val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
|
||||
list += index
|
||||
}
|
||||
}
|
||||
|
||||
allPendingTasks += index
|
||||
}
|
||||
|
||||
// Return the pending tasks list for a given host port (hyper local), or an empty list if
|
||||
// there is no map entry for that host
|
||||
private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(hostPort)
|
||||
pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
|
||||
}
|
||||
|
||||
// Return the pending tasks list for a given host, or an empty list if
|
||||
// there is no map entry for that host
|
||||
private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
|
||||
private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
pendingTasksForHost.getOrElse(host, ArrayBuffer())
|
||||
}
|
||||
|
||||
// Return the pending tasks (rack level) list for a given host, or an empty list if
|
||||
// there is no map entry for that host
|
||||
private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
|
||||
}
|
||||
|
||||
// Number of pending tasks for a given host (which would be data local)
|
||||
def numPendingTasksForHost(hostPort: String): Int = {
|
||||
getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
|
||||
}
|
||||
|
||||
// Number of pending rack local tasks for a given host
|
||||
def numRackLocalPendingTasksForHost(hostPort: String): Int = {
|
||||
getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
|
||||
}
|
||||
|
||||
|
||||
// Dequeue a pending task from the given list and return its index.
|
||||
// Return None if the list is empty.
|
||||
// This method also cleans up any tasks in the list that have already
|
||||
|
@ -132,26 +260,49 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
|
|||
}
|
||||
|
||||
// Return a speculative task for a given host if any are available. The task should not have an
|
||||
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
|
||||
// task must have a preference for this host (or no preferred locations at all).
|
||||
private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
|
||||
val hostsAlive = sched.hostsAlive
|
||||
// attempt running on this host, in case the host is slow. In addition, if locality is set, the
|
||||
// task must have a preference for this host/rack/no preferred locations at all.
|
||||
private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
|
||||
|
||||
assert (TaskLocality.isAllowed(locality, TaskLocality.HOST_LOCAL))
|
||||
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
|
||||
val localTask = speculatableTasks.find {
|
||||
index =>
|
||||
val locations = tasks(index).preferredLocations.toSet & hostsAlive
|
||||
val attemptLocs = taskAttempts(index).map(_.host)
|
||||
(locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
|
||||
|
||||
if (speculatableTasks.size > 0) {
|
||||
val localTask = speculatableTasks.find {
|
||||
index =>
|
||||
val locations = findPreferredLocations(tasks(index).preferredLocations, sched)
|
||||
val attemptLocs = taskAttempts(index).map(_.hostPort)
|
||||
(locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
|
||||
}
|
||||
|
||||
if (localTask != None) {
|
||||
speculatableTasks -= localTask.get
|
||||
return localTask
|
||||
}
|
||||
if (localTask != None) {
|
||||
speculatableTasks -= localTask.get
|
||||
return localTask
|
||||
}
|
||||
if (!localOnly && speculatableTasks.size > 0) {
|
||||
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host))
|
||||
if (nonLocalTask != None) {
|
||||
speculatableTasks -= nonLocalTask.get
|
||||
return nonLocalTask
|
||||
|
||||
// check for rack locality
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
|
||||
val rackTask = speculatableTasks.find {
|
||||
index =>
|
||||
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, true)
|
||||
val attemptLocs = taskAttempts(index).map(_.hostPort)
|
||||
locations.contains(hostPort) && !attemptLocs.contains(hostPort)
|
||||
}
|
||||
|
||||
if (rackTask != None) {
|
||||
speculatableTasks -= rackTask.get
|
||||
return rackTask
|
||||
}
|
||||
}
|
||||
|
||||
// Any task ...
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
|
||||
// Check for attemptLocs also ?
|
||||
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
|
||||
if (nonLocalTask != None) {
|
||||
speculatableTasks -= nonLocalTask.get
|
||||
return nonLocalTask
|
||||
}
|
||||
}
|
||||
}
|
||||
return None
|
||||
|
@ -159,59 +310,103 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
|
|||
|
||||
// Dequeue a pending task for a given node and return its index.
|
||||
// If localOnly is set to false, allow non-local tasks as well.
|
||||
private def findTask(host: String, localOnly: Boolean): Option[Int] = {
|
||||
val localTask = findTaskFromList(getPendingTasksForHost(host))
|
||||
private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
|
||||
val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
|
||||
if (localTask != None) {
|
||||
return localTask
|
||||
}
|
||||
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
|
||||
val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
|
||||
if (rackLocalTask != None) {
|
||||
return rackLocalTask
|
||||
}
|
||||
}
|
||||
|
||||
// Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
|
||||
// TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
|
||||
val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
|
||||
if (noPrefTask != None) {
|
||||
return noPrefTask
|
||||
}
|
||||
if (!localOnly) {
|
||||
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
|
||||
val nonLocalTask = findTaskFromList(allPendingTasks)
|
||||
if (nonLocalTask != None) {
|
||||
return nonLocalTask
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, if all else has failed, find a speculative task
|
||||
return findSpeculativeTask(host, localOnly)
|
||||
return findSpeculativeTask(hostPort, locality)
|
||||
}
|
||||
|
||||
// Does a host count as a preferred location for a task? This is true if
|
||||
// either the task has preferred locations and this host is one, or it has
|
||||
// no preferred locations (in which we still count the launch as preferred).
|
||||
private def isPreferredLocation(task: Task[_], host: String): Boolean = {
|
||||
private def isPreferredLocation(task: Task[_], hostPort: String): Boolean = {
|
||||
val locs = task.preferredLocations
|
||||
return (locs.contains(host) || locs.isEmpty)
|
||||
// DEBUG code
|
||||
locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs))
|
||||
|
||||
if (locs.contains(hostPort) || locs.isEmpty) return true
|
||||
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
locs.contains(host)
|
||||
}
|
||||
|
||||
// Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
|
||||
// This is true if either the task has preferred locations and this host is one, or it has
|
||||
// no preferred locations (in which we still count the launch as preferred).
|
||||
def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
|
||||
|
||||
val locs = task.preferredLocations
|
||||
|
||||
// DEBUG code
|
||||
locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs))
|
||||
|
||||
val preferredRacks = new HashSet[String]()
|
||||
for (preferredHost <- locs) {
|
||||
val rack = sched.getRackForHost(preferredHost)
|
||||
if (None != rack) preferredRacks += rack.get
|
||||
}
|
||||
|
||||
if (preferredRacks.isEmpty) return false
|
||||
|
||||
val hostRack = sched.getRackForHost(hostPort)
|
||||
|
||||
return None != hostRack && preferredRacks.contains(hostRack.get)
|
||||
}
|
||||
|
||||
// Respond to an offer of a single slave from the scheduler by finding a task
|
||||
def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
|
||||
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
|
||||
val time = System.currentTimeMillis
|
||||
val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
|
||||
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
|
||||
|
||||
findTask(host, localOnly) match {
|
||||
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
|
||||
// If explicitly specified, use that
|
||||
val locality = if (overrideLocality != null) overrideLocality else {
|
||||
// expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
|
||||
val time = System.currentTimeMillis
|
||||
if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.HOST_LOCAL else TaskLocality.ANY
|
||||
}
|
||||
|
||||
findTask(hostPort, locality) match {
|
||||
case Some(index) => {
|
||||
// Found a task; do some bookkeeping and return a Mesos task for it
|
||||
val task = tasks(index)
|
||||
val taskId = sched.newTaskId()
|
||||
// Figure out whether this should count as a preferred launch
|
||||
val preferred = isPreferredLocation(task, host)
|
||||
val prefStr = if (preferred) {
|
||||
"preferred"
|
||||
} else {
|
||||
"non-preferred, not one of " + task.preferredLocations.mkString(", ")
|
||||
}
|
||||
logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
|
||||
taskSet.id, index, taskId, execId, host, prefStr))
|
||||
val taskLocality = if (isPreferredLocation(task, hostPort)) TaskLocality.HOST_LOCAL else
|
||||
if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else TaskLocality.ANY
|
||||
val prefStr = taskLocality.toString
|
||||
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
|
||||
taskSet.id, index, taskId, execId, hostPort, prefStr))
|
||||
// Do various bookkeeping
|
||||
copiesRunning(index) += 1
|
||||
val info = new TaskInfo(taskId, index, time, execId, host, preferred)
|
||||
val time = System.currentTimeMillis
|
||||
val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
|
||||
taskInfos(taskId) = info
|
||||
taskAttempts(index) = info :: taskAttempts(index)
|
||||
if (preferred) {
|
||||
if (TaskLocality.HOST_LOCAL == taskLocality) {
|
||||
lastPreferredLaunchTime = time
|
||||
}
|
||||
// Serialize and return the task
|
||||
|
@ -355,17 +550,15 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
|
|||
sched.taskSetFinished(this)
|
||||
}
|
||||
|
||||
def executorLost(execId: String, hostname: String) {
|
||||
def executorLost(execId: String, hostPort: String) {
|
||||
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
|
||||
val newHostsAlive = sched.hostsAlive
|
||||
// If some task has preferred locations only on hostname, and there are no more executors there,
|
||||
// put it in the no-prefs list to avoid the wait from delay scheduling
|
||||
if (!newHostsAlive.contains(hostname)) {
|
||||
for (index <- getPendingTasksForHost(hostname)) {
|
||||
val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive
|
||||
if (newLocs.isEmpty) {
|
||||
pendingTasksWithNoPrefs += index
|
||||
}
|
||||
for (index <- getPendingTasksForHostPort(hostPort)) {
|
||||
val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, true)
|
||||
if (newLocs.isEmpty) {
|
||||
assert (findPreferredLocations(tasks(index).preferredLocations, sched).isEmpty)
|
||||
pendingTasksWithNoPrefs += index
|
||||
}
|
||||
}
|
||||
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
|
||||
|
@ -419,7 +612,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
|
|||
!speculatableTasks.contains(index)) {
|
||||
logInfo(
|
||||
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
|
||||
taskSet.id, index, info.host, threshold))
|
||||
taskSet.id, index, info.hostPort, threshold))
|
||||
speculatableTasks += index
|
||||
foundTasks = true
|
||||
}
|
||||
|
@ -427,4 +620,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
|
|||
}
|
||||
return foundTasks
|
||||
}
|
||||
|
||||
def hasPendingTasks(): Boolean = {
|
||||
numTasks > 0 && tasksFinished < numTasks
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,5 +4,5 @@ package spark.scheduler.cluster
|
|||
* Represents free resources available on an executor.
|
||||
*/
|
||||
private[spark]
|
||||
class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) {
|
||||
class WorkerOffer(val executorId: String, val hostPort: String, val cores: Int) {
|
||||
}
|
||||
|
|
|
@ -7,7 +7,7 @@ import scala.collection.mutable.HashMap
|
|||
import spark._
|
||||
import spark.executor.ExecutorURLClassLoader
|
||||
import spark.scheduler._
|
||||
import spark.scheduler.cluster.TaskInfo
|
||||
import spark.scheduler.cluster.{TaskLocality, TaskInfo}
|
||||
|
||||
/**
|
||||
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
|
||||
|
@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
|
|||
|
||||
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
|
||||
logInfo("Running " + task)
|
||||
val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true)
|
||||
val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.HOST_LOCAL)
|
||||
// Set the Spark execution environment for the worker thread
|
||||
SparkEnv.set(env)
|
||||
try {
|
||||
|
|
|
@ -37,17 +37,27 @@ class BlockManager(
|
|||
maxMemory: Long)
|
||||
extends Logging {
|
||||
|
||||
class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
|
||||
var pending: Boolean = true
|
||||
var size: Long = -1L
|
||||
var failed: Boolean = false
|
||||
private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
|
||||
@volatile var pending: Boolean = true
|
||||
@volatile var size: Long = -1L
|
||||
@volatile var initThread: Thread = null
|
||||
@volatile var failed = false
|
||||
|
||||
setInitThread()
|
||||
|
||||
private def setInitThread() {
|
||||
// Set current thread as init thread - waitForReady will not block this thread
|
||||
// (in case there is non trivial initialization which ends up calling waitForReady as part of
|
||||
// initialization itself)
|
||||
this.initThread = Thread.currentThread()
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
|
||||
* Return true if the block is available, false otherwise.
|
||||
*/
|
||||
def waitForReady(): Boolean = {
|
||||
if (pending) {
|
||||
if (initThread != Thread.currentThread() && pending) {
|
||||
synchronized {
|
||||
while (pending) this.wait()
|
||||
}
|
||||
|
@ -57,19 +67,26 @@ class BlockManager(
|
|||
|
||||
/** Mark this BlockInfo as ready (i.e. block is finished writing) */
|
||||
def markReady(sizeInBytes: Long) {
|
||||
assert (pending)
|
||||
size = sizeInBytes
|
||||
initThread = null
|
||||
failed = false
|
||||
initThread = null
|
||||
pending = false
|
||||
synchronized {
|
||||
pending = false
|
||||
failed = false
|
||||
size = sizeInBytes
|
||||
this.notifyAll()
|
||||
}
|
||||
}
|
||||
|
||||
/** Mark this BlockInfo as ready but failed */
|
||||
def markFailure() {
|
||||
assert (pending)
|
||||
size = 0
|
||||
initThread = null
|
||||
failed = true
|
||||
initThread = null
|
||||
pending = false
|
||||
synchronized {
|
||||
failed = true
|
||||
pending = false
|
||||
this.notifyAll()
|
||||
}
|
||||
}
|
||||
|
@ -101,7 +118,7 @@ class BlockManager(
|
|||
|
||||
val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties
|
||||
|
||||
val host = System.getProperty("spark.hostname", Utils.localHostName())
|
||||
val hostPort = Utils.localHostPort()
|
||||
|
||||
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
|
||||
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
|
||||
|
@ -212,9 +229,12 @@ class BlockManager(
|
|||
* Tell the master about the current storage status of a block. This will send a block update
|
||||
* message reflecting the current status, *not* the desired storage level in its block info.
|
||||
* For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk.
|
||||
*
|
||||
* droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
|
||||
* This ensures that update in master will compensate for the increase in memory on slave.
|
||||
*/
|
||||
def reportBlockStatus(blockId: String, info: BlockInfo) {
|
||||
val needReregister = !tryToReportBlockStatus(blockId, info)
|
||||
def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) {
|
||||
val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
|
||||
if (needReregister) {
|
||||
logInfo("Got told to reregister updating block " + blockId)
|
||||
// Reregistering will report our new block for free.
|
||||
|
@ -228,7 +248,7 @@ class BlockManager(
|
|||
* which will be true if the block was successfully recorded and false if
|
||||
* the slave needs to re-register.
|
||||
*/
|
||||
private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = {
|
||||
private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
|
||||
val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
|
||||
info.level match {
|
||||
case null =>
|
||||
|
@ -237,7 +257,7 @@ class BlockManager(
|
|||
val inMem = level.useMemory && memoryStore.contains(blockId)
|
||||
val onDisk = level.useDisk && diskStore.contains(blockId)
|
||||
val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
|
||||
val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
|
||||
val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize
|
||||
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
|
||||
(storageLevel, memSize, diskSize, info.tellMaster)
|
||||
}
|
||||
|
@ -257,7 +277,7 @@ class BlockManager(
|
|||
def getLocations(blockId: String): Seq[String] = {
|
||||
val startTimeMs = System.currentTimeMillis
|
||||
var managers = master.getLocations(blockId)
|
||||
val locations = managers.map(_.ip)
|
||||
val locations = managers.map(_.hostPort)
|
||||
logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs))
|
||||
return locations
|
||||
}
|
||||
|
@ -267,7 +287,7 @@ class BlockManager(
|
|||
*/
|
||||
def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
|
||||
val startTimeMs = System.currentTimeMillis
|
||||
val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
|
||||
val locations = master.getLocations(blockIds).map(_.map(_.hostPort).toSeq).toArray
|
||||
logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
|
||||
return locations
|
||||
}
|
||||
|
@ -339,6 +359,8 @@ class BlockManager(
|
|||
case Some(bytes) =>
|
||||
// Put a copy of the block back in memory before returning it. Note that we can't
|
||||
// put the ByteBuffer returned by the disk store as that's a memory-mapped file.
|
||||
// The use of rewind assumes this.
|
||||
assert (0 == bytes.position())
|
||||
val copyForMemory = ByteBuffer.allocate(bytes.limit)
|
||||
copyForMemory.put(bytes)
|
||||
memoryStore.putBytes(blockId, copyForMemory, level)
|
||||
|
@ -411,6 +433,7 @@ class BlockManager(
|
|||
// Read it as a byte buffer into memory first, then return it
|
||||
diskStore.getBytes(blockId) match {
|
||||
case Some(bytes) =>
|
||||
assert (0 == bytes.position())
|
||||
if (level.useMemory) {
|
||||
if (level.deserialized) {
|
||||
memoryStore.putBytes(blockId, bytes, level)
|
||||
|
@ -450,7 +473,7 @@ class BlockManager(
|
|||
for (loc <- locations) {
|
||||
logDebug("Getting remote block " + blockId + " from " + loc)
|
||||
val data = BlockManagerWorker.syncGetBlock(
|
||||
GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
|
||||
GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
|
||||
if (data != null) {
|
||||
return Some(dataDeserialize(blockId, data))
|
||||
}
|
||||
|
@ -501,17 +524,17 @@ class BlockManager(
|
|||
throw new IllegalArgumentException("Storage level is null or invalid")
|
||||
}
|
||||
|
||||
val oldBlock = blockInfo.get(blockId).orNull
|
||||
if (oldBlock != null && oldBlock.waitForReady()) {
|
||||
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
|
||||
return oldBlock.size
|
||||
}
|
||||
|
||||
// Remember the block's storage level so that we can correctly drop it to disk if it needs
|
||||
// to be dropped right after it got put into memory. Note, however, that other threads will
|
||||
// not be able to get() this block until we call markReady on its BlockInfo.
|
||||
val myInfo = new BlockInfo(level, tellMaster)
|
||||
blockInfo.put(blockId, myInfo)
|
||||
// Do atomically !
|
||||
val oldBlockOpt = blockInfo.putIfAbsent(blockId, myInfo)
|
||||
|
||||
if (oldBlockOpt.isDefined && oldBlockOpt.get.waitForReady()) {
|
||||
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
|
||||
return oldBlockOpt.get.size
|
||||
}
|
||||
|
||||
val startTimeMs = System.currentTimeMillis
|
||||
|
||||
|
@ -531,6 +554,7 @@ class BlockManager(
|
|||
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
|
||||
+ " to get into synchronized block")
|
||||
|
||||
var marked = false
|
||||
try {
|
||||
if (level.useMemory) {
|
||||
// Save it just to memory first, even if it also has useDisk set to true; we will later
|
||||
|
@ -555,20 +579,20 @@ class BlockManager(
|
|||
|
||||
// Now that the block is in either the memory or disk store, let other threads read it,
|
||||
// and tell the master about it.
|
||||
marked = true
|
||||
myInfo.markReady(size)
|
||||
if (tellMaster) {
|
||||
reportBlockStatus(blockId, myInfo)
|
||||
}
|
||||
} catch {
|
||||
} finally {
|
||||
// If we failed at putting the block to memory/disk, notify other possible readers
|
||||
// that it has failed, and then remove it from the block info map.
|
||||
case e: Exception => {
|
||||
if (! marked) {
|
||||
// Note that the remove must happen before markFailure otherwise another thread
|
||||
// could've inserted a new BlockInfo before we remove it.
|
||||
blockInfo.remove(blockId)
|
||||
myInfo.markFailure()
|
||||
logWarning("Putting block " + blockId + " failed", e)
|
||||
throw e
|
||||
logWarning("Putting block " + blockId + " failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -611,16 +635,17 @@ class BlockManager(
|
|||
throw new IllegalArgumentException("Storage level is null or invalid")
|
||||
}
|
||||
|
||||
if (blockInfo.contains(blockId)) {
|
||||
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
|
||||
return
|
||||
}
|
||||
|
||||
// Remember the block's storage level so that we can correctly drop it to disk if it needs
|
||||
// to be dropped right after it got put into memory. Note, however, that other threads will
|
||||
// not be able to get() this block until we call markReady on its BlockInfo.
|
||||
val myInfo = new BlockInfo(level, tellMaster)
|
||||
blockInfo.put(blockId, myInfo)
|
||||
// Do atomically !
|
||||
val prevInfo = blockInfo.putIfAbsent(blockId, myInfo)
|
||||
if (prevInfo != null) {
|
||||
// Should we check for prevInfo.waitForReady() here ?
|
||||
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
|
||||
return
|
||||
}
|
||||
|
||||
val startTimeMs = System.currentTimeMillis
|
||||
|
||||
|
@ -639,6 +664,7 @@ class BlockManager(
|
|||
logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
|
||||
+ " to get into synchronized block")
|
||||
|
||||
var marked = false
|
||||
try {
|
||||
if (level.useMemory) {
|
||||
// Store it only in memory at first, even if useDisk is also set to true
|
||||
|
@ -649,22 +675,24 @@ class BlockManager(
|
|||
diskStore.putBytes(blockId, bytes, level)
|
||||
}
|
||||
|
||||
// assert (0 == bytes.position(), "" + bytes)
|
||||
|
||||
// Now that the block is in either the memory or disk store, let other threads read it,
|
||||
// and tell the master about it.
|
||||
marked = true
|
||||
myInfo.markReady(bytes.limit)
|
||||
if (tellMaster) {
|
||||
reportBlockStatus(blockId, myInfo)
|
||||
}
|
||||
} catch {
|
||||
} finally {
|
||||
// If we failed at putting the block to memory/disk, notify other possible readers
|
||||
// that it has failed, and then remove it from the block info map.
|
||||
case e: Exception => {
|
||||
if (! marked) {
|
||||
// Note that the remove must happen before markFailure otherwise another thread
|
||||
// could've inserted a new BlockInfo before we remove it.
|
||||
blockInfo.remove(blockId)
|
||||
myInfo.markFailure()
|
||||
logWarning("Putting block " + blockId + " failed", e)
|
||||
throw e
|
||||
logWarning("Putting block " + blockId + " failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -698,7 +726,7 @@ class BlockManager(
|
|||
logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
|
||||
+ data.limit() + " Bytes. To node: " + peer)
|
||||
if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
|
||||
new ConnectionManagerId(peer.ip, peer.port))) {
|
||||
new ConnectionManagerId(peer.host, peer.port))) {
|
||||
logError("Failed to call syncPutBlock to " + peer)
|
||||
}
|
||||
logDebug("Replicated BlockId " + blockId + " once used " +
|
||||
|
@ -730,6 +758,14 @@ class BlockManager(
|
|||
val info = blockInfo.get(blockId).orNull
|
||||
if (info != null) {
|
||||
info.synchronized {
|
||||
// required ? As of now, this will be invoked only for blocks which are ready
|
||||
// But in case this changes in future, adding for consistency sake.
|
||||
if (! info.waitForReady() ) {
|
||||
// If we get here, the block write failed.
|
||||
logWarning("Block " + blockId + " was marked as failure. Nothing to drop")
|
||||
return
|
||||
}
|
||||
|
||||
val level = info.level
|
||||
if (level.useDisk && !diskStore.contains(blockId)) {
|
||||
logInfo("Writing block " + blockId + " to disk")
|
||||
|
@ -740,12 +776,13 @@ class BlockManager(
|
|||
diskStore.putBytes(blockId, bytes, level)
|
||||
}
|
||||
}
|
||||
val droppedMemorySize = memoryStore.getSize(blockId)
|
||||
val blockWasRemoved = memoryStore.remove(blockId)
|
||||
if (!blockWasRemoved) {
|
||||
logWarning("Block " + blockId + " could not be dropped from memory as it does not exist")
|
||||
}
|
||||
if (info.tellMaster) {
|
||||
reportBlockStatus(blockId, info)
|
||||
reportBlockStatus(blockId, info, droppedMemorySize)
|
||||
}
|
||||
if (!level.useDisk) {
|
||||
// The block is completely gone from this node; forget it so we can put() it again later.
|
||||
|
@ -938,8 +975,8 @@ class BlockFetcherIterator(
|
|||
|
||||
def sendRequest(req: FetchRequest) {
|
||||
logDebug("Sending request for %d blocks (%s) from %s".format(
|
||||
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
|
||||
val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
|
||||
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort))
|
||||
val cmId = new ConnectionManagerId(req.address.host, req.address.port)
|
||||
val blockMessageArray = new BlockMessageArray(req.blocks.map {
|
||||
case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
|
||||
})
|
||||
|
|
|
@ -2,6 +2,7 @@ package spark.storage
|
|||
|
||||
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import spark.Utils
|
||||
|
||||
/**
|
||||
* This class represent an unique identifier for a BlockManager.
|
||||
|
@ -13,7 +14,7 @@ import java.util.concurrent.ConcurrentHashMap
|
|||
*/
|
||||
private[spark] class BlockManagerId private (
|
||||
private var executorId_ : String,
|
||||
private var ip_ : String,
|
||||
private var host_ : String,
|
||||
private var port_ : Int
|
||||
) extends Externalizable {
|
||||
|
||||
|
@ -21,32 +22,45 @@ private[spark] class BlockManagerId private (
|
|||
|
||||
def executorId: String = executorId_
|
||||
|
||||
def ip: String = ip_
|
||||
if (null != host_){
|
||||
Utils.checkHost(host_, "Expected hostname")
|
||||
assert (port_ > 0)
|
||||
}
|
||||
|
||||
def hostPort: String = {
|
||||
// DEBUG code
|
||||
Utils.checkHost(host)
|
||||
assert (port > 0)
|
||||
|
||||
host + ":" + port
|
||||
}
|
||||
|
||||
def host: String = host_
|
||||
|
||||
def port: Int = port_
|
||||
|
||||
override def writeExternal(out: ObjectOutput) {
|
||||
out.writeUTF(executorId_)
|
||||
out.writeUTF(ip_)
|
||||
out.writeUTF(host_)
|
||||
out.writeInt(port_)
|
||||
}
|
||||
|
||||
override def readExternal(in: ObjectInput) {
|
||||
executorId_ = in.readUTF()
|
||||
ip_ = in.readUTF()
|
||||
host_ = in.readUTF()
|
||||
port_ = in.readInt()
|
||||
}
|
||||
|
||||
@throws(classOf[IOException])
|
||||
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
|
||||
|
||||
override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port)
|
||||
override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, host, port)
|
||||
|
||||
override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port
|
||||
override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port
|
||||
|
||||
override def equals(that: Any) = that match {
|
||||
case id: BlockManagerId =>
|
||||
executorId == id.executorId && port == id.port && ip == id.ip
|
||||
executorId == id.executorId && port == id.port && host == id.host
|
||||
case _ =>
|
||||
false
|
||||
}
|
||||
|
@ -55,8 +69,8 @@ private[spark] class BlockManagerId private (
|
|||
|
||||
private[spark] object BlockManagerId {
|
||||
|
||||
def apply(execId: String, ip: String, port: Int) =
|
||||
getCachedBlockManagerId(new BlockManagerId(execId, ip, port))
|
||||
def apply(execId: String, host: String, port: Int) =
|
||||
getCachedBlockManagerId(new BlockManagerId(execId, host, port))
|
||||
|
||||
def apply(in: ObjectInput) = {
|
||||
val obj = new BlockManagerId()
|
||||
|
@ -67,11 +81,7 @@ private[spark] object BlockManagerId {
|
|||
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
|
||||
|
||||
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
|
||||
if (blockManagerIdCache.containsKey(id)) {
|
||||
blockManagerIdCache.get(id)
|
||||
} else {
|
||||
blockManagerIdCache.put(id, id)
|
||||
id
|
||||
}
|
||||
blockManagerIdCache.putIfAbsent(id, id)
|
||||
blockManagerIdCache.get(id)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -332,8 +332,8 @@ object BlockManagerMasterActor {
|
|||
// Mapping from block id to its status.
|
||||
private val _blocks = new JHashMap[String, BlockStatus]
|
||||
|
||||
logInfo("Registering block manager %s:%d with %s RAM".format(
|
||||
blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem)))
|
||||
logInfo("Registering block manager %s with %s RAM".format(
|
||||
blockManagerId.hostPort, Utils.memoryBytesToString(maxMem)))
|
||||
|
||||
def updateLastSeenMs() {
|
||||
_lastSeenMs = System.currentTimeMillis()
|
||||
|
@ -358,13 +358,13 @@ object BlockManagerMasterActor {
|
|||
_blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize))
|
||||
if (storageLevel.useMemory) {
|
||||
_remainingMem -= memSize
|
||||
logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format(
|
||||
blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
|
||||
logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
|
||||
blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize),
|
||||
Utils.memoryBytesToString(_remainingMem)))
|
||||
}
|
||||
if (storageLevel.useDisk) {
|
||||
logInfo("Added %s on disk on %s:%d (size: %s)".format(
|
||||
blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
|
||||
logInfo("Added %s on disk on %s (size: %s)".format(
|
||||
blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize)))
|
||||
}
|
||||
} else if (_blocks.containsKey(blockId)) {
|
||||
// If isValid is not true, drop the block.
|
||||
|
@ -372,13 +372,13 @@ object BlockManagerMasterActor {
|
|||
_blocks.remove(blockId)
|
||||
if (blockStatus.storageLevel.useMemory) {
|
||||
_remainingMem += blockStatus.memSize
|
||||
logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format(
|
||||
blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
|
||||
logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
|
||||
blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize),
|
||||
Utils.memoryBytesToString(_remainingMem)))
|
||||
}
|
||||
if (blockStatus.storageLevel.useDisk) {
|
||||
logInfo("Removed %s on %s:%d on disk (size: %s)".format(
|
||||
blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
|
||||
logInfo("Removed %s on %s on disk (size: %s)".format(
|
||||
blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -115,6 +115,7 @@ private[spark] object BlockMessageArray {
|
|||
val newBuffer = ByteBuffer.allocate(totalSize)
|
||||
newBuffer.clear()
|
||||
bufferMessage.buffers.foreach(buffer => {
|
||||
assert (0 == buffer.position())
|
||||
newBuffer.put(buffer)
|
||||
buffer.rewind()
|
||||
})
|
||||
|
|
|
@ -20,6 +20,9 @@ import spark.Utils
|
|||
private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
||||
extends BlockStore(blockManager) {
|
||||
|
||||
private val mapMode = MapMode.READ_ONLY
|
||||
private var mapOpenMode = "r"
|
||||
|
||||
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
|
||||
val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
|
||||
|
||||
|
@ -35,7 +38,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
getFile(blockId).length()
|
||||
}
|
||||
|
||||
override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
|
||||
override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
|
||||
// So that we do not modify the input offsets !
|
||||
// duplicate does not copy buffer, so inexpensive
|
||||
val bytes = _bytes.duplicate()
|
||||
logDebug("Attempting to put block " + blockId)
|
||||
val startTime = System.currentTimeMillis
|
||||
val file = createFile(blockId)
|
||||
|
@ -49,6 +55,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime)))
|
||||
}
|
||||
|
||||
private def getFileBytes(file: File): ByteBuffer = {
|
||||
val length = file.length()
|
||||
val channel = new RandomAccessFile(file, mapOpenMode).getChannel()
|
||||
val buffer = try {
|
||||
channel.map(mapMode, 0, length)
|
||||
} finally {
|
||||
channel.close()
|
||||
}
|
||||
|
||||
buffer
|
||||
}
|
||||
|
||||
override def putValues(
|
||||
blockId: String,
|
||||
values: ArrayBuffer[Any],
|
||||
|
@ -70,9 +88,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
|
||||
if (returnValues) {
|
||||
// Return a byte buffer for the contents of the file
|
||||
val channel = new RandomAccessFile(file, "r").getChannel()
|
||||
val buffer = channel.map(MapMode.READ_ONLY, 0, length)
|
||||
channel.close()
|
||||
val buffer = getFileBytes(file)
|
||||
PutResult(length, Right(buffer))
|
||||
} else {
|
||||
PutResult(length, null)
|
||||
|
@ -81,10 +97,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
|
||||
override def getBytes(blockId: String): Option[ByteBuffer] = {
|
||||
val file = getFile(blockId)
|
||||
val length = file.length().toInt
|
||||
val channel = new RandomAccessFile(file, "r").getChannel()
|
||||
val bytes = channel.map(MapMode.READ_ONLY, 0, length)
|
||||
channel.close()
|
||||
val bytes = getFileBytes(file)
|
||||
Some(bytes)
|
||||
}
|
||||
|
||||
|
@ -96,7 +109,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
val file = getFile(blockId)
|
||||
if (file.exists()) {
|
||||
file.delete()
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
|
@ -175,11 +187,12 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
}
|
||||
|
||||
private def addShutdownHook() {
|
||||
localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir) )
|
||||
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
|
||||
override def run() {
|
||||
logDebug("Shutdown hook called")
|
||||
try {
|
||||
localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
|
||||
localDirs.foreach(localDir => if (! Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir))
|
||||
} catch {
|
||||
case t: Throwable => logError("Exception while deleting local spark dirs", t)
|
||||
}
|
||||
|
|
|
@ -31,7 +31,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
|
|||
}
|
||||
}
|
||||
|
||||
override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
|
||||
override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
|
||||
// Work on a duplicate - since the original input might be used elsewhere.
|
||||
val bytes = _bytes.duplicate()
|
||||
bytes.rewind()
|
||||
if (level.deserialized) {
|
||||
val values = blockManager.dataDeserialize(blockId, bytes)
|
||||
|
|
|
@ -123,11 +123,7 @@ object StorageLevel {
|
|||
val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
|
||||
|
||||
private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
|
||||
if (storageLevelCache.containsKey(level)) {
|
||||
storageLevelCache.get(level)
|
||||
} else {
|
||||
storageLevelCache.put(level, level)
|
||||
level
|
||||
}
|
||||
storageLevelCache.putIfAbsent(level, level)
|
||||
storageLevelCache.get(level)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ import cc.spray.{SprayCanRootService, HttpService}
|
|||
import cc.spray.can.server.HttpServer
|
||||
import cc.spray.io.pipelines.MessageHandlerDispatch.SingletonHandler
|
||||
import akka.dispatch.Await
|
||||
import spark.SparkException
|
||||
import spark.{Utils, SparkException}
|
||||
import java.util.concurrent.TimeoutException
|
||||
|
||||
/**
|
||||
|
@ -31,7 +31,10 @@ private[spark] object AkkaUtils {
|
|||
val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
|
||||
val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt
|
||||
val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt
|
||||
val lifecycleEvents = System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean
|
||||
val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
|
||||
// 10 seconds is the default akka timeout, but in a cluster, we need higher by default.
|
||||
val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt
|
||||
|
||||
val akkaConf = ConfigFactory.parseString("""
|
||||
akka.daemonic = on
|
||||
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
|
||||
|
@ -45,8 +48,9 @@ private[spark] object AkkaUtils {
|
|||
akka.remote.netty.execution-pool-size = %d
|
||||
akka.actor.default-dispatcher.throughput = %d
|
||||
akka.remote.log-remote-lifecycle-events = %s
|
||||
akka.remote.netty.write-timeout = %ds
|
||||
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
|
||||
if (lifecycleEvents) "on" else "off"))
|
||||
lifecycleEvents, akkaWriteTimeout))
|
||||
|
||||
val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
|
||||
|
||||
|
@ -60,6 +64,7 @@ private[spark] object AkkaUtils {
|
|||
/**
|
||||
* Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to
|
||||
* handle requests. Returns the bound port or throws a SparkException on failure.
|
||||
* TODO: Not changing ip to host here - is it required ?
|
||||
*/
|
||||
def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route,
|
||||
name: String = "HttpServer"): ActorRef = {
|
||||
|
|
|
@ -3,6 +3,7 @@ package spark.util
|
|||
import java.util.concurrent.ConcurrentHashMap
|
||||
import scala.collection.JavaConversions
|
||||
import scala.collection.mutable.Map
|
||||
import spark.scheduler.MapStatus
|
||||
|
||||
/**
|
||||
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
|
||||
|
@ -42,6 +43,13 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
|
|||
this
|
||||
}
|
||||
|
||||
// Should we return previous value directly or as Option ?
|
||||
def putIfAbsent(key: A, value: B): Option[B] = {
|
||||
val prev = internalMap.putIfAbsent(key, (value, currentTime))
|
||||
if (prev != null) Some(prev._1) else None
|
||||
}
|
||||
|
||||
|
||||
override def -= (key: A): this.type = {
|
||||
internalMap.remove(key)
|
||||
this
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
@import spark.deploy.master._
|
||||
@import spark.Utils
|
||||
|
||||
@spark.common.html.layout(title = "Spark Master on " + state.host) {
|
||||
@spark.common.html.layout(title = "Spark Master on " + state.host + ":" + state.port) {
|
||||
|
||||
<!-- Cluster Details -->
|
||||
<div class="row">
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
@(worker: spark.deploy.WorkerState)
|
||||
@import spark.Utils
|
||||
|
||||
@spark.common.html.layout(title = "Spark Worker on " + worker.host) {
|
||||
@spark.common.html.layout(title = "Spark Worker on " + worker.host + ":" + worker.port) {
|
||||
|
||||
<!-- Worker Details -->
|
||||
<div class="row">
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
<tbody>
|
||||
@for(status <- workersStatusList) {
|
||||
<tr>
|
||||
<td>@(status.blockManagerId.ip + ":" + status.blockManagerId.port)</td>
|
||||
<td>@(status.blockManagerId.host + ":" + status.blockManagerId.port)</td>
|
||||
<td>
|
||||
@(Utils.memoryBytesToString(status.memUsed(prefix)))
|
||||
(@(Utils.memoryBytesToString(status.memRemaining)) Total Available)
|
||||
|
|
|
@ -153,7 +153,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
|
|||
val blockManager = SparkEnv.get.blockManager
|
||||
blockManager.master.getLocations(blockId).foreach(id => {
|
||||
val bytes = BlockManagerWorker.syncGetBlock(
|
||||
GetBlock(blockId), ConnectionManagerId(id.ip, id.port))
|
||||
GetBlock(blockId), ConnectionManagerId(id.host, id.port))
|
||||
val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList
|
||||
assert(deserialized === (1 to 100).toList)
|
||||
})
|
||||
|
|
|
@ -18,6 +18,7 @@ class FileSuite extends FunSuite with LocalSparkContext {
|
|||
val outputDir = new File(tempDir, "output").getAbsolutePath
|
||||
val nums = sc.makeRDD(1 to 4)
|
||||
nums.saveAsTextFile(outputDir)
|
||||
println("outputDir = " + outputDir)
|
||||
// Read the plain text file and check it's OK
|
||||
val outputFile = new File(outputDir, "part-00000")
|
||||
val content = Source.fromFile(outputFile).mkString
|
||||
|
|
|
@ -271,7 +271,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
|
|||
// have the 2nd attempt pass
|
||||
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
|
||||
// we can see both result blocks now
|
||||
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB"))
|
||||
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
|
||||
complete(taskSets(3), Seq((Success, 43)))
|
||||
assert(results === Map(0 -> 42, 1 -> 43))
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue