Merge branch 'dev' of github.com:mesos/spark into dev

This commit is contained in:
Matei Zaharia 2012-10-02 17:31:01 -07:00
commit 97cbd699d7
2 changed files with 66 additions and 39 deletions

View file

@ -55,7 +55,7 @@ class SparkContext(
val sparkHome: String,
val jars: Seq[String])
extends Logging {
def this(master: String, frameworkName: String) = this(master, frameworkName, null, Nil)
// Ensure logging is initialized before we spawn any threads
@ -78,30 +78,30 @@ class SparkContext(
true,
isLocal)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
val addedFiles = HashMap[String, Long]()
val addedJars = HashMap[String, Long]()
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[([0-9]+),([0-9]+),([0-9]+)]""".r
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
master match {
case "local" =>
case "local" =>
new LocalScheduler(1, 0, this)
case LOCAL_N_REGEX(threads) =>
case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt, 0, this)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
@ -112,10 +112,21 @@ class SparkContext(
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
scheduler.initialize(backend)
scheduler
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerlave) =>
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
// Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
val sparkMemEnv = System.getenv("SPARK_MEM")
val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512
if (sparkMemEnvInt > memoryPerSlaveInt) {
throw new SparkException(
"Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format(
memoryPerSlaveInt, sparkMemEnvInt))
}
val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(numSlaves.toInt, coresPerSlave.toInt, memoryPerlave.toInt)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
scheduler.initialize(backend)
@ -140,13 +151,13 @@ class SparkContext(
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
// Methods for creating RDDs
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
new ParallelCollection[T](this, seq, numSlices)
}
def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
parallelize(seq, numSlices)
}
@ -187,14 +198,14 @@ class SparkContext(
}
/**
* Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
* Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly.
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
: RDD[(K, V)] = {
hadoopFile(path,
fm.erasure.asInstanceOf[Class[F]],
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]],
minSplits)
@ -215,7 +226,7 @@ class SparkContext(
new Configuration)
}
/**
/**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
@ -231,7 +242,7 @@ class SparkContext(
new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf)
}
/**
/**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
@ -257,14 +268,14 @@ class SparkContext(
sequenceFile(path, keyClass, valueClass, defaultMinSplits)
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
* Version of sequenceFile() for types implicitly convertible to Writables through a
* WritableConverter.
*
* WritableConverters are provided in a somewhat strange way (by an implicit function) to support
* both subclasses of Writable and types for which we define a converter (e.g. Int to
* both subclasses of Writable and types for which we define a converter (e.g. Int to
* IntWritable). The most natural thing would've been to have implicit objects for the
* converters, but then we couldn't have an object for every subclass of Writable (you can't
* have a parameterized singleton object). We use functions instead to create a new converter
* have a parameterized singleton object). We use functions instead to create a new converter
* for the appropriate type. In addition, we pass the converter a ClassManifest of its type to
* allow it to figure out the Writable class to use in the subclass case.
*/
@ -289,7 +300,7 @@ class SparkContext(
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T: ClassManifest](
path: String,
path: String,
minSplits: Int = defaultMinSplits
): RDD[T] = {
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits)
@ -318,7 +329,7 @@ class SparkContext(
/**
* Create an accumulator from a "mutable collection" type.
*
*
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
@ -329,7 +340,7 @@ class SparkContext(
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
// Adds a file dependency to all Tasks executed in the future.
def addFile(path: String) {
val uri = new URI(path)
@ -338,11 +349,11 @@ class SparkContext(
case _ => path
}
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case the task is executed locally
val filename = new File(path.split("/").last)
Utils.fetchFile(path, new File("."))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
@ -350,7 +361,7 @@ class SparkContext(
addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedFiles.clear()
}
// Adds a jar dependency to all Tasks executed in the future.
def addJar(path: String) {
val uri = new URI(path)
@ -366,7 +377,7 @@ class SparkContext(
addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedJars.clear()
}
// Stop the SparkContext
def stop() {
dagScheduler.stop()
@ -400,7 +411,7 @@ class SparkContext(
/**
* Run a function on a given set of partitions in an RDD and return the results. This is the main
* entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
* whether the scheduler can run the computation on the master rather than shipping it out to the
* whether the scheduler can run the computation on the master rather than shipping it out to the
* cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
@ -419,13 +430,13 @@ class SparkContext(
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: Iterator[T] => U,
func: Iterator[T] => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
@ -472,7 +483,7 @@ class SparkContext(
private[spark] def newShuffleId(): Int = {
nextShuffleId.getAndIncrement()
}
private var nextRddId = new AtomicInteger(0)
// Register a new RDD, returning its RDD ID
@ -500,7 +511,7 @@ object SparkContext {
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
@ -521,7 +532,7 @@ object SparkContext {
implicit def longToLongWritable(l: Long) = new LongWritable(l)
implicit def floatToFloatWritable(f: Float) = new FloatWritable(f)
implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d)
implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b)
@ -532,7 +543,7 @@ object SparkContext {
private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = {
def anyToWritable[U <% Writable](u: U): Writable = u
new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]],
arr.map(x => anyToWritable(x)).toArray)
}
@ -576,7 +587,7 @@ object SparkContext {
Nil
}
}
// Find the JAR that contains the class of a particular object
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
}

View file

@ -18,9 +18,9 @@ import storage.StorageLevel
class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
val clusterUrl = "local-cluster[2,1,512]"
@transient var sc: SparkContext = _
after {
if (sc != null) {
sc.stop()
@ -28,6 +28,22 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
}
}
test("local-cluster format") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
assert(sc.parallelize(1 to 2, 2).count == 2)
sc.stop()
sc = new SparkContext("local-cluster[2 , 1 , 512]", "test")
assert(sc.parallelize(1 to 2, 2).count == 2)
sc.stop()
sc = new SparkContext("local-cluster[2, 1, 512]", "test")
assert(sc.parallelize(1 to 2, 2).count == 2)
sc.stop()
sc = new SparkContext("local-cluster[ 2, 1, 512 ]", "test")
assert(sc.parallelize(1 to 2, 2).count == 2)
sc.stop()
sc = null
}
test("simple groupByKey") {
sc = new SparkContext(clusterUrl, "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 5)
@ -38,7 +54,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("accumulators") {
sc = new SparkContext(clusterUrl, "test")
val accum = sc.accumulator(0)