Checkpoint commit - compiles and passes a lot of tests - not all though, looking into FileSuite issues

This commit is contained in:
Mridul Muralidharan 2013-04-15 18:12:11 +05:30
parent 6798a09df8
commit d90d2af103
66 changed files with 3489 additions and 478 deletions

View file

@ -0,0 +1,18 @@
package spark.deploy
/**
* Contains util methods to interact with Hadoop from spark.
*/
object SparkHadoopUtil {
def getUserNameFromEnvironment(): String = {
// defaulting to -D ...
System.getProperty("user.name")
}
def runAsUser(func: (Product) => Unit, args: Product) {
// Add support, if exists - for now, simply run func !
func(args)
}
}

View file

@ -0,0 +1,59 @@
package spark.deploy
import collection.mutable.HashMap
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import java.security.PrivilegedExceptionAction
/**
* Contains util methods to interact with Hadoop from spark.
*/
object SparkHadoopUtil {
val yarnConf = new YarnConfiguration(new Configuration())
def getUserNameFromEnvironment(): String = {
// defaulting to env if -D is not present ...
val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name))
// If nothing found, default to user we are running as
if (retval == null) System.getProperty("user.name") else retval
}
def runAsUser(func: (Product) => Unit, args: Product) {
runAsUser(func, args, getUserNameFromEnvironment())
}
def runAsUser(func: (Product) => Unit, args: Product, user: String) {
// println("running as user " + jobUserName)
UserGroupInformation.setConfiguration(yarnConf)
val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user)
appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
def run: AnyRef = {
func(args)
// no return value ...
null
}
})
}
// Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true.
def isYarnMode(): Boolean = {
val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))
java.lang.Boolean.valueOf(yarnMode)
}
// Set an env variable indicating we are running in YARN mode.
// Note that anything with SPARK prefix gets propagated to all (remote) processes
def setYarnMode() {
System.setProperty("SPARK_YARN_MODE", "true")
}
def setYarnMode(env: HashMap[String, String]) {
env("SPARK_YARN_MODE") = "true"
}
}

View file

@ -0,0 +1,342 @@
package spark.deploy.yarn
import java.net.Socket
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import scala.collection.JavaConversions._
import spark.{SparkContext, Logging, Utils}
import org.apache.hadoop.security.UserGroupInformation
import java.security.PrivilegedExceptionAction
class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
def this(args: ApplicationMasterArguments) = this(args, new Configuration())
private var rpc: YarnRPC = YarnRPC.create(conf)
private var resourceManager: AMRMProtocol = null
private var appAttemptId: ApplicationAttemptId = null
private var userThread: Thread = null
private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
private var yarnAllocator: YarnAllocationHandler = null
def run() {
// Initialization
val jobUserName = Utils.getUserNameFromEnvironment()
logInfo("running as user " + jobUserName)
// run as user ...
UserGroupInformation.setConfiguration(yarnConf)
val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName)
appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
def run: AnyRef = {
runImpl()
return null
}
})
}
private def runImpl() {
appAttemptId = getApplicationAttemptId()
resourceManager = registerWithResourceManager()
val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
// Compute number of threads for akka
val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
if (minimumMemory > 0) {
val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
if (numCore > 0) {
// do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
// TODO: Uncomment when hadoop is on a version which has this fixed.
// args.workerCores = numCore
}
}
// Workaround until hadoop moves to something which has
// https://issues.apache.org/jira/browse/HADOOP-8406
// ignore result
// This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
// Hence args.workerCores = numCore disabled above. Any better option ?
// org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
ApplicationMaster.register(this)
// Start the user's JAR
userThread = startUserClass()
// This a bit hacky, but we need to wait until the spark.master.port property has
// been set by the Thread executing the user class.
waitForSparkMaster()
// Allocate all containers
allocateWorkers()
// Wait for the user class to Finish
userThread.join()
// Finish the ApplicationMaster
finishApplicationMaster()
// TODO: Exit based on success/failure
System.exit(0)
}
private def getApplicationAttemptId(): ApplicationAttemptId = {
val envs = System.getenv()
val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
val containerId = ConverterUtils.toContainerId(containerIdString)
val appAttemptId = containerId.getApplicationAttemptId()
logInfo("ApplicationAttemptId: " + appAttemptId)
return appAttemptId
}
private def registerWithResourceManager(): AMRMProtocol = {
val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
YarnConfiguration.RM_SCHEDULER_ADDRESS,
YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
logInfo("Connecting to ResourceManager at " + rmAddress)
return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
}
private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
logInfo("Registering the ApplicationMaster")
val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
.asInstanceOf[RegisterApplicationMasterRequest]
appMasterRequest.setApplicationAttemptId(appAttemptId)
// Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
// Users can then monitor stderr/stdout on that node if required.
appMasterRequest.setHost(Utils.localHostName())
appMasterRequest.setRpcPort(0)
// What do we provide here ? Might make sense to expose something sensible later ?
appMasterRequest.setTrackingUrl("")
return resourceManager.registerApplicationMaster(appMasterRequest)
}
private def waitForSparkMaster() {
logInfo("Waiting for spark master to be reachable.")
var masterUp = false
while(!masterUp) {
val masterHost = System.getProperty("spark.master.host")
val masterPort = System.getProperty("spark.master.port")
try {
val socket = new Socket(masterHost, masterPort.toInt)
socket.close()
logInfo("Master now available: " + masterHost + ":" + masterPort)
masterUp = true
} catch {
case e: Exception =>
logError("Failed to connect to master at " + masterHost + ":" + masterPort)
Thread.sleep(100)
}
}
}
private def startUserClass(): Thread = {
logInfo("Starting the user JAR in a separate Thread")
val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader)
.getMethod("main", classOf[Array[String]])
val t = new Thread {
override def run() {
var mainArgs: Array[String] = null
var startIndex = 0
// I am sure there is a better 'scala' way to do this .... but I am just trying to get things to work right now !
if (args.userArgs.isEmpty || args.userArgs.get(0) != "yarn-standalone") {
// ensure that first param is ALWAYS "yarn-standalone"
mainArgs = new Array[String](args.userArgs.size() + 1)
mainArgs.update(0, "yarn-standalone")
startIndex = 1
}
else {
mainArgs = new Array[String](args.userArgs.size())
}
args.userArgs.copyToArray(mainArgs, startIndex, args.userArgs.size())
mainMethod.invoke(null, mainArgs)
}
}
t.start()
return t
}
private def allocateWorkers() {
logInfo("Waiting for spark context initialization")
try {
var sparkContext: SparkContext = null
ApplicationMaster.sparkContextRef.synchronized {
var count = 0
while (ApplicationMaster.sparkContextRef.get() == null) {
logInfo("Waiting for spark context initialization ... " + count)
count = count + 1
ApplicationMaster.sparkContextRef.wait(10000L)
}
sparkContext = ApplicationMaster.sparkContextRef.get()
assert(sparkContext != null)
this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData)
}
logInfo("Allocating " + args.numWorkers + " workers.")
// Wait until all containers have finished
// TODO: This is a bit ugly. Can we make it nicer?
// TODO: Handle container failure
while(yarnAllocator.getNumWorkersRunning < args.numWorkers &&
// If user thread exists, then quit !
userThread.isAlive) {
this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
ApplicationMaster.incrementAllocatorLoop(1)
Thread.sleep(100)
}
} finally {
// in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT :
// so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
}
logInfo("All workers have launched.")
// Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
if (userThread.isAlive){
// ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
// must be <= timeoutInterval/ 2.
// On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
// so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
launchReporterThread(interval)
}
}
// TODO: We might want to extend this to allocate more containers in case they die !
private def launchReporterThread(_sleepTime: Long): Thread = {
val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
val t = new Thread {
override def run() {
while (userThread.isAlive){
val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
if (missingWorkerCount > 0) {
logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
yarnAllocator.allocateContainers(missingWorkerCount)
}
else sendProgress()
Thread.sleep(sleepTime)
}
}
}
// setting to daemon status, though this is usually not a good idea.
t.setDaemon(true)
t.start()
logInfo("Started progress reporter thread - sleep time : " + sleepTime)
return t
}
private def sendProgress() {
logDebug("Sending progress")
// simulated with an allocate request with no nodes requested ...
yarnAllocator.allocateContainers(0)
}
/*
def printContainers(containers: List[Container]) = {
for (container <- containers) {
logInfo("Launching shell command on a new container."
+ ", containerId=" + container.getId()
+ ", containerNode=" + container.getNodeId().getHost()
+ ":" + container.getNodeId().getPort()
+ ", containerNodeURI=" + container.getNodeHttpAddress()
+ ", containerState" + container.getState()
+ ", containerResourceMemory"
+ container.getResource().getMemory())
}
}
*/
def finishApplicationMaster() {
val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
.asInstanceOf[FinishApplicationMasterRequest]
finishReq.setAppAttemptId(appAttemptId)
// TODO: Check if the application has failed or succeeded
finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED)
resourceManager.finishApplicationMaster(finishReq)
}
}
object ApplicationMaster {
// number of times to wait for the allocator loop to complete.
// each loop iteration waits for 100ms, so maximum of 3 seconds.
// This is to ensure that we have reasonable number of containers before we start
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more
// containers are available. Might need to handle this better.
private val ALLOCATOR_LOOP_WAIT_COUNT = 30
def incrementAllocatorLoop(by: Int) {
val count = yarnAllocatorLoop.getAndAdd(by)
if (count >= ALLOCATOR_LOOP_WAIT_COUNT){
yarnAllocatorLoop.synchronized {
// to wake threads off wait ...
yarnAllocatorLoop.notifyAll()
}
}
}
private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
def register(master: ApplicationMaster) {
applicationMasters.add(master)
}
val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null)
val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
def sparkContextInitialized(sc: SparkContext): Boolean = {
var modified = false
sparkContextRef.synchronized {
modified = sparkContextRef.compareAndSet(null, sc)
sparkContextRef.notifyAll()
}
// Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
// Should not really have to do this, but it helps yarn to evict resources earlier.
// not to mention, prevent Client declaring failure even though we exit'ed properly.
if (modified) {
Runtime.getRuntime().addShutdownHook(new Thread with Logging {
// This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
logInfo("Adding shutdown hook for context " + sc)
override def run() {
logInfo("Invoking sc stop from shutdown hook")
sc.stop()
// best case ...
for (master <- applicationMasters) master.finishApplicationMaster
}
} )
}
// Wait for initialization to complete and atleast 'some' nodes can get allocated
yarnAllocatorLoop.synchronized {
while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){
yarnAllocatorLoop.wait(1000L)
}
}
modified
}
def main(argStrings: Array[String]) {
val args = new ApplicationMasterArguments(argStrings)
new ApplicationMaster(args).run()
}
}

View file

@ -0,0 +1,78 @@
package spark.deploy.yarn
import spark.util.IntParam
import collection.mutable.ArrayBuffer
class ApplicationMasterArguments(val args: Array[String]) {
var userJar: String = null
var userClass: String = null
var userArgs: Seq[String] = Seq[String]()
var workerMemory = 1024
var workerCores = 1
var numWorkers = 2
parseArgs(args.toList)
private def parseArgs(inputArgs: List[String]): Unit = {
val userArgsBuffer = new ArrayBuffer[String]()
var args = inputArgs
while (! args.isEmpty) {
args match {
case ("--jar") :: value :: tail =>
userJar = value
args = tail
case ("--class") :: value :: tail =>
userClass = value
args = tail
case ("--args") :: value :: tail =>
userArgsBuffer += value
args = tail
case ("--num-workers") :: IntParam(value) :: tail =>
numWorkers = value
args = tail
case ("--worker-memory") :: IntParam(value) :: tail =>
workerMemory = value
args = tail
case ("--worker-cores") :: IntParam(value) :: tail =>
workerCores = value
args = tail
case Nil =>
if (userJar == null || userClass == null) {
printUsageAndExit(1)
}
case _ =>
printUsageAndExit(1, args)
}
}
userArgs = userArgsBuffer.readOnly
}
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
if (unknownParam != null) {
System.err.println("Unknown/unsupported param " + unknownParam)
}
System.err.println(
"Usage: spark.deploy.yarn.ApplicationMaster [options] \n" +
"Options:\n" +
" --jar JAR_PATH Path to your application's JAR file (required)\n" +
" --class CLASS_NAME Name of your application's main class (required)\n" +
" --args ARGS Arguments to be passed to your application's main class.\n" +
" Mutliple invocations are possible, each will be passed in order.\n" +
" Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" +
" --num-workers NUM Number of workers to start (Default: 2)\n" +
" --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n")
System.exit(exitCode)
}
}

View file

@ -0,0 +1,326 @@
package spark.deploy.yarn
import java.net.{InetSocketAddress, URI}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
import spark.{Logging, Utils}
import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils}
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import spark.deploy.SparkHadoopUtil
class Client(conf: Configuration, args: ClientArguments) extends Logging {
def this(args: ClientArguments) = this(new Configuration(), args)
var applicationsManager: ClientRMProtocol = null
var rpc: YarnRPC = YarnRPC.create(conf)
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
def run() {
connectToASM()
logClusterResourceDetails()
val newApp = getNewApplication()
val appId = newApp.getApplicationId()
verifyClusterResources(newApp)
val appContext = createApplicationSubmissionContext(appId)
val localResources = prepareLocalResources(appId, "spark")
val env = setupLaunchEnv(localResources)
val amContainer = createContainerLaunchContext(newApp, localResources, env)
appContext.setQueue(args.amQueue)
appContext.setAMContainerSpec(amContainer)
appContext.setUser(args.amUser)
submitApp(appContext)
monitorApplication(appId)
System.exit(0)
}
def connectToASM() {
val rmAddress: InetSocketAddress = NetUtils.createSocketAddr(
yarnConf.get(YarnConfiguration.RM_ADDRESS, YarnConfiguration.DEFAULT_RM_ADDRESS)
)
logInfo("Connecting to ResourceManager at" + rmAddress)
applicationsManager = rpc.getProxy(classOf[ClientRMProtocol], rmAddress, conf)
.asInstanceOf[ClientRMProtocol]
}
def logClusterResourceDetails() {
val clusterMetrics: YarnClusterMetrics = getYarnClusterMetrics
logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers)
/*
val clusterNodeReports: List[NodeReport] = getNodeReports
logDebug("Got Cluster node info from ASM")
for (node <- clusterNodeReports) {
logDebug("Got node report from ASM for, nodeId=" + node.getNodeId + ", nodeAddress=" + node.getHttpAddress +
", nodeRackName=" + node.getRackName + ", nodeNumContainers=" + node.getNumContainers + ", nodeHealthStatus=" + node.getNodeHealthStatus)
}
*/
val queueInfo: QueueInfo = getQueueInfo(args.amQueue)
logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity +
", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size +
", queueChildQueueCount=" + queueInfo.getChildQueues.size)
}
def getYarnClusterMetrics: YarnClusterMetrics = {
val request: GetClusterMetricsRequest = Records.newRecord(classOf[GetClusterMetricsRequest])
val response: GetClusterMetricsResponse = applicationsManager.getClusterMetrics(request)
return response.getClusterMetrics
}
def getNodeReports: List[NodeReport] = {
val request: GetClusterNodesRequest = Records.newRecord(classOf[GetClusterNodesRequest])
val response: GetClusterNodesResponse = applicationsManager.getClusterNodes(request)
return response.getNodeReports.toList
}
def getQueueInfo(queueName: String): QueueInfo = {
val request: GetQueueInfoRequest = Records.newRecord(classOf[GetQueueInfoRequest])
request.setQueueName(queueName)
request.setIncludeApplications(true)
request.setIncludeChildQueues(false)
request.setRecursive(false)
Records.newRecord(classOf[GetQueueInfoRequest])
return applicationsManager.getQueueInfo(request).getQueueInfo
}
def getNewApplication(): GetNewApplicationResponse = {
logInfo("Requesting new Application")
val request = Records.newRecord(classOf[GetNewApplicationRequest])
val response = applicationsManager.getNewApplication(request)
logInfo("Got new ApplicationId: " + response.getApplicationId())
return response
}
def verifyClusterResources(app: GetNewApplicationResponse) = {
val maxMem = app.getMaximumResourceCapability().getMemory()
logInfo("Max mem capabililty of resources in this cluster " + maxMem)
// If the cluster does not have enough memory resources, exit.
val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory
if (requestedMem > maxMem) {
logError("Cluster cannot satisfy memory resource request of " + requestedMem)
System.exit(1)
}
}
def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = {
logInfo("Setting up application submission context for ASM")
val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
appContext.setApplicationId(appId)
appContext.setApplicationName("Spark")
return appContext
}
def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
val locaResources = HashMap[String, LocalResource]()
// Upload Spark and the application JAR to the remote file system
// Add them as local resources to the AM
val fs = FileSystem.get(conf)
Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
.foreach { case(destName, _localPath) =>
val localPath: String = if (_localPath != null) _localPath.trim() else ""
if (! localPath.isEmpty()) {
val src = new Path(localPath)
val pathSuffix = appName + "/" + appId.getId() + destName
val dst = new Path(fs.getHomeDirectory(), pathSuffix)
logInfo("Uploading " + src + " to " + dst)
fs.copyFromLocalFile(false, true, src, dst)
val destStatus = fs.getFileStatus(dst)
val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
amJarRsrc.setType(LocalResourceType.FILE)
amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst))
amJarRsrc.setTimestamp(destStatus.getModificationTime())
amJarRsrc.setSize(destStatus.getLen())
locaResources(destName) = amJarRsrc
}
}
return locaResources
}
def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
logInfo("Setting up the launch environment")
val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null)
val env = new HashMap[String, String]()
Apps.addToEnvironment(env, Environment.USER.name, args.amUser)
// If log4j present, ensure ours overrides all others
if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
Client.populateHadoopClasspath(yarnConf, env)
SparkHadoopUtil.setYarnMode(env)
env("SPARK_YARN_JAR_PATH") =
localResources("spark.jar").getResource().getScheme.toString() + "://" +
localResources("spark.jar").getResource().getFile().toString()
env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString()
env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString()
env("SPARK_YARN_USERJAR_PATH") =
localResources("app.jar").getResource().getScheme.toString() + "://" +
localResources("app.jar").getResource().getFile().toString()
env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString()
env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString()
if (log4jConfLocalRes != null) {
env("SPARK_YARN_LOG4J_PATH") =
log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString()
env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString()
env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
}
// Add each SPARK-* key to the environment
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
return env
}
def userArgsToString(clientArgs: ClientArguments): String = {
val prefix = " --args "
val args = clientArgs.userArgs
val retval = new StringBuilder()
for (arg <- args){
retval.append(prefix).append(" '").append(arg).append("' ")
}
retval.toString
}
def createContainerLaunchContext(newApp: GetNewApplicationResponse,
localResources: HashMap[String, LocalResource],
env: HashMap[String, String]): ContainerLaunchContext = {
logInfo("Setting up container launch context")
val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
amContainer.setLocalResources(localResources)
amContainer.setEnvironment(env)
val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
(if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD
// Extra options for the JVM
var JAVA_OPTS = ""
// Add Xmx for am memory
JAVA_OPTS += "-Xmx" + amMemory + "m "
// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
// The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
// node, spark gc effects all other containers performance (which can also be other spark containers)
// Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
// limited to subset of cores on a node.
if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) {
// In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
JAVA_OPTS += " -XX:+CMSIncrementalMode "
JAVA_OPTS += " -XX:+CMSIncrementalPacing "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
}
if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
}
// Command for the ApplicationMaster
val commands = List[String]("java " +
" -server " +
JAVA_OPTS +
" spark.deploy.yarn.ApplicationMaster" +
" --class " + args.userClass +
" --jar " + args.userJar +
userArgsToString(args) +
" --worker-memory " + args.workerMemory +
" --worker-cores " + args.workerCores +
" --num-workers " + args.numWorkers +
" 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
" 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
logInfo("Command for the ApplicationMaster: " + commands(0))
amContainer.setCommands(commands)
val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
// Memory for the ApplicationMaster
capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
amContainer.setResource(capability)
return amContainer
}
def submitApp(appContext: ApplicationSubmissionContext) = {
// Create the request to send to the applications manager
val appRequest = Records.newRecord(classOf[SubmitApplicationRequest])
.asInstanceOf[SubmitApplicationRequest]
appRequest.setApplicationSubmissionContext(appContext)
// Submit the application to the applications manager
logInfo("Submitting application to ASM")
applicationsManager.submitApplication(appRequest)
}
def monitorApplication(appId: ApplicationId): Boolean = {
while(true) {
Thread.sleep(1000)
val reportRequest = Records.newRecord(classOf[GetApplicationReportRequest])
.asInstanceOf[GetApplicationReportRequest]
reportRequest.setApplicationId(appId)
val reportResponse = applicationsManager.getApplicationReport(reportRequest)
val report = reportResponse.getApplicationReport()
logInfo("Application report from ASM: \n" +
"\t application identifier: " + appId.toString() + "\n" +
"\t appId: " + appId.getId() + "\n" +
"\t clientToken: " + report.getClientToken() + "\n" +
"\t appDiagnostics: " + report.getDiagnostics() + "\n" +
"\t appMasterHost: " + report.getHost() + "\n" +
"\t appQueue: " + report.getQueue() + "\n" +
"\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
"\t appStartTime: " + report.getStartTime() + "\n" +
"\t yarnAppState: " + report.getYarnApplicationState() + "\n" +
"\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" +
"\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
"\t appUser: " + report.getUser()
)
val state = report.getYarnApplicationState()
val dsStatus = report.getFinalApplicationStatus()
if (state == YarnApplicationState.FINISHED ||
state == YarnApplicationState.FAILED ||
state == YarnApplicationState.KILLED) {
return true
}
}
return true
}
}
object Client {
def main(argStrings: Array[String]) {
val args = new ClientArguments(argStrings)
SparkHadoopUtil.setYarnMode()
new Client(args).run
}
// Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
}
}
}

View file

@ -0,0 +1,104 @@
package spark.deploy.yarn
import spark.util.MemoryParam
import spark.util.IntParam
import collection.mutable.{ArrayBuffer, HashMap}
import spark.scheduler.{InputFormatInfo, SplitInfo}
// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
class ClientArguments(val args: Array[String]) {
var userJar: String = null
var userClass: String = null
var userArgs: Seq[String] = Seq[String]()
var workerMemory = 1024
var workerCores = 1
var numWorkers = 2
var amUser = System.getProperty("user.name")
var amQueue = System.getProperty("QUEUE", "default")
var amMemory: Int = 512
// TODO
var inputFormatInfo: List[InputFormatInfo] = null
parseArgs(args.toList)
private def parseArgs(inputArgs: List[String]): Unit = {
val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]()
val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]()
var args = inputArgs
while (! args.isEmpty) {
args match {
case ("--jar") :: value :: tail =>
userJar = value
args = tail
case ("--class") :: value :: tail =>
userClass = value
args = tail
case ("--args") :: value :: tail =>
userArgsBuffer += value
args = tail
case ("--master-memory") :: MemoryParam(value) :: tail =>
amMemory = value
args = tail
case ("--num-workers") :: IntParam(value) :: tail =>
numWorkers = value
args = tail
case ("--worker-memory") :: MemoryParam(value) :: tail =>
workerMemory = value
args = tail
case ("--worker-cores") :: IntParam(value) :: tail =>
workerCores = value
args = tail
case ("--user") :: value :: tail =>
amUser = value
args = tail
case ("--queue") :: value :: tail =>
amQueue = value
args = tail
case Nil =>
if (userJar == null || userClass == null) {
printUsageAndExit(1)
}
case _ =>
printUsageAndExit(1, args)
}
}
userArgs = userArgsBuffer.readOnly
inputFormatInfo = inputFormatMap.values.toList
}
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
if (unknownParam != null) {
System.err.println("Unknown/unsupported param " + unknownParam)
}
System.err.println(
"Usage: spark.deploy.yarn.Client [options] \n" +
"Options:\n" +
" --jar JAR_PATH Path to your application's JAR file (required)\n" +
" --class CLASS_NAME Name of your application's main class (required)\n" +
" --args ARGS Arguments to be passed to your application's main class.\n" +
" Mutliple invocations are possible, each will be passed in order.\n" +
" Note that first argument will ALWAYS be yarn-standalone : will be added if missing.\n" +
" --num-workers NUM Number of workers to start (Default: 2)\n" +
" --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
" --user USERNAME Run the ApplicationMaster as a different user\n"
)
System.exit(exitCode)
}
}

View file

@ -0,0 +1,171 @@
package spark.deploy.yarn
import java.net.URI
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import spark.{Logging, Utils}
class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String,
slaveId: String, hostname: String, workerMemory: Int, workerCores: Int)
extends Runnable with Logging {
var rpc: YarnRPC = YarnRPC.create(conf)
var cm: ContainerManager = null
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
def run = {
logInfo("Starting Worker Container")
cm = connectToCM
startContainer
}
def startContainer = {
logInfo("Setting up ContainerLaunchContext")
val ctx = Records.newRecord(classOf[ContainerLaunchContext])
.asInstanceOf[ContainerLaunchContext]
ctx.setContainerId(container.getId())
ctx.setResource(container.getResource())
val localResources = prepareLocalResources
ctx.setLocalResources(localResources)
val env = prepareEnvironment
ctx.setEnvironment(env)
// Extra options for the JVM
var JAVA_OPTS = ""
// Set the JVM memory
val workerMemoryString = workerMemory + "m"
JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " "
if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
}
// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
// The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
// node, spark gc effects all other containers performance (which can also be other spark containers)
// Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
// limited to subset of cores on a node.
/*
else {
// If no java_opts specified, default to using -XX:+CMSIncrementalMode
// It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it.
// In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines
// The options are based on
// http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
JAVA_OPTS += " -XX:+CMSIncrementalMode "
JAVA_OPTS += " -XX:+CMSIncrementalPacing "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
}
*/
ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
val commands = List[String]("java " +
" -server " +
// Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
// Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state.
// TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
" -XX:OnOutOfMemoryError='kill %p' " +
JAVA_OPTS +
" spark.executor.StandaloneExecutorBackend " +
masterAddress + " " +
slaveId + " " +
hostname + " " +
workerCores +
" 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
" 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
logInfo("Setting up worker with commands: " + commands)
ctx.setCommands(commands)
// Send the start request to the ContainerManager
val startReq = Records.newRecord(classOf[StartContainerRequest])
.asInstanceOf[StartContainerRequest]
startReq.setContainerLaunchContext(ctx)
cm.startContainer(startReq)
}
def prepareLocalResources: HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
val locaResources = HashMap[String, LocalResource]()
// Spark JAR
val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
sparkJarResource.setType(LocalResourceType.FILE)
sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
locaResources("spark.jar") = sparkJarResource
// User JAR
val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
userJarResource.setType(LocalResourceType.FILE)
userJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
userJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
locaResources("app.jar") = userJarResource
// Log4j conf - if available
if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
log4jConfResource.setType(LocalResourceType.FILE)
log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION)
log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI(
new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
locaResources("log4j.properties") = log4jConfResource
}
logInfo("Prepared Local resources " + locaResources)
return locaResources
}
def prepareEnvironment: HashMap[String, String] = {
val env = new HashMap[String, String]()
// should we add this ?
Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment())
// If log4j present, ensure ours overrides all others
if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
// Which is correct ?
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties")
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
}
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
Client.populateHadoopClasspath(yarnConf, env)
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
return env
}
def connectToCM: ContainerManager = {
val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort()
val cmAddress = NetUtils.createSocketAddr(cmHostPortStr)
logInfo("Connecting to ContainerManager at " + cmHostPortStr)
return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager]
}
}

View file

@ -0,0 +1,547 @@
package spark.deploy.yarn
import spark.{Logging, Utils}
import spark.scheduler.SplitInfo
import scala.collection
import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
import org.apache.hadoop.yarn.util.{RackResolver, Records}
import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
import java.util.concurrent.atomic.AtomicInteger
import org.apache.hadoop.yarn.api.AMRMProtocol
import collection.JavaConversions._
import collection.mutable.{ArrayBuffer, HashMap, HashSet}
import org.apache.hadoop.conf.Configuration
import java.util.{Collections, Set => JSet}
import java.lang.{Boolean => JBoolean}
object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
type AllocationType = Value
val HOST, RACK, ANY = Value
}
// too many params ? refactor it 'somehow' ?
// needs to be mt-safe
// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it
// more proactive and decoupled.
// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info
// on how we are requesting for containers.
private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol,
val appAttemptId: ApplicationAttemptId,
val maxWorkers: Int, val workerMemory: Int, val workerCores: Int,
val preferredHostToCount: Map[String, Int],
val preferredRackToCount: Map[String, Int])
extends Logging {
// These three are locked on allocatedHostToContainersMap. Complementary data structures
// allocatedHostToContainersMap : containers which are running : host, Set<containerid>
// allocatedContainerToHostMap: container to host mapping
private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]()
private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
// allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node)
// As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap
private val allocatedRackCount = new HashMap[String, Int]()
// containers which have been released.
private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
// containers to be released in next request to RM
private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
private val numWorkersRunning = new AtomicInteger()
// Used to generate a unique id per worker
private val workerIdCounter = new AtomicInteger()
private val lastResponseId = new AtomicInteger()
def getNumWorkersRunning: Int = numWorkersRunning.intValue
def isResourceConstraintSatisfied(container: Container): Boolean = {
container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
}
def allocateContainers(workersToRequest: Int) {
// We need to send the request only once from what I understand ... but for now, not modifying this much.
// Keep polling the Resource Manager for containers
val amResp = allocateWorkerResources(workersToRequest).getAMResponse
val _allocatedContainers = amResp.getAllocatedContainers()
if (_allocatedContainers.size > 0) {
logDebug("Allocated " + _allocatedContainers.size + " containers, current count " +
numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
", pendingReleaseContainers : " + pendingReleaseContainers)
logDebug("Cluster Resources: " + amResp.getAvailableResources)
val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
// ignore if not satisfying constraints {
for (container <- _allocatedContainers) {
if (isResourceConstraintSatisfied(container)) {
// allocatedContainers += container
val host = container.getNodeId.getHost
val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]())
containers += container
}
// Add all ignored containers to released list
else releasedContainerList.add(container.getId())
}
// Find the appropriate containers to use
// Slightly non trivial groupBy I guess ...
val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
for (candidateHost <- hostToContainers.keySet)
{
val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null)
assert(remainingContainers != null)
if (requiredHostCount >= remainingContainers.size){
// Since we got <= required containers, add all to dataLocalContainers
dataLocalContainers.put(candidateHost, remainingContainers)
// all consumed
remainingContainers = null
}
else if (requiredHostCount > 0) {
// container list has more containers than we need for data locality.
// Split into two : data local container count of (remainingContainers.size - requiredHostCount)
// and rest as remainingContainer
val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount)
dataLocalContainers.put(candidateHost, dataLocal)
// remainingContainers = remaining
// yarn has nasty habit of allocating a tonne of containers on a host - discourage this :
// add remaining to release list. If we have insufficient containers, next allocation cycle
// will reallocate (but wont treat it as data local)
for (container <- remaining) releasedContainerList.add(container.getId())
remainingContainers = null
}
// now rack local
if (remainingContainers != null){
val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
if (rack != null){
val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
rackLocalContainers.get(rack).getOrElse(List()).size
if (requiredRackCount >= remainingContainers.size){
// Add all to dataLocalContainers
dataLocalContainers.put(rack, remainingContainers)
// all consumed
remainingContainers = null
}
else if (requiredRackCount > 0) {
// container list has more containers than we need for data locality.
// Split into two : data local container count of (remainingContainers.size - requiredRackCount)
// and rest as remainingContainer
val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount)
val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]())
existingRackLocal ++= rackLocal
remainingContainers = remaining
}
}
}
// If still not consumed, then it is off rack host - add to that list.
if (remainingContainers != null){
offRackContainers.put(candidateHost, remainingContainers)
}
}
// Now that we have split the containers into various groups, go through them in order :
// first host local, then rack local and then off rack (everything else).
// Note that the list we create below tries to ensure that not all containers end up within a host
// if there are sufficiently large number of hosts/containers.
val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
// Run each of the allocated containers
for (container <- allocatedContainers) {
val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
val workerHostname = container.getNodeId.getHost
val containerId = container.getId
assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
if (numWorkersRunningNow > maxWorkers) {
logInfo("Ignoring container " + containerId + " at host " + workerHostname +
" .. we already have required number of containers")
releasedContainerList.add(containerId)
// reset counter back to old value.
numWorkersRunning.decrementAndGet()
}
else {
// deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter)
val workerId = workerIdCounter.incrementAndGet().toString
val masterUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.master.host"), System.getProperty("spark.master.port"),
StandaloneSchedulerBackend.ACTOR_NAME)
logInfo("launching container on " + containerId + " host " + workerHostname)
// just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..
pendingReleaseContainers.remove(containerId)
val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
allocatedHostToContainersMap.synchronized {
val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]())
containerSet += containerId
allocatedContainerToHostMap.put(containerId, workerHostname)
if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
}
new Thread(
new WorkerRunnable(container, conf, masterUrl, workerId,
workerHostname, workerMemory, workerCores)
).start()
}
}
logDebug("After allocated " + allocatedContainers.size + " containers (orig : " +
_allocatedContainers.size + "), current count " + numWorkersRunning.get() +
", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
}
val completedContainers = amResp.getCompletedContainersStatuses()
if (completedContainers.size > 0){
logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() +
", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
for (completedContainer <- completedContainers){
val containerId = completedContainer.getContainerId
// Was this released by us ? If yes, then simply remove from containerSet and move on.
if (pendingReleaseContainers.containsKey(containerId)) {
pendingReleaseContainers.remove(containerId)
}
else {
// simply decrement count - next iteration of ReporterThread will take care of allocating !
numWorkersRunning.decrementAndGet()
logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState +
" httpaddress: " + completedContainer.getDiagnostics)
}
allocatedHostToContainersMap.synchronized {
if (allocatedContainerToHostMap.containsKey(containerId)) {
val host = allocatedContainerToHostMap.get(containerId).getOrElse(null)
assert (host != null)
val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null)
assert (containerSet != null)
containerSet -= containerId
if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host)
else allocatedHostToContainersMap.update(host, containerSet)
allocatedContainerToHostMap -= containerId
// doing this within locked context, sigh ... move to outside ?
val rack = YarnAllocationHandler.lookupRack(conf, host)
if (rack != null) {
val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
if (rackCount > 0) allocatedRackCount.put(rack, rackCount)
else allocatedRackCount.remove(rack)
}
}
}
}
logDebug("After completed " + completedContainers.size + " containers, current count " +
numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
", pendingReleaseContainers : " + pendingReleaseContainers)
}
}
def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = {
// First generate modified racks and new set of hosts under it : then issue requests
val rackToCounts = new HashMap[String, Int]()
// Within this lock - used to read/write to the rack related maps too.
for (container <- hostContainers) {
val candidateHost = container.getHostName
val candidateNumContainers = container.getNumContainers
assert(YarnAllocationHandler.ANY_HOST != candidateHost)
val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
if (rack != null) {
var count = rackToCounts.getOrElse(rack, 0)
count += candidateNumContainers
rackToCounts.put(rack, count)
}
}
val requestedContainers: ArrayBuffer[ResourceRequest] =
new ArrayBuffer[ResourceRequest](rackToCounts.size)
for ((rack, count) <- rackToCounts){
requestedContainers +=
createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY)
}
requestedContainers.toList
}
def allocatedContainersOnHost(host: String): Int = {
var retval = 0
allocatedHostToContainersMap.synchronized {
retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
}
retval
}
def allocatedContainersOnRack(rack: String): Int = {
var retval = 0
allocatedHostToContainersMap.synchronized {
retval = allocatedRackCount.getOrElse(rack, 0)
}
retval
}
private def allocateWorkerResources(numWorkers: Int): AllocateResponse = {
var resourceRequests: List[ResourceRequest] = null
// default.
if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty)
resourceRequests = List(
createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY))
}
else {
// request for all hosts in preferred nodes and for numWorkers -
// candidates.size, request by default allocation policy.
val hostContainerRequests: ArrayBuffer[ResourceRequest] =
new ArrayBuffer[ResourceRequest](preferredHostToCount.size)
for ((candidateHost, candidateCount) <- preferredHostToCount) {
val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
if (requiredCount > 0) {
hostContainerRequests +=
createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY)
}
}
val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList)
val anyContainerRequests: ResourceRequest =
createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)
val containerRequests: ArrayBuffer[ResourceRequest] =
new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1)
containerRequests ++= hostContainerRequests
containerRequests ++= rackContainerRequests
containerRequests += anyContainerRequests
resourceRequests = containerRequests.toList
}
val req = Records.newRecord(classOf[AllocateRequest])
req.setResponseId(lastResponseId.incrementAndGet)
req.setApplicationAttemptId(appAttemptId)
req.addAllAsks(resourceRequests)
val releasedContainerList = createReleasedContainerList()
req.addAllReleases(releasedContainerList)
if (numWorkers > 0) {
logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.")
}
else {
logDebug("Empty allocation req .. release : " + releasedContainerList)
}
for (req <- resourceRequests) {
logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers +
", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability)
}
resourceManager.allocate(req)
}
private def createResourceRequest(requestType: AllocationType.AllocationType,
resource:String, numWorkers: Int, priority: Int): ResourceRequest = {
// If hostname specified, we need atleast two requests - node local and rack local.
// There must be a third request - which is ANY : that will be specially handled.
requestType match {
case AllocationType.HOST => {
assert (YarnAllocationHandler.ANY_HOST != resource)
val hostname = resource
val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority)
// add to host->rack mapping
YarnAllocationHandler.populateRackInfo(conf, hostname)
nodeLocal
}
case AllocationType.RACK => {
val rack = resource
createResourceRequestImpl(rack, numWorkers, priority)
}
case AllocationType.ANY => {
createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority)
}
case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType)
}
}
private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = {
val rsrcRequest = Records.newRecord(classOf[ResourceRequest])
val memCapability = Records.newRecord(classOf[Resource])
// There probably is some overhead here, let's reserve a bit more memory.
memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
rsrcRequest.setCapability(memCapability)
val pri = Records.newRecord(classOf[Priority])
pri.setPriority(priority)
rsrcRequest.setPriority(pri)
rsrcRequest.setHostName(hostname)
rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0))
rsrcRequest
}
def createReleasedContainerList(): ArrayBuffer[ContainerId] = {
val retval = new ArrayBuffer[ContainerId](1)
// iterator on COW list ...
for (container <- releasedContainerList.iterator()){
retval += container
}
// remove from the original list.
if (! retval.isEmpty) {
releasedContainerList.removeAll(retval)
for (v <- retval) pendingReleaseContainers.put(v, true)
logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " +
pendingReleaseContainers)
}
retval
}
}
object YarnAllocationHandler {
val ANY_HOST = "*"
// all requests are issued with same priority : we do not (yet) have any distinction between
// request types (like map/reduce in hadoop for example)
val PRIORITY = 1
// Additional memory overhead - in mb
val MEMORY_OVERHEAD = 384
// host to rack map - saved from allocation requests
// We are expecting this not to change.
// Note that it is possible for this to change : and RM will indicate that to us via update
// response to allocate. But we are punting on handling that for now.
private val hostToRack = new ConcurrentHashMap[String, String]()
private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
def newAllocator(conf: Configuration,
resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
args: ApplicationMasterArguments,
map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers,
args.workerMemory, args.workerCores, hostToCount, rackToCount)
}
def newAllocator(conf: Configuration,
resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
maxWorkers: Int, workerMemory: Int, workerCores: Int,
map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers,
workerMemory, workerCores, hostToCount, rackToCount)
}
// A simple method to copy the split info map.
private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) :
// host to count, rack to count
(Map[String, Int], Map[String, Int]) = {
if (input == null) return (Map[String, Int](), Map[String, Int]())
val hostToCount = new HashMap[String, Int]
val rackToCount = new HashMap[String, Int]
for ((host, splits) <- input) {
val hostCount = hostToCount.getOrElse(host, 0)
hostToCount.put(host, hostCount + splits.size)
val rack = lookupRack(conf, host)
if (rack != null){
val rackCount = rackToCount.getOrElse(host, 0)
rackToCount.put(host, rackCount + splits.size)
}
}
(hostToCount.toMap, rackToCount.toMap)
}
def lookupRack(conf: Configuration, host: String): String = {
if (! hostToRack.contains(host)) populateRackInfo(conf, host)
hostToRack.get(host)
}
def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
val set = rackToHostSet.get(rack)
if (set == null) return None
// No better way to get a Set[String] from JSet ?
val convertedSet: collection.mutable.Set[String] = set
Some(convertedSet.toSet)
}
def populateRackInfo(conf: Configuration, hostname: String) {
Utils.checkHost(hostname)
if (!hostToRack.containsKey(hostname)) {
// If there are repeated failures to resolve, all to an ignore list ?
val rackInfo = RackResolver.resolve(conf, hostname)
if (rackInfo != null && rackInfo.getNetworkLocation != null) {
val rack = rackInfo.getNetworkLocation
hostToRack.put(hostname, rack)
if (! rackToHostSet.containsKey(rack)) {
rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
}
rackToHostSet.get(rack).add(hostname)
// Since RackResolver caches, we are disabling this for now ...
} /* else {
// right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
hostToRack.put(hostname, null)
} */
}
}
}

View file

@ -0,0 +1,42 @@
package spark.scheduler.cluster
import spark._
import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
import org.apache.hadoop.conf.Configuration
/**
*
* This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
*/
private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
def this(sc: SparkContext) = this(sc, new Configuration())
// Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
// Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
// Subsequent creations are ignored - since nodes are already allocated by then.
// By default, rack is unknown
override def getRackForHost(hostPort: String): Option[String] = {
val host = Utils.parseHostPort(hostPort)._1
val retval = YarnAllocationHandler.lookupRack(conf, host)
if (retval != null) Some(retval) else None
}
// By default, if rack is unknown, return nothing
override def getCachedHostsForRack(rack: String): Option[Set[String]] = {
if (rack == None || rack == null) return None
YarnAllocationHandler.fetchCachedHostsForRack(rack)
}
override def postStartHook() {
val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
if (sparkContextInitialized){
// Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
Thread.sleep(3000L)
}
logInfo("YarnClusterScheduler.postStartHook done")
}
}

View file

@ -0,0 +1,18 @@
package spark.deploy
/**
* Contains util methods to interact with Hadoop from spark.
*/
object SparkHadoopUtil {
def getUserNameFromEnvironment(): String = {
// defaulting to -D ...
System.getProperty("user.name")
}
def runAsUser(func: (Product) => Unit, args: Product) {
// Add support, if exists - for now, simply run func !
func(args)
}
}

View file

@ -8,12 +8,20 @@ import scala.collection.mutable.Set
import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._
import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = {
new ClassReader(cls.getResourceAsStream(
cls.getName.replaceFirst("^.*\\.", "") + ".class"))
// Copy data over, before delegating to ClassReader - else we can run out of open file handles.
val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
val resourceStream = cls.getResourceAsStream(className)
// todo: Fixme - continuing with earlier behavior ...
if (resourceStream == null) return new ClassReader(resourceStream)
val baos = new ByteArrayOutputStream(128)
Utils.copyStream(resourceStream, baos, true)
new ClassReader(new ByteArrayInputStream(baos.toByteArray))
}
// Check whether a class represents a Scala closure

View file

@ -3,18 +3,25 @@ package spark
import spark.storage.BlockManagerId
private[spark] class FetchFailedException(
val bmAddress: BlockManagerId,
val shuffleId: Int,
val mapId: Int,
val reduceId: Int,
taskEndReason: TaskEndReason,
message: String,
cause: Throwable)
extends Exception {
override def getMessage(): String =
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) =
this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
cause)
def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
this(FetchFailed(null, shuffleId, -1, reduceId),
"Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)
override def getMessage(): String = message
override def getCause(): Throwable = cause
def toTaskEndReason: TaskEndReason =
FetchFailed(bmAddress, shuffleId, mapId, reduceId)
def toTaskEndReason: TaskEndReason = taskEndReason
}

View file

@ -68,6 +68,10 @@ trait Logging {
if (log.isErrorEnabled) log.error(msg, throwable)
}
protected def isTraceEnabled(): Boolean = {
log.isTraceEnabled
}
// Method for ensuring that logging is initialized, to avoid having multiple
// threads do it concurrently (as SLF4J initialization is not thread safe).
protected def initLogging() { log }

View file

@ -1,7 +1,6 @@
package spark
import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
@ -12,8 +11,7 @@ import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
@ -40,10 +38,12 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
private[spark] class MapOutputTracker extends Logging {
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
@ -52,7 +52,7 @@ private[spark] class MapOutputTracker extends Logging {
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
@ -60,7 +60,6 @@ private[spark] class MapOutputTracker extends Logging {
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
try {
val timeout = 10.seconds
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
@ -77,10 +76,9 @@ private[spark] class MapOutputTracker extends Logging {
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.get(shuffleId) != None) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
@ -101,8 +99,9 @@ private[spark] class MapOutputTracker extends Logging {
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var array = mapStatuses(shuffleId)
if (array != null) {
var arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
var array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
@ -115,13 +114,14 @@ private[spark] class MapOutputTracker extends Logging {
}
// Remembers which map output locations are currently being fetched on a worker
val fetching = new HashSet[Int]
private val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
@ -132,31 +132,49 @@ private[spark] class MapOutputTracker extends Logging {
case e: InterruptedException =>
}
}
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
} else {
}
// Either while we waited the fetch happened successfully, or
// someone fetched it in between the get and the fetching.synchronized.
fetchedStatuses = mapStatuses.get(shuffleId).orNull
if (fetchedStatuses == null) {
// We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val host = System.getProperty("spark.hostname", Utils.localHostName)
// This try-finally prevents hangs due to timeouts:
var fetchedStatuses: Array[MapStatus] = null
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val hostPort = Utils.localHostPort()
// This try-finally prevents hangs due to timeouts:
var fetchedStatuses: Array[MapStatus] = null
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
if (fetchedStatuses != null) {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
}
else{
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
} else {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
}
@ -194,7 +212,8 @@ private[spark] class MapOutputTracker extends Logging {
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
mapStatuses.clear()
generation = newGen
}
}
@ -232,10 +251,13 @@ private[spark] class MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
objOut.writeObject(statuses)
// Since statuses can be modified in parallel, sync on it
statuses.synchronized {
objOut.writeObject(statuses)
}
objOut.close()
out.toByteArray
}
@ -243,7 +265,10 @@ private[spark] class MapOutputTracker extends Logging {
// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
objIn.readObject().asInstanceOf[Array[MapStatus]]
objIn.readObject().
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
// comment this out - nulls could be due to missing location ?
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
}
}
@ -253,14 +278,11 @@ private[spark] object MapOutputTracker {
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
def convertMapStatuses(
private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
if (statuses == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
assert (statuses != null)
statuses.map {
status =>
if (status == null) {

View file

@ -37,7 +37,7 @@ import spark.partial.PartialResult
import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
import spark.scheduler._
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import spark.storage.BlockManagerUI
import spark.util.{MetadataCleaner, TimeStampedHashMap}
@ -59,7 +59,10 @@ class SparkContext(
val appName: String,
val sparkHome: String = null,
val jars: Seq[String] = Nil,
val environment: Map[String, String] = Map())
val environment: Map[String, String] = Map(),
// This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too.
// This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host
val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map())
extends Logging {
// Ensure logging is initialized before we spawn any threads
@ -67,7 +70,7 @@ class SparkContext(
// Set Spark driver host and port system properties
if (System.getProperty("spark.driver.host") == null) {
System.setProperty("spark.driver.host", Utils.localIpAddress)
System.setProperty("spark.driver.host", Utils.localHostName())
}
if (System.getProperty("spark.driver.port") == null) {
System.setProperty("spark.driver.port", "0")
@ -99,7 +102,7 @@ class SparkContext(
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
if (jars != null) jars.foreach { addJar(_) }
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
@ -111,7 +114,7 @@ class SparkContext(
executorEnvs(key) = value
}
}
executorEnvs ++= environment
if (environment != null) executorEnvs ++= environment
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
@ -164,6 +167,22 @@ class SparkContext(
}
scheduler
case "yarn-standalone" =>
val scheduler = try {
val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(this).asInstanceOf[ClusterScheduler]
} catch {
// TODO: Enumerate the exact reasons why it can fail
// But irrespective of it, it means we cannot proceed !
case th: Throwable => {
throw new SparkException("YARN mode not available ?", th)
}
}
val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
scheduler.initialize(backend)
scheduler
case _ =>
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
@ -183,7 +202,7 @@ class SparkContext(
}
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
@volatile private var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
@ -207,6 +226,9 @@ class SparkContext(
private[spark] var checkpointDir: Option[String] = None
// Post init
taskScheduler.postStartHook()
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
@ -471,7 +493,7 @@ class SparkContext(
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.ip + ":" + blockManagerId.port, mem)
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
}
@ -527,10 +549,13 @@ class SparkContext(
/** Shut down the SparkContext. */
def stop() {
if (dagScheduler != null) {
// Do this only if not stopped already - best case effort.
// prevent NPE if stopped more than once.
val dagSchedulerCopy = dagScheduler
dagScheduler = null
if (dagSchedulerCopy != null) {
metadataCleaner.cancel()
dagScheduler.stop()
dagScheduler = null
dagSchedulerCopy.stop()
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@ -546,6 +571,7 @@ class SparkContext(
}
}
/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable

View file

@ -72,6 +72,16 @@ object SparkEnv extends Logging {
System.setProperty("spark.driver.port", boundPort.toString)
}
// set only if unset until now.
if (System.getProperty("spark.hostPort", null) == null) {
if (!isDriver){
// unexpected
Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
}
Utils.checkHost(hostname)
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
}
val classLoader = Thread.currentThread.getContextClassLoader
// Create an instance of the class named by the given Java system property, or by
@ -88,9 +98,10 @@ object SparkEnv extends Logging {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
val driverIp: String = System.getProperty("spark.driver.host", "localhost")
val driverHost: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name)
Utils.checkHost(driverHost, "Expected hostname")
val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
logInfo("Connecting to " + name + ": " + url)
actorSystem.actorFor(url)
}

View file

@ -1,18 +1,18 @@
package spark
import java.io._
import java.net._
import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import scala.io.Source
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import scala.Some
import spark.serializer.SerializerInstance
import spark.deploy.SparkHadoopUtil
/**
* Various utility methods used by Spark.
@ -68,6 +68,41 @@ private object Utils extends Logging {
return buf
}
private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
val absolutePath = file.getAbsolutePath()
shutdownDeletePaths.synchronized {
shutdownDeletePaths += absolutePath
}
}
// Is the path already registered to be deleted via a shutdown hook ?
def hasShutdownDeleteDir(file: File): Boolean = {
val absolutePath = file.getAbsolutePath()
shutdownDeletePaths.synchronized {
shutdownDeletePaths.contains(absolutePath)
}
}
// Note: if file is child of some registered path, while not equal to it, then return true; else false
// This is to ensure that two shutdown hooks do not try to delete each others paths - resulting in IOException
// and incomplete cleanup
def hasRootAsShutdownDeleteDir(file: File): Boolean = {
val absolutePath = file.getAbsolutePath()
val retval = shutdownDeletePaths.synchronized {
shutdownDeletePaths.find(path => ! absolutePath.equals(path) && absolutePath.startsWith(path) ).isDefined
}
if (retval) logInfo("path = " + file + ", already present as root for deletion.")
retval
}
/** Create a temporary directory inside the given parent directory */
def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
var attempts = 0
@ -86,10 +121,14 @@ private object Utils extends Logging {
}
} catch { case e: IOException => ; }
}
registerShutdownDeleteDir(dir)
// Add a shutdown hook to delete the temp dir when the JVM exits
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
override def run() {
Utils.deleteRecursively(dir)
// Attempt to delete if some patch which is parent of this is not already registered.
if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
}
})
return dir
@ -227,8 +266,10 @@ private object Utils extends Logging {
/**
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
* Note, this is typically not used from within core spark.
*/
lazy val localIpAddress: String = findLocalIpAddress()
lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress)
private def findLocalIpAddress(): String = {
val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
@ -266,6 +307,8 @@ private object Utils extends Logging {
* hostname it reports to the master.
*/
def setCustomHostname(hostname: String) {
// DEBUG code
Utils.checkHost(hostname)
customHostname = Some(hostname)
}
@ -273,7 +316,90 @@ private object Utils extends Logging {
* Get the local machine's hostname.
*/
def localHostName(): String = {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
// customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
customHostname.getOrElse(localIpAddressHostname)
}
def getAddressHostName(address: String): String = {
InetAddress.getByName(address).getHostName
}
def localHostPort(): String = {
val retval = System.getProperty("spark.hostPort", null)
if (retval == null) {
logErrorWithStack("spark.hostPort not set but invoking localHostPort")
return localHostName()
}
retval
}
// Used by DEBUG code : remove when all testing done
def checkHost(host: String, message: String = "") {
// Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous !
if (host.matches("^[0-9]+(\\.[0-9]+)*$")) {
Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message)
}
if (Utils.parseHostPort(host)._2 != 0){
Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message)
}
}
// Used by DEBUG code : remove when all testing done
def checkHostPort(hostPort: String, message: String = "") {
val (host, port) = Utils.parseHostPort(hostPort)
checkHost(host)
if (port <= 0){
Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message)
}
}
def getUserNameFromEnvironment(): String = {
SparkHadoopUtil.getUserNameFromEnvironment
}
// Used by DEBUG code : remove when all testing done
def logErrorWithStack(msg: String) {
try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
// temp code for debug
System.exit(-1)
}
// Typically, this will be of order of number of nodes in cluster
// If not, we should change it to LRUCache or something.
private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
def parseHostPort(hostPort: String): (String, Int) = {
{
// Check cache first.
var cached = hostPortParseResults.get(hostPort)
if (cached != null) return cached
}
val indx: Int = hostPort.lastIndexOf(':')
// This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now.
// For now, we assume that if port exists, then it is valid - not check if it is an int > 0
if (-1 == indx) {
val retval = (hostPort, 0)
hostPortParseResults.put(hostPort, retval)
return retval
}
val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt)
hostPortParseResults.putIfAbsent(hostPort, retval)
hostPortParseResults.get(hostPort)
}
def addIfNoPort(hostPort: String, port: Int): String = {
if (port <= 0) throw new IllegalArgumentException("Invalid port specified " + port)
// This is potentially broken - when dealing with ipv6 addresses for example, sigh ... but then hadoop does not support ipv6 right now.
// For now, we assume that if port exists, then it is valid - not check if it is an int > 0
val indx: Int = hostPort.lastIndexOf(':')
if (-1 != indx) return hostPort
hostPort + ":" + port
}
private[spark] val daemonThreadFactory: ThreadFactory =

View file

@ -278,6 +278,8 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte]
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])

View file

@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, ApplicationInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
import spark.Utils
private[spark] sealed trait DeployMessage extends Serializable
@ -19,7 +20,10 @@ case class RegisterWorker(
memory: Int,
webUiPort: Int,
publicAddress: String)
extends DeployMessage
extends DeployMessage {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
}
private[spark]
case class ExecutorStateChanged(
@ -58,7 +62,9 @@ private[spark]
case class RegisteredApplication(appId: String) extends DeployMessage
private[spark]
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
Utils.checkHostPort(hostPort, "Required hostport")
}
private[spark]
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String],
@ -81,6 +87,9 @@ private[spark]
case class MasterState(host: String, port: Int, workers: Array[WorkerInfo],
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
def uri = "spark://" + host + ":" + port
}
@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState
private[spark]
case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
}

View file

@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
def write(obj: WorkerInfo) = JsObject(
"id" -> JsString(obj.id),
"host" -> JsString(obj.host),
"port" -> JsNumber(obj.port),
"webuiaddress" -> JsString(obj.webUiAddress),
"cores" -> JsNumber(obj.cores),
"coresused" -> JsNumber(obj.coresUsed),

View file

@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
private val localIpAddress = Utils.localIpAddress
private val localHostname = Utils.localHostName()
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localIpAddress + ":" + masterPort
val masterUrl = "spark://" + localHostname + ":" + masterPort
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
memoryPerWorker, masterUrl, null, Some(workerNum))
workerActorSystems += workerSystem
}

View file

@ -59,10 +59,10 @@ private[spark] class Client(
markDisconnected()
context.stop(self)
case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) =>
case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) =>
val fullId = appId + "/" + id
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores))
listener.executorAdded(fullId, workerId, host, cores, memory)
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores))
listener.executorAdded(fullId, workerId, hostPort, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
val fullId = appId + "/" + id

View file

@ -12,7 +12,7 @@ private[spark] trait ClientListener {
def disconnected(): Unit
def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
}

View file

@ -16,7 +16,7 @@ private[spark] object TestClient {
System.exit(0)
}
def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {}
def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
}

View file

@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils}
import spark.util.AkkaUtils
private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
@ -35,9 +35,11 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
var firstApp: Option[ApplicationInfo] = None
Utils.checkHost(host, "Expected hostname")
val masterPublicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else ip
if (envVar != null) envVar else host
}
// As a temporary workaround before better ways of configuring memory, we allow users to set
@ -46,7 +48,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
override def preStart() {
logInfo("Starting Spark master at spark://" + ip + ":" + port)
logInfo("Starting Spark master at spark://" + host + ":" + port)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
startWebUi()
@ -145,7 +147,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
}
case RequestMasterState => {
sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray)
sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray)
}
}
@ -211,13 +213,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
publicAddress: String): WorkerInfo = {
// There may be one or more refs to dead workers on this same node (w/ different ID's), remove them.
workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
workers += worker
idToWorker(worker.id) = worker
@ -307,7 +309,7 @@ private[spark] object Master {
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort)
actorSystem.awaitTermination()
}

View file

@ -7,13 +7,13 @@ import spark.Utils
* Command-line parser for the master.
*/
private[spark] class MasterArguments(args: Array[String]) {
var ip = Utils.localHostName()
var host = Utils.localHostName()
var port = 7077
var webUiPort = 8080
// Check for settings in environment variables
if (System.getenv("SPARK_MASTER_IP") != null) {
ip = System.getenv("SPARK_MASTER_IP")
if (System.getenv("SPARK_MASTER_HOST") != null) {
host = System.getenv("SPARK_MASTER_HOST")
}
if (System.getenv("SPARK_MASTER_PORT") != null) {
port = System.getenv("SPARK_MASTER_PORT").toInt
@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) {
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
ip = value
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
host = value
parse(tail)
case ("--host" | "-h") :: value :: tail =>
Utils.checkHost(value, "Please use hostname " + value)
host = value
parse(tail)
case ("--port" | "-p") :: IntParam(value) :: tail =>
@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) {
"Usage: Master [options]\n" +
"\n" +
"Options:\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
" -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
" -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: 7077)\n" +
" --webui-port PORT Port for web UI (default: 8080)")
System.exit(exitCode)

View file

@ -2,6 +2,7 @@ package spark.deploy.master
import akka.actor.ActorRef
import scala.collection.mutable
import spark.Utils
private[spark] class WorkerInfo(
val id: String,
@ -13,6 +14,9 @@ private[spark] class WorkerInfo(
val webUiPort: Int,
val publicAddress: String) {
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
var state: WorkerState.Value = WorkerState.ALIVE
var coresUsed = 0
@ -23,6 +27,11 @@ private[spark] class WorkerInfo(
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
def hostPort: String = {
assert (port > 0)
host + ":" + port
}
def addExecutor(exec: ExecutorInfo) {
executors(exec.fullId) = exec
coresUsed += exec.cores

View file

@ -21,11 +21,13 @@ private[spark] class ExecutorRunner(
val memory: Int,
val worker: ActorRef,
val workerId: String,
val hostname: String,
val hostPort: String,
val sparkHome: File,
val workDir: File)
extends Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
val fullId = appId + "/" + execId
var workerThread: Thread = null
var process: Process = null
@ -68,7 +70,7 @@ private[spark] class ExecutorRunner(
/** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => hostname
case "{{HOSTPORT}}" => hostPort
case "{{CORES}}" => cores.toString
case other => other
}

View file

@ -16,7 +16,7 @@ import spark.deploy.master.Master
import java.io.File
private[spark] class Worker(
ip: String,
host: String,
port: Int,
webUiPort: Int,
cores: Int,
@ -25,6 +25,9 @@ private[spark] class Worker(
workDirPath: String = null)
extends Actor with Logging {
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
@ -39,7 +42,7 @@ private[spark] class Worker(
val finishedExecutors = new HashMap[String, ExecutorRunner]
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else ip
if (envVar != null) envVar else host
}
var coresUsed = 0
@ -64,7 +67,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
ip, port, cores, Utils.memoryMegabytesToString(memory)))
host, port, cores, Utils.memoryMegabytesToString(memory)))
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
@ -75,7 +78,7 @@ private[spark] class Worker(
def connectToMaster() {
logInfo("Connecting to master " + masterUrl)
master = context.actorFor(Master.toAkkaUrl(masterUrl))
master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
}
@ -106,7 +109,7 @@ private[spark] class Worker(
case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(
appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@ -141,7 +144,7 @@ private[spark] class Worker(
masterDisconnected()
case RequestWorkerState => {
sender ! WorkerState(ip, port, workerId, executors.values.toList,
sender ! WorkerState(host, port, workerId, executors.values.toList,
finishedExecutors.values.toList, masterUrl, cores, memory,
coresUsed, memoryUsed, masterWebUiUrl)
}
@ -156,7 +159,7 @@ private[spark] class Worker(
}
def generateWorkerId(): String = {
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port)
}
override def postStop() {
@ -167,7 +170,7 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}

View file

@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory
* Command-line parser for the master.
*/
private[spark] class WorkerArguments(args: Array[String]) {
var ip = Utils.localHostName()
var host = Utils.localHostName()
var port = 0
var webUiPort = 8081
var cores = inferDefaultCores()
@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) {
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
ip = value
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
host = value
parse(tail)
case ("--host" | "-h") :: value :: tail =>
Utils.checkHost(value, "Please use hostname " + value)
host = value
parse(tail)
case ("--port" | "-p") :: IntParam(value) :: tail =>
@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) {
" -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
" -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
" -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
" -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)")
System.exit(exitCode)

View file

@ -27,6 +27,11 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
initLogging()
// No ip or host:port - just hostname
Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
// must not have port specified.
assert (0 == Utils.parseHostPort(slaveHostname)._2)
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)

View file

@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor
import spark.scheduler.cluster.LaunchTask
import spark.scheduler.cluster.RegisterExecutorFailed
import spark.scheduler.cluster.RegisterExecutor
import spark.Utils
import spark.deploy.SparkHadoopUtil
private[spark] class StandaloneExecutorBackend(
driverUrl: String,
executorId: String,
hostname: String,
hostPort: String,
cores: Int)
extends Actor
with ExecutorBackend
with Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
var executor: Executor = null
var driver: ActorRef = null
override def preStart() {
logInfo("Connecting to driver: " + driverUrl)
driver = context.actorFor(driverUrl)
driver ! RegisterExecutor(executorId, hostname, cores)
driver ! RegisterExecutor(executorId, hostPort, cores)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(driver) // Doesn't work with remote actors, but useful for testing
}
@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend(
override def receive = {
case RegisteredExecutor(sparkProperties) =>
logInfo("Successfully registered with driver")
executor = new Executor(executorId, hostname, sparkProperties)
// Make this host instead of hostPort ?
executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@ -63,11 +68,29 @@ private[spark] class StandaloneExecutorBackend(
private[spark] object StandaloneExecutorBackend {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores))
}
// This will be run 'as' the user
def run0(args: Product) {
assert(4 == args.productArity)
runImpl(args.productElement(0).asInstanceOf[String],
args.productElement(0).asInstanceOf[String],
args.productElement(0).asInstanceOf[String],
args.productElement(0).asInstanceOf[Int])
}
private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
// Debug code
Utils.checkHost(hostname)
// set it
val sparkHostPort = hostname + ":" + boundPort
System.setProperty("spark.hostPort", sparkHostPort)
val actor = actorSystem.actorOf(
Props(new StandaloneExecutorBackend(driverUrl, executorId, hostname, cores)),
Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
name = "Executor")
actorSystem.awaitTermination()
}

View file

@ -13,7 +13,7 @@ import java.net._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
val socketRemoteConnectionManagerId: ConnectionManagerId) extends Logging {
def this(channel_ : SocketChannel, selector_ : Selector) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
@ -33,15 +33,42 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
// Read channels typically do not register for write and write does not for read
// Now, we do have write registering for read too (temporarily), but this is to detect
// channel close NOT to actually read/consume data on it !
// How does this work if/when we move to SSL ?
// What is the interest to register with selector for when we want this connection to be selected
def registerInterest()
// What is the interest to register with selector for when we want this connection to be de-selected
// Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack, it will be
// SelectionKey.OP_READ (until we fix it properly)
def unregisterInterest()
// On receiving a read event, should we change the interest for this channel or not ?
// Will be true for ReceivingConnection, false for SendingConnection.
def changeInterestForRead(): Boolean
// On receiving a write event, should we change the interest for this channel or not ?
// Will be false for ReceivingConnection, true for SendingConnection.
// Actually, for now, should not get triggered for ReceivingConnection
def changeInterestForWrite(): Boolean
def getRemoteConnectionManagerId(): ConnectionManagerId = {
socketRemoteConnectionManagerId
}
def key() = channel.keyFor(selector)
def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
def read() {
// Returns whether we have to register for further reads or not.
def read(): Boolean = {
throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
}
def write() {
// Returns whether we have to register for further writes or not.
def write(): Boolean = {
throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
}
@ -64,7 +91,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
if (onExceptionCallback != null) {
onExceptionCallback(this, e)
} else {
logError("Error in connection to " + remoteConnectionManagerId +
logError("Error in connection to " + getRemoteConnectionManagerId() +
" and OnExceptionCallback not registered", e)
}
}
@ -73,7 +100,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
if (onCloseCallback != null) {
onCloseCallback(this)
} else {
logWarning("Connection to " + remoteConnectionManagerId +
logWarning("Connection to " + getRemoteConnectionManagerId() +
" closed and OnExceptionCallback not registered")
}
@ -122,7 +149,7 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
messages.synchronized{
/*messages += message*/
messages.enqueue(message)
logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
logDebug("Added [" + message + "] to outbox for sending to [" + getRemoteConnectionManagerId() + "]")
}
}
@ -149,9 +176,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
/*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
"] in " + message.timeTaken )
}
}
@ -170,15 +197,15 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
messages.enqueue(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
logDebug("Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
logTrace("Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
return chunk
} else {
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
"] in " + message.timeTaken )
}
}
@ -187,26 +214,39 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
val outbox = new Outbox(1)
private val outbox = new Outbox(1)
val currentBuffers = new ArrayBuffer[ByteBuffer]()
/*channel.socket.setSendBufferSize(256 * 1024)*/
override def getRemoteAddress() = address
val DEFAULT_INTEREST = SelectionKey.OP_READ
override def registerInterest() {
// Registering read too - does not really help in most cases, but for some
// it does - so let us keep it for now.
changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
}
override def unregisterInterest() {
changeConnectionKeyInterest(DEFAULT_INTEREST)
}
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
if (channel.isConnected) {
changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
registerInterest()
}
}
}
// MUST be called within the selector loop
def connect() {
try{
channel.connect(address)
channel.register(selector, SelectionKey.OP_CONNECT)
channel.connect(address)
logInfo("Initiating connection to [" + address + "]")
} catch {
case e: Exception => {
@ -216,20 +256,33 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
def finishConnect() {
def finishConnect(force: Boolean): Boolean = {
try {
channel.finishConnect
changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
// Typically, this should finish immediately since it was triggered by a connect
// selection - though need not necessarily always complete successfully.
val connected = channel.finishConnect
if (!force && !connected) {
logInfo("finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
return false
}
// Fallback to previous behavior - assume finishConnect completed
// This will happen only when finishConnect failed for some repeated number of times (10 or so)
// Is highly unlikely unless there was an unclean close of socket, etc
registerInterest()
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
return true
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
callOnExceptionCallback(e)
// ignore
return true
}
}
}
override def write() {
override def write(): Boolean = {
try{
while(true) {
if (currentBuffers.size == 0) {
@ -239,8 +292,9 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
currentBuffers ++= chunk.buffers
}
case None => {
changeConnectionKeyInterest(SelectionKey.OP_READ)
return
// changeConnectionKeyInterest(0)
/*key.interestOps(0)*/
return false
}
}
}
@ -254,38 +308,53 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
currentBuffers -= buffer
}
if (writtenBytes < remainingBytes) {
return
// re-register for write.
return true
}
}
}
} catch {
case e: Exception => {
logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
return false
}
}
// should not happen - to keep scala compiler happy
return true
}
override def read() {
// This is a hack to determine if remote socket was closed or not.
// SendingConnection DOES NOT expect to receive any data - if it does, it is an error
// For a bunch of cases, read will return -1 in case remote socket is closed : hence we
// register for reads to determine that.
override def read(): Boolean = {
// We don't expect the other side to send anything; so, we just read to detect an error or EOF.
try {
val length = channel.read(ByteBuffer.allocate(1))
if (length == -1) { // EOF
close()
} else if (length > 0) {
logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId)
logWarning("Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
}
} catch {
case e: Exception =>
logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e)
logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
}
false
}
override def changeInterestForRead(): Boolean = false
override def changeInterestForWrite(): Boolean = true
}
// Must be created within selector loop - else deadlock
private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
extends Connection(channel_, selector_) {
@ -298,13 +367,13 @@ extends Connection(channel_, selector_) {
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
logDebug("Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
newMessage
}
val message = messages.getOrElseUpdate(header.id, createNewMessage)
logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
logTrace("Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
message.getChunkForReceiving(header.chunkSize)
}
@ -317,6 +386,26 @@ extends Connection(channel_, selector_) {
}
}
@volatile private var inferredRemoteManagerId: ConnectionManagerId = null
override def getRemoteConnectionManagerId(): ConnectionManagerId = {
val currId = inferredRemoteManagerId
if (currId != null) currId else super.getRemoteConnectionManagerId()
}
// The reciever's remote address is the local socket on remote side : which is NOT the connection manager id of the receiver.
// We infer that from the messages we receive on the receiver socket.
private def processConnectionManagerId(header: MessageChunkHeader) {
val currId = inferredRemoteManagerId
if (header.address == null || currId != null) return
val managerId = ConnectionManagerId.fromSocketAddress(header.address)
if (managerId != null) {
inferredRemoteManagerId = managerId
}
}
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
var onReceiveCallback: (Connection , Message) => Unit = null
@ -324,17 +413,18 @@ extends Connection(channel_, selector_) {
channel.register(selector, SelectionKey.OP_READ)
override def read() {
override def read(): Boolean = {
try {
while (true) {
if (currentChunk == null) {
val headerBytesRead = channel.read(headerBuffer)
if (headerBytesRead == -1) {
close()
return
return false
}
if (headerBuffer.remaining > 0) {
return
// re-register for read event ...
return true
}
headerBuffer.flip
if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
@ -342,6 +432,9 @@ extends Connection(channel_, selector_) {
}
val header = MessageChunkHeader.create(headerBuffer)
headerBuffer.clear()
processConnectionManagerId(header)
header.typ match {
case Message.BUFFER_MESSAGE => {
if (header.totalSize == 0) {
@ -349,7 +442,8 @@ extends Connection(channel_, selector_) {
onReceiveCallback(this, Message.create(header))
}
currentChunk = null
return
// re-register for read event ...
return true
} else {
currentChunk = inbox.getChunk(header).orNull
}
@ -362,10 +456,11 @@ extends Connection(channel_, selector_) {
val bytesRead = channel.read(currentChunk.buffer)
if (bytesRead == 0) {
return
// re-register for read event ...
return true
} else if (bytesRead == -1) {
close()
return
return false
}
/*logDebug("Read " + bytesRead + " bytes for the buffer")*/
@ -376,7 +471,7 @@ extends Connection(channel_, selector_) {
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
bufferMessage.finishTime = System.currentTimeMillis
logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
logDebug("Finished receiving [" + bufferMessage + "] from [" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
if (onReceiveCallback != null) {
onReceiveCallback(this, bufferMessage)
}
@ -387,12 +482,31 @@ extends Connection(channel_, selector_) {
}
} catch {
case e: Exception => {
logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
return false
}
}
// should not happen - to keep scala compiler happy
return true
}
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
override def changeInterestForRead(): Boolean = true
override def changeInterestForWrite(): Boolean = {
throw new IllegalStateException("Unexpected invocation right now")
}
override def registerInterest() {
// Registering read too - does not really help in most cases, but for some
// it does - so let us keep it for now.
changeConnectionKeyInterest(SelectionKey.OP_READ)
}
override def unregisterInterest() {
changeConnectionKeyInterest(0)
}
}

View file

@ -6,12 +6,12 @@ import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.net._
import java.util.concurrent.Executors
import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
import scala.collection.mutable.HashSet
import scala.collection.mutable.HashMap
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer
import akka.dispatch.{Await, Promise, ExecutionContext, Future}
@ -19,6 +19,10 @@ import akka.util.Duration
import akka.util.duration._
private[spark] case class ConnectionManagerId(host: String, port: Int) {
// DEBUG code
Utils.checkHost(host)
assert (port > 0)
def toSocketAddress() = new InetSocketAddress(host, port)
}
@ -42,19 +46,37 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def markDone() { completionHandler(this) }
}
val selector = SelectorProvider.provider.openSelector()
val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
val serverChannel = ServerSocketChannel.open()
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val messageStatuses = new HashMap[Int, MessageStatus]
val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)]
private val selector = SelectorProvider.provider.openSelector()
private val handleMessageExecutor = new ThreadPoolExecutor(
System.getProperty("spark.core.connection.handler.threads.min","20").toInt,
System.getProperty("spark.core.connection.handler.threads.max","60").toInt,
System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
private val handleReadWriteExecutor = new ThreadPoolExecutor(
System.getProperty("spark.core.connection.io.threads.min","4").toInt,
System.getProperty("spark.core.connection.io.threads.max","32").toInt,
System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
// Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
private val handleConnectExecutor = new ThreadPoolExecutor(
System.getProperty("spark.core.connection.connect.threads.min","1").toInt,
System.getProperty("spark.core.connection.connect.threads.max","8").toInt,
System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
private val serverChannel = ServerSocketChannel.open()
private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
private val messageStatuses = new HashMap[Int, MessageStatus]
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
@ -66,45 +88,138 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
val selectorThread = new Thread("connection-manager-thread") {
private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
selectorThread.setDaemon(true)
selectorThread.start()
private def run() {
private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
private def triggerWrite(key: SelectionKey) {
val conn = connectionsByKey.getOrElse(key, null)
if (conn == null) return
writeRunnableStarted.synchronized {
// So that we do not trigger more write events while processing this one.
// The write method will re-register when done.
if (conn.changeInterestForWrite()) conn.unregisterInterest()
if (writeRunnableStarted.contains(key)) {
// key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
return
}
writeRunnableStarted += key
}
handleReadWriteExecutor.execute(new Runnable {
override def run() {
var register: Boolean = false
try {
register = conn.write()
} finally {
writeRunnableStarted.synchronized {
writeRunnableStarted -= key
if (register && conn.changeInterestForWrite()) {
conn.registerInterest()
}
}
}
}
} )
}
private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
private def triggerRead(key: SelectionKey) {
val conn = connectionsByKey.getOrElse(key, null)
if (conn == null) return
readRunnableStarted.synchronized {
// So that we do not trigger more read events while processing this one.
// The read method will re-register when done.
if (conn.changeInterestForRead())conn.unregisterInterest()
if (readRunnableStarted.contains(key)) {
return
}
readRunnableStarted += key
}
handleReadWriteExecutor.execute(new Runnable {
override def run() {
var register: Boolean = false
try {
register = conn.read()
} finally {
readRunnableStarted.synchronized {
readRunnableStarted -= key
if (register && conn.changeInterestForRead()) {
conn.registerInterest()
}
}
}
}
} )
}
private def triggerConnect(key: SelectionKey) {
val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
if (conn == null) return
// prevent other events from being triggered
// Since we are still trying to connect, we do not need to do the additional steps in triggerWrite
conn.changeConnectionKeyInterest(0)
handleConnectExecutor.execute(new Runnable {
override def run() {
var tries: Int = 10
while (tries >= 0) {
if (conn.finishConnect(false)) return
// Sleep ?
Thread.sleep(1)
tries -= 1
}
// fallback to previous behavior : we should not really come here since this method was
// triggered since channel became connectable : but at times, the first finishConnect need not
// succeed : hence the loop to retry a few 'times'.
conn.finishConnect(true)
}
} )
}
def run() {
try {
while(!selectorThread.isInterrupted) {
for ((connectionManagerId, sendingConnection) <- connectionRequests) {
sendingConnection.connect()
addConnection(sendingConnection)
connectionRequests -= connectionManagerId
}
sendMessageRequests.synchronized {
while (!sendMessageRequests.isEmpty) {
val (message, connection) = sendMessageRequests.dequeue
connection.send(message)
}
while (! registerRequests.isEmpty) {
val conn: SendingConnection = registerRequests.dequeue
addListeners(conn)
conn.connect()
addConnection(conn)
}
while (!keyInterestChangeRequests.isEmpty) {
while(!keyInterestChangeRequests.isEmpty) {
val (key, ops) = keyInterestChangeRequests.dequeue
val connection = connectionsByKey(key)
val lastOps = key.interestOps()
key.interestOps(ops)
val connection = connectionsByKey.getOrElse(key, null)
if (connection != null) {
val lastOps = key.interestOps()
key.interestOps(ops)
def intToOpStr(op: Int): String = {
val opStrs = ArrayBuffer[String]()
if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
// hot loop - prevent materialization of string if trace not enabled.
if (isTraceEnabled()) {
def intToOpStr(op: Int): String = {
val opStrs = ArrayBuffer[String]()
if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
}
logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
"] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
}
}
logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
"] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
}
val selectedKeysCount = selector.select()
@ -123,12 +238,15 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
if (key.isValid) {
if (key.isAcceptable) {
acceptConnection(key)
} else if (key.isConnectable) {
connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
} else if (key.isReadable) {
connectionsByKey(key).read()
} else if (key.isWritable) {
connectionsByKey(key).write()
} else
if (key.isConnectable) {
triggerConnect(key)
} else
if (key.isReadable) {
triggerRead(key)
} else
if (key.isWritable) {
triggerWrite(key)
}
}
}
@ -138,94 +256,116 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
private def acceptConnection(key: SelectionKey) {
def acceptConnection(key: SelectionKey) {
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
val newChannel = serverChannel.accept()
val newConnection = new ReceivingConnection(newChannel, selector)
newConnection.onReceive(receiveMessage)
newConnection.onClose(removeConnection)
addConnection(newConnection)
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
var newChannel = serverChannel.accept()
// accept them all in a tight loop. non blocking accept with no processing, should be fine
while (newChannel != null) {
try {
val newConnection = new ReceivingConnection(newChannel, selector)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
} catch {
// might happen in case of issues with registering with selector
case e: Exception => logError("Error in accept loop", e)
}
newChannel = serverChannel.accept()
}
}
private def addConnection(connection: Connection) {
connectionsByKey += ((connection.key, connection))
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
}
private def addListeners(connection: Connection) {
connection.onKeyInterestChange(changeConnectionKeyInterest)
connection.onException(handleConnectionError)
connection.onClose(removeConnection)
}
private def removeConnection(connection: Connection) {
def addConnection(connection: Connection) {
connectionsByKey += ((connection.key, connection))
}
def removeConnection(connection: Connection) {
connectionsByKey -= connection.key
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
connectionsById -= sendingConnectionManagerId
try {
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
messageStatuses.synchronized {
messageStatuses
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.markDone()
}
connectionsById -= sendingConnectionManagerId
messageStatuses.synchronized {
messageStatuses
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.markDone()
}
})
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
} else if (connection.isInstanceOf[ReceivingConnection]) {
val receivingConnection = connection.asInstanceOf[ReceivingConnection]
val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
} else if (connection.isInstanceOf[ReceivingConnection]) {
val receivingConnection = connection.asInstanceOf[ReceivingConnection]
val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
if (sendingConnectionManagerId == null) {
logError("Corresponding SendingConnectionManagerId not found")
return
}
logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
val sendingConnection = connectionsById(sendingConnectionManagerId)
sendingConnection.close()
connectionsById -= sendingConnectionManagerId
messageStatuses.synchronized {
for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
s.synchronized {
s.attempted = true
s.acked = false
s.markDone()
}
val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
if (! sendingConnectionOpt.isDefined) {
logError("Corresponding SendingConnectionManagerId not found")
return
}
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
val sendingConnection = sendingConnectionOpt.get
connectionsById -= remoteConnectionManagerId
sendingConnection.close()
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
assert (sendingConnectionManagerId == remoteConnectionManagerId)
messageStatuses.synchronized {
for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
s.synchronized {
s.attempted = true
s.acked = false
s.markDone()
}
}
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
}
} finally {
// So that the selection keys can be removed.
wakeupSelector()
}
}
private def handleConnectionError(connection: Connection, e: Exception) {
logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
def handleConnectionError(connection: Connection, e: Exception) {
logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId())
removeConnection(connection)
}
private def changeConnectionKeyInterest(connection: Connection, ops: Int) {
def changeConnectionKeyInterest(connection: Connection, ops: Int) {
keyInterestChangeRequests += ((connection.key, ops))
// so that registerations happen !
wakeupSelector()
}
private def receiveMessage(connection: Connection, message: Message) {
def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
@ -293,18 +433,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
new SendingConnection(inetSocketAddress, selector, connectionManagerId))
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
registerRequests.enqueue(newConnection)
newConnection
}
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
val connection = connectionsById.getOrElse(lookupKey, startNewConnection())
// I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ...
// If we do re-add it, we should consistently use it everywhere I guess ?
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
message.senderAddress = id.toSocketAddress()
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
/*connection.send(message)*/
sendMessageRequests.synchronized {
sendMessageRequests += ((message, connection))
}
connection.send(message)
wakeupSelector()
}
private def wakeupSelector() {
selector.wakeup()
}
@ -337,6 +481,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logWarning("All connections not cleaned up")
}
handleMessageExecutor.shutdown()
handleReadWriteExecutor.shutdown()
handleConnectExecutor.shutdown()
logInfo("ConnectionManager stopped")
}
}

View file

@ -17,6 +17,7 @@ private[spark] class MessageChunkHeader(
val other: Int,
val address: InetSocketAddress) {
lazy val buffer = {
// No need to change this, at 'use' time, we do a reverse lookup of the hostname. Refer to network.Connection
val ip = address.getAddress.getAddress()
val port = address.getPort()
ByteBuffer.

View file

@ -50,6 +50,11 @@ class DAGScheduler(
eventQueue.put(ExecutorLost(execId))
}
// Called by TaskScheduler when a host is added
override def executorGained(execId: String, hostPort: String) {
eventQueue.put(ExecutorGained(execId, hostPort))
}
// Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
@ -113,7 +118,7 @@ class DAGScheduler(
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
cacheLocs(rdd.id) = blockManagerMaster.getLocations(blockIds).map {
locations => locations.map(_.ip).toList
locations => locations.map(_.hostPort).toList
}.toArray
}
cacheLocs(rdd.id)
@ -293,6 +298,9 @@ class DAGScheduler(
submitStage(finalStage)
}
case ExecutorGained(execId, hostPort) =>
handleExecutorGained(execId, hostPort)
case ExecutorLost(execId) =>
handleExecutorLost(execId)
@ -631,6 +639,14 @@ class DAGScheduler(
}
}
private def handleExecutorGained(execId: String, hostPort: String) {
// remove from failedGeneration(execId) ?
if (failedGeneration.contains(execId)) {
logInfo("Host gained which was in lost list earlier: " + hostPort)
failedGeneration -= execId
}
}
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.

View file

@ -32,6 +32,10 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
private[spark] case class ExecutorGained(execId: String, hostPort: String) extends DAGSchedulerEvent {
Utils.checkHostPort(hostPort, "Required hostport")
}
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent

View file

@ -0,0 +1,156 @@
package spark.scheduler
import spark.Logging
import scala.collection.immutable.Set
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.util.ReflectionUtils
import org.apache.hadoop.mapreduce.Job
import org.apache.hadoop.conf.Configuration
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.collection.JavaConversions._
/**
* Parses and holds information about inputFormat (and files) specified as a parameter.
*/
class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_],
val path: String) extends Logging {
var mapreduceInputFormat: Boolean = false
var mapredInputFormat: Boolean = false
validate()
override def toString(): String = {
"InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path
}
override def hashCode(): Int = {
var hashCode = inputFormatClazz.hashCode
hashCode = hashCode * 31 + path.hashCode
hashCode
}
// Since we are not doing canonicalization of path, this can be wrong : like relative vs absolute path
// .. which is fine, this is best case effort to remove duplicates - right ?
override def equals(other: Any): Boolean = other match {
case that: InputFormatInfo => {
// not checking config - that should be fine, right ?
this.inputFormatClazz == that.inputFormatClazz &&
this.path == that.path
}
case _ => false
}
private def validate() {
logDebug("validate InputFormatInfo : " + inputFormatClazz + ", path " + path)
try {
if (classOf[org.apache.hadoop.mapreduce.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
logDebug("inputformat is from mapreduce package")
mapreduceInputFormat = true
}
else if (classOf[org.apache.hadoop.mapred.InputFormat[_, _]].isAssignableFrom(inputFormatClazz)) {
logDebug("inputformat is from mapred package")
mapredInputFormat = true
}
else {
throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz +
" is NOT a supported input format ? does not implement either of the supported hadoop api's")
}
}
catch {
case e: ClassNotFoundException => {
throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e)
}
}
}
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
val conf = new JobConf(configuration)
FileInputFormat.setInputPaths(conf, path)
val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[
org.apache.hadoop.mapreduce.InputFormat[_, _]]
val job = new Job(conf)
val retval = new ArrayBuffer[SplitInfo]()
val list = instance.getSplits(job)
for (split <- list) {
retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, split)
}
return retval.toSet
}
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
val jobConf = new JobConf(configuration)
FileInputFormat.setInputPaths(jobConf, path)
val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], jobConf).asInstanceOf[
org.apache.hadoop.mapred.InputFormat[_, _]]
val retval = new ArrayBuffer[SplitInfo]()
instance.getSplits(jobConf, jobConf.getNumMapTasks()).foreach(
elem => retval ++= SplitInfo.toSplitInfo(inputFormatClazz, path, elem)
)
return retval.toSet
}
private def findPreferredLocations(): Set[SplitInfo] = {
logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat +
", inputFormatClazz : " + inputFormatClazz)
if (mapreduceInputFormat) {
return prefLocsFromMapreduceInputFormat()
}
else {
assert(mapredInputFormat)
return prefLocsFromMapredInputFormat()
}
}
}
object InputFormatInfo {
/**
Computes the preferred locations based on input(s) and returned a location to block map.
Typical use of this method for allocation would follow some algo like this
(which is what we currently do in YARN branch) :
a) For each host, count number of splits hosted on that host.
b) Decrement the currently allocated containers on that host.
c) Compute rack info for each host and update rack -> count map based on (b).
d) Allocate nodes based on (c)
e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node
(even if data locality on that is very high) : this is to prevent fragility of job if a single
(or small set of) hosts go down.
go to (a) until required nodes are allocated.
If a node 'dies', follow same procedure.
PS: I know the wording here is weird, hopefully it makes some sense !
*/
def computePreferredLocations(formats: Seq[InputFormatInfo]): HashMap[String, HashSet[SplitInfo]] = {
val nodeToSplit = new HashMap[String, HashSet[SplitInfo]]
for (inputSplit <- formats) {
val splits = inputSplit.findPreferredLocations()
for (split <- splits){
val location = split.hostLocation
val set = nodeToSplit.getOrElseUpdate(location, new HashSet[SplitInfo])
set += split
}
}
nodeToSplit
}
}

View file

@ -70,6 +70,14 @@ private[spark] class ResultTask[T, U](
rdd.partitions(partition)
}
// data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts.
val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq
{
// DEBUG code
preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs))
}
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId)
metrics = Some(context.taskMetrics)
@ -80,7 +88,7 @@ private[spark] class ResultTask[T, U](
}
}
override def preferredLocations: Seq[String] = locs
override def preferredLocations: Seq[String] = preferredLocs
override def toString = "ResultTask(" + stageId + ", " + partition + ")"

View file

@ -77,13 +77,21 @@ private[spark] class ShuffleMapTask(
var rdd: RDD[_],
var dep: ShuffleDependency[_,_],
var partition: Int,
@transient var locs: Seq[String])
@transient private var locs: Seq[String])
extends Task[MapStatus](stageId)
with Externalizable
with Logging {
protected def this() = this(0, null, null, 0, null)
// data locality is on a per host basis, not hyper specific to container (host:port). Unique on set of hosts.
private val preferredLocs: Seq[String] = if (locs == null) Nil else locs.map(loc => Utils.parseHostPort(loc)._1).toSet.toSeq
{
// DEBUG code
preferredLocs.foreach (host => Utils.checkHost(host, "preferredLocs : " + preferredLocs))
}
var split = if (rdd == null) {
null
} else {
@ -154,7 +162,7 @@ private[spark] class ShuffleMapTask(
}
}
override def preferredLocations: Seq[String] = locs
override def preferredLocations: Seq[String] = preferredLocs
override def toString = "ShuffleMapTask(%d, %d)".format(stageId, partition)
}

View file

@ -0,0 +1,61 @@
package spark.scheduler
import collection.mutable.ArrayBuffer
// information about a specific split instance : handles both split instances.
// So that we do not need to worry about the differences.
class SplitInfo(val inputFormatClazz: Class[_], val hostLocation: String, val path: String,
val length: Long, val underlyingSplit: Any) {
override def toString(): String = {
"SplitInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz +
", hostLocation : " + hostLocation + ", path : " + path +
", length : " + length + ", underlyingSplit " + underlyingSplit
}
override def hashCode(): Int = {
var hashCode = inputFormatClazz.hashCode
hashCode = hashCode * 31 + hostLocation.hashCode
hashCode = hashCode * 31 + path.hashCode
// ignore overflow ? It is hashcode anyway !
hashCode = hashCode * 31 + (length & 0x7fffffff).toInt
hashCode
}
// This is practically useless since most of the Split impl's dont seem to implement equals :-(
// So unless there is identity equality between underlyingSplits, it will always fail even if it
// is pointing to same block.
override def equals(other: Any): Boolean = other match {
case that: SplitInfo => {
this.hostLocation == that.hostLocation &&
this.inputFormatClazz == that.inputFormatClazz &&
this.path == that.path &&
this.length == that.length &&
// other split specific checks (like start for FileSplit)
this.underlyingSplit == that.underlyingSplit
}
case _ => false
}
}
object SplitInfo {
def toSplitInfo(inputFormatClazz: Class[_], path: String,
mapredSplit: org.apache.hadoop.mapred.InputSplit): Seq[SplitInfo] = {
val retval = new ArrayBuffer[SplitInfo]()
val length = mapredSplit.getLength
for (host <- mapredSplit.getLocations) {
retval += new SplitInfo(inputFormatClazz, host, path, length, mapredSplit)
}
retval
}
def toSplitInfo(inputFormatClazz: Class[_], path: String,
mapreduceSplit: org.apache.hadoop.mapreduce.InputSplit): Seq[SplitInfo] = {
val retval = new ArrayBuffer[SplitInfo]()
val length = mapreduceSplit.getLength
for (host <- mapreduceSplit.getLocations) {
retval += new SplitInfo(inputFormatClazz, host, path, length, mapreduceSplit)
}
retval
}
}

View file

@ -10,6 +10,10 @@ package spark.scheduler
private[spark] trait TaskScheduler {
def start(): Unit
// Invoked after system has successfully initialized (typically in spark context).
// Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc.
def postStartHook() { }
// Disconnect from the cluster.
def stop(): Unit

View file

@ -14,6 +14,9 @@ private[spark] trait TaskSchedulerListener {
def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
// A node was added to the cluster.
def executorGained(execId: String, hostPort: String): Unit
// A node was lost from the cluster.
def executorLost(execId: String): Unit

View file

@ -1,6 +1,6 @@
package spark.scheduler.cluster
import java.io.{File, FileInputStream, FileOutputStream}
import java.lang.{Boolean => JBoolean}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@ -25,6 +25,30 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
// Threshold above which we warn user initial TaskSet may be starved
val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
// How often to revive offers in case there are pending tasks - that is how often to try to get
// tasks scheduled in case there are nodes available : default 0 is to disable it - to preserve existing behavior
// Note that this is required due to delayed scheduling due to data locality waits, etc.
// TODO: rename property ?
val TASK_REVIVAL_INTERVAL = System.getProperty("spark.tasks.revive.interval", "0").toLong
/*
This property controls how aggressive we should be to modulate waiting for host local task scheduling.
To elaborate, currently there is a time limit (3 sec def) to ensure that spark attempts to wait for host locality of tasks before
scheduling on other nodes. We have modified this in yarn branch such that offers to task set happen in prioritized order :
host-local, rack-local and then others
But once all available host local (and no pref) tasks are scheduled, instead of waiting for 3 sec before
scheduling to other nodes (which degrades performance for time sensitive tasks and on larger clusters), we can
modulate that : to also allow rack local nodes or any node. The default is still set to HOST - so that previous behavior is
maintained. This is to allow tuning the tension between pulling rdd data off node and scheduling computation asap.
TODO: rename property ? The value is one of
- HOST_LOCAL (default, no change w.r.t current behavior),
- RACK_LOCAL and
- ANY
Note that this property makes more sense when used in conjugation with spark.tasks.revive.interval > 0 : else it is not very effective.
*/
val TASK_SCHEDULING_AGGRESSION = TaskLocality.parse(System.getProperty("spark.tasks.schedule.aggression", "HOST_LOCAL"))
val activeTaskSets = new HashMap[String, TaskSetManager]
var activeTaskSetsQueue = new ArrayBuffer[TaskSetManager]
@ -33,9 +57,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val taskIdToExecutorId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
var hasReceivedTask = false
var hasLaunchedTask = false
val starvationTimer = new Timer(true)
@volatile private var hasReceivedTask = false
@volatile private var hasLaunchedTask = false
private val starvationTimer = new Timer(true)
// Incrementing Mesos task IDs
val nextTaskId = new AtomicLong(0)
@ -43,11 +67,16 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
// Which executor IDs we have executors on
val activeExecutorIds = new HashSet[String]
// TODO: We might want to remove this and merge it with execId datastructures - but later.
// Which hosts in the cluster are alive (contains hostPort's)
private val hostPortsAlive = new HashSet[String]
private val hostToAliveHostPorts = new HashMap[String, HashSet[String]]
// The set of executors we have on each host; this is used to compute hostsAlive, which
// in turn is used to decide when we can attain data locality on a given host
val executorsByHost = new HashMap[String, HashSet[String]]
val executorsByHostPort = new HashMap[String, HashSet[String]]
val executorIdToHost = new HashMap[String, String]
val executorIdToHostPort = new HashMap[String, String]
// JAR server, if any JARs were added by the user to the SparkContext
var jarServer: HttpServer = null
@ -75,11 +104,12 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
override def start() {
backend.start()
if (System.getProperty("spark.speculation", "false") == "true") {
if (JBoolean.getBoolean("spark.speculation")) {
new Thread("ClusterScheduler speculation check") {
setDaemon(true)
override def run() {
logInfo("Starting speculative execution thread")
while (true) {
try {
Thread.sleep(SPECULATION_INTERVAL)
@ -91,6 +121,27 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}.start()
}
// Change to always run with some default if TASK_REVIVAL_INTERVAL <= 0 ?
if (TASK_REVIVAL_INTERVAL > 0) {
new Thread("ClusterScheduler task offer revival check") {
setDaemon(true)
override def run() {
logInfo("Starting speculative task offer revival thread")
while (true) {
try {
Thread.sleep(TASK_REVIVAL_INTERVAL)
} catch {
case e: InterruptedException => {}
}
if (hasPendingTasks()) backend.reviveOffers()
}
}
}.start()
}
}
override def submitTasks(taskSet: TaskSet) {
@ -139,22 +190,92 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
SparkEnv.set(sc.env)
// Mark each slave as alive and remember its hostname
for (o <- offers) {
executorIdToHost(o.executorId) = o.hostname
if (!executorsByHost.contains(o.hostname)) {
executorsByHost(o.hostname) = new HashSet()
// DEBUG Code
Utils.checkHostPort(o.hostPort)
executorIdToHostPort(o.executorId) = o.hostPort
if (! executorsByHostPort.contains(o.hostPort)) {
executorsByHostPort(o.hostPort) = new HashSet[String]()
}
hostPortsAlive += o.hostPort
hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(o.hostPort)._1, new HashSet[String]).add(o.hostPort)
executorGained(o.executorId, o.hostPort)
}
// Build a list of tasks to assign to each slave
val tasks = offers.map(o => new ArrayBuffer[TaskDescription](o.cores))
val availableCpus = offers.map(o => o.cores).toArray
var launchedTask = false
for (manager <- activeTaskSetsQueue.sortBy(m => (m.taskSet.priority, m.taskSet.stageId))) {
// Split offers based on host local, rack local and off-rack tasks.
val hostLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
val rackLocalOffers = new HashMap[String, ArrayBuffer[Int]]()
val otherOffers = new HashMap[String, ArrayBuffer[Int]]()
for (i <- 0 until offers.size) {
val hostPort = offers(i).hostPort
// DEBUG code
Utils.checkHostPort(hostPort)
val host = Utils.parseHostPort(hostPort)._1
val numHostLocalTasks = math.max(0, math.min(manager.numPendingTasksForHost(hostPort), availableCpus(i)))
if (numHostLocalTasks > 0){
val list = hostLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
for (j <- 0 until numHostLocalTasks) list += i
}
val numRackLocalTasks = math.max(0,
// Remove host local tasks (which are also rack local btw !) from this
math.min(manager.numRackLocalPendingTasksForHost(hostPort) - numHostLocalTasks, availableCpus(i)))
if (numRackLocalTasks > 0){
val list = rackLocalOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
for (j <- 0 until numRackLocalTasks) list += i
}
if (numHostLocalTasks <= 0 && numRackLocalTasks <= 0){
// add to others list - spread even this across cluster.
val list = otherOffers.getOrElseUpdate(host, new ArrayBuffer[Int])
list += i
}
}
val offersPriorityList = new ArrayBuffer[Int](
hostLocalOffers.size + rackLocalOffers.size + otherOffers.size)
// First host local, then rack, then others
val numHostLocalOffers = {
val hostLocalPriorityList = ClusterScheduler.prioritizeContainers(hostLocalOffers)
offersPriorityList ++= hostLocalPriorityList
hostLocalPriorityList.size
}
val numRackLocalOffers = {
val rackLocalPriorityList = ClusterScheduler.prioritizeContainers(rackLocalOffers)
offersPriorityList ++= rackLocalPriorityList
rackLocalPriorityList.size
}
offersPriorityList ++= ClusterScheduler.prioritizeContainers(otherOffers)
var lastLoop = false
val lastLoopIndex = TASK_SCHEDULING_AGGRESSION match {
case TaskLocality.HOST_LOCAL => numHostLocalOffers
case TaskLocality.RACK_LOCAL => numRackLocalOffers + numHostLocalOffers
case TaskLocality.ANY => offersPriorityList.size
}
do {
launchedTask = false
for (i <- 0 until offers.size) {
var loopCount = 0
for (i <- offersPriorityList) {
val execId = offers(i).executorId
val host = offers(i).hostname
manager.slaveOffer(execId, host, availableCpus(i)) match {
val hostPort = offers(i).hostPort
// If last loop and within the lastLoopIndex, expand scope - else use null (which will use default/existing)
val overrideLocality = if (lastLoop && loopCount < lastLoopIndex) TASK_SCHEDULING_AGGRESSION else null
// If last loop, override waiting for host locality - we scheduled all local tasks already and there might be more available ...
loopCount += 1
manager.slaveOffer(execId, hostPort, availableCpus(i), overrideLocality) match {
case Some(task) =>
tasks(i) += task
val tid = task.taskId
@ -162,15 +283,31 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
taskSetTaskIds(manager.taskSet.id) += tid
taskIdToExecutorId(tid) = execId
activeExecutorIds += execId
executorsByHost(host) += execId
executorsByHostPort(hostPort) += execId
availableCpus(i) -= 1
launchedTask = true
case None => {}
}
}
// Loop once more - when lastLoop = true, then we try to schedule task on all nodes irrespective of
// data locality (we still go in order of priority : but that would not change anything since
// if data local tasks had been available, we would have scheduled them already)
if (lastLoop) {
// prevent more looping
launchedTask = false
} else if (!lastLoop && !launchedTask) {
// Do this only if TASK_SCHEDULING_AGGRESSION != HOST_LOCAL
if (TASK_SCHEDULING_AGGRESSION != TaskLocality.HOST_LOCAL) {
// fudge launchedTask to ensure we loop once more
launchedTask = true
// dont loop anymore
lastLoop = true
}
}
} while (launchedTask)
}
if (tasks.size > 0) {
hasLaunchedTask = true
}
@ -256,10 +393,15 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (jarServer != null) {
jarServer.stop()
}
// sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
// TODO: Do something better !
Thread.sleep(5000L)
}
override def defaultParallelism() = backend.defaultParallelism()
// Check for speculatable tasks in all our active jobs.
def checkSpeculatableTasks() {
var shouldRevive = false
@ -273,12 +415,20 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
// Check for pending tasks in all our active jobs.
def hasPendingTasks(): Boolean = {
synchronized {
activeTaskSetsQueue.exists( _.hasPendingTasks() )
}
}
def executorLost(executorId: String, reason: ExecutorLossReason) {
var failedExecutor: Option[String] = None
synchronized {
if (activeExecutorIds.contains(executorId)) {
val host = executorIdToHost(executorId)
logError("Lost executor %s on %s: %s".format(executorId, host, reason))
val hostPort = executorIdToHostPort(executorId)
logError("Lost executor %s on %s: %s".format(executorId, hostPort, reason))
removeExecutor(executorId)
failedExecutor = Some(executorId)
} else {
@ -296,19 +446,95 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
/** Get a list of hosts that currently have executors */
def hostsAlive: scala.collection.Set[String] = executorsByHost.keySet
/** Remove an executor from all our data structures and mark it as lost */
private def removeExecutor(executorId: String) {
activeExecutorIds -= executorId
val host = executorIdToHost(executorId)
val execs = executorsByHost.getOrElse(host, new HashSet)
val hostPort = executorIdToHostPort(executorId)
if (hostPortsAlive.contains(hostPort)) {
// DEBUG Code
Utils.checkHostPort(hostPort)
hostPortsAlive -= hostPort
hostToAliveHostPorts.getOrElseUpdate(Utils.parseHostPort(hostPort)._1, new HashSet[String]).remove(hostPort)
}
val execs = executorsByHostPort.getOrElse(hostPort, new HashSet)
execs -= executorId
if (execs.isEmpty) {
executorsByHost -= host
executorsByHostPort -= hostPort
}
executorIdToHost -= executorId
activeTaskSetsQueue.foreach(_.executorLost(executorId, host))
executorIdToHostPort -= executorId
activeTaskSetsQueue.foreach(_.executorLost(executorId, hostPort))
}
def executorGained(execId: String, hostPort: String) {
listener.executorGained(execId, hostPort)
}
def getExecutorsAliveOnHost(host: String): Option[Set[String]] = {
val retval = hostToAliveHostPorts.get(host)
if (retval.isDefined) {
return Some(retval.get.toSet)
}
None
}
// By default, rack is unknown
def getRackForHost(value: String): Option[String] = None
// By default, (cached) hosts for rack is unknown
def getCachedHostsForRack(rack: String): Option[Set[String]] = None
}
object ClusterScheduler {
// Used to 'spray' available containers across the available set to ensure too many containers on same host
// are not used up. Used in yarn mode and in task scheduling (when there are multiple containers available
// to execute a task)
// For example: yarn can returns more containers than we would have requested under ANY, this method
// prioritizes how to use the allocated containers.
// flatten the map such that the array buffer entries are spread out across the returned value.
// given <host, list[container]> == <h1, [c1 .. c5]>, <h2, [c1 .. c3]>, <h3, [c1, c2]>, <h4, c1>, <h5, c1>, i
// the return value would be something like : h1c1, h2c1, h3c1, h4c1, h5c1, h1c2, h2c2, h3c2, h1c3, h2c3, h1c4, h1c5
// We then 'use' the containers in this order (consuming only the top K from this list where
// K = number to be user). This is to ensure that if we have multiple eligible allocations,
// they dont end up allocating all containers on a small number of hosts - increasing probability of
// multiple container failure when a host goes down.
// Note, there is bias for keys with higher number of entries in value to be picked first (by design)
// Also note that invocation of this method is expected to have containers of same 'type'
// (host-local, rack-local, off-rack) and not across types : so that reordering is simply better from
// the available list - everything else being same.
// That is, we we first consume data local, then rack local and finally off rack nodes. So the
// prioritization from this method applies to within each category
def prioritizeContainers[K, T] (map: HashMap[K, ArrayBuffer[T]]): List[T] = {
val _keyList = new ArrayBuffer[K](map.size)
_keyList ++= map.keys
// order keyList based on population of value in map
val keyList = _keyList.sortWith(
(left, right) => map.get(left).getOrElse(Set()).size > map.get(right).getOrElse(Set()).size
)
val retval = new ArrayBuffer[T](keyList.size * 2)
var index = 0
var found = true
while (found){
found = false
for (key <- keyList) {
val containerList: ArrayBuffer[T] = map.get(key).getOrElse(null)
assert(containerList != null)
// Get the index'th entry for this host - if present
if (index < containerList.size){
retval += containerList.apply(index)
found = true
}
}
index += 1
}
retval.toList
}
}

View file

@ -27,7 +27,7 @@ private[spark] class SparkDeploySchedulerBackend(
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
StandaloneSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTPORT}}", "{{CORES}}")
val command = Command("spark.executor.StandaloneExecutorBackend", args, sc.executorEnvs)
val sparkHome = sc.getSparkHome().getOrElse(
throw new IllegalArgumentException("must supply spark home for spark standalone"))
@ -57,9 +57,9 @@ private[spark] class SparkDeploySchedulerBackend(
}
}
override def executorAdded(executorId: String, workerId: String, host: String, cores: Int, memory: Int) {
logInfo("Granted executor ID %s on host %s with %d cores, %s RAM".format(
executorId, host, cores, Utils.memoryMegabytesToString(memory)))
override def executorAdded(executorId: String, workerId: String, hostPort: String, cores: Int, memory: Int) {
logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format(
executorId, hostPort, cores, Utils.memoryMegabytesToString(memory)))
}
override def executorRemoved(executorId: String, message: String, exitStatus: Option[Int]) {

View file

@ -3,6 +3,7 @@ package spark.scheduler.cluster
import spark.TaskState.TaskState
import java.nio.ByteBuffer
import spark.util.SerializableBuffer
import spark.Utils
private[spark] sealed trait StandaloneClusterMessage extends Serializable
@ -19,8 +20,10 @@ case class RegisterExecutorFailed(message: String) extends StandaloneClusterMess
// Executors to driver
private[spark]
case class RegisterExecutor(executorId: String, host: String, cores: Int)
extends StandaloneClusterMessage
case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
extends StandaloneClusterMessage {
Utils.checkHostPort(hostPort, "Expected host port")
}
private[spark]
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, data: SerializableBuffer)

View file

@ -5,8 +5,9 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import akka.actor._
import akka.util.duration._
import akka.pattern.ask
import akka.util.Duration
import spark.{SparkException, Logging, TaskState}
import spark.{Utils, SparkException, Logging, TaskState}
import akka.dispatch.Await
import java.util.concurrent.atomic.AtomicInteger
import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
@ -24,12 +25,12 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
var totalCoreCount = new AtomicInteger(0)
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor {
val executorActor = new HashMap[String, ActorRef]
val executorAddress = new HashMap[String, Address]
val executorHost = new HashMap[String, String]
val freeCores = new HashMap[String, Int]
val actorToExecutorId = new HashMap[ActorRef, String]
val addressToExecutorId = new HashMap[Address, String]
private val executorActor = new HashMap[String, ActorRef]
private val executorAddress = new HashMap[String, Address]
private val executorHostPort = new HashMap[String, String]
private val freeCores = new HashMap[String, Int]
private val actorToExecutorId = new HashMap[ActorRef, String]
private val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@ -37,7 +38,8 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
}
def receive = {
case RegisterExecutor(executorId, host, cores) =>
case RegisterExecutor(executorId, hostPort, cores) =>
Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
if (executorActor.contains(executorId)) {
sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
} else {
@ -45,7 +47,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
sender ! RegisteredExecutor(sparkProperties)
context.watch(sender)
executorActor(executorId) = sender
executorHost(executorId) = host
executorHostPort(executorId) = hostPort
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
actorToExecutorId(sender) = executorId
@ -85,13 +87,13 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
// Make fake resource offers on all executors
def makeOffers() {
launchTasks(scheduler.resourceOffers(
executorHost.toArray.map {case (id, host) => new WorkerOffer(id, host, freeCores(id))}))
executorHostPort.toArray.map {case (id, hostPort) => new WorkerOffer(id, hostPort, freeCores(id))}))
}
// Make fake resource offers on just one executor
def makeOffers(executorId: String) {
launchTasks(scheduler.resourceOffers(
Seq(new WorkerOffer(executorId, executorHost(executorId), freeCores(executorId)))))
Seq(new WorkerOffer(executorId, executorHostPort(executorId), freeCores(executorId)))))
}
// Launch tasks returned by a set of resource offers
@ -110,9 +112,9 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
actorToExecutorId -= executorActor(executorId)
addressToExecutorId -= executorAddress(executorId)
executorActor -= executorId
executorHost -= executorId
executorHostPort -= executorId
freeCores -= executorId
executorHost -= executorId
executorHostPort -= executorId
totalCoreCount.addAndGet(-numCores)
scheduler.executorLost(executorId, SlaveLost(reason))
}
@ -128,7 +130,7 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
while (iterator.hasNext) {
val entry = iterator.next
val (key, value) = (entry.getKey.toString, entry.getValue.toString)
if (key.startsWith("spark.")) {
if (key.startsWith("spark.") && !key.equals("spark.hostPort")) {
properties += ((key, value))
}
}
@ -136,10 +138,11 @@ class StandaloneSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Actor
Props(new DriverActor(properties)), name = StandaloneSchedulerBackend.ACTOR_NAME)
}
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
override def stop() {
try {
if (driverActor != null) {
val timeout = 5.seconds
val future = driverActor.ask(StopDriver)(timeout)
Await.result(future, timeout)
}

View file

@ -1,5 +1,7 @@
package spark.scheduler.cluster
import spark.Utils
/**
* Information about a running task attempt inside a TaskSet.
*/
@ -9,8 +11,11 @@ class TaskInfo(
val index: Int,
val launchTime: Long,
val executorId: String,
val host: String,
val preferred: Boolean) {
val hostPort: String,
val taskLocality: TaskLocality.TaskLocality) {
Utils.checkHostPort(hostPort, "Expected hostport")
var finishTime: Long = 0
var failed = false

View file

@ -1,7 +1,6 @@
package spark.scheduler.cluster
import java.util.Arrays
import java.util.{HashMap => JHashMap}
import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@ -14,6 +13,36 @@ import spark.scheduler._
import spark.TaskState.TaskState
import java.nio.ByteBuffer
private[spark] object TaskLocality extends Enumeration("HOST_LOCAL", "RACK_LOCAL", "ANY") with Logging {
val HOST_LOCAL, RACK_LOCAL, ANY = Value
type TaskLocality = Value
def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
constraint match {
case TaskLocality.HOST_LOCAL => condition == TaskLocality.HOST_LOCAL
case TaskLocality.RACK_LOCAL => condition == TaskLocality.HOST_LOCAL || condition == TaskLocality.RACK_LOCAL
// For anything else, allow
case _ => true
}
}
def parse(str: String): TaskLocality = {
// better way to do this ?
try {
TaskLocality.withName(str)
} catch {
case nEx: NoSuchElementException => {
logWarning("Invalid task locality specified '" + str + "', defaulting to HOST_LOCAL");
// default to preserve earlier behavior
HOST_LOCAL
}
}
}
}
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
@ -47,14 +76,22 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
// Last time when we launched a preferred task (for delay scheduling)
var lastPreferredLaunchTime = System.currentTimeMillis
// List of pending tasks for each node. These collections are actually
// List of pending tasks for each node (hyper local to container). These collections are actually
// treated as stacks, in which new tasks are added to the end of the
// ArrayBuffer and removed from the end. This makes it faster to detect
// tasks that repeatedly fail because whenever a task failed, it is put
// back at the head of the stack. They are also only cleaned up lazily;
// when a task is launched, it remains in all the pending lists except
// the one that it was launched from, but gets removed from them later.
val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
// List of pending tasks for each node.
// Essentially, similar to pendingTasksForHostPort, except at host level
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List of pending tasks for each node based on rack locality.
// Essentially, similar to pendingTasksForHost, except at rack level
private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List containing pending tasks with no locality preferences
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
@ -96,26 +133,117 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
addPendingTask(i)
}
private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler, rackLocal: Boolean = false): ArrayBuffer[String] = {
// DEBUG code
_taskPreferredLocations.foreach(h => Utils.checkHost(h, "taskPreferredLocation " + _taskPreferredLocations))
val taskPreferredLocations = if (! rackLocal) _taskPreferredLocations else {
// Expand set to include all 'seen' rack local hosts.
// This works since container allocation/management happens within master - so any rack locality information is updated in msater.
// Best case effort, and maybe sort of kludge for now ... rework it later ?
val hosts = new HashSet[String]
_taskPreferredLocations.foreach(h => {
val rackOpt = scheduler.getRackForHost(h)
if (rackOpt.isDefined) {
val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
if (hostsOpt.isDefined) {
hosts ++= hostsOpt.get
}
}
// Ensure that irrespective of what scheduler says, host is always added !
hosts += h
})
hosts
}
val retval = new ArrayBuffer[String]
scheduler.synchronized {
for (prefLocation <- taskPreferredLocations) {
val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(prefLocation)
if (aliveLocationsOpt.isDefined) {
retval ++= aliveLocationsOpt.get
}
}
}
retval
}
// Add a task to all the pending-task lists that it should be on.
private def addPendingTask(index: Int) {
val locations = tasks(index).preferredLocations.toSet & sched.hostsAlive
if (locations.size == 0) {
// We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
// hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched)
val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, true)
if (rackLocalLocations.size == 0) {
// Current impl ensures this.
assert (hostLocalLocations.size == 0)
pendingTasksWithNoPrefs += index
} else {
for (host <- locations) {
val list = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
// host locality
for (hostPort <- hostLocalLocations) {
// DEBUG Code
Utils.checkHostPort(hostPort)
val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
hostPortList += index
val host = Utils.parseHostPort(hostPort)._1
val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
hostList += index
}
// rack locality
for (rackLocalHostPort <- rackLocalLocations) {
// DEBUG Code
Utils.checkHostPort(rackLocalHostPort)
val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
list += index
}
}
allPendingTasks += index
}
// Return the pending tasks list for a given host port (hyper local), or an empty list if
// there is no map entry for that host
private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
// DEBUG Code
Utils.checkHostPort(hostPort)
pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
}
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
private def getPendingTasksForHost(host: String): ArrayBuffer[Int] = {
private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
val host = Utils.parseHostPort(hostPort)._1
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
// Return the pending tasks (rack level) list for a given host, or an empty list if
// there is no map entry for that host
private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
val host = Utils.parseHostPort(hostPort)._1
pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
}
// Number of pending tasks for a given host (which would be data local)
def numPendingTasksForHost(hostPort: String): Int = {
getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Number of pending rack local tasks for a given host
def numRackLocalPendingTasksForHost(hostPort: String): Int = {
getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Dequeue a pending task from the given list and return its index.
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
@ -132,26 +260,49 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
}
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if localOnly is set, the
// task must have a preference for this host (or no preferred locations at all).
private def findSpeculativeTask(host: String, localOnly: Boolean): Option[Int] = {
val hostsAlive = sched.hostsAlive
// attempt running on this host, in case the host is slow. In addition, if locality is set, the
// task must have a preference for this host/rack/no preferred locations at all.
private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
assert (TaskLocality.isAllowed(locality, TaskLocality.HOST_LOCAL))
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
val localTask = speculatableTasks.find {
index =>
val locations = tasks(index).preferredLocations.toSet & hostsAlive
val attemptLocs = taskAttempts(index).map(_.host)
(locations.size == 0 || locations.contains(host)) && !attemptLocs.contains(host)
if (speculatableTasks.size > 0) {
val localTask = speculatableTasks.find {
index =>
val locations = findPreferredLocations(tasks(index).preferredLocations, sched)
val attemptLocs = taskAttempts(index).map(_.hostPort)
(locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
}
if (localTask != None) {
speculatableTasks -= localTask.get
return localTask
}
if (localTask != None) {
speculatableTasks -= localTask.get
return localTask
}
if (!localOnly && speculatableTasks.size > 0) {
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.host).contains(host))
if (nonLocalTask != None) {
speculatableTasks -= nonLocalTask.get
return nonLocalTask
// check for rack locality
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
val rackTask = speculatableTasks.find {
index =>
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, true)
val attemptLocs = taskAttempts(index).map(_.hostPort)
locations.contains(hostPort) && !attemptLocs.contains(hostPort)
}
if (rackTask != None) {
speculatableTasks -= rackTask.get
return rackTask
}
}
// Any task ...
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
// Check for attemptLocs also ?
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
if (nonLocalTask != None) {
speculatableTasks -= nonLocalTask.get
return nonLocalTask
}
}
}
return None
@ -159,59 +310,103 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
private def findTask(host: String, localOnly: Boolean): Option[Int] = {
val localTask = findTaskFromList(getPendingTasksForHost(host))
private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
if (localTask != None) {
return localTask
}
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
if (rackLocalTask != None) {
return rackLocalTask
}
}
// Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
// TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
if (noPrefTask != None) {
return noPrefTask
}
if (!localOnly) {
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
val nonLocalTask = findTaskFromList(allPendingTasks)
if (nonLocalTask != None) {
return nonLocalTask
}
}
// Finally, if all else has failed, find a speculative task
return findSpeculativeTask(host, localOnly)
return findSpeculativeTask(hostPort, locality)
}
// Does a host count as a preferred location for a task? This is true if
// either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
private def isPreferredLocation(task: Task[_], host: String): Boolean = {
private def isPreferredLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
return (locs.contains(host) || locs.isEmpty)
// DEBUG code
locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs))
if (locs.contains(hostPort) || locs.isEmpty) return true
val host = Utils.parseHostPort(hostPort)._1
locs.contains(host)
}
// Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
// This is true if either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
// DEBUG code
locs.foreach(h => Utils.checkHost(h, "preferredLocation " + locs))
val preferredRacks = new HashSet[String]()
for (preferredHost <- locs) {
val rack = sched.getRackForHost(preferredHost)
if (None != rack) preferredRacks += rack.get
}
if (preferredRacks.isEmpty) return false
val hostRack = sched.getRackForHost(hostPort)
return None != hostRack && preferredRacks.contains(hostRack.get)
}
// Respond to an offer of a single slave from the scheduler by finding a task
def slaveOffer(execId: String, host: String, availableCpus: Double): Option[TaskDescription] = {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
val time = System.currentTimeMillis
val localOnly = (time - lastPreferredLaunchTime < LOCALITY_WAIT)
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
findTask(host, localOnly) match {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
// If explicitly specified, use that
val locality = if (overrideLocality != null) overrideLocality else {
// expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
val time = System.currentTimeMillis
if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.HOST_LOCAL else TaskLocality.ANY
}
findTask(hostPort, locality) match {
case Some(index) => {
// Found a task; do some bookkeeping and return a Mesos task for it
val task = tasks(index)
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val preferred = isPreferredLocation(task, host)
val prefStr = if (preferred) {
"preferred"
} else {
"non-preferred, not one of " + task.preferredLocations.mkString(", ")
}
logInfo("Starting task %s:%d as TID %s on executor %s: %s (%s)".format(
taskSet.id, index, taskId, execId, host, prefStr))
val taskLocality = if (isPreferredLocation(task, hostPort)) TaskLocality.HOST_LOCAL else
if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else TaskLocality.ANY
val prefStr = taskLocality.toString
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
taskSet.id, index, taskId, execId, hostPort, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
val info = new TaskInfo(taskId, index, time, execId, host, preferred)
val time = System.currentTimeMillis
val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (preferred) {
if (TaskLocality.HOST_LOCAL == taskLocality) {
lastPreferredLaunchTime = time
}
// Serialize and return the task
@ -355,17 +550,15 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
sched.taskSetFinished(this)
}
def executorLost(execId: String, hostname: String) {
def executorLost(execId: String, hostPort: String) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
val newHostsAlive = sched.hostsAlive
// If some task has preferred locations only on hostname, and there are no more executors there,
// put it in the no-prefs list to avoid the wait from delay scheduling
if (!newHostsAlive.contains(hostname)) {
for (index <- getPendingTasksForHost(hostname)) {
val newLocs = tasks(index).preferredLocations.toSet & newHostsAlive
if (newLocs.isEmpty) {
pendingTasksWithNoPrefs += index
}
for (index <- getPendingTasksForHostPort(hostPort)) {
val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, true)
if (newLocs.isEmpty) {
assert (findPreferredLocations(tasks(index).preferredLocations, sched).isEmpty)
pendingTasksWithNoPrefs += index
}
}
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
@ -419,7 +612,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
taskSet.id, index, info.host, threshold))
taskSet.id, index, info.hostPort, threshold))
speculatableTasks += index
foundTasks = true
}
@ -427,4 +620,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
}
return foundTasks
}
def hasPendingTasks(): Boolean = {
numTasks > 0 && tasksFinished < numTasks
}
}

View file

@ -4,5 +4,5 @@ package spark.scheduler.cluster
* Represents free resources available on an executor.
*/
private[spark]
class WorkerOffer(val executorId: String, val hostname: String, val cores: Int) {
class WorkerOffer(val executorId: String, val hostPort: String, val cores: Int) {
}

View file

@ -7,7 +7,7 @@ import scala.collection.mutable.HashMap
import spark._
import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
import spark.scheduler.cluster.TaskInfo
import spark.scheduler.cluster.{TaskLocality, TaskInfo}
/**
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
@ -53,7 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
logInfo("Running " + task)
val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true)
val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.HOST_LOCAL)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {

View file

@ -37,17 +37,27 @@ class BlockManager(
maxMemory: Long)
extends Logging {
class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
var pending: Boolean = true
var size: Long = -1L
var failed: Boolean = false
private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
@volatile var pending: Boolean = true
@volatile var size: Long = -1L
@volatile var initThread: Thread = null
@volatile var failed = false
setInitThread()
private def setInitThread() {
// Set current thread as init thread - waitForReady will not block this thread
// (in case there is non trivial initialization which ends up calling waitForReady as part of
// initialization itself)
this.initThread = Thread.currentThread()
}
/**
* Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
* Return true if the block is available, false otherwise.
*/
def waitForReady(): Boolean = {
if (pending) {
if (initThread != Thread.currentThread() && pending) {
synchronized {
while (pending) this.wait()
}
@ -57,19 +67,26 @@ class BlockManager(
/** Mark this BlockInfo as ready (i.e. block is finished writing) */
def markReady(sizeInBytes: Long) {
assert (pending)
size = sizeInBytes
initThread = null
failed = false
initThread = null
pending = false
synchronized {
pending = false
failed = false
size = sizeInBytes
this.notifyAll()
}
}
/** Mark this BlockInfo as ready but failed */
def markFailure() {
assert (pending)
size = 0
initThread = null
failed = true
initThread = null
pending = false
synchronized {
failed = true
pending = false
this.notifyAll()
}
}
@ -101,7 +118,7 @@ class BlockManager(
val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties
val host = System.getProperty("spark.hostname", Utils.localHostName())
val hostPort = Utils.localHostPort()
val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
@ -212,9 +229,12 @@ class BlockManager(
* Tell the master about the current storage status of a block. This will send a block update
* message reflecting the current status, *not* the desired storage level in its block info.
* For example, a block with MEMORY_AND_DISK set might have fallen out to be only on disk.
*
* droppedMemorySize exists to account for when block is dropped from memory to disk (so it is still valid).
* This ensures that update in master will compensate for the increase in memory on slave.
*/
def reportBlockStatus(blockId: String, info: BlockInfo) {
val needReregister = !tryToReportBlockStatus(blockId, info)
def reportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L) {
val needReregister = !tryToReportBlockStatus(blockId, info, droppedMemorySize)
if (needReregister) {
logInfo("Got told to reregister updating block " + blockId)
// Reregistering will report our new block for free.
@ -228,7 +248,7 @@ class BlockManager(
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
private def tryToReportBlockStatus(blockId: String, info: BlockInfo): Boolean = {
private def tryToReportBlockStatus(blockId: String, info: BlockInfo, droppedMemorySize: Long = 0L): Boolean = {
val (curLevel, inMemSize, onDiskSize, tellMaster) = info.synchronized {
info.level match {
case null =>
@ -237,7 +257,7 @@ class BlockManager(
val inMem = level.useMemory && memoryStore.contains(blockId)
val onDisk = level.useDisk && diskStore.contains(blockId)
val storageLevel = StorageLevel(onDisk, inMem, level.deserialized, level.replication)
val memSize = if (inMem) memoryStore.getSize(blockId) else 0L
val memSize = if (inMem) memoryStore.getSize(blockId) else droppedMemorySize
val diskSize = if (onDisk) diskStore.getSize(blockId) else 0L
(storageLevel, memSize, diskSize, info.tellMaster)
}
@ -257,7 +277,7 @@ class BlockManager(
def getLocations(blockId: String): Seq[String] = {
val startTimeMs = System.currentTimeMillis
var managers = master.getLocations(blockId)
val locations = managers.map(_.ip)
val locations = managers.map(_.hostPort)
logDebug("Got block locations in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@ -267,7 +287,7 @@ class BlockManager(
*/
def getLocations(blockIds: Array[String]): Array[Seq[String]] = {
val startTimeMs = System.currentTimeMillis
val locations = master.getLocations(blockIds).map(_.map(_.ip).toSeq).toArray
val locations = master.getLocations(blockIds).map(_.map(_.hostPort).toSeq).toArray
logDebug("Got multiple block location in " + Utils.getUsedTimeMs(startTimeMs))
return locations
}
@ -339,6 +359,8 @@ class BlockManager(
case Some(bytes) =>
// Put a copy of the block back in memory before returning it. Note that we can't
// put the ByteBuffer returned by the disk store as that's a memory-mapped file.
// The use of rewind assumes this.
assert (0 == bytes.position())
val copyForMemory = ByteBuffer.allocate(bytes.limit)
copyForMemory.put(bytes)
memoryStore.putBytes(blockId, copyForMemory, level)
@ -411,6 +433,7 @@ class BlockManager(
// Read it as a byte buffer into memory first, then return it
diskStore.getBytes(blockId) match {
case Some(bytes) =>
assert (0 == bytes.position())
if (level.useMemory) {
if (level.deserialized) {
memoryStore.putBytes(blockId, bytes, level)
@ -450,7 +473,7 @@ class BlockManager(
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.ip, loc.port))
GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
return Some(dataDeserialize(blockId, data))
}
@ -501,17 +524,17 @@ class BlockManager(
throw new IllegalArgumentException("Storage level is null or invalid")
}
val oldBlock = blockInfo.get(blockId).orNull
if (oldBlock != null && oldBlock.waitForReady()) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
return oldBlock.size
}
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
val myInfo = new BlockInfo(level, tellMaster)
blockInfo.put(blockId, myInfo)
// Do atomically !
val oldBlockOpt = blockInfo.putIfAbsent(blockId, myInfo)
if (oldBlockOpt.isDefined && oldBlockOpt.get.waitForReady()) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
return oldBlockOpt.get.size
}
val startTimeMs = System.currentTimeMillis
@ -531,6 +554,7 @@ class BlockManager(
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
var marked = false
try {
if (level.useMemory) {
// Save it just to memory first, even if it also has useDisk set to true; we will later
@ -555,20 +579,20 @@ class BlockManager(
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
marked = true
myInfo.markReady(size)
if (tellMaster) {
reportBlockStatus(blockId, myInfo)
}
} catch {
} finally {
// If we failed at putting the block to memory/disk, notify other possible readers
// that it has failed, and then remove it from the block info map.
case e: Exception => {
if (! marked) {
// Note that the remove must happen before markFailure otherwise another thread
// could've inserted a new BlockInfo before we remove it.
blockInfo.remove(blockId)
myInfo.markFailure()
logWarning("Putting block " + blockId + " failed", e)
throw e
logWarning("Putting block " + blockId + " failed")
}
}
}
@ -611,16 +635,17 @@ class BlockManager(
throw new IllegalArgumentException("Storage level is null or invalid")
}
if (blockInfo.contains(blockId)) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
return
}
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
val myInfo = new BlockInfo(level, tellMaster)
blockInfo.put(blockId, myInfo)
// Do atomically !
val prevInfo = blockInfo.putIfAbsent(blockId, myInfo)
if (prevInfo != null) {
// Should we check for prevInfo.waitForReady() here ?
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
return
}
val startTimeMs = System.currentTimeMillis
@ -639,6 +664,7 @@ class BlockManager(
logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
var marked = false
try {
if (level.useMemory) {
// Store it only in memory at first, even if useDisk is also set to true
@ -649,22 +675,24 @@ class BlockManager(
diskStore.putBytes(blockId, bytes, level)
}
// assert (0 == bytes.position(), "" + bytes)
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
marked = true
myInfo.markReady(bytes.limit)
if (tellMaster) {
reportBlockStatus(blockId, myInfo)
}
} catch {
} finally {
// If we failed at putting the block to memory/disk, notify other possible readers
// that it has failed, and then remove it from the block info map.
case e: Exception => {
if (! marked) {
// Note that the remove must happen before markFailure otherwise another thread
// could've inserted a new BlockInfo before we remove it.
blockInfo.remove(blockId)
myInfo.markFailure()
logWarning("Putting block " + blockId + " failed", e)
throw e
logWarning("Putting block " + blockId + " failed")
}
}
}
@ -698,7 +726,7 @@ class BlockManager(
logDebug("Try to replicate BlockId " + blockId + " once; The size of the data is "
+ data.limit() + " Bytes. To node: " + peer)
if (!BlockManagerWorker.syncPutBlock(PutBlock(blockId, data, tLevel),
new ConnectionManagerId(peer.ip, peer.port))) {
new ConnectionManagerId(peer.host, peer.port))) {
logError("Failed to call syncPutBlock to " + peer)
}
logDebug("Replicated BlockId " + blockId + " once used " +
@ -730,6 +758,14 @@ class BlockManager(
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
// required ? As of now, this will be invoked only for blocks which are ready
// But in case this changes in future, adding for consistency sake.
if (! info.waitForReady() ) {
// If we get here, the block write failed.
logWarning("Block " + blockId + " was marked as failure. Nothing to drop")
return
}
val level = info.level
if (level.useDisk && !diskStore.contains(blockId)) {
logInfo("Writing block " + blockId + " to disk")
@ -740,12 +776,13 @@ class BlockManager(
diskStore.putBytes(blockId, bytes, level)
}
}
val droppedMemorySize = memoryStore.getSize(blockId)
val blockWasRemoved = memoryStore.remove(blockId)
if (!blockWasRemoved) {
logWarning("Block " + blockId + " could not be dropped from memory as it does not exist")
}
if (info.tellMaster) {
reportBlockStatus(blockId, info)
reportBlockStatus(blockId, info, droppedMemorySize)
}
if (!level.useDisk) {
// The block is completely gone from this node; forget it so we can put() it again later.
@ -938,8 +975,8 @@ class BlockFetcherIterator(
def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
req.blocks.size, Utils.memoryBytesToString(req.size), req.address.hostPort))
val cmId = new ConnectionManagerId(req.address.host, req.address.port)
val blockMessageArray = new BlockMessageArray(req.blocks.map {
case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
})

View file

@ -2,6 +2,7 @@ package spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
import spark.Utils
/**
* This class represent an unique identifier for a BlockManager.
@ -13,7 +14,7 @@ import java.util.concurrent.ConcurrentHashMap
*/
private[spark] class BlockManagerId private (
private var executorId_ : String,
private var ip_ : String,
private var host_ : String,
private var port_ : Int
) extends Externalizable {
@ -21,32 +22,45 @@ private[spark] class BlockManagerId private (
def executorId: String = executorId_
def ip: String = ip_
if (null != host_){
Utils.checkHost(host_, "Expected hostname")
assert (port_ > 0)
}
def hostPort: String = {
// DEBUG code
Utils.checkHost(host)
assert (port > 0)
host + ":" + port
}
def host: String = host_
def port: Int = port_
override def writeExternal(out: ObjectOutput) {
out.writeUTF(executorId_)
out.writeUTF(ip_)
out.writeUTF(host_)
out.writeInt(port_)
}
override def readExternal(in: ObjectInput) {
executorId_ = in.readUTF()
ip_ = in.readUTF()
host_ = in.readUTF()
port_ = in.readInt()
}
@throws(classOf[IOException])
private def readResolve(): Object = BlockManagerId.getCachedBlockManagerId(this)
override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, ip, port)
override def toString = "BlockManagerId(%s, %s, %d)".format(executorId, host, port)
override def hashCode: Int = (executorId.hashCode * 41 + ip.hashCode) * 41 + port
override def hashCode: Int = (executorId.hashCode * 41 + host.hashCode) * 41 + port
override def equals(that: Any) = that match {
case id: BlockManagerId =>
executorId == id.executorId && port == id.port && ip == id.ip
executorId == id.executorId && port == id.port && host == id.host
case _ =>
false
}
@ -55,8 +69,8 @@ private[spark] class BlockManagerId private (
private[spark] object BlockManagerId {
def apply(execId: String, ip: String, port: Int) =
getCachedBlockManagerId(new BlockManagerId(execId, ip, port))
def apply(execId: String, host: String, port: Int) =
getCachedBlockManagerId(new BlockManagerId(execId, host, port))
def apply(in: ObjectInput) = {
val obj = new BlockManagerId()
@ -67,11 +81,7 @@ private[spark] object BlockManagerId {
val blockManagerIdCache = new ConcurrentHashMap[BlockManagerId, BlockManagerId]()
def getCachedBlockManagerId(id: BlockManagerId): BlockManagerId = {
if (blockManagerIdCache.containsKey(id)) {
blockManagerIdCache.get(id)
} else {
blockManagerIdCache.put(id, id)
id
}
blockManagerIdCache.putIfAbsent(id, id)
blockManagerIdCache.get(id)
}
}

View file

@ -332,8 +332,8 @@ object BlockManagerMasterActor {
// Mapping from block id to its status.
private val _blocks = new JHashMap[String, BlockStatus]
logInfo("Registering block manager %s:%d with %s RAM".format(
blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(maxMem)))
logInfo("Registering block manager %s with %s RAM".format(
blockManagerId.hostPort, Utils.memoryBytesToString(maxMem)))
def updateLastSeenMs() {
_lastSeenMs = System.currentTimeMillis()
@ -358,13 +358,13 @@ object BlockManagerMasterActor {
_blocks.put(blockId, BlockStatus(storageLevel, memSize, diskSize))
if (storageLevel.useMemory) {
_remainingMem -= memSize
logInfo("Added %s in memory on %s:%d (size: %s, free: %s)".format(
blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
logInfo("Added %s in memory on %s (size: %s, free: %s)".format(
blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize),
Utils.memoryBytesToString(_remainingMem)))
}
if (storageLevel.useDisk) {
logInfo("Added %s on disk on %s:%d (size: %s)".format(
blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
logInfo("Added %s on disk on %s (size: %s)".format(
blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize)))
}
} else if (_blocks.containsKey(blockId)) {
// If isValid is not true, drop the block.
@ -372,13 +372,13 @@ object BlockManagerMasterActor {
_blocks.remove(blockId)
if (blockStatus.storageLevel.useMemory) {
_remainingMem += blockStatus.memSize
logInfo("Removed %s on %s:%d in memory (size: %s, free: %s)".format(
blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(memSize),
logInfo("Removed %s on %s in memory (size: %s, free: %s)".format(
blockId, blockManagerId.hostPort, Utils.memoryBytesToString(memSize),
Utils.memoryBytesToString(_remainingMem)))
}
if (blockStatus.storageLevel.useDisk) {
logInfo("Removed %s on %s:%d on disk (size: %s)".format(
blockId, blockManagerId.ip, blockManagerId.port, Utils.memoryBytesToString(diskSize)))
logInfo("Removed %s on %s on disk (size: %s)".format(
blockId, blockManagerId.hostPort, Utils.memoryBytesToString(diskSize)))
}
}
}

View file

@ -115,6 +115,7 @@ private[spark] object BlockMessageArray {
val newBuffer = ByteBuffer.allocate(totalSize)
newBuffer.clear()
bufferMessage.buffers.foreach(buffer => {
assert (0 == buffer.position())
newBuffer.put(buffer)
buffer.rewind()
})

View file

@ -20,6 +20,9 @@ import spark.Utils
private class DiskStore(blockManager: BlockManager, rootDirs: String)
extends BlockStore(blockManager) {
private val mapMode = MapMode.READ_ONLY
private var mapOpenMode = "r"
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
@ -35,7 +38,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
getFile(blockId).length()
}
override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
// So that we do not modify the input offsets !
// duplicate does not copy buffer, so inexpensive
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
val file = createFile(blockId)
@ -49,6 +55,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
blockId, Utils.memoryBytesToString(bytes.limit), (finishTime - startTime)))
}
private def getFileBytes(file: File): ByteBuffer = {
val length = file.length()
val channel = new RandomAccessFile(file, mapOpenMode).getChannel()
val buffer = try {
channel.map(mapMode, 0, length)
} finally {
channel.close()
}
buffer
}
override def putValues(
blockId: String,
values: ArrayBuffer[Any],
@ -70,9 +88,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
if (returnValues) {
// Return a byte buffer for the contents of the file
val channel = new RandomAccessFile(file, "r").getChannel()
val buffer = channel.map(MapMode.READ_ONLY, 0, length)
channel.close()
val buffer = getFileBytes(file)
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
@ -81,10 +97,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
override def getBytes(blockId: String): Option[ByteBuffer] = {
val file = getFile(blockId)
val length = file.length().toInt
val channel = new RandomAccessFile(file, "r").getChannel()
val bytes = channel.map(MapMode.READ_ONLY, 0, length)
channel.close()
val bytes = getFileBytes(file)
Some(bytes)
}
@ -96,7 +109,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val file = getFile(blockId)
if (file.exists()) {
file.delete()
true
} else {
false
}
@ -175,11 +187,12 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
private def addShutdownHook() {
localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir) )
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
try {
localDirs.foreach(localDir => Utils.deleteRecursively(localDir))
localDirs.foreach(localDir => if (! Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir))
} catch {
case t: Throwable => logError("Exception while deleting local spark dirs", t)
}

View file

@ -31,7 +31,9 @@ private class MemoryStore(blockManager: BlockManager, maxMemory: Long)
}
}
override def putBytes(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
override def putBytes(blockId: String, _bytes: ByteBuffer, level: StorageLevel) {
// Work on a duplicate - since the original input might be used elsewhere.
val bytes = _bytes.duplicate()
bytes.rewind()
if (level.deserialized) {
val values = blockManager.dataDeserialize(blockId, bytes)

View file

@ -123,11 +123,7 @@ object StorageLevel {
val storageLevelCache = new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
if (storageLevelCache.containsKey(level)) {
storageLevelCache.get(level)
} else {
storageLevelCache.put(level, level)
level
}
storageLevelCache.putIfAbsent(level, level)
storageLevelCache.get(level)
}
}

View file

@ -11,7 +11,7 @@ import cc.spray.{SprayCanRootService, HttpService}
import cc.spray.can.server.HttpServer
import cc.spray.io.pipelines.MessageHandlerDispatch.SingletonHandler
import akka.dispatch.Await
import spark.SparkException
import spark.{Utils, SparkException}
import java.util.concurrent.TimeoutException
/**
@ -31,7 +31,10 @@ private[spark] object AkkaUtils {
val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
val akkaTimeout = System.getProperty("spark.akka.timeout", "20").toInt
val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt
val lifecycleEvents = System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean
val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
// 10 seconds is the default akka timeout, but in a cluster, we need higher by default.
val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt
val akkaConf = ConfigFactory.parseString("""
akka.daemonic = on
akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
@ -45,8 +48,9 @@ private[spark] object AkkaUtils {
akka.remote.netty.execution-pool-size = %d
akka.actor.default-dispatcher.throughput = %d
akka.remote.log-remote-lifecycle-events = %s
akka.remote.netty.write-timeout = %ds
""".format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
if (lifecycleEvents) "on" else "off"))
lifecycleEvents, akkaWriteTimeout))
val actorSystem = ActorSystem(name, akkaConf, getClass.getClassLoader)
@ -60,6 +64,7 @@ private[spark] object AkkaUtils {
/**
* Creates a Spray HTTP server bound to a given IP and port with a given Spray Route object to
* handle requests. Returns the bound port or throws a SparkException on failure.
* TODO: Not changing ip to host here - is it required ?
*/
def startSprayServer(actorSystem: ActorSystem, ip: String, port: Int, route: Route,
name: String = "HttpServer"): ActorRef = {

View file

@ -3,6 +3,7 @@ package spark.util
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConversions
import scala.collection.mutable.Map
import spark.scheduler.MapStatus
/**
* This is a custom implementation of scala.collection.mutable.Map which stores the insertion
@ -42,6 +43,13 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with spark.Logging {
this
}
// Should we return previous value directly or as Option ?
def putIfAbsent(key: A, value: B): Option[B] = {
val prev = internalMap.putIfAbsent(key, (value, currentTime))
if (prev != null) Some(prev._1) else None
}
override def -= (key: A): this.type = {
internalMap.remove(key)
this

View file

@ -2,7 +2,7 @@
@import spark.deploy.master._
@import spark.Utils
@spark.common.html.layout(title = "Spark Master on " + state.host) {
@spark.common.html.layout(title = "Spark Master on " + state.host + ":" + state.port) {
<!-- Cluster Details -->
<div class="row">

View file

@ -1,7 +1,7 @@
@(worker: spark.deploy.WorkerState)
@import spark.Utils
@spark.common.html.layout(title = "Spark Worker on " + worker.host) {
@spark.common.html.layout(title = "Spark Worker on " + worker.host + ":" + worker.port) {
<!-- Worker Details -->
<div class="row">

View file

@ -12,7 +12,7 @@
<tbody>
@for(status <- workersStatusList) {
<tr>
<td>@(status.blockManagerId.ip + ":" + status.blockManagerId.port)</td>
<td>@(status.blockManagerId.host + ":" + status.blockManagerId.port)</td>
<td>
@(Utils.memoryBytesToString(status.memUsed(prefix)))
(@(Utils.memoryBytesToString(status.memRemaining)) Total Available)

View file

@ -153,7 +153,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val blockManager = SparkEnv.get.blockManager
blockManager.master.getLocations(blockId).foreach(id => {
val bytes = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(id.ip, id.port))
GetBlock(blockId), ConnectionManagerId(id.host, id.port))
val deserialized = blockManager.dataDeserialize(blockId, bytes).asInstanceOf[Iterator[Int]].toList
assert(deserialized === (1 to 100).toList)
})

View file

@ -18,6 +18,7 @@ class FileSuite extends FunSuite with LocalSparkContext {
val outputDir = new File(tempDir, "output").getAbsolutePath
val nums = sc.makeRDD(1 to 4)
nums.saveAsTextFile(outputDir)
println("outputDir = " + outputDir)
// Read the plain text file and check it's OK
val outputFile = new File(outputDir, "part-00000")
val content = Source.fromFile(outputFile).mkString

View file

@ -271,7 +271,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
// have the 2nd attempt pass
complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
// we can see both result blocks now
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB"))
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
complete(taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
}