Moved Spark home detection to SparkContext and added a setSparkHome

method for setting it programatically.
This commit is contained in:
Matei Zaharia 2010-10-16 10:02:22 -07:00
parent 47b38fd207
commit a4953c5051
2 changed files with 81 additions and 51 deletions

View file

@ -13,19 +13,12 @@ import scala.collection.JavaConversions._
import mesos.{Scheduler => MScheduler}
import mesos._
// The main Scheduler implementation, which talks to Mesos. Clients are expected
// to first call start(), then submit tasks through the runTasks method.
//
// This implementation is currently a little quick and dirty. The following
// improvements need to be made to it:
// 1) Right now, the scheduler uses a linear scan through the tasks to find a
// local one for a given node. It would be faster to have a separate list of
// pending tasks for each node.
// 2) Presenting a single slave in Job.slaveOffer makes it
// difficult to balance tasks across nodes. It would be better to pass
// all the offers to the Job and have it load-balance.
/**
* The main Scheduler implementation, which runs jobs on Mesos. Clients should
* first call start(), then submit tasks through the runTasks method.
*/
private class MesosScheduler(
master: String, frameworkName: String, execArg: Array[Byte])
sc: SparkContext, master: String, frameworkName: String, execArg: Array[Byte])
extends MScheduler with spark.Scheduler with Logging
{
// Environment variables to pass to our executors
@ -77,21 +70,15 @@ extends MScheduler with spark.Scheduler with Logging
override def getFrameworkName(d: SchedulerDriver): String = frameworkName
// Get Spark's home location from either the spark.home Java property
// or the SPARK_HOME environment variable (in that order of preference).
// If neither of these is set, throws an exception.
def getSparkHome(): String = {
if (System.getProperty("spark.home") != null)
System.getProperty("spark.home")
else if (System.getenv("SPARK_HOME") != null)
System.getenv("SPARK_HOME")
else
throw new SparkException("Spark home is not set; either set the " +
"spark.home system property or the SPARK_HOME environment variable")
}
override def getExecutorInfo(d: SchedulerDriver): ExecutorInfo = {
val execScript = new File(getSparkHome, "spark-executor").getCanonicalPath
val sparkHome = sc.getSparkHome match {
case Some(path) => path
case None =>
throw new SparkException("Spark home is not set; either set the " +
"spark.home system property or the SPARK_HOME environment variable " +
"or call SparkContext.setSparkHome")
}
val execScript = new File(sparkHome, "spark-executor").getCanonicalPath
val params = new JHashMap[String, String]
for (key <- ENV_VARS_TO_SEND_TO_EXECUTORS) {
if (System.getenv(key) != null)

View file

@ -4,17 +4,43 @@ import java.io._
import java.util.UUID
import scala.collection.mutable.ArrayBuffer
import scala.actors.Actor._
class SparkContext(master: String, frameworkName: String) extends Logging {
private[spark] var scheduler: Scheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
master match {
case "local" =>
new LocalScheduler(1)
case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt)
case _ =>
System.loadLibrary("mesos")
new MesosScheduler(this, master, frameworkName, createExecArg())
}
}
private val local = scheduler.isInstanceOf[LocalScheduler]
scheduler.start()
Broadcast.initialize(true)
private var sparkHome: Option[String] = None
// Methods for creating RDDs
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int) =
new ParallelArray[T](this, seq, numSlices)
def parallelize[T: ClassManifest](seq: Seq[T]): ParallelArray[T] =
parallelize(seq, scheduler.numCores)
def textFile(path: String) = new HdfsTextFile(this, path)
// Methods for creating shared variables
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
@ -22,21 +48,7 @@ class SparkContext(master: String, frameworkName: String) extends Logging {
def broadcast[T](value: T) = new CentralizedHDFSBroadcast(value, local)
//def broadcast[T](value: T) = new ChainedStreamingBroadcast(value, local)
def textFile(path: String) = new HdfsTextFile(this, path)
val LOCAL_REGEX = """local\[([0-9]+)\]""".r
private[spark] var scheduler: Scheduler = master match {
case "local" => new LocalScheduler(1)
case LOCAL_REGEX(threads) => new LocalScheduler(threads.toInt)
case _ => { System.loadLibrary("mesos");
new MesosScheduler(master, frameworkName, createExecArg()) }
}
private val local = scheduler.isInstanceOf[LocalScheduler]
scheduler.start()
// Create and serialize an executor argument to use when running on Mesos
private def createExecArg(): Array[Byte] = {
// Our executor arg is an array containing all the spark.* system properties
val props = new ArrayBuffer[(String, String)]
@ -50,10 +62,45 @@ class SparkContext(master: String, frameworkName: String) extends Logging {
return Utils.serialize(props.toArray)
}
// Stop the SparkContext
def stop() {
scheduler.stop()
scheduler = null
}
// Wait for the scheduler to be registered
def waitForRegister() {
scheduler.waitForRegister()
}
// Set the Spark home location
def setSparkHome(path: String) {
if (path == null)
throw new NullPointerException("Path passed to setSparkHome was null")
sparkHome = Some(path)
}
// Get Spark's home location from either a value set through setSparkHome,
// or the spark.home Java property, or the SPARK_HOME environment variable
// (in that order of preference). If neither of these is set, return None.
def getSparkHome(): Option[String] = {
if (sparkHome != None)
sparkHome
else if (System.getProperty("spark.home") != null)
Some(System.getProperty("spark.home"))
else if (System.getenv("SPARK_HOME") != null)
Some(System.getenv("SPARK_HOME"))
else
None
}
// Submit an array of tasks (passed as functions) to the scheduler
def runTasks[T: ClassManifest](tasks: Array[() => T]): Array[T] = {
runTaskObjects(tasks.map(f => new FunctionTask(f)))
}
// Run an array of spark.Task objects
private[spark] def runTaskObjects[T: ClassManifest](tasks: Seq[Task[T]])
: Array[T] = {
logInfo("Running " + tasks.length + " tasks in parallel")
@ -62,15 +109,6 @@ class SparkContext(master: String, frameworkName: String) extends Logging {
logInfo("Tasks finished in " + (System.nanoTime - start) / 1e9 + " s")
return result
}
def stop() {
scheduler.stop()
scheduler = null
}
def waitForRegister() {
scheduler.waitForRegister()
}
// Clean a closure to make it ready to serialized and send to tasks
// (removes unreferenced variables in $outer's, updates REPL variables)
@ -80,6 +118,11 @@ class SparkContext(master: String, frameworkName: String) extends Logging {
}
}
/**
* The SparkContext object contains a number of implicit conversions and
* parameters for use with various Spark features.
*/
object SparkContext {
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2