Merge branch 'dev' of github.com:radlab/spark into dev

This commit is contained in:
Tathagata Das 2012-10-25 13:03:18 -07:00
commit 650d717544
281 changed files with 16874 additions and 2281 deletions

3
.gitignore vendored
View file

@ -13,6 +13,8 @@ third_party/libmesos.dylib
conf/java-opts
conf/spark-env.sh
conf/log4j.properties
docs/_site
docs/api
target/
reports/
.project
@ -28,3 +30,4 @@ project/plugins/lib_managed/
project/plugins/src_managed/
logs/
log/
spark-tests.log

View file

@ -6,23 +6,21 @@ Lightning-Fast Cluster Computing - <http://www.spark-project.org/>
## Online Documentation
You can find the latest Spark documentation, including a programming
guide, on the project wiki at <http://github.com/mesos/spark/wiki>. This
file only contains basic setup instructions.
guide, on the project webpage at <http://spark-project.org/documentation.html>.
This README file only contains basic setup instructions.
## Building
Spark requires Scala 2.9.1. This version has been tested with 2.9.1.final.
Spark requires Scala 2.9.2. The project is built using Simple Build Tool (SBT),
which is packaged with it. To build Spark and its example programs, run:
The project is built using Simple Build Tool (SBT), which is packaged with it.
To build Spark and its example programs, run:
sbt/sbt package
sbt/sbt compile
To run Spark, you will need to have Scala's bin in your `PATH`, or you
will need to set the `SCALA_HOME` environment variable to point to where
To run Spark, you will need to have Scala's bin directory in your `PATH`, or
you will need to set the `SCALA_HOME` environment variable to point to where
you've installed Scala. Scala must be accessible through one of these
methods on Mesos slave nodes as well as on the master.
methods on your cluster's worker nodes as well as its master.
To run one of the examples, use `./run <class> <params>`. For example:
@ -32,12 +30,12 @@ will run the Logistic Regression example locally on 2 CPUs.
Each of the example programs prints usage help if no params are given.
All of the Spark samples take a `<host>` parameter that is the Mesos master
to connect to. This can be a Mesos URL, or "local" to run locally with one
thread, or "local[N]" to run locally with N threads.
All of the Spark samples take a `<host>` parameter that is the cluster URL
to connect to. This can be a mesos:// or spark:// URL, or "local" to run
locally with one thread, or "local[N]" to run locally with N threads.
## A Note About Hadoop
## A Note About Hadoop Versions
Spark uses the Hadoop core library to talk to HDFS and other Hadoop-supported
storage systems. Because the HDFS API has changed in different versions of
@ -48,27 +46,23 @@ of `project/SparkBuild.scala`, then rebuilding Spark.
## Configuration
Spark can be configured through two files: `conf/java-opts` and
`conf/spark-env.sh`.
Please refer to the "Configuration" guide in the online documentation for a
full overview on how to configure Spark. At the minimum, you will need to
create a `conf/spark-env.sh` script (copy `conf/spark-env.sh.template`) and
set the following two variables:
In `java-opts`, you can add flags to be passed to the JVM when running Spark.
- `SCALA_HOME`: Location where Scala is installed.
In `spark-env.sh`, you can set any environment variables you wish to be available
when running Spark programs, such as `PATH`, `SCALA_HOME`, etc. There are also
several Spark-specific variables you can set:
- `MESOS_NATIVE_LIBRARY`: Your Mesos library (only needed if you want to run
on Mesos). For example, this might be `/usr/local/lib/libmesos.so` on Linux.
- `SPARK_CLASSPATH`: Extra entries to be added to the classpath, separated by ":".
- `SPARK_MEM`: Memory for Spark to use, in the format used by java's `-Xmx`
option (for example, `-Xmx200m` means 200 MB, `-Xmx1g` means 1 GB, etc).
## Contributing to Spark
- `SPARK_LIBRARY_PATH`: Extra entries to add to `java.library.path` for locating
shared libraries.
- `SPARK_JAVA_OPTS`: Extra options to pass to JVM.
- `MESOS_NATIVE_LIBRARY`: Your Mesos library, if you want to run on a Mesos
cluster. For example, this might be `/usr/local/lib/libmesos.so` on Linux.
Note that `spark-env.sh` must be a shell script (it must be executable and start
with a `#!` header to specify the shell to use).
Contributions via GitHub pull requests are gladly accepted from their original
author. Along with any pull requests, please state that the contribution is
your original work and that you license the work to the project under the
project's open source license. Whether or not you state this explicitly, by
submitting any copyrighted material via pull request, email, or other means
you agree to license the material under the project's open source license and
warrant that you have the legal authority to do so.

View file

@ -8,6 +8,11 @@ import spark.bagel.Bagel._
import scala.xml.{XML,NodeSeq}
/**
* Run PageRank on XML Wikipedia dumps from http://wiki.freebase.com/wiki/WEX. Uses the "articles"
* files from there, which contains one line per wiki article in a tab-separated format
* (http://wiki.freebase.com/wiki/WEX/Documentation#articles).
*/
object WikipediaPageRank {
def main(args: Array[String]) {
if (args.length < 5) {

View file

@ -1,6 +1,7 @@
package spark.bagel.examples
import spark._
import serializer.{DeserializationStream, SerializationStream, SerializerInstance}
import spark.SparkContext._
import spark.bagel._
@ -33,10 +34,10 @@ object WikipediaPageRankStandalone {
val partitioner = new HashPartitioner(sc.defaultParallelism)
val links =
if (usePartitioner)
input.map(parseArticle _).partitionBy(partitioner).cache
input.map(parseArticle _).partitionBy(partitioner).cache()
else
input.map(parseArticle _).cache
val n = links.count
input.map(parseArticle _).cache()
val n = links.count()
val defaultRank = 1.0 / n
val a = 0.15
@ -51,7 +52,7 @@ object WikipediaPageRankStandalone {
(ranks
.filter { case (id, rank) => rank >= threshold }
.map { case (id, rank) => "%s\t%s\n".format(id, rank) }
.collect.mkString)
.collect().mkString)
println(top)
val time = (System.currentTimeMillis - startTime) / 1000.0
@ -113,7 +114,7 @@ object WikipediaPageRankStandalone {
}
}
class WPRSerializer extends spark.Serializer {
class WPRSerializer extends spark.serializer.Serializer {
def newInstance(): SerializerInstance = new WPRSerializerInstance()
}
@ -142,7 +143,7 @@ class WPRSerializerInstance extends SerializerInstance {
class WPRSerializationStream(os: OutputStream) extends SerializationStream {
val dos = new DataOutputStream(os)
def writeObject[T](t: T): Unit = t match {
def writeObject[T](t: T): SerializationStream = t match {
case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match {
case links: Array[String] => {
dos.writeInt(0) // links
@ -151,17 +152,20 @@ class WPRSerializationStream(os: OutputStream) extends SerializationStream {
for (link <- links) {
dos.writeUTF(link)
}
this
}
case rank: Double => {
dos.writeInt(1) // rank
dos.writeUTF(id)
dos.writeDouble(rank)
this
}
}
case (id: String, rank: Double) => {
dos.writeInt(2) // rank without wrapper
dos.writeUTF(id)
dos.writeDouble(rank)
this
}
}

View file

@ -1,8 +1,10 @@
# Set everything to be logged to the console
log4j.rootCategory=WARN, console
log4j.appender.console=org.apache.log4j.ConsoleAppender
log4j.appender.console.layout=org.apache.log4j.PatternLayout
log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
log4j.appender.file.file=spark-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %p %c{1}: %m%n
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN

View file

@ -22,6 +22,8 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter {
sc.stop()
sc = null
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.master.port")
}
test("halting by voting") {

View file

@ -1,5 +1,8 @@
#!/usr/bin/env bash
# This Spark deploy script is a modified version of the Apache Hadoop deploy
# script, available under the Apache 2 license:
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.

View file

@ -1,5 +1,8 @@
#!/usr/bin/env bash
# This Spark deploy script is a modified version of the Apache Hadoop deploy
# script, available under the Apache 2 license:
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.

View file

@ -14,7 +14,21 @@ if [ "$SPARK_MASTER_PORT" = "" ]; then
SPARK_MASTER_PORT=7077
fi
hostname=`hostname`
ip=`host "$hostname" | cut -d " " -f 4`
if [ "$SPARK_MASTER_IP" = "" ]; then
hostname=`hostname`
hostouput=`host "$hostname"`
if [[ "$hostouput" == *"not found"* ]]; then
echo $hostouput
echo "Fail to identiy the IP for the master."
echo "Set SPARK_MASTER_IP explicitly in configuration instead."
exit 1
fi
ip=`host "$hostname" | cut -d " " -f 4`
else
ip=$SPARK_MASTER_IP
fi
echo "Master IP: $ip"
"$bin"/spark-daemons.sh start spark.deploy.worker.Worker spark://$ip:$SPARK_MASTER_PORT

View file

@ -1,21 +1,24 @@
#!/usr/bin/env bash
# Set Spark environment variables for your site in this file. Some useful
# variables to set are:
# This file contains environment variables required to run Spark. Copy it as
# spark-env.sh and edit that to configure Spark for your site. At a minimum,
# the following two variables should be set:
# - MESOS_NATIVE_LIBRARY, to point to your Mesos native library (libmesos.so)
# - SCALA_HOME, to point to your Scala installation
#
# If using the standalone deploy mode, you can also set variables for it:
# - SPARK_MASTER_IP, to bind the master to a different IP address
# - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports
# - SPARK_WORKER_CORES, to set the number of cores to use on this machine
# - SPARK_WORKER_MEMORY, to set how much memory to use (e.g. 1000m, 2g)
# - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT
#
# Finally, Spark also relies on the following variables, but these can be set
# on just the *master* (i.e. in your driver program), and will automatically
# be propagated to workers:
# - SPARK_MEM, to change the amount of memory used per node (this should
# be in the same format as the JVM's -Xmx option, e.g. 300m or 1g)
# - SPARK_CLASSPATH, to add elements to Spark's classpath
# - SPARK_JAVA_OPTS, to add JVM options
# - SPARK_MEM, to change the amount of memory used per node (this should
# be in the same format as the JVM's -Xmx option, e.g. 300m or 1g).
# - SPARK_LIBRARY_PATH, to add extra search paths for native libraries.
# Settings used by the scripts in the bin/ directory, apply to standalone mode only.
# Note that the same worker settings apply to all of the workers.
# - SPARK_MASTER_IP, to bind the master to a different ip address, for example a public one (Default: local ip address)
# - SPARK_MASTER_PORT, to start the spark master on a different port (Default: 7077)
# - SPARK_MASTER_WEBUI_PORT, to specify a different port for the Master WebUI (Default: 8080)
# - SPARK_WORKER_PORT, to start the spark worker on a specific port (Default: random)
# - SPARK_WORKER_CORES, to specify the number of cores to use (Default: all available cores)
# - SPARK_WORKER_MEMORY, to specify how much memory to use, e.g. 1000M, 2G (Default: MAX(Available - 1024MB, 512MB))
# - SPARK_WORKER_WEBUI_PORT, to specify a different port for the Worker WebUI (Default: 8081)

Binary file not shown.

View file

@ -0,0 +1,7 @@
package org.apache.hadoop.mapred
trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
}

View file

@ -0,0 +1,9 @@
package org.apache.hadoop.mapreduce
import org.apache.hadoop.conf.Configuration
trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
}

View file

@ -0,0 +1,7 @@
package org.apache.hadoop.mapred
trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
}

View file

@ -0,0 +1,10 @@
package org.apache.hadoop.mapreduce
import org.apache.hadoop.conf.Configuration
import task.{TaskAttemptContextImpl, JobContextImpl}
trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
}

View file

@ -3,18 +3,20 @@ package spark
import java.io._
import scala.collection.mutable.Map
import scala.collection.generic.Growable
/**
* A datatype that can be accumulated, i.e. has an commutative and associative +.
* A datatype that can be accumulated, i.e. has an commutative and associative "add" operation,
* but where the result type, `R`, may be different from the element type being added, `T`.
*
* You must define how to add data, and how to merge two of these together. For some datatypes, these might be
* the same operation (eg., a counter). In that case, you might want to use [[spark.AccumulatorParam]]. They won't
* always be the same, though -- eg., imagine you are accumulating a set. You will add items to the set, and you
* will union two sets together.
* You must define how to add data, and how to merge two of these together. For some datatypes,
* such as a counter, these might be the same operation. In that case, you can use the simpler
* [[spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are
* accumulating a set. You will add items to the set, and you will union two sets together.
*
* @param initialValue initial value of accumulator
* @param param helper object defining how to add elements of type `T`
* @tparam R the full accumulated data
* @param param helper object defining how to add elements of type `R` and `T`
* @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
class Accumulable[R, T] (
@ -43,13 +45,29 @@ class Accumulable[R, T] (
* @param term the other Accumulable that will get merged with this
*/
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
/**
* Access the accumulator's current value; only allowed on master.
*/
def value = {
if (!deserialized) value_
else throw new UnsupportedOperationException("Can't read accumulator value in task")
}
private[spark] def localValue = value_
/**
* Get the current value of this accumulator from within a task.
*
* This is NOT the global value of the accumulator. To get the global value after a
* completed operation on the dataset, call `value`.
*
* The typical use of this method is to directly mutate the local value, eg., to add
* an element to a Set.
*/
def localValue = value_
/**
* Set the accumulator's value; only allowed on master.
*/
def value_= (r: R) {
if (!deserialized) value_ = r
else throw new UnsupportedOperationException("Can't assign accumulator value in task")
@ -67,31 +85,64 @@ class Accumulable[R, T] (
}
/**
* Helper object defining how to accumulate values of a particular type.
* Helper object defining how to accumulate values of a particular type. An implicit
* AccumulableParam needs to be available when you create Accumulables of a specific type.
*
* @tparam R the full accumulated data
* @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
trait AccumulableParam[R, T] extends Serializable {
/**
* Add additional data to the accumulator value.
* Add additional data to the accumulator value. Is allowed to modify and return `r`
* for efficiency (to avoid allocating objects).
*
* @param r the current value of the accumulator
* @param t the data to be added to the accumulator
* @return the new value of the accumulator
*/
def addAccumulator(r: R, t: T) : R
def addAccumulator(r: R, t: T): R
/**
* Merge two accumulated values together
* Merge two accumulated values together. Is allowed to modify and return the first value
* for efficiency (to avoid allocating objects).
*
* @param r1 one set of accumulated data
* @param r2 another set of accumulated data
* @return both data sets merged together
*/
def addInPlace(r1: R, r2: R): R
/**
* Return the "zero" (identity) value for an accumulator type, given its initial value. For
* example, if R was a vector of N dimensions, this would return a vector of N zeroes.
*/
def zero(initialValue: R): R
}
private[spark]
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
extends AccumulableParam[R,T] {
def addAccumulator(growable: R, elem: T): R = {
growable += elem
growable
}
def addInPlace(t1: R, t2: R): R = {
t1 ++= t2
t1
}
def zero(initialValue: R): R = {
// We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
// Instead we'll serialize it to a buffer and load it back.
val ser = (new spark.JavaSerializer).newInstance()
val copy = ser.deserialize[R](ser.serialize(initialValue))
copy.clear() // In case it contained stuff
copy
}
}
/**
* A simpler value of [[spark.Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged.
@ -100,17 +151,18 @@ trait AccumulableParam[R, T] extends Serializable {
* @param param helper object defining how to add elements of type `T`
* @tparam T result type
*/
class Accumulator[T](
@transient initialValue: T,
param: AccumulatorParam[T]) extends Accumulable[T,T](initialValue, param)
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T])
extends Accumulable[T,T](initialValue, param)
/**
* A simpler version of [[spark.AccumulableParam]] where the only datatype you can add in is the same type
* as the accumulated value
* as the accumulated value. An implicit AccumulatorParam object needs to be available when you create
* Accumulators of a specific type.
*
* @tparam T type of value to accumulate
*/
trait AccumulatorParam[T] extends AccumulableParam[T, T] {
def addAccumulator(t1: T, t2: T) : T = {
def addAccumulator(t1: T, t2: T): T = {
addInPlace(t1, t2)
}
}

View file

@ -1,7 +1,44 @@
package spark
class Aggregator[K, V, C] (
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions._
/** A set of functions used to aggregate data.
*
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
*/
case class Aggregator[K, V, C] (
val createCombiner: V => C,
val mergeValue: (C, V) => C,
val mergeCombiners: (C, C) => C)
extends Serializable
val mergeCombiners: (C, C) => C) {
def combineValuesByKey(iter: Iterator[(K, V)]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
for ((k, v) <- iter) {
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, createCombiner(v))
} else {
combiners.put(k, mergeValue(oldC, v))
}
}
combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
for ((k, c) <- iter) {
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, mergeCombiners(oldC, c))
}
}
combiners.iterator
}
}

View file

@ -1,52 +1,43 @@
package spark
import java.io.EOFException
import java.net.URL
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import spark.storage.BlockException
import spark.storage.BlockManagerId
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit) {
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
override def fetch[K, V](shuffleId: Int, reduceId: Int) = {
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
val addresses = SparkEnv.get.mapOutputTracker.getServerAddresses(shuffleId)
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[Int]]
for ((address, index) <- addresses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += index
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
val blocksByAddress: Seq[(BlockManagerId, Seq[String])] = splitsByAddress.toSeq.map {
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(i => "shuffleid_%d_%d_%d".format(shuffleId, i, reduceId)))
(address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
}
for ((blockId, blockOption) <- blockManager.getMultiple(blocksByAddress)) {
def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[(K, V)] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
case Some(block) => {
val values = block
for(value <- values) {
val v = value.asInstanceOf[(K, V)]
func(v._1, v._2)
}
block.asInstanceOf[Iterator[(K, V)]]
}
case None => {
val regex = "shuffleid_([0-9]*)_([0-9]*)_([0-9]*)".r
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
case regex(shufId, mapId, reduceId) =>
val addr = addresses(mapId.toInt)
throw new FetchFailedException(addr, shufId.toInt, mapId.toInt, reduceId.toInt, null)
case regex(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block")
@ -54,8 +45,6 @@ class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
}
}
}
logDebug("Fetching and merging outputs of shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
blockManager.getMultiple(blocksByAddress).flatMap(unpackBlock)
}
}

View file

@ -9,7 +9,7 @@ import java.util.LinkedHashMap
* some cache entries have pointers to a shared object. Nonetheless, this Cache should work well
* when most of the space is used by arrays of primitives or of simple classes.
*/
class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
private[spark] class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
logInfo("BoundedMemoryCache.maxBytes = " + maxBytes)
def this() {
@ -104,9 +104,9 @@ class BoundedMemoryCache(maxBytes: Long) extends Cache with Logging {
}
// An entry in our map; stores a cached object and its size in bytes
case class Entry(value: Any, size: Long)
private[spark] case class Entry(value: Any, size: Long)
object BoundedMemoryCache {
private[spark] object BoundedMemoryCache {
/**
* Get maximum cache capacity from system configuration
*/

View file

@ -2,9 +2,9 @@ package spark
import java.util.concurrent.atomic.AtomicInteger
sealed trait CachePutResponse
case class CachePutSuccess(size: Long) extends CachePutResponse
case class CachePutFailure() extends CachePutResponse
private[spark] sealed trait CachePutResponse
private[spark] case class CachePutSuccess(size: Long) extends CachePutResponse
private[spark] case class CachePutFailure() extends CachePutResponse
/**
* An interface for caches in Spark, to allow for multiple implementations. Caches are used to store
@ -22,7 +22,7 @@ case class CachePutFailure() extends CachePutResponse
* This abstract class handles the creation of key spaces, so that subclasses need only deal with
* keys that are unique across modules.
*/
abstract class Cache {
private[spark] abstract class Cache {
private val nextKeySpaceId = new AtomicInteger(0)
private def newKeySpaceId() = nextKeySpaceId.getAndIncrement()
@ -52,7 +52,7 @@ abstract class Cache {
/**
* A key namespace in a Cache.
*/
class KeySpace(cache: Cache, val keySpaceId: Int) {
private[spark] class KeySpace(cache: Cache, val keySpaceId: Int) {
def get(datasetId: Any, partition: Int): Any =
cache.get((keySpaceId, datasetId), partition)

View file

@ -15,19 +15,20 @@ import scala.collection.mutable.HashSet
import spark.storage.BlockManager
import spark.storage.StorageLevel
sealed trait CacheTrackerMessage
case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
case class MemoryCacheLost(host: String) extends CacheTrackerMessage
case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
case object GetCacheStatus extends CacheTrackerMessage
case object GetCacheLocations extends CacheTrackerMessage
case object StopCacheTracker extends CacheTrackerMessage
private[spark] sealed trait CacheTrackerMessage
class CacheTrackerActor extends Actor with Logging {
private[spark] case class AddedToCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class DroppedFromCache(rddId: Int, partition: Int, host: String, size: Long = 0L)
extends CacheTrackerMessage
private[spark] case class MemoryCacheLost(host: String) extends CacheTrackerMessage
private[spark] case class RegisterRDD(rddId: Int, numPartitions: Int) extends CacheTrackerMessage
private[spark] case class SlaveCacheStarted(host: String, size: Long) extends CacheTrackerMessage
private[spark] case object GetCacheStatus extends CacheTrackerMessage
private[spark] case object GetCacheLocations extends CacheTrackerMessage
private[spark] case object StopCacheTracker extends CacheTrackerMessage
private[spark] class CacheTrackerActor extends Actor with Logging {
// TODO: Should probably store (String, CacheType) tuples
private val locs = new HashMap[Int, Array[List[String]]]
@ -43,8 +44,6 @@ class CacheTrackerActor extends Actor with Logging {
def receive = {
case SlaveCacheStarted(host: String, size: Long) =>
logInfo("Started slave cache (size %s) on %s".format(
Utils.memoryBytesToString(size), host))
slaveCapacity.put(host, size)
slaveUsage.put(host, 0)
sender ! true
@ -56,22 +55,12 @@ class CacheTrackerActor extends Actor with Logging {
case AddedToCache(rddId, partition, host, size) =>
slaveUsage.put(host, getCacheUsage(host) + size)
logInfo("Cache entry added: (%s, %s) on %s (size added: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
Utils.memoryBytesToString(getCacheAvailable(host))))
locs(rddId)(partition) = host :: locs(rddId)(partition)
sender ! true
case DroppedFromCache(rddId, partition, host, size) =>
logInfo("Cache entry removed: (%s, %s) on %s (size dropped: %s, available: %s)".format(
rddId, partition, host, Utils.memoryBytesToString(size),
Utils.memoryBytesToString(getCacheAvailable(host))))
slaveUsage.put(host, getCacheUsage(host) - size)
// Do a sanity check to make sure usage is greater than 0.
val usage = getCacheUsage(host)
if (usage < 0) {
logError("Cache usage on %s is negative (%d)".format(host, usage))
}
locs(rddId)(partition) = locs(rddId)(partition).filterNot(_ == host)
sender ! true
@ -101,7 +90,7 @@ class CacheTrackerActor extends Actor with Logging {
}
}
class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
private[spark] class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: BlockManager)
extends Logging {
// Tracker actor on the master, or remote reference to it on workers
@ -151,7 +140,6 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
logInfo("Registering RDD ID " + rddId + " with cache")
registeredRddIds += rddId
communicate(RegisterRDD(rddId, numPartitions))
logInfo(RegisterRDD(rddId, numPartitions) + " successful")
}
}
}
@ -169,9 +157,8 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
}
// For BlockManager.scala only
def notifyTheCacheTrackerFromBlockManager(t: AddedToCache) {
def notifyFromBlockManager(t: AddedToCache) {
communicate(t)
logInfo("notifyTheCacheTrackerFromBlockManager successful")
}
// Get a snapshot of the currently known locations
@ -181,7 +168,7 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
// Gets or computes an RDD split
def getOrCompute[T](rdd: RDD[T], split: Split, storageLevel: StorageLevel): Iterator[T] = {
val key = "rdd:%d:%d".format(rdd.id, split.index)
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
@ -221,23 +208,19 @@ class CacheTracker(actorSystem: ActorSystem, isMaster: Boolean, blockManager: Bl
// TODO: fetch any remote copy of the split that may be available
// TODO: also register a listener for when it unloads
logInfo("Computing partition " + split)
val elements = new ArrayBuffer[Any]
elements ++= rdd.compute(split)
try {
// BlockManager will iterate over results from compute to create RDD
blockManager.put(key, rdd.compute(split), storageLevel, false)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
//future.apply() // Wait for the reply from the cache tracker
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
case None =>
logWarning("loading partition failed after computing it " + key)
return null
}
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
return elements.iterator.asInstanceOf[Iterator[T]]
}
}

View file

@ -9,7 +9,7 @@ import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._
object ClosureCleaner extends Logging {
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = {
new ClassReader(cls.getResourceAsStream(
@ -154,7 +154,7 @@ object ClosureCleaner extends Logging {
}
}
class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
return new EmptyVisitor {
@ -180,7 +180,7 @@ class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor
}
}
class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String,

View file

@ -1,22 +1,51 @@
package spark
abstract class Dependency[T](val rdd: RDD[T], val isShuffle: Boolean) extends Serializable
/**
* Base class for dependencies.
*/
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd, false) {
/**
* Base class for dependencies where each partition of the parent RDD is used by at most one
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
*/
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
/**
* Get the parent partitions for a child partition.
* @param outputPartition a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
def getParents(outputPartition: Int): Seq[Int]
}
class ShuffleDependency[K, V, C](
val shuffleId: Int,
/**
* Represents a dependency on the output of a shuffle stage.
* @param shuffleId the shuffle id
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[(K, V)],
val aggregator: Aggregator[K, V, C],
val partitioner: Partitioner)
extends Dependency(rdd, true)
extends Dependency(rdd) {
val shuffleId: Int = rdd.context.newShuffleId()
}
/**
* Represents a one-to-one dependency between partitions of the parent and child RDDs.
*/
class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = List(partitionId)
}
/**
* Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
* @param rdd the parent RDD
* @param inStart the start of the range in the parent RDD
* @param outStart the start of the range in the child RDD
* @param length the length of the range
*/
class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
extends NarrowDependency[T](rdd) {

View file

@ -4,33 +4,49 @@ import spark.partial.BoundedDouble
import spark.partial.MeanEvaluator
import spark.partial.PartialResult
import spark.partial.SumEvaluator
import spark.util.StatCounter
/**
* Extra functions available on RDDs of Doubles through an implicit conversion.
* Import `spark.SparkContext._` at the top of your program to use these functions.
*/
class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
/** Add up the elements in this RDD. */
def sum(): Double = {
self.reduce(_ + _)
}
/**
* Return a [[spark.util.StatCounter]] object that captures the mean, variance and count
* of the RDD's elements in one operation.
*/
def stats(): StatCounter = {
self.mapPartitions(nums => Iterator(StatCounter(nums))).reduce((a, b) => a.merge(b))
}
/** Compute the mean of this RDD's elements. */
def mean(): Double = stats().mean
/** Compute the variance of this RDD's elements. */
def variance(): Double = stats().variance
/** Compute the standard deviation of this RDD's elements. */
def stdev(): Double = stats().stdev
/**
* Compute the sample standard deviation of this RDD's elements (which corrects for bias in
* estimating the standard deviation by dividing by N-1 instead of N).
*/
def sampleStdev(): Double = stats().stdev
/** (Experimental) Approximate operation to return the mean within a timeout. */
def meanApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
val evaluator = new MeanEvaluator(self.splits.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val processPartition = (ctx: TaskContext, ns: Iterator[Double]) => StatCounter(ns)
val evaluator = new SumEvaluator(self.splits.size, confidence)

View file

@ -2,7 +2,7 @@ package spark
import spark.storage.BlockManagerId
class FetchFailedException(
private[spark] class FetchFailedException(
val bmAddress: BlockManagerId,
val shuffleId: Int,
val mapId: Int,

View file

@ -16,11 +16,14 @@ import spark.Logging
import spark.SerializableWritable
/**
* Saves an RDD using a Hadoop OutputFormat as specified by a JobConf. The JobConf should also
* contain an output key class, an output value class, a filename to write to, etc exactly like in
* a Hadoop job.
* Internal helper class that saves an RDD using a Hadoop OutputFormat. This is only public
* because we need to access this class from the `spark` package to use some package-private Hadoop
* functions, but this class should not be used directly by users.
*
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializable {
class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable {
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@ -42,7 +45,7 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializabl
setConfParams()
val jCtxt = getJobContext()
getOutputCommitter().setupJob(jCtxt)
getOutputCommitter().setupJob(jCtxt)
}
@ -126,14 +129,14 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with Serializabl
private def getJobContext(): JobContext = {
if (jobContext == null) {
jobContext = new JobContext(conf.value, jID.value)
jobContext = newJobContext(conf.value, jID.value)
}
return jobContext
}
private def getTaskContext(): TaskAttemptContext = {
if (taskContext == null) {
taskContext = new TaskAttemptContext(conf.value, taID.value)
taskContext = newTaskAttemptContext(conf.value, taID.value)
}
return taskContext
}

View file

@ -0,0 +1,47 @@
package spark
import java.io.{File, PrintWriter}
import java.net.URL
import scala.collection.mutable.HashMap
import org.apache.hadoop.fs.FileUtil
private[spark] class HttpFileServer extends Logging {
var baseDir : File = null
var fileDir : File = null
var jarDir : File = null
var httpServer : HttpServer = null
var serverUri : String = null
def initialize() {
baseDir = Utils.createTempDir()
fileDir = new File(baseDir, "files")
jarDir = new File(baseDir, "jars")
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
httpServer = new HttpServer(baseDir)
httpServer.start()
serverUri = httpServer.uri
}
def stop() {
httpServer.stop()
}
def addFile(file: File) : String = {
addFileToDir(file, fileDir)
return serverUri + "/files/" + file.getName
}
def addJar(file: File) : String = {
addFileToDir(file, jarDir)
return serverUri + "/jars/" + file.getName
}
def addFileToDir(file: File, dir: File) : String = {
Utils.copyFile(file, new File(dir, file.getName))
return dir + "/" + file.getName
}
}

View file

@ -12,14 +12,14 @@ import org.eclipse.jetty.util.thread.QueuedThreadPool
/**
* Exception type thrown by HttpServer when it is in the wrong state for an operation.
*/
class ServerStateException(message: String) extends Exception(message)
private[spark] class ServerStateException(message: String) extends Exception(message)
/**
* An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext
* as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server.
*/
class HttpServer(resourceBase: File) extends Logging {
private[spark] class HttpServer(resourceBase: File) extends Logging {
private var server: Server = null
private var port: Int = -1

View file

@ -3,16 +3,17 @@ package spark
import java.io._
import java.nio.ByteBuffer
import serializer.{Serializer, SerializerInstance, DeserializationStream, SerializationStream}
import spark.util.ByteBufferInputStream
class JavaSerializationStream(out: OutputStream) extends SerializationStream {
private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream {
val objOut = new ObjectOutputStream(out)
def writeObject[T](t: T) { objOut.writeObject(t) }
def writeObject[T](t: T): SerializationStream = { objOut.writeObject(t); this }
def flush() { objOut.flush() }
def close() { objOut.close() }
}
class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
private[spark] class JavaDeserializationStream(in: InputStream, loader: ClassLoader)
extends DeserializationStream {
val objIn = new ObjectInputStream(in) {
override def resolveClass(desc: ObjectStreamClass) =
@ -23,7 +24,7 @@ extends DeserializationStream {
def close() { objIn.close() }
}
class JavaSerializerInstance extends SerializerInstance {
private[spark] class JavaSerializerInstance extends SerializerInstance {
def serialize[T](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
@ -57,6 +58,9 @@ class JavaSerializerInstance extends SerializerInstance {
}
}
/**
* A Spark serializer that uses Java's built-in serialization.
*/
class JavaSerializer extends Serializer {
def newInstance(): SerializerInstance = new JavaSerializerInstance
}

View file

@ -13,6 +13,7 @@ import com.esotericsoftware.kryo.serialize.ClassSerializer
import com.esotericsoftware.kryo.serialize.SerializableSerializer
import de.javakaffee.kryoserializers.KryoReflectionFactorySupport
import serializer.{SerializerInstance, DeserializationStream, SerializationStream}
import spark.broadcast._
import spark.storage._
@ -20,7 +21,7 @@ import spark.storage._
* Zig-zag encoder used to write object sizes to serialization streams.
* Based on Kryo's integer encoder.
*/
object ZigZag {
private[spark] object ZigZag {
def writeInt(n: Int, out: OutputStream) {
var value = n
if ((value & ~0x7F) == 0) {
@ -68,22 +69,25 @@ object ZigZag {
}
}
private[spark]
class KryoSerializationStream(kryo: Kryo, threadBuffer: ByteBuffer, out: OutputStream)
extends SerializationStream {
val channel = Channels.newChannel(out)
def writeObject[T](t: T) {
def writeObject[T](t: T): SerializationStream = {
kryo.writeClassAndObject(threadBuffer, t)
ZigZag.writeInt(threadBuffer.position(), out)
threadBuffer.flip()
channel.write(threadBuffer)
threadBuffer.clear()
this
}
def flush() { out.flush() }
def close() { out.close() }
}
private[spark]
class KryoDeserializationStream(objectBuffer: ObjectBuffer, in: InputStream)
extends DeserializationStream {
def readObject[T](): T = {
@ -94,7 +98,7 @@ extends DeserializationStream {
def close() { in.close() }
}
class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
val kryo = ks.kryo
val threadBuffer = ks.threadBuffer.get()
val objectBuffer = ks.objectBuffer.get()
@ -155,13 +159,21 @@ class KryoSerializerInstance(ks: KryoSerializer) extends SerializerInstance {
}
}
// Used by clients to register their own classes
/**
* Interface implemented by clients to register their classes with Kryo when using Kryo
* serialization.
*/
trait KryoRegistrator {
def registerClasses(kryo: Kryo): Unit
}
class KryoSerializer extends Serializer with Logging {
val kryo = createKryo()
/**
* A Spark serializer that uses the [[http://code.google.com/p/kryo/wiki/V1Documentation Kryo 1.x library]].
*/
class KryoSerializer extends spark.serializer.Serializer with Logging {
// Make this lazy so that it only gets called once we receive our first task on each executor,
// so we can pull out any custom Kryo registrator from the user's JARs.
lazy val kryo = createKryo()
val bufferSize = System.getProperty("spark.kryoserializer.buffer.mb", "32").toInt * 1024 * 1024
@ -192,8 +204,8 @@ class KryoSerializer extends Serializer with Logging {
(1, 1, 1), (1, 1, 1, 1), (1, 1, 1, 1, 1),
None,
ByteBuffer.allocate(1),
StorageLevel.MEMORY_ONLY_DESER,
PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY_DESER),
StorageLevel.MEMORY_ONLY,
PutBlock("1", ByteBuffer.allocate(1), StorageLevel.MEMORY_ONLY),
GotBlock("1", ByteBuffer.allocate(1)),
GetBlock("1")
)
@ -256,7 +268,8 @@ class KryoSerializer extends Serializer with Logging {
val regCls = System.getProperty("spark.kryo.registrator")
if (regCls != null) {
logInfo("Running user registrator: " + regCls)
val reg = Class.forName(regCls).newInstance().asInstanceOf[KryoRegistrator]
val classLoader = Thread.currentThread.getContextClassLoader
val reg = Class.forName(regCls, true, classLoader).newInstance().asInstanceOf[KryoRegistrator]
reg.registerClasses(kryo)
}
kryo

View file

@ -15,7 +15,7 @@ trait Logging {
private var log_ : Logger = null
// Method to get or create the logger for this object
def log: Logger = {
protected def log: Logger = {
if (log_ == null) {
var className = this.getClass.getName
// Ignore trailing $'s in the class names for Scala objects
@ -28,48 +28,48 @@ trait Logging {
}
// Log methods that take only a String
def logInfo(msg: => String) {
protected def logInfo(msg: => String) {
if (log.isInfoEnabled) log.info(msg)
}
def logDebug(msg: => String) {
protected def logDebug(msg: => String) {
if (log.isDebugEnabled) log.debug(msg)
}
def logTrace(msg: => String) {
protected def logTrace(msg: => String) {
if (log.isTraceEnabled) log.trace(msg)
}
def logWarning(msg: => String) {
protected def logWarning(msg: => String) {
if (log.isWarnEnabled) log.warn(msg)
}
def logError(msg: => String) {
protected def logError(msg: => String) {
if (log.isErrorEnabled) log.error(msg)
}
// Log methods that take Throwables (Exceptions/Errors) too
def logInfo(msg: => String, throwable: Throwable) {
protected def logInfo(msg: => String, throwable: Throwable) {
if (log.isInfoEnabled) log.info(msg, throwable)
}
def logDebug(msg: => String, throwable: Throwable) {
protected def logDebug(msg: => String, throwable: Throwable) {
if (log.isDebugEnabled) log.debug(msg, throwable)
}
def logTrace(msg: => String, throwable: Throwable) {
protected def logTrace(msg: => String, throwable: Throwable) {
if (log.isTraceEnabled) log.trace(msg, throwable)
}
def logWarning(msg: => String, throwable: Throwable) {
protected def logWarning(msg: => String, throwable: Throwable) {
if (log.isWarnEnabled) log.warn(msg, throwable)
}
def logError(msg: => String, throwable: Throwable) {
protected def logError(msg: => String, throwable: Throwable) {
if (log.isErrorEnabled) log.error(msg, throwable)
}
// Method for ensuring that logging is initialized, to avoid having multiple
// threads do it concurrently (as SLF4J initialization is not thread safe).
def initLogging() { log }
protected def initLogging() { log }
}

View file

@ -1,6 +1,6 @@
package spark
import java.io.{DataInputStream, DataOutputStream, ByteArrayOutputStream, ByteArrayInputStream}
import java.io._
import java.util.concurrent.ConcurrentHashMap
import akka.actor._
@ -14,16 +14,19 @@ import akka.util.duration._
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scheduler.MapStatus
import spark.storage.BlockManagerId
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
sealed trait MapOutputTrackerMessage
case class GetMapOutputLocations(shuffleId: Int) extends MapOutputTrackerMessage
case object StopMapOutputTracker extends MapOutputTrackerMessage
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
def receive = {
case GetMapOutputLocations(shuffleId: Int) =>
logInfo("Asked to get map output locations for shuffle " + shuffleId)
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
sender ! tracker.getSerializedLocations(shuffleId)
case StopMapOutputTracker =>
@ -33,23 +36,23 @@ class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Loggin
}
}
class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logging {
val ip: String = System.getProperty("spark.master.host", "localhost")
val port: Int = System.getProperty("spark.master.port", "7077").toInt
val actorName: String = "MapOutputTracker"
val timeout = 10.seconds
var bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
var mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var generation: Long = 0
private var generationLock = new java.lang.Object
private val generationLock = new java.lang.Object
// Cache a serialized version of the output locations for each shuffle to send them out faster
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
val cachedSerializedLocs = new HashMap[Int, Array[Byte]]
val cachedSerializedStatuses = new HashMap[Int, Array[Byte]]
var trackerActor: ActorRef = if (isMaster) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
@ -80,31 +83,34 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (bmAddresses.get(shuffleId) != null) {
if (mapStatuses.get(shuffleId) != null) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
bmAddresses.put(shuffleId, new Array[BlockManagerId](numMaps))
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var array = bmAddresses.get(shuffleId)
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
var array = mapStatuses.get(shuffleId)
array.synchronized {
array(mapId) = bmAddress
array(mapId) = status
}
}
def registerMapOutputs(shuffleId: Int, locs: Array[BlockManagerId], changeGeneration: Boolean = false) {
bmAddresses.put(shuffleId, Array[BlockManagerId]() ++ locs)
def registerMapOutputs(
shuffleId: Int,
statuses: Array[MapStatus],
changeGeneration: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeGeneration) {
incrementGeneration()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var array = bmAddresses.get(shuffleId)
var array = mapStatuses.get(shuffleId)
if (array != null) {
array.synchronized {
if (array(mapId) == bmAddress) {
if (array(mapId).address == bmAddress) {
array(mapId) = null
}
}
@ -117,10 +123,10 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
// Remembers which map output locations are currently being fetched on a worker
val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs for a given shuffle
def getServerAddresses(shuffleId: Int): Array[BlockManagerId] = {
val locs = bmAddresses.get(shuffleId)
if (locs == null) {
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId)
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
fetching.synchronized {
if (fetching.contains(shuffleId)) {
@ -129,34 +135,38 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
try {
fetching.wait()
} catch {
case _ =>
case e: InterruptedException =>
}
}
return bmAddresses.get(shuffleId)
return mapStatuses.get(shuffleId).map(status =>
(status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId))))
} else {
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val fetchedBytes = askTracker(GetMapOutputLocations(shuffleId)).asInstanceOf[Array[Byte]]
val fetchedLocs = deserializeLocations(fetchedBytes)
val host = System.getProperty("spark.hostname", Utils.localHostName)
val fetchedBytes = askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
val fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
bmAddresses.put(shuffleId, fetchedLocs)
mapStatuses.put(shuffleId, fetchedStatuses)
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
return fetchedLocs
return fetchedStatuses.map(s =>
(s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
} else {
return locs
return statuses.map(s =>
(s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId))))
}
}
def stop() {
communicate(StopMapOutputTracker)
bmAddresses.clear()
mapStatuses.clear()
trackerActor = null
}
@ -182,75 +192,83 @@ class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolean) extends Logg
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
bmAddresses = new ConcurrentHashMap[Int, Array[BlockManagerId]]
mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]
generation = newGen
}
}
}
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
var locs: Array[BlockManagerId] = null
var statuses: Array[MapStatus] = null
var generationGotten: Long = -1
generationLock.synchronized {
if (generation > cacheGeneration) {
cachedSerializedLocs.clear()
cachedSerializedStatuses.clear()
cacheGeneration = generation
}
cachedSerializedLocs.get(shuffleId) match {
cachedSerializedStatuses.get(shuffleId) match {
case Some(bytes) =>
return bytes
case None =>
locs = bmAddresses.get(shuffleId)
statuses = mapStatuses.get(shuffleId)
generationGotten = generation
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeLocations(locs)
val bytes = serializeStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the generation hasn't changed while we were working
generationLock.synchronized {
if (generation == generationGotten) {
cachedSerializedLocs(shuffleId) = bytes
cachedSerializedStatuses(shuffleId) = bytes
}
}
return bytes
}
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by grouping together the locations by block manager ID.
def serializeLocations(locs: Array[BlockManagerId]): Array[Byte] = {
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val dataOut = new DataOutputStream(out)
dataOut.writeInt(locs.length)
val grouped = locs.zipWithIndex.groupBy(_._1)
dataOut.writeInt(grouped.size)
for ((id, pairs) <- grouped if id != null) {
dataOut.writeUTF(id.ip)
dataOut.writeInt(id.port)
dataOut.writeInt(pairs.length)
for ((_, blockIndex) <- pairs) {
dataOut.writeInt(blockIndex)
}
}
dataOut.close()
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
objOut.writeObject(statuses)
objOut.close()
out.toByteArray
}
// Opposite of serializeLocations.
def deserializeLocations(bytes: Array[Byte]): Array[BlockManagerId] = {
val dataIn = new DataInputStream(new ByteArrayInputStream(bytes))
val length = dataIn.readInt()
val array = new Array[BlockManagerId](length)
val numGroups = dataIn.readInt()
for (i <- 0 until numGroups) {
val ip = dataIn.readUTF()
val port = dataIn.readInt()
val id = new BlockManagerId(ip, port)
val numBlocks = dataIn.readInt()
for (j <- 0 until numBlocks) {
array(dataIn.readInt()) = id
}
}
array
// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
objIn.readObject().asInstanceOf[Array[MapStatus]]
}
}
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
/**
* Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
* We do this by encoding the log base 1.1 of the size as an integer, which can support
* sizes up to 35 GB with at most 10% error.
*/
def compressSize(size: Long): Byte = {
if (size <= 1L) {
0
} else {
math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
}
}
/**
* Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
*/
def decompressSize(compressedSize: Byte): Long = {
if (compressedSize == 0) {
1
} else {
math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
}
}
}

View file

@ -1,11 +1,6 @@
package spark
import java.io.EOFException
import java.net.URL
import java.io.ObjectInputStream
import java.util.concurrent.atomic.AtomicLong
import java.util.{HashMap => JHashMap}
import java.util.Date
import java.util.{Date, HashMap => JHashMap}
import java.text.SimpleDateFormat
import scala.collection.Map
@ -15,46 +10,66 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.mapred.FileOutputFormat
import org.apache.hadoop.mapred.HadoopWriter
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputCommitter
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter}
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
import org.apache.hadoop.mapreduce.TaskAttemptID
import org.apache.hadoop.mapreduce.TaskAttemptContext
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext}
import spark.SparkContext._
import spark.partial.BoundedDouble
import spark.partial.PartialResult
import spark.rdd._
import spark.SparkContext._
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
* Import `spark.SparkContext._` at the top of your program to use these functions.
*/
class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self: RDD[(K, V)])
extends Logging
with HadoopMapReduceUtil
with Serializable {
/**
* Generic function to combine the elements for each key using a custom set of aggregation
* functions. Turns an RDD[(K, V)] into a result of type RDD[(K, C)], for a "combined type" C
* Note that V and C can be different -- for example, one might group an RDD of type
* (Int, Int) into an RDD of type (Int, Seq[Int]). Users provide three functions:
*
* - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
*
* In addition, users can control the partitioning of the output RDD, and whether to perform
* map-side aggregation (if a mapper can produce multiple items with the same key).
*/
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner): RDD[(K, C)] = {
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
new ShuffledRDD(self, aggregator, partitioner)
partitioner: Partitioner,
mapSideCombine: Boolean = true): RDD[(K, C)] = {
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V](self, partitioner)
values.mapPartitions(aggregator.combineValuesByKey(_), true)
}
}
/**
* Simplified version of combineByKey that hash-partitions the output RDD.
*/
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
@ -62,10 +77,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits))
}
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce.
*/
def reduceByKey(partitioner: Partitioner, func: (V, V) => V): RDD[(K, V)] = {
combineByKey[V]((v: V) => v, func, func, partitioner)
}
/**
* Merge the values for each key using an associative reduce function, but return the results
* immediately to the master as a Map. This will also perform the merging locally on each mapper
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: (V, V) => V): Map[K, V] = {
def reducePartition(iter: Iterator[(K, V)]): Iterator[JHashMap[K, V]] = {
val map = new JHashMap[K, V]
@ -87,22 +112,34 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
self.mapPartitions(reducePartition).reduce(mergeMaps)
}
// Alias for backwards compatibility
/** Alias for reduceByKeyLocally */
def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
// TODO: This should probably be a distributed version
/** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
// TODO: This should probably be a distributed version
/**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[Map[K, BoundedDouble]] = {
self.map(_._1).countByValueApprox(timeout, confidence)
}
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits.
*/
def reduceByKey(func: (V, V) => V, numSplits: Int): RDD[(K, V)] = {
reduceByKey(new HashPartitioner(numSplits), func)
}
/**
* Group the values for each key in the RDD into a single sequence. Allows controlling the
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
*/
def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
@ -112,19 +149,39 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
bufs.asInstanceOf[RDD[(K, Seq[V])]]
}
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with into `numSplits` partitions.
*/
def groupByKey(numSplits: Int): RDD[(K, Seq[V])] = {
groupByKey(new HashPartitioner(numSplits))
}
def partitionBy(partitioner: Partitioner): RDD[(K, V)] = {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
bufs.flatMapValues(buf => buf)
/**
* Return a copy of the RDD partitioned using the specified partitioner. If `mapSideCombine`
* is true, Spark will group values of the same key together on the map side before the
* repartitioning, to only send each key over the network once. If a large number of
* duplicated keys are expected, and the size of the keys are large, `mapSideCombine` should
* be set to true.
*/
def partitionBy(partitioner: Partitioner, mapSideCombine: Boolean = false): RDD[(K, V)] = {
if (mapSideCombine) {
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
bufs.flatMapValues(buf => buf)
} else {
new ShuffledRDD[K, V](self, partitioner)
}
}
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce.
*/
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
this.cogroup(other, partitioner).flatMapValues {
case (vs, ws) =>
@ -132,6 +189,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to
* partition the output RDD.
*/
def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = {
this.cogroup(other, partitioner).flatMapValues {
case (vs, ws) =>
@ -143,6 +206,12 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to
* partition the output RDD.
*/
def rightOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner)
: RDD[(K, (Option[V], W))] = {
this.cogroup(other, partitioner).flatMapValues {
@ -155,56 +224,117 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
def combineByKey[C](createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) : RDD[(K, C)] = {
/**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the default
* parallelism level.
*/
def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C)
: RDD[(K, C)] = {
combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(self))
}
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level.
*/
def reduceByKey(func: (V, V) => V): RDD[(K, V)] = {
reduceByKey(defaultPartitioner(self), func)
}
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with the default parallelism level.
*/
def groupByKey(): RDD[(K, Seq[V])] = {
groupByKey(defaultPartitioner(self))
}
/**
* Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
* pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
* (k, v2) is in `other`. Performs a hash join across the cluster.
*/
def join[W](other: RDD[(K, W)]): RDD[(K, (V, W))] = {
join(other, defaultPartitioner(self, other))
}
/**
* Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
* pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
* (k, v2) is in `other`. Performs a hash join across the cluster.
*/
def join[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, W))] = {
join(other, new HashPartitioner(numSplits))
}
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
* using the default level of parallelism.
*/
def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = {
leftOuterJoin(other, defaultPartitioner(self, other))
}
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
* into `numSplits` partitions.
*/
def leftOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (V, Option[W]))] = {
leftOuterJoin(other, new HashPartitioner(numSplits))
}
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD using the default parallelism level.
*/
def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = {
rightOuterJoin(other, defaultPartitioner(self, other))
}
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD into the given number of partitions.
*/
def rightOuterJoin[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Option[V], W))] = {
rightOuterJoin(other, new HashPartitioner(numSplits))
}
/**
* Return the key-value pairs in this RDD to the master as a Map.
*/
def collectAsMap(): Map[K, V] = HashMap(self.collect(): _*)
/**
* Pass each value in the key-value pair RDD through a map function without changing the keys;
* this also retains the original RDD's partitioning.
*/
def mapValues[U](f: V => U): RDD[(K, U)] = {
val cleanF = self.context.clean(f)
new MappedValuesRDD(self, cleanF)
}
/**
* Pass each value in the key-value pair RDD through a flatMap function without changing the
* keys; this also retains the original RDD's partitioning.
*/
def flatMapValues[U](f: V => TraversableOnce[U]): RDD[(K, U)] = {
val cleanF = self.context.clean(f)
new FlatMappedValuesRDD(self, cleanF)
}
/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (Seq[V], Seq[W]))] = {
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]], other.asInstanceOf[RDD[(_, _)]]),
@ -215,12 +345,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
(vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
}
}
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], partitioner: Partitioner)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
val cg = new CoGroupedRDD[K](
Seq(self.asInstanceOf[RDD[(_, _)]],
other1.asInstanceOf[RDD[(_, _)]],
other1.asInstanceOf[RDD[(_, _)]],
other2.asInstanceOf[RDD[(_, _)]]),
partitioner)
val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
@ -230,28 +364,46 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
cogroup(other, defaultPartitioner(self, other))
}
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)])
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
cogroup(other1, other2, defaultPartitioner(self, other1, other2))
}
/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: RDD[(K, W)], numSplits: Int): RDD[(K, (Seq[V], Seq[W]))] = {
cogroup(other, new HashPartitioner(numSplits))
}
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)], numSplits: Int)
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
cogroup(other1, other2, new HashPartitioner(numSplits))
}
/** Alias for cogroup. */
def groupWith[W](other: RDD[(K, W)]): RDD[(K, (Seq[V], Seq[W]))] = {
cogroup(other, defaultPartitioner(self, other))
}
/** Alias for cogroup. */
def groupWith[W1, W2](other1: RDD[(K, W1)], other2: RDD[(K, W2)])
: RDD[(K, (Seq[V], Seq[W1], Seq[W2]))] = {
cogroup(other1, other2, defaultPartitioner(self, other1, other2))
@ -268,6 +420,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
return new HashPartitioner(self.context.defaultParallelism)
}
/**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
* RDD has a known partitioner by only searching the partition that the key maps to.
*/
def lookup(key: K): Seq[V] = {
self.partitioner match {
case Some(p) =>
@ -286,14 +442,26 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
}
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
*/
def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
}
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
}
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
def saveAsNewAPIHadoopFile(
path: String,
keyClass: Class[_],
@ -302,6 +470,10 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, new Configuration)
}
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
def saveAsNewAPIHadoopFile(
path: String,
keyClass: Class[_],
@ -323,7 +495,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/* "reduce task" <split #> <attempt # = spark task #> */
val attemptId = new TaskAttemptID(jobtrackerID,
stageId, false, context.splitId, attemptNumber)
val hadoopContext = new TaskAttemptContext(wrappedConf.value, attemptId)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
@ -342,13 +514,17 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* setupJob/commitJob, so we just use a dummy "map" task.
*/
val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0)
val jobTaskContext = new TaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
val count = self.context.runJob(self, writeShard _).sum
jobCommitter.cleanupJob(jobTaskContext)
}
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
*/
def saveAsHadoopFile(
path: String,
keyClass: Class[_],
@ -363,7 +539,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf))
saveAsHadoopDataset(conf)
}
/**
* Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for
* that storage system. The JobConf should set an OutputFormat and any output paths required
* (e.g. a table name to write to) in the same way as it would be configured for a Hadoop
* MapReduce job.
*/
def saveAsHadoopDataset(conf: JobConf) {
val outputFormatClass = conf.getOutputFormat
val keyClass = conf.getOutputKeyClass
@ -377,7 +559,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
if (valueClass == null) {
throw new SparkException("Output value class not set")
}
logInfo("Saving as hadoop file of type (" + keyClass.getSimpleName+ ", " + valueClass.getSimpleName+ ")")
val writer = new HadoopWriter(conf)
@ -390,14 +572,14 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
writer.setup(context.stageId, context.splitId, attemptNumber)
writer.open()
var count = 0
while(iter.hasNext) {
val record = iter.next
count += 1
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
}
writer.close()
writer.commit()
}
@ -406,35 +588,42 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
writer.cleanup()
}
def getKeyClass() = implicitly[ClassManifest[K]].erasure
private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
def getValueClass() = implicitly[ClassManifest[V]].erasure
private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure
}
/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
* an implicit conversion. Import `spark.SparkContext._` at the top of your program to use these
* functions. They will work with any key type that has a `scala.math.Ordered` implementation.
*/
class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
self: RDD[(K, V)])
extends Logging
extends Logging
with Serializable {
def sortByKey(ascending: Boolean = true): RDD[(K,V)] = {
val rangePartitionedRDD = self.partitionBy(new RangePartitioner(self.splits.size, self, ascending))
new SortedRDD(rangePartitionedRDD, ascending)
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
* `collect` or `save` on the resulting RDD will return or output an ordered list of records
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
def sortByKey(ascending: Boolean = true, numSplits: Int = self.splits.size): RDD[(K,V)] = {
val shuffled =
new ShuffledRDD[K, V](self, new RangePartitioner(numSplits, self, ascending))
shuffled.mapPartitions(iter => {
val buf = iter.toArray
if (ascending) {
buf.sortWith((x, y) => x._1 < y._1).iterator
} else {
buf.sortWith((x, y) => x._1 > y._1).iterator
}
}, true)
}
}
class SortedRDD[K <% Ordered[K], V](prev: RDD[(K, V)], ascending: Boolean)
extends RDD[(K, V)](prev.context) {
override def splits = prev.splits
override val partitioner = prev.partitioner
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = {
prev.iterator(split).toArray
.sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1).iterator
}
}
private[spark]
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
@ -442,9 +631,10 @@ class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)]
override def compute(split: Split) = prev.iterator(split).map{case (k, v) => (k, f(v))}
}
private[spark]
class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U])
extends RDD[(K, U)](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override val partitioner = prev.partitioner
@ -454,6 +644,6 @@ class FlatMappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => TraversableOnce[U]
}
}
object Manifests {
private[spark] object Manifests {
val seqSeqManifest = classManifest[Seq[Seq[_]]]
}

View file

@ -3,7 +3,7 @@ package spark
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
class ParallelCollectionSplit[T: ClassManifest](
private[spark] class ParallelCollectionSplit[T: ClassManifest](
val rddId: Long,
val slice: Int,
values: Seq[T])
@ -21,7 +21,7 @@ class ParallelCollectionSplit[T: ClassManifest](
override val index: Int = slice
}
class ParallelCollection[T: ClassManifest](
private[spark] class ParallelCollection[T: ClassManifest](
sc: SparkContext,
@transient data: Seq[T],
numSlices: Int)

View file

@ -1,10 +1,17 @@
package spark
/**
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
* Maps each key to a partition ID, from 0 to `numPartitions - 1`.
*/
abstract class Partitioner extends Serializable {
def numPartitions: Int
def getPartition(key: Any): Int
}
/**
* A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
@ -29,6 +36,10 @@ class HashPartitioner(partitions: Int) extends Partitioner {
}
}
/**
* A [[spark.Partitioner]] that partitions sortable records by range into roughly equal ranges.
* Determines the ranges by sampling the RDD passed in.
*/
class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
partitions: Int,
@transient rdd: RDD[(K,V)],
@ -41,9 +52,9 @@ class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
Array()
} else {
val rddSize = rdd.count()
val maxSampleSize = partitions * 10.0
val maxSampleSize = partitions * 20.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
val rddSample = rdd.sample(true, frac, 1).map(_._1).collect().sortWith(_ < _)
val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _)
if (rddSample.length == 0) {
Array()
} else {

View file

@ -31,51 +31,86 @@ import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
import spark.partial.PartialResult
import spark.rdd.BlockRDD
import spark.rdd.CartesianRDD
import spark.rdd.FilteredRDD
import spark.rdd.FlatMappedRDD
import spark.rdd.GlommedRDD
import spark.rdd.MappedRDD
import spark.rdd.MapPartitionsRDD
import spark.rdd.MapPartitionsWithSplitRDD
import spark.rdd.PipedRDD
import spark.rdd.SampledRDD
import spark.rdd.UnionRDD
import spark.storage.StorageLevel
import SparkContext._
/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
* partitioned collection of elements that can be operated on in parallel.
* partitioned collection of elements that can be operated on in parallel. This class contains the
* basic operations available on all RDDs, such as `map`, `filter`, and `persist`. In addition,
* [[spark.PairRDDFunctions]] contains operations available only on RDDs of key-value pairs, such
* as `groupByKey` and `join`; [[spark.DoubleRDDFunctions]] contains operations available only on
* RDDs of Doubles; and [[spark.SequenceFileRDDFunctions]] contains operations available on RDDs
* that can be saved as SequenceFiles. These operations are automatically available on any RDD of
* the right type (e.g. RDD[(Int, Int)] through implicit conversions when you
* `import spark.SparkContext._`.
*
* Each RDD is characterized by five main properties:
* - A list of splits (partitions)
* - A function for computing each split
* - A list of dependencies on other RDDs
* - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned)
* - Optionally, a list of preferred locations to compute each split on (e.g. block locations for
* HDFS)
* Internally, each RDD is characterized by five main properties:
*
* All the scheduling and execution in Spark is done based on these methods, allowing each RDD to
* implement its own way of computing itself.
* - A list of splits (partitions)
* - A function for computing each split
* - A list of dependencies on other RDDs
* - Optionally, a Partitioner for key-value RDDs (e.g. to say that the RDD is hash-partitioned)
* - Optionally, a list of preferred locations to compute each split on (e.g. block locations for
* an HDFS file)
*
* This class also contains transformation methods available on all RDDs (e.g. map and filter). In
* addition, PairRDDFunctions contains extra methods available on RDDs of key-value pairs, and
* SequenceFileRDDFunctions contains extra methods for saving RDDs to Hadoop SequenceFiles.
* All of the scheduling and execution in Spark is done based on these methods, allowing each RDD
* to implement its own way of computing itself. Indeed, users can implement custom RDDs (e.g. for
* reading data from a new storage system) by overriding these functions. Please refer to the
* [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
* on RDD internals.
*/
abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serializable {
// Methods that must be implemented by subclasses
// Methods that must be implemented by subclasses:
/** Set of partitions in this RDD. */
def splits: Array[Split]
/** Function for computing a given partition. */
def compute(split: Split): Iterator[T]
/** How this RDD depends on any parent RDDs. */
@transient val dependencies: List[Dependency[_]]
// Methods available on all RDDs:
// Optionally overridden by subclasses to specify how they are partitioned
/** Record user function generating this RDD. */
private[spark] val origin = Utils.getSparkCallSite
/** Optionally overridden by subclasses to specify how they are partitioned. */
val partitioner: Option[Partitioner] = None
// Optionally overridden by subclasses to specify placement preferences
/** Optionally overridden by subclasses to specify placement preferences. */
def preferredLocations(split: Split): Seq[String] = Nil
/** The [[spark.SparkContext]] that this RDD was created on. */
def context = sc
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
// Get a unique ID for this RDD
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
// Variables relating to persistence
private var storageLevel: StorageLevel = StorageLevel.NONE
// Change this RDD's storage level
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
def persist(newLevel: StorageLevel): RDD[T] = {
// TODO: Handle changes of StorageLevel
if (storageLevel != StorageLevel.NONE && newLevel != storageLevel) {
@ -86,22 +121,23 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
this
}
// Turn on the default caching level for this RDD
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY_DESER)
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def persist(): RDD[T] = persist(StorageLevel.MEMORY_ONLY)
// Turn on the default caching level for this RDD
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): RDD[T] = persist()
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
def checkpoint(level: StorageLevel = StorageLevel.DISK_AND_MEMORY_DESER_2): RDD[T] = {
private[spark] def checkpoint(level: StorageLevel = StorageLevel.MEMORY_AND_DISK_2): RDD[T] = {
if (!level.useDisk && level.replication < 2) {
throw new Exception("Cannot checkpoint without using disk or replication (level requested was " + level + ")")
}
// This is a hack. Ideally this should re-use the code used by the CacheTracker
// to generate the key.
def getSplitKey(split: Split) = "rdd:%d:%d".format(this.id, split.index)
def getSplitKey(split: Split) = "rdd_%d_%d".format(this.id, split.index)
persist(level)
sc.runJob(this, (iter: Iterator[T]) => {} )
@ -113,7 +149,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}
}
// Read this RDD; will read from cache if applicable, or otherwise compute
/**
* Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
final def iterator(split: Split): Iterator[T] = {
if (storageLevel != StorageLevel.NONE) {
SparkEnv.get.cacheTracker.getOrCompute[T](this, split, storageLevel)
@ -124,15 +164,32 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
// Transformations (return a new RDD)
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] =
new FlatMappedRDD(this, sc.clean(f))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: T => Boolean): RDD[T] = new FilteredRDD(this, sc.clean(f))
def distinct(): RDD[T] = map(x => (x, "")).reduceByKey((x, y) => x).map(_._1)
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numSplits: Int = splits.size): RDD[T] =
map(x => (x, null)).reduceByKey((x, y) => x, numSplits).map(_._1)
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): RDD[T] =
new SampledRDD(this, withReplacement, fraction, seed)
@ -143,8 +200,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
var initialCount = count()
var maxSelected = 0
if (initialCount > Integer.MAX_VALUE) {
maxSelected = Integer.MAX_VALUE
if (initialCount > Integer.MAX_VALUE - 1) {
maxSelected = Integer.MAX_VALUE - 1
} else {
maxSelected = initialCount.toInt
}
@ -159,56 +216,109 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
total = num
}
var samples = this.sample(withReplacement, fraction, seed).collect()
val rand = new Random(seed)
var samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
while (samples.length < total) {
samples = this.sample(withReplacement, fraction, seed).collect()
samples = this.sample(withReplacement, fraction, rand.nextInt).collect()
}
val arr = samples.take(total)
return arr
Utils.randomizeInPlace(samples, rand).take(total)
}
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
def ++(other: RDD[T]): RDD[T] = this.union(other)
/**
* Return an RDD created by coalescing all elements within each partition into an array.
*/
def glom(): RDD[Array[T]] = new GlommedRDD(this)
/**
* Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
* elements (a, b) where a is in `this` and b is in `other`.
*/
def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K: ClassManifest](f: T => K, numSplits: Int): RDD[(K, Seq[T])] = {
val cleanF = sc.clean(f)
this.map(t => (cleanF(t), t)).groupByKey(numSplits)
}
/**
* Return an RDD of grouped items.
*/
def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = groupBy[K](f, sc.defaultParallelism)
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: String): RDD[String] = new PipedRDD(this, command)
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f))
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
new MapPartitionsWithSplitRDD(this, sc.clean(f))
// Actions (launch a job to return a value to the user program)
/**
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}
/**
* Return an array that contains all of the elements in this RDD.
*/
def collect(): Array[T] = {
val results = sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
Array.concat(results: _*)
}
/**
* Return an array that contains all of the elements in this RDD.
*/
def toArray(): Array[T] = collect()
/**
* Reduces the elements of this RDD using the specified associative binary operator.
*/
def reduce(f: (T, T) => T): T = {
val cleanF = sc.clean(f)
val reducePartition: Iterator[T] => Option[T] = iter => {
@ -257,7 +367,10 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
(iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp))
return results.fold(zeroValue)(cleanCombOp)
}
/**
* Return the number of elements in the RDD.
*/
def count(): Long = {
sc.runJob(this, (iter: Iterator[T]) => {
var result = 0L
@ -270,7 +383,8 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}
/**
* Approximate version of count() that returns a potentially incomplete result after a timeout.
* (Experimental) Approximate version of count() that returns a potentially incomplete result
* within a timeout, even if not all tasks have finished.
*/
def countApprox(timeout: Long, confidence: Double = 0.95): PartialResult[BoundedDouble] = {
val countElements: (TaskContext, Iterator[T]) => Long = { (ctx, iter) =>
@ -286,12 +400,11 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}
/**
* Count elements equal to each value, returning a map of (value, count) pairs. The final combine
* step happens locally on the master, equivalent to running a single reduce task.
*
* TODO: This should perhaps be distributed by default.
* Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): Map[T, Long] = {
// TODO: This should perhaps be distributed by default.
def countPartition(iter: Iterator[T]): Iterator[OLMap[T]] = {
val map = new OLMap[T]
while (iter.hasNext) {
@ -313,7 +426,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
}
/**
* Approximate version of countByValue().
* (Experimental) Approximate version of countByValue().
*/
def countByValueApprox(
timeout: Long,
@ -353,18 +466,27 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
return buf.toArray
}
/**
* Return the first element in this RDD.
*/
def first(): T = take(1) match {
case Array(t) => t
case _ => throw new UnsupportedOperationException("empty collection")
}
/**
* Save this RDD as a text file, using string representations of elements.
*/
def saveAsTextFile(path: String) {
this.map(x => (NullWritable.get(), new Text(x.toString)))
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
}
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
def saveAsObjectFile(path: String) {
this.glom
this.mapPartitions(iter => iter.grouped(10).map(_.toArray))
.map(x => (NullWritable.get(), new BytesWritable(Utils.serialize(x))))
.saveAsSequenceFile(path)
}
@ -374,45 +496,3 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
}
}
class MappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => U)
extends RDD[U](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = prev.iterator(split).map(f)
}
class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: T => TraversableOnce[U])
extends RDD[U](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = prev.iterator(split).flatMap(f)
}
class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = prev.iterator(split).filter(f)
}
class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = Array(prev.iterator(split).toArray).iterator
}
class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
f: Iterator[T] => Iterator[U])
extends RDD[U](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
override def compute(split: Split) = f(prev.iterator(split))
}

View file

@ -23,19 +23,21 @@ import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.Text
import SparkContext._
import spark.SparkContext._
/**
* Extra functions available on RDDs of (key, value) pairs to create a Hadoop SequenceFile,
* through an implicit conversion. Note that this can't be part of PairRDDFunctions because
* we need more implicit parameters to convert our keys and values to Writable.
*
* Users should import `spark.SparkContext._` at the top of their program to use these functions.
*/
class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest](
self: RDD[(K, V)])
extends Logging
with Serializable {
def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = {
private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = {
val c = {
if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
classManifest[T].erasure
@ -47,6 +49,13 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
c.asInstanceOf[Class[_ <: Writable]]
}
/**
* Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key
* and value types. If the key or value are Writable, then we use their classes directly;
* otherwise we map primitive types such as Int and Double to IntWritable, DoubleWritable, etc,
* byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
* file system.
*/
def saveAsSequenceFile(path: String) {
def anyToWritable[U <% Writable](u: U): Writable = u

View file

@ -1,10 +1,12 @@
package spark
abstract class ShuffleFetcher {
// Fetch the shuffle outputs for a given ShuffleDependency, calling func exactly
// once on each key-value pair obtained.
def fetch[K, V](shuffleId: Int, reduceId: Int, func: (K, V) => Unit)
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
def fetch[K, V](shuffleId: Int, reduceId: Int) : Iterator[(K, V)]
// Stop the fetcher
/** Stop the fetcher */
def stop() {}
}

View file

@ -1,98 +0,0 @@
package spark
import java.io._
import java.net.URL
import java.util.UUID
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.{ArrayBuffer, HashMap}
import spark._
class ShuffleManager extends Logging {
private var nextShuffleId = new AtomicLong(0)
private var shuffleDir: File = null
private var server: HttpServer = null
private var serverUri: String = null
initialize()
private def initialize() {
// TODO: localDir should be created by some mechanism common to Spark
// so that it can be shared among shuffle, broadcast, etc
val localDirRoot = System.getProperty("spark.local.dir", "/tmp")
var tries = 0
var foundLocalDir = false
var localDir: File = null
var localDirUuid: UUID = null
while (!foundLocalDir && tries < 10) {
tries += 1
try {
localDirUuid = UUID.randomUUID
localDir = new File(localDirRoot, "spark-local-" + localDirUuid)
if (!localDir.exists) {
localDir.mkdirs()
foundLocalDir = true
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir failed", e)
}
}
if (!foundLocalDir) {
logError("Failed 10 attempts to create local dir in " + localDirRoot)
System.exit(1)
}
shuffleDir = new File(localDir, "shuffle")
shuffleDir.mkdirs()
logInfo("Shuffle dir: " + shuffleDir)
// Add a shutdown hook to delete the local dir
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dir") {
override def run() {
Utils.deleteRecursively(localDir)
}
})
val extServerPort = System.getProperty(
"spark.localFileShuffle.external.server.port", "-1").toInt
if (extServerPort != -1) {
// We're using an external HTTP server; set URI relative to its root
var extServerPath = System.getProperty(
"spark.localFileShuffle.external.server.path", "")
if (extServerPath != "" && !extServerPath.endsWith("/")) {
extServerPath += "/"
}
serverUri = "http://%s:%d/%s/spark-local-%s".format(
Utils.localIpAddress, extServerPort, extServerPath, localDirUuid)
} else {
// Create our own server
server = new HttpServer(localDir)
server.start()
serverUri = server.uri
}
logInfo("Local URI: " + serverUri)
}
def stop() {
if (server != null) {
server.stop()
}
}
def getOutputFile(shuffleId: Long, inputId: Int, outputId: Int): File = {
val dir = new File(shuffleDir, shuffleId + "/" + inputId)
dir.mkdirs()
val file = new File(dir, "" + outputId)
return file
}
def getServerUri(): String = {
serverUri
}
def newShuffleId(): Long = {
nextShuffleId.getAndIncrement()
}
}

View file

@ -1,51 +0,0 @@
package spark
import java.util.{HashMap => JHashMap}
class ShuffledRDDSplit(val idx: Int) extends Split {
override val index = idx
override def hashCode(): Int = idx
}
class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)],
aggregator: Aggregator[K, V, C],
part : Partitioner)
extends RDD[(K, C)](parent.context) {
//override val partitioner = Some(part)
override val partitioner = Some(part)
@transient
val splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i))
override def splits = splits_
override def preferredLocations(split: Split) = Nil
val dep = new ShuffleDependency(context.newShuffleId, parent, aggregator, part)
override val dependencies = List(dep)
override def compute(split: Split): Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
def mergePair(k: K, c: C) {
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, aggregator.mergeCombiners(oldC, c))
}
}
val fetcher = SparkEnv.get.shuffleFetcher
fetcher.fetch[K, C](dep.shuffleId, split.index, mergePair)
return new Iterator[(K, C)] {
var iter = combiners.entrySet().iterator()
def hasNext: Boolean = iter.hasNext()
def next(): (K, C) = {
val entry = iter.next()
(entry.getKey, entry.getValue)
}
}
}
}

View file

@ -22,7 +22,7 @@ import it.unimi.dsi.fastutil.ints.IntOpenHashSet
* Based on the following JavaWorld article:
* http://www.javaworld.com/javaworld/javaqa/2003-12/02-qa-1226-sizeof.html
*/
object SizeEstimator extends Logging {
private[spark] object SizeEstimator extends Logging {
// Sizes of primitive types
private val BYTE_SIZE = 1
@ -77,22 +77,18 @@ object SizeEstimator extends Logging {
return System.getProperty("spark.test.useCompressedOops").toBoolean
}
try {
val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic";
val server = ManagementFactory.getPlatformMBeanServer();
val hotSpotMBeanName = "com.sun.management:type=HotSpotDiagnostic"
val server = ManagementFactory.getPlatformMBeanServer()
val bean = ManagementFactory.newPlatformMXBeanProxy(server,
hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean]);
hotSpotMBeanName, classOf[HotSpotDiagnosticMXBean])
return bean.getVMOption("UseCompressedOops").getValue.toBoolean
} catch {
case e: IllegalArgumentException => {
logWarning("Exception while trying to check if compressed oops is enabled", e)
// Fall back to checking if maxMemory < 32GB
return Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
}
case e: SecurityException => {
logWarning("No permission to create MBeanServer", e)
// Fall back to checking if maxMemory < 32GB
return Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
case e: Exception => {
// Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB
val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024)
val guessInWords = if (guess) "yes" else "not"
logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords)
return guess
}
}
}
@ -146,6 +142,10 @@ object SizeEstimator extends Logging {
val cls = obj.getClass
if (cls.isArray) {
visitArray(obj, cls, state)
} else if (obj.isInstanceOf[ClassLoader] || obj.isInstanceOf[Class[_]]) {
// Hadoop JobConfs created in the interpreter have a ClassLoader, which greatly confuses
// the size estimator since it references the whole REPL. Do nothing in this case. In
// general all ClassLoaders and Classes will be shared between objects anyway.
} else {
val classInfo = getClassInfo(cls)
state.size += classInfo.shellSize

View file

@ -5,7 +5,7 @@ import com.google.common.collect.MapMaker
/**
* An implementation of Cache that uses soft references.
*/
class SoftReferenceCache extends Cache {
private[spark] class SoftReferenceCache extends Cache {
val map = new MapMaker().softValues().makeMap[Any, Any]()
override def get(datasetId: Any, partition: Int): Any =

View file

@ -2,13 +2,15 @@ package spark
import java.io._
import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader}
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import akka.actor.Actor
import akka.actor.Actor._
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
@ -25,34 +27,59 @@ import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
import org.apache.mesos.{Scheduler, MesosNativeLibrary}
import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import spark.rdd.HadoopRDD
import spark.rdd.NewHadoopRDD
import spark.rdd.UnionRDD
import spark.scheduler.ShuffleMapTask
import spark.scheduler.DAGScheduler
import spark.scheduler.TaskScheduler
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import spark.storage.BlockManagerMaster
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
*
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI.
* @param sparkHome Location where Spark is installed on cluster nodes.
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
* @param environment Environment variables to set on worker nodes.
*/
class SparkContext(
val master: String,
val frameworkName: String,
val jobName: String,
val sparkHome: String,
val jars: Seq[String])
val jars: Seq[String],
environment: Map[String, String])
extends Logging {
def this(master: String, frameworkName: String) = this(master, frameworkName, null, Nil)
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI
* @param sparkHome Location where Spark is installed on cluster nodes.
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
*/
def this(master: String, jobName: String, sparkHome: String, jars: Seq[String]) =
this(master, jobName, sparkHome, jars, Map())
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI
*/
def this(master: String, jobName: String) = this(master, jobName, null, Nil, Map())
// Ensure logging is initialized before we spawn any threads
initLogging()
@ -68,46 +95,89 @@ class SparkContext(
private val isLocal = (master == "local" || master.startsWith("local["))
// Create the Spark execution environment (cache, map output tracker, etc)
val env = SparkEnv.createFromSystemProperties(
private[spark] val env = SparkEnv.createFromSystemProperties(
System.getProperty("spark.master.host"),
System.getProperty("spark.master.port").toInt,
true,
isLocal)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
private[spark] val addedFiles = HashMap[String, Long]()
private[spark] val addedJars = HashMap[String, Long]()
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
for (key <- Seq("SPARK_MEM", "SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS",
"SPARK_TESTING")) {
val value = System.getenv(key)
if (value != null) {
executorEnvs(key) = value
}
}
executorEnvs ++= environment
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+),([0-9]+)\]""".r
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
master match {
case "local" =>
new LocalScheduler(1, 0)
case "local" =>
new LocalScheduler(1, 0, this)
case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt, 0)
case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt, 0, this)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
new LocalScheduler(threads.toInt, maxFailures.toInt)
new LocalScheduler(threads.toInt, maxFailures.toInt, this)
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, frameworkName)
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName)
scheduler.initialize(backend)
scheduler
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
// Check to make sure SPARK_MEM <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
val sparkMemEnv = System.getenv("SPARK_MEM")
val sparkMemEnvInt = if (sparkMemEnv != null) Utils.memoryStringToMb(sparkMemEnv) else 512
if (sparkMemEnvInt > memoryPerSlaveInt) {
throw new SparkException(
"Slave memory (%d MB) cannot be smaller than SPARK_MEM (%d MB)".format(
memoryPerSlaveInt, sparkMemEnvInt))
}
val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, jobName)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
}
scheduler
case _ =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
new CoarseMesosSchedulerBackend(scheduler, this, master, frameworkName)
new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName)
} else {
new MesosSchedulerBackend(scheduler, this, master, frameworkName)
new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, jobName)
}
scheduler.initialize(backend)
scheduler
@ -119,14 +189,20 @@ class SparkContext(
// Methods for creating RDDs
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
new ParallelCollection[T](this, seq, numSlices)
}
def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism ): RDD[T] = {
/** Distribute a local Scala collection to form an RDD. */
def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
parallelize(seq, numSlices)
}
/**
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minSplits: Int = defaultMinSplits): RDD[String] = {
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minSplits)
.map(pair => pair._2.toString)
@ -163,19 +239,31 @@ class SparkContext(
}
/**
* Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly.
* Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
* val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits)
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
: RDD[(K, V)] = {
hadoopFile(path,
fm.erasure.asInstanceOf[Class[F]],
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]],
minSplits)
}
/**
* Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
* val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path)
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
hadoopFile[K, V, F](path, defaultMinSplits)
@ -191,7 +279,7 @@ class SparkContext(
new Configuration)
}
/**
/**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
@ -207,7 +295,7 @@ class SparkContext(
new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf)
}
/**
/**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
@ -219,7 +307,7 @@ class SparkContext(
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types */
/** Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
@ -229,18 +317,23 @@ class SparkContext(
hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] =
sequenceFile(path, keyClass, valueClass, defaultMinSplits)
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
* WritableConverter.
* Version of sequenceFile() for types implicitly convertible to Writables through a
* WritableConverter. For example, to access a SequenceFile where the keys are Text and the
* values are IntWritable, you could simply write
* {{{
* sparkContext.sequenceFile[String, Int](path, ...)
* }}}
*
* WritableConverters are provided in a somewhat strange way (by an implicit function) to support
* both subclasses of Writable and types for which we define a converter (e.g. Int to
* both subclasses of Writable and types for which we define a converter (e.g. Int to
* IntWritable). The most natural thing would've been to have implicit objects for the
* converters, but then we couldn't have an object for every subclass of Writable (you can't
* have a parameterized singleton object). We use functions instead to create a new converter
* have a parameterized singleton object). We use functions instead to create a new converter
* for the appropriate type. In addition, we pass the converter a ClassManifest of its type to
* allow it to figure out the Writable class to use in the subclass case.
*/
@ -265,7 +358,7 @@ class SparkContext(
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T: ClassManifest](
path: String,
path: String,
minSplits: Int = defaultMinSplits
): RDD[T] = {
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits)
@ -275,43 +368,128 @@ class SparkContext(
/** Build the union of a list of RDDs. */
def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
/** Build the union of a list of RDDs. */
/** Build the union of a list of RDDs passed as variable-length arguments. */
def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] =
new UnionRDD(this, Seq(first) ++ rest)
// Methods for creating shared variables
/**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
/**
* Create an accumulable shared variable, with a `+=` method
* Create an [[spark.Accumulable]] shared variable, with a `+=` method
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
def accumulable[T,R](initialValue: T)(implicit param: AccumulableParam[T,R]) =
new Accumulable(initialValue, param)
/**
* Create an accumulator from a "mutable collection" type.
*
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
}
// Keep around a weak hash map of values to Cached versions?
def broadcast[T](value: T) = SparkEnv.get.broadcastManager.newBroadcast[T] (value, isLocal)
/**
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T] (value, isLocal)
// Stop the SparkContext
/**
* Add a file to be downloaded into the working directory of this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
*/
def addFile(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
case _ => path
}
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case the task is executed locally
val filename = new File(path.split("/").last)
Utils.fetchFile(path, new File("."))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
/**
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
*/
def getSlavesMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.ip + ":" + blockManagerId.port, mem)
}
}
/**
* Clear the job's list of files added by `addFile` so that they do not get donwloaded to
* any new nodes.
*/
def clearFiles() {
addedFiles.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedFiles.clear()
}
/**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
*/
def addJar(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
case _ => path
}
addedJars(key) = System.currentTimeMillis
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
/**
* Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
* any new nodes.
*/
def clearJars() {
addedJars.keySet.map(_.split("/").last).foreach { k => new File(k).delete() }
addedJars.clear()
}
/** Shut down the SparkContext. */
def stop() {
dagScheduler.stop()
dagScheduler = null
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
// Clean up locally linked files
clearFiles()
clearJars()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
logInfo("Successfully stopped SparkContext")
}
// Get Spark's home location from either a value set through the constructor,
// or the spark.home Java property, or the SPARK_HOME environment variable
// (in that order of preference). If neither of these is set, return None.
def getSparkHome(): Option[String] = {
/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable
* (in that order of preference). If neither of these is set, return None.
*/
private[spark] def getSparkHome(): Option[String] = {
if (sparkHome != null) {
Some(sparkHome)
} else if (System.getProperty("spark.home") != null) {
@ -326,7 +504,7 @@ class SparkContext(
/**
* Run a function on a given set of partitions in an RDD and return the results. This is the main
* entry point to the scheduler, by which all actions get launched. The allowLocal flag specifies
* whether the scheduler can run the computation on the master rather than shipping it out to the
* whether the scheduler can run the computation on the master rather than shipping it out to the
* cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
@ -335,22 +513,27 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
logInfo("Starting job...")
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, allowLocal)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result
}
/**
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: Iterator[T] => U,
func: Iterator[T] => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
@ -358,6 +541,9 @@ class SparkContext(
runJob(rdd, func, 0 until rdd.splits.size, false)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
runJob(rdd, func, 0 until rdd.splits.size, false)
}
@ -371,38 +557,37 @@ class SparkContext(
evaluator: ApproximateEvaluator[U, R],
timeout: Long
): PartialResult[R] = {
logInfo("Starting job...")
val callSite = Utils.getSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, timeout)
logInfo("Job finished in " + (System.nanoTime - start) / 1e9 + " s")
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result
}
// Clean a closure to make it ready to serialized and send to tasks
// (removes unreferenced variables in $outer's, updates REPL variables)
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
*/
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
return f
}
// Default level of parallelism to use when not given by user (e.g. for reduce tasks)
/** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
def defaultParallelism: Int = taskScheduler.defaultParallelism
// Default min number of splits for Hadoop RDDs when not given by user
/** Default min number of splits for Hadoop RDDs when not given by user */
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
private var nextShuffleId = new AtomicInteger(0)
private[spark] def newShuffleId(): Int = {
nextShuffleId.getAndIncrement()
}
private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement()
private var nextRddId = new AtomicInteger(0)
// Register a new RDD, returning its RDD ID
private[spark] def newRddId(): Int = {
nextRddId.getAndIncrement()
}
/** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
}
/**
@ -429,7 +614,7 @@ object SparkContext {
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
@ -450,7 +635,7 @@ object SparkContext {
implicit def longToLongWritable(l: Long) = new LongWritable(l)
implicit def floatToFloatWritable(f: Float) = new FloatWritable(f)
implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d)
implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b)
@ -461,7 +646,7 @@ object SparkContext {
private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = {
def anyToWritable[U <% Writable](u: U): Writable = u
new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]],
arr.map(x => anyToWritable(x)).toArray)
}
@ -489,8 +674,10 @@ object SparkContext {
implicit def writableWritableConverter[T <: Writable]() =
new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T])
// Find the JAR from which a given class was loaded, to make it easy for users to pass
// their JARs to SparkContext
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
* their JARs to SparkContext
*/
def jarOfClass(cls: Class[_]): Seq[String] = {
val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class")
if (uri != null) {
@ -505,8 +692,8 @@ object SparkContext {
Nil
}
}
// Find the JAR that contains the class of a particular object
/** Find the JAR that contains the class of a particular object */
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
}
@ -518,7 +705,7 @@ object SparkContext {
* that doesn't know the type of T when it is created. This sounds strange but is necessary to
* support converting subclasses of Writable to themselves (writableWritableConverter).
*/
class WritableConverter[T](
private[spark] class WritableConverter[T](
val writableClass: ClassManifest[T] => Class[_ <: Writable],
val convert: Writable => T)
extends Serializable

View file

@ -1,46 +1,57 @@
package spark
import akka.actor.ActorSystem
import akka.actor.ActorSystemImpl
import akka.remote.RemoteActorRefProvider
import serializer.Serializer
import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
import spark.util.AkkaUtils
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
* Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
* objects needs to have the right SparkEnv set. You can get the current environment with
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
*/
class SparkEnv (
val actorSystem: ActorSystem,
val cache: Cache,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheTracker: CacheTracker,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer
) {
/** No-parameter constructor for unit tests. */
def this() = {
this(null, null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
this(null, new JavaSerializer, new JavaSerializer, null, null, null, null, null, null, null)
}
def stop() {
httpFileServer.stop()
mapOutputTracker.stop()
cacheTracker.stop()
shuffleFetcher.stop()
shuffleManager.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
actorSystem.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
actorSystem.awaitTermination()
}
}
object SparkEnv {
object SparkEnv extends Logging {
private val env = new ThreadLocal[SparkEnv]
def set(e: SparkEnv) {
@ -66,66 +77,55 @@ object SparkEnv {
System.setProperty("spark.master.port", boundPort.toString)
}
val serializerClass = System.getProperty("spark.serializer", "spark.KryoSerializer")
val serializer = Class.forName(serializerClass).newInstance().asInstanceOf[Serializer]
val classLoader = Thread.currentThread.getContextClassLoader
// Create an instance of the class named by the given Java system property, or by
// defaultClassName if the property is not set, and return it as a T
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = System.getProperty(propertyName, defaultClassName)
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
val blockManagerMaster = new BlockManagerMaster(actorSystem, isMaster, isLocal)
val blockManager = new BlockManager(blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager
val shuffleManager = new ShuffleManager()
val connectionManager = blockManager.connectionManager
val broadcastManager = new BroadcastManager(isMaster)
val closureSerializerClass =
System.getProperty("spark.closure.serializer", "spark.JavaSerializer")
val closureSerializer =
Class.forName(closureSerializerClass).newInstance().asInstanceOf[Serializer]
val cacheClass = System.getProperty("spark.cache.class", "spark.BoundedMemoryCache")
val cache = Class.forName(cacheClass).newInstance().asInstanceOf[Cache]
val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "spark.JavaSerializer")
val cacheTracker = new CacheTracker(actorSystem, isMaster, blockManager)
blockManager.cacheTracker = cacheTracker
val mapOutputTracker = new MapOutputTracker(actorSystem, isMaster)
val shuffleFetcherClass =
System.getProperty("spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val shuffleFetcher =
Class.forName(shuffleFetcherClass).newInstance().asInstanceOf[ShuffleFetcher]
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
val httpFileServer = new HttpFileServer()
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
/*
if (System.getProperty("spark.stream.distributed", "false") == "true") {
val blockManagerClass = classOf[spark.storage.BlockManager].asInstanceOf[Class[_]]
if (isLocal || !isMaster) {
(new Thread() {
override def run() {
println("Wait started")
Thread.sleep(60000)
println("Wait ended")
val receiverClass = Class.forName("spark.stream.TestStreamReceiver4")
val constructor = receiverClass.getConstructor(blockManagerClass)
val receiver = constructor.newInstance(blockManager)
receiver.asInstanceOf[Thread].start()
}
}).start()
}
// Warn about deprecated spark.cache.class property
if (System.getProperty("spark.cache.class") != null) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
"levels using the RDD.persist() method instead.")
}
*/
new SparkEnv(
actorSystem,
cache,
serializer,
closureSerializer,
cacheTracker,
mapOutputTracker,
shuffleFetcher,
shuffleManager,
broadcastManager,
blockManager,
connectionManager)
connectionManager,
httpFileServer)
}
}

View file

@ -7,10 +7,16 @@ import spark.storage.BlockManagerId
* tasks several times for "ephemeral" failures, and only report back failures that require some
* old stages to be resubmitted, such as shuffle map fetch failures.
*/
sealed trait TaskEndReason
private[spark] sealed trait TaskEndReason
case object Success extends TaskEndReason
private[spark] case object Success extends TaskEndReason
private[spark]
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
private[spark]
case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
case class ExceptionFailure(exception: Throwable) extends TaskEndReason
case class OtherFailure(message: String) extends TaskEndReason
private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason
private[spark] case class OtherFailure(message: String) extends TaskEndReason

View file

@ -2,7 +2,7 @@ package spark
import org.apache.mesos.Protos.{TaskState => MesosTaskState}
object TaskState
private[spark] object TaskState
extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value

View file

@ -1,18 +1,18 @@
package spark
import java.io._
import java.net.InetAddress
import java.net.{InetAddress, URL, URI}
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import scala.util.Random
import java.util.{Locale, UUID}
import scala.io.Source
/**
* Various utility methods used by Spark.
*/
object Utils {
private object Utils extends Logging {
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@ -71,7 +71,7 @@ object Utils {
while (dir == null) {
attempts += 1
if (attempts > maxAttempts) {
throw new IOException("Failed to create a temp directory after " + maxAttempts +
throw new IOException("Failed to create a temp directory after " + maxAttempts +
" attempts!")
}
try {
@ -116,22 +116,84 @@ object Utils {
copyStream(in, out, true)
}
/** Download a file from a given URL to the local filesystem */
def downloadFile(url: URL, localPath: String) {
val in = url.openStream()
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
}
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val targetFile = new File(targetDir, filename)
val uri = new URI(url)
uri.getScheme match {
case "http" | "https" | "ftp" =>
logInfo("Fetching " + url + " to " + targetFile)
val in = new URL(url).openStream()
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
case "file" | null =>
// Remove the file if it already exists
targetFile.delete()
// Symlink the file locally.
if (uri.isAbsolute) {
// url is absolute, i.e. it starts with "file:///". Extract the source
// file's absolute path from the url.
val sourceFile = new File(uri)
logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
} else {
// url is not absolute, i.e. itself is the path to the source file.
logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
FileUtil.symLink(url, targetFile.getAbsolutePath)
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
}
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xzf", filename), targetDir)
} else if (filename.endsWith(".tar")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xf", filename), targetDir)
}
// Make the file executable - That's necessary for scripts
FileUtil.chmod(filename, "a+x")
}
/**
* Shuffle the elements of a collection into a random order, returning the
* result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention.
*/
def randomize[T](seq: TraversableOnce[T]): Seq[T] = {
val buf = new ArrayBuffer[T]()
buf ++= seq
val rand = new Random()
for (i <- (buf.size - 1) to 1 by -1) {
def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
randomizeInPlace(seq.toArray)
}
/**
* Shuffle the elements of an array into a random order, modifying the
* original array. Returns the original array.
*/
def randomizeInPlace[T](arr: Array[T], rand: Random = new Random): Array[T] = {
for (i <- (arr.length - 1) to 1 by -1) {
val j = rand.nextInt(i)
val tmp = buf(j)
buf(j) = buf(i)
buf(i) = tmp
val tmp = arr(j)
arr(j) = arr(i)
arr(i) = tmp
}
buf
arr
}
/**
@ -155,7 +217,7 @@ object Utils {
def localHostName(): String = {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
}
/**
* Returns a standard ThreadFactory except all threads are daemons.
*/
@ -179,10 +241,10 @@ object Utils {
return threadPool
}
/**
* Return the string to tell how long has passed in seconds. The passing parameter should be in
* millisecond.
* Return the string to tell how long has passed in seconds. The passing parameter should be in
* millisecond.
*/
def getUsedTimeMs(startTimeMs: Long): String = {
return " " + (System.currentTimeMillis - startTimeMs) + " ms"
@ -294,4 +356,43 @@ object Utils {
def execute(command: Seq[String]) {
execute(command, new File("."))
}
/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*/
def getSparkCallSite: String = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
// Keep crawling up the stack trace until we find the first function not inside of the spark
// package. We track the last (shallowest) contiguous Spark method. This might be an RDD
// transformation, a SparkContext function (such as parallelize), or anything else that leads
// to instantiation of an RDD. We also track the first (deepest) user method, file, and line.
var lastSparkMethod = "<unknown>"
var firstUserFile = "<unknown>"
var firstUserLine = 0
var finished = false
for (el <- trace) {
if (!finished) {
if (el.getClassName.startsWith("spark.") && !el.getClassName.startsWith("spark.examples.")) {
lastSparkMethod = if (el.getMethodName == "<init>") {
// Spark method is a constructor; get its class name
el.getClassName.substring(el.getClassName.lastIndexOf('.') + 1)
} else {
el.getMethodName
}
}
else {
firstUserLine = el.getLineNumber
firstUserFile = el.getFileName
finished = true
}
}
}
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
}
}

View file

@ -22,8 +22,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
import JavaDoubleRDD.fromRDD
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaDoubleRDD = fromRDD(srdd.cache())
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
// first() has to be overriden here in order for its return type to be Double instead of Object.
@ -31,36 +36,63 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
// Transformations (return a new RDD)
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct())
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numSplits: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numSplits))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD =
fromRDD(srdd.filter(x => f(x).booleanValue()))
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD =
fromRDD(srdd.sample(withReplacement, fraction, seed))
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd))
// Double RDD functions
/** Return the sum of the elements in this RDD. */
def sum(): Double = srdd.sum()
/** Return a [[spark.StatCounter]] describing the elements in this RDD. */
def stats(): StatCounter = srdd.stats()
/** Return the mean of the elements in this RDD. */
def mean(): Double = srdd.mean()
/** Return the variance of the elements in this RDD. */
def variance(): Double = srdd.variance()
/** Return the standard deviation of the elements in this RDD. */
def stdev(): Double = srdd.stdev()
/** Return the approximate mean of the elements in this RDD. */
def meanApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
srdd.meanApprox(timeout, confidence)
/** Return the approximate mean of the elements in this RDD. */
def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout)
/** Return the approximate sum of the elements in this RDD. */
def sumApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
srdd.sumApprox(timeout, confidence)
/** Return the approximate sum of the elements in this RDD. */
def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
}

View file

@ -1,13 +1,5 @@
package spark.api.java
import spark.SparkContext.rddToPairRDDFunctions
import spark.api.java.function.{Function2 => JFunction2}
import spark.api.java.function.{Function => JFunction}
import spark.partial.BoundedDouble
import spark.partial.PartialResult
import spark.storage.StorageLevel
import spark._
import java.util.{List => JList}
import java.util.Comparator
@ -19,6 +11,17 @@ import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.hadoop.conf.Configuration
import spark.api.java.function.{Function2 => JFunction2}
import spark.api.java.function.{Function => JFunction}
import spark.partial.BoundedDouble
import spark.partial.PartialResult
import spark.OrderedRDDFunctions
import spark.storage.StorageLevel
import spark.HashPartitioner
import spark.Partitioner
import spark.RDD
import spark.SparkContext.rddToPairRDDFunctions
class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K],
implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
@ -31,21 +34,44 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
// Common RDD functions
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache())
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
// Transformations (return a new RDD)
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct())
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numSplits: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numSplits))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: Function[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.union(other.rdd))
@ -56,7 +82,21 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
override def first(): (K, V) = rdd.first()
// Pair RDD functions
/**
* Generic function to combine the elements for each key using a custom set of aggregation
* functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
* "combined type" C * Note that V and C can be different -- for example, one might group an
* RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
* functions:
*
* - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
*
* In addition, users can control the partitioning of the output RDD, and whether to perform
* map-side aggregation (if a mapper can produce multiple items with the same key).
*/
def combineByKey[C](createCombiner: Function[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
@ -71,50 +111,113 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
))
}
/**
* Simplified version of combineByKey that hash-partitions the output RDD.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
numSplits: Int): JavaPairRDD[K, C] =
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numSplits))
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce.
*/
def reduceByKey(partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
fromRDD(rdd.reduceByKey(partitioner, func))
/**
* Merge the values for each key using an associative reduce function, but return the results
* immediately to the master as a Map. This will also perform the merging locally on each mapper
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
mapAsJavaMap(rdd.reduceByKeyLocally(func))
/** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
/**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
/**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with numSplits splits.
*/
def reduceByKey(func: JFunction2[V, V, V], numSplits: Int): JavaPairRDD[K, V] =
fromRDD(rdd.reduceByKey(func, numSplits))
/**
* Group the values for each key in the RDD into a single sequence. Allows controlling the
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
*/
def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey(partitioner)))
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with into `numSplits` partitions.
*/
def groupByKey(numSplits: Int): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey(numSplits)))
/**
* Return a copy of the RDD partitioned using the specified partitioner. If `mapSideCombine`
* is true, Spark will group values of the same key together on the map side before the
* repartitioning, to only send each key over the network once. If a large number of
* duplicated keys are expected, and the size of the keys are large, `mapSideCombine` should
* be set to true.
*/
def partitionBy(partitioner: Partitioner): JavaPairRDD[K, V] =
fromRDD(rdd.partitionBy(partitioner))
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce.
*/
def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] =
fromRDD(rdd.join(other, partitioner))
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to
* partition the output RDD.
*/
def leftOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
: JavaPairRDD[K, (V, Option[W])] =
fromRDD(rdd.leftOuterJoin(other, partitioner))
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to
* partition the output RDD.
*/
def rightOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
: JavaPairRDD[K, (Option[V], W)] =
fromRDD(rdd.rightOuterJoin(other, partitioner))
/**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the default
* parallelism level.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
@ -123,40 +226,94 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners))
}
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level.
*/
def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = {
val partitioner = rdd.defaultPartitioner(rdd)
fromRDD(reduceByKey(partitioner, func))
}
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with the default parallelism level.
*/
def groupByKey(): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey()))
/**
* Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
* pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
* (k, v2) is in `other`. Performs a hash join across the cluster.
*/
def join[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, W)] =
fromRDD(rdd.join(other))
/**
* Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
* pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
* (k, v2) is in `other`. Performs a hash join across the cluster.
*/
def join[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, W)] =
fromRDD(rdd.join(other, numSplits))
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
* using the default level of parallelism.
*/
def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Option[W])] =
fromRDD(rdd.leftOuterJoin(other))
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
* into `numSplits` partitions.
*/
def leftOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (V, Option[W])] =
fromRDD(rdd.leftOuterJoin(other, numSplits))
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD using the default parallelism level.
*/
def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Option[V], W)] =
fromRDD(rdd.rightOuterJoin(other))
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD into the given number of partitions.
*/
def rightOuterJoin[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (Option[V], W)] =
fromRDD(rdd.rightOuterJoin(other, numSplits))
/**
* Return the key-value pairs in this RDD to the master as a Map.
*/
def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap())
/**
* Pass each value in the key-value pair RDD through a map function without changing the keys;
* this also retains the original RDD's partitioning.
*/
def mapValues[U](f: Function[V, U]): JavaPairRDD[K, U] = {
implicit val cm: ClassManifest[U] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
fromRDD(rdd.mapValues(f))
}
/**
* Pass each value in the key-value pair RDD through a flatMap function without changing the
* keys; this also retains the original RDD's partitioning.
*/
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
import scala.collection.JavaConverters._
def fn = (x: V) => f.apply(x).asScala
@ -165,37 +322,68 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(rdd.flatMapValues(fn))
}
/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
: JavaPairRDD[K, (JList[V], JList[W])] =
fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner)))
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], partitioner: Partitioner)
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))
/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
fromRDD(cogroupResultToJava(rdd.cogroup(other)))
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))
/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: JavaPairRDD[K, W], numSplits: Int): JavaPairRDD[K, (JList[V], JList[W])]
= fromRDD(cogroupResultToJava(rdd.cogroup(other, numSplits)))
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numSplits: Int)
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numSplits)))
/** Alias for cogroup. */
def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
fromRDD(cogroupResultToJava(rdd.groupWith(other)))
/** Alias for cogroup. */
def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))
/**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
* RDD has a known partitioner by only searching the partition that the key maps to.
*/
def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key))
/** Output the RDD to any Hadoop-supported file system. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
path: String,
keyClass: Class[_],
@ -205,6 +393,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
path: String,
keyClass: Class[_],
@ -213,6 +402,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,
keyClass: Class[_],
@ -222,6 +412,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,
keyClass: Class[_],
@ -230,21 +421,49 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
/**
* Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for
* that storage system. The JobConf should set an OutputFormat and any output paths required
* (e.g. a table name to write to) in the same way as it would be configured for a Hadoop
* MapReduce job.
*/
def saveAsHadoopDataset(conf: JobConf) {
rdd.saveAsHadoopDataset(conf)
}
// Ordered RDD Functions
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements in
* ascending order. Calling `collect` or `save` on the resulting RDD will return or output an
* ordered list of records (in the `save` case, they will be written to multiple `part-X` files
* in the filesystem, in order of the keys).
*/
def sortByKey(): JavaPairRDD[K, V] = sortByKey(true)
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
* `collect` or `save` on the resulting RDD will return or output an ordered list of records
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
sortByKey(comp, true)
}
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
* `collect` or `save` on the resulting RDD will return or output an ordered list of records
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
def sortByKey(comp: Comparator[K]): JavaPairRDD[K, V] = sortByKey(comp, true)
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
* `collect` or `save` on the resulting RDD will return or output an ordered list of records
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
def sortByKey(comp: Comparator[K], ascending: Boolean): JavaPairRDD[K, V] = {
class KeyOrdering(val a: K) extends Ordered[K] {
override def compare(b: K) = comp.compare(a, b)
@ -274,4 +493,4 @@ object JavaPairRDD {
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
}
}

View file

@ -11,20 +11,43 @@ JavaRDDLike[T, JavaRDD[T]] {
// Common RDD functions
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
// Transformations (return a new RDD)
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct())
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numSplits: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numSplits))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] =
wrapRDD(rdd.filter((x => f(x).booleanValue())))
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))
}

View file

@ -19,41 +19,71 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def rdd: RDD[T]
/** Set of partitions in this RDD. */
def splits: JList[Split] = new java.util.ArrayList(rdd.splits.toSeq)
/** The [[spark.SparkContext]] that this RDD was created on. */
def context: SparkContext = rdd.context
/** A unique ID for this RDD (within its SparkContext). */
def id: Int = rdd.id
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel: StorageLevel = rdd.getStorageLevel
/**
* Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
def iterator(split: Split): java.util.Iterator[T] = asJavaIterator(rdd.iterator(split))
// Transformations (return a new RDD)
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[R](f: JFunction[T, R]): JavaRDD[R] =
new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType())
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[R](f: DoubleFunction[T]): JavaDoubleRDD =
new JavaDoubleRDD(rdd.map(x => f(x).doubleValue()))
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType())
}
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType())
}
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue()))
}
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
@ -61,29 +91,50 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue()))
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]):
JavaPairRDD[K, V] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
}
/**
* Return an RDD created by coalescing all elements within each partition into an array.
*/
def glom(): JavaRDD[JList[T]] =
new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
/**
* Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
* elements (a, b) where a is in `this` and b is in `other`.
*/
def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] =
JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest,
other.classManifest)
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
implicit val kcm: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
@ -92,6 +143,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm)
}
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K](f: JFunction[T, K], numSplits: Int): JavaPairRDD[K, JList[T]] = {
implicit val kcm: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
@ -100,56 +155,114 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numSplits)(f.returnType)))(kcm, vcm)
}
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: String): JavaRDD[String] = rdd.pipe(command)
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: JList[String]): JavaRDD[String] =
rdd.pipe(asScalaBuffer(command))
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] =
rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env))
// Actions (launch a job to return a value to the user program)
/**
* Applies a function f to all elements of this RDD.
*/
def foreach(f: VoidFunction[T]) {
val cleanF = rdd.context.clean(f)
rdd.foreach(cleanF)
}
/**
* Return an array that contains all of the elements in this RDD.
*/
def collect(): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.collect().toSeq
new java.util.ArrayList(arr)
}
/**
* Reduces the elements of this RDD using the specified associative binary operator.
*/
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* modify t1 and return it as its result value to avoid object allocation; however, it should not
* modify t2.
*/
def fold(zeroValue: T)(f: JFunction2[T, T, T]): T =
rdd.fold(zeroValue)(f)
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using
* given combine functions and a neutral "zero value". This function can return a different result
* type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U
* and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
* allowed to modify and return their first argument instead of creating a new U to avoid memory
* allocation.
*/
def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U],
combOp: JFunction2[U, U, U]): U =
rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType)
/**
* Return the number of elements in the RDD.
*/
def count(): Long = rdd.count()
/**
* (Experimental) Approximate version of count() that returns a potentially incomplete result
* within a timeout, even if not all tasks have finished.
*/
def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
rdd.countApprox(timeout, confidence)
/**
* (Experimental) Approximate version of count() that returns a potentially incomplete result
* within a timeout, even if not all tasks have finished.
*/
def countApprox(timeout: Long): PartialResult[BoundedDouble] =
rdd.countApprox(timeout)
/**
* Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): java.util.Map[T, java.lang.Long] =
mapAsJavaMap(rdd.countByValue().map((x => (x._1, new lang.Long(x._2)))))
/**
* (Experimental) Approximate version of countByValue().
*/
def countByValueApprox(
timeout: Long,
confidence: Double
): PartialResult[java.util.Map[T, BoundedDouble]] =
rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
/**
* (Experimental) Approximate version of countByValue().
*/
def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
rdd.countByValueApprox(timeout).map(mapAsJavaMap)
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
* it will be slow if a lot of partitions are required. In that case, use collect() to get the
* whole RDD instead.
*/
def take(num: Int): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.take(num).toSeq
@ -162,9 +275,18 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
new java.util.ArrayList(arr)
}
/**
* Return the first element in this RDD.
*/
def first(): T = rdd.first()
/**
* Save this RDD as a text file, using string representations of elements.
*/
def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
}

View file

@ -1,43 +1,78 @@
package spark.api.java
import java.util.{Map => JMap}
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import spark.{Accumulator, AccumulatorParam, RDD, SparkContext}
import spark.SparkContext.IntAccumulatorParam
import spark.SparkContext.DoubleAccumulatorParam
import spark.broadcast.Broadcast
import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import scala.collection.JavaConversions
/**
* A Java-friendly version of [[spark.SparkContext]] that returns [[spark.api.java.JavaRDD]]s and
* works with Java collections instead of Scala ones.
*/
class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround {
def this(master: String, frameworkName: String) = this(new SparkContext(master, frameworkName))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI
*/
def this(master: String, jobName: String) = this(new SparkContext(master, jobName))
def this(master: String, frameworkName: String, sparkHome: String, jarFile: String) =
this(new SparkContext(master, frameworkName, sparkHome, Seq(jarFile)))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
*/
def this(master: String, jobName: String, sparkHome: String, jarFile: String) =
this(new SparkContext(master, jobName, sparkHome, Seq(jarFile)))
def this(master: String, frameworkName: String, sparkHome: String, jars: Array[String]) =
this(new SparkContext(master, frameworkName, sparkHome, jars.toSeq))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
*/
def this(master: String, jobName: String, sparkHome: String, jars: Array[String]) =
this(new SparkContext(master, jobName, sparkHome, jars.toSeq))
val env = sc.env
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param jobName A name for your job, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
* @param environment Environment variables to set on worker nodes
*/
def this(master: String, jobName: String, sparkHome: String, jars: Array[String],
environment: JMap[String, String]) =
this(new SparkContext(master, jobName, sparkHome, jars.toSeq, environment))
private[spark] val env = sc.env
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)
}
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T]): JavaRDD[T] =
parallelize(list, sc.defaultParallelism)
/** Distribute a local Scala collection to form an RDD. */
def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int)
: JavaPairRDD[K, V] = {
implicit val kcm: ClassManifest[K] =
@ -47,21 +82,32 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices))
}
/** Distribute a local Scala collection to form an RDD. */
def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]]): JavaPairRDD[K, V] =
parallelizePairs(list, sc.defaultParallelism)
/** Distribute a local Scala collection to form an RDD. */
def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD =
JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()),
numSlices))
/** Distribute a local Scala collection to form an RDD. */
def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD =
parallelizeDoubles(list, sc.defaultParallelism)
/**
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String): JavaRDD[String] = sc.textFile(path)
/**
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits)
/**Get an RDD for a Hadoop SequenceFile with given key and value types */
/**Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
@ -72,6 +118,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits))
}
/**Get an RDD for a Hadoop SequenceFile. */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]):
JavaPairRDD[K, V] = {
implicit val kcm = ClassManifest.fromClass(keyClass)
@ -92,6 +139,13 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
sc.objectFile(path, minSplits)(cm)
}
/**
* Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and
* BytesWritable values that contain a serialized partition. This is still an experimental storage
* format and may not be supported exactly as is in future Spark releases. It will also be pretty
* slow if you use the default serializer (Java serialization), though the nice thing about it is
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T](path: String): JavaRDD[T] = {
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
@ -115,6 +169,11 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits))
}
/**
* Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
* etc).
*/
def hadoopRDD[K, V, F <: InputFormat[K, V]](
conf: JobConf,
inputFormatClass: Class[F],
@ -126,7 +185,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass))
}
/**Get an RDD for a Hadoop file with an arbitrary InputFormat */
/** Get an RDD for a Hadoop file with an arbitrary InputFormat */
def hadoopFile[K, V, F <: InputFormat[K, V]](
path: String,
inputFormatClass: Class[F],
@ -139,6 +198,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits))
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat */
def hadoopFile[K, V, F <: InputFormat[K, V]](
path: String,
inputFormatClass: Class[F],
@ -180,12 +240,14 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass))
}
/** Build the union of two or more RDDs. */
override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = {
val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
implicit val cm: ClassManifest[T] = first.classManifest
sc.union(rdds)(cm)
}
/** Build the union of two or more RDDs. */
override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]])
: JavaPairRDD[K, V] = {
val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
@ -195,26 +257,49 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm)
}
/** Build the union of two or more RDDs. */
override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = {
val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd)
new JavaDoubleRDD(sc.union(rdds))
}
/**
* Create an [[spark.Accumulator]] integer variable, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`.
*/
def intAccumulator(initialValue: Int): Accumulator[Int] =
sc.accumulator(initialValue)(IntAccumulatorParam)
/**
* Create an [[spark.Accumulator]] double variable, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`.
*/
def doubleAccumulator(initialValue: Double): Accumulator[Double] =
sc.accumulator(initialValue)(DoubleAccumulatorParam)
/**
* Create an [[spark.Accumulator]] variable of a given type, which tasks can "add" values
* to using the `+=` method. Only the master can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
sc.accumulator(initialValue)(accumulatorParam)
/**
* Broadcast a read-only variable to the cluster, returning a [[spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)
/** Shut down the SparkContext. */
def stop() {
sc.stop()
}
/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable
* (in that order of preference). If neither of these is set, return None.
*/
def getSparkHome(): Option[String] = sc.getSparkHome()
}

View file

@ -0,0 +1,20 @@
package spark.api.java;
import spark.storage.StorageLevel;
/**
* Expose some commonly useful storage level constants.
*/
public class StorageLevels {
public static final StorageLevel NONE = new StorageLevel(false, false, false, 1);
public static final StorageLevel DISK_ONLY = new StorageLevel(true, false, false, 1);
public static final StorageLevel DISK_ONLY_2 = new StorageLevel(true, false, false, 2);
public static final StorageLevel MEMORY_ONLY = new StorageLevel(false, true, true, 1);
public static final StorageLevel MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2);
public static final StorageLevel MEMORY_ONLY_SER = new StorageLevel(false, true, false, 1);
public static final StorageLevel MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2);
public static final StorageLevel MEMORY_AND_DISK = new StorageLevel(true, true, true, 1);
public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2);
public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1);
public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2);
}

View file

@ -5,6 +5,9 @@ import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
* A function that returns zero or more records of type Double from each input record.
*/
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>

View file

@ -5,6 +5,9 @@ import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
* A function that returns Doubles, and can be used to construct DoubleRDDs.
*/
// DoubleFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>

View file

@ -1,5 +1,8 @@
package spark.api.java.function
/**
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
@throws(classOf[Exception])
def call(x: T) : java.lang.Iterable[R]

View file

@ -8,8 +8,9 @@ import java.io.Serializable;
/**
* Base class for functions whose return types do not have special RDDs; DoubleFunction is
* handled separately, to allow DoubleRDDs to be constructed when mapping RDDs to doubles.
* Base class for functions whose return types do not create special RDDs. PairFunction and
* DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed
* when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
public abstract R call(T t) throws Exception;

View file

@ -6,6 +6,9 @@ import scala.runtime.AbstractFunction2;
import java.io.Serializable;
/**
* A two-argument function that takes arguments of type T1 and T2 and returns an R.
*/
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {

View file

@ -7,6 +7,10 @@ import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
* A function that returns zero or more key-value pair records from each input record. The
* key-value pairs are represented as scala.Tuple2 objects.
*/
// PairFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and PairFlatMapFunction.
public abstract class PairFlatMapFunction<T, K, V>

View file

@ -7,6 +7,9 @@ import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
* A function that returns key-value pairs (Tuple2<K, V>), and can be used to construct PairRDDs.
*/
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
public abstract class PairFunction<T, K, V>

View file

@ -1,5 +1,8 @@
package spark.api.java.function
/**
* A function with no return value.
*/
// This allows Java users to write void methods without having to return Unit.
abstract class VoidFunction[T] extends Serializable {
@throws(classOf[Exception])

View file

@ -7,7 +7,7 @@ import scala.runtime.AbstractFunction1
* apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply
* isn't marked to allow that).
*/
abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
@throws(classOf[Exception])
def call(t: T): R

View file

@ -7,7 +7,7 @@ import scala.runtime.AbstractFunction2
* apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply
* isn't marked to allow that).
*/
abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
@throws(classOf[Exception])
def call(t1: T1, t2: T2): R

View file

@ -11,14 +11,17 @@ import scala.math
import spark._
import spark.storage.StorageLevel
class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
private[spark] class BitTorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id)
with Logging
with Serializable {
def value = value_
def blockId: String = "broadcast_" + id
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -45,7 +48,7 @@ extends Broadcast[T] with Logging with Serializable {
// Used only in Workers
@transient var ttGuide: TalkToGuide = null
@transient var hostAddress = Utils.localIpAddress
@transient var hostAddress = Utils.localIpAddress()
@transient var listenPort = -1
@transient var guidePort = -1
@ -53,7 +56,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast
sendBroadcast()
}
def sendBroadcast() {
@ -106,20 +109,22 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource
// Register with the Tracker
MultiTracker.registerBroadcast(uuid,
MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
case Some(x) => x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables
initializeWorkerVariables()
logInfo("Local host address: " + hostAddress)
@ -131,18 +136,17 @@ extends Broadcast[T] with Logging with Serializable {
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
logError("Reading Broadcasted variable " + uuid + " failed")
logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
@ -254,8 +258,8 @@ extends Broadcast[T] with Logging with Serializable {
}
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableUUID)
def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
@ -307,9 +311,11 @@ extends Broadcast[T] with Logging with Serializable {
var threadPool = Utils.newDaemonFixedThreadPool(MultiTracker.MaxChatSlots)
while (hasBlocks.get < totalBlocks) {
var numThreadsToCreate =
math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
var numThreadsToCreate = 0
listOfSources.synchronized {
numThreadsToCreate = math.min(listOfSources.size, MultiTracker.MaxChatSlots) -
threadPool.getActiveCount
}
while (hasBlocks.get < totalBlocks && numThreadsToCreate > 0) {
var peerToTalkTo = pickPeerToTalkToRandom
@ -722,7 +728,6 @@ extends Broadcast[T] with Logging with Serializable {
guidePortLock.synchronized { guidePortLock.notifyAll() }
try {
// Don't stop until there is a copy in HDFS
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
@ -730,14 +735,17 @@ extends Broadcast[T] with Logging with Serializable {
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
logError("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
listOfSources.synchronized {
setOfCompletedSources.synchronized {
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
}
}
}
}
}
@ -760,7 +768,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
MultiTracker.unregisterBroadcast(uuid)
MultiTracker.unregisterBroadcast(id)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
@ -918,9 +926,7 @@ extends Broadcast[T] with Logging with Serializable {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
logError("ServeMultipleRequests Timeout.")
}
case e: Exception => { }
}
if (clientSocket != null) {
logDebug("Serve: Accepted new client connection:" + clientSocket)
@ -1023,9 +1029,12 @@ extends Broadcast[T] with Logging with Serializable {
}
}
class BitTorrentBroadcastFactory
private[spark] class BitTorrentBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new BitTorrentBroadcast[T](value_, isLocal)
def stop() = MultiTracker.stop
def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new BitTorrentBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
}

View file

@ -1,25 +1,20 @@
package spark.broadcast
import java.io._
import java.net._
import java.util.{BitSet, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import scala.collection.mutable.Map
import java.util.concurrent.atomic.AtomicLong
import spark._
trait Broadcast[T] extends Serializable {
val uuid = UUID.randomUUID
abstract class Broadcast[T](id: Long) extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.
override def toString = "spark.Broadcast(" + uuid + ")"
override def toString = "spark.Broadcast(" + id + ")"
}
private[spark]
class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializable {
private var initialized = false
@ -49,14 +44,10 @@ class BroadcastManager(val isMaster_ : Boolean) extends Logging with Serializabl
broadcastFactory.stop()
}
private def getBroadcastFactory: BroadcastFactory = {
if (broadcastFactory == null) {
throw new SparkException ("Broadcast.getBroadcastFactory called before initialize")
}
broadcastFactory
}
private val nextBroadcastId = new AtomicLong(0)
def newBroadcast[T](value_ : T, isLocal: Boolean) = broadcastFactory.newBroadcast[T](value_, isLocal)
def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def isMaster = isMaster_
}

View file

@ -6,8 +6,8 @@ package spark.broadcast
* BroadcastFactory implementation to instantiate a particular broadcast for the
* entire Spark job.
*/
trait BroadcastFactory {
private[spark] trait BroadcastFactory {
def initialize(isMaster: Boolean): Unit
def newBroadcast[T](value_ : T, isLocal: Boolean): Broadcast[T]
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}

View file

@ -12,44 +12,47 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import spark._
import spark.storage.StorageLevel
class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId: String = "broadcast_" + id
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
if (!isLocal) {
HttpBroadcast.write(uuid, value_)
HttpBroadcast.write(id, value_)
}
// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) => value_ = x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](uuid)
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + uuid + " took " + time + " s")
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
}
}
class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isMaster: Boolean) = HttpBroadcast.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new HttpBroadcast[T](value_, isLocal)
def stop() = HttpBroadcast.stop()
private[spark] class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isMaster: Boolean) { HttpBroadcast.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def stop() { HttpBroadcast.stop() }
}
private object HttpBroadcast extends Logging {
@ -65,7 +68,7 @@ private object HttpBroadcast extends Logging {
synchronized {
if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
compress = System.getProperty("spark.compress", "false").toBoolean
compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
if (isMaster) {
createServer()
}
@ -76,9 +79,12 @@ private object HttpBroadcast extends Logging {
}
def stop() {
if (server != null) {
server.stop()
server = null
synchronized {
if (server != null) {
server.stop()
server = null
}
initialized = false
}
}
@ -91,8 +97,8 @@ private object HttpBroadcast extends Logging {
logInfo("Broadcast server started at " + serverUri)
}
def write(uuid: UUID, value: Any) {
val file = new File(broadcastDir, "broadcast-" + uuid)
def write(id: Long, value: Any) {
val file = new File(broadcastDir, "broadcast-" + id)
val out: OutputStream = if (compress) {
new LZFOutputStream(new FileOutputStream(file)) // Does its own buffering
} else {
@ -104,8 +110,8 @@ private object HttpBroadcast extends Logging {
serOut.close()
}
def read[T](uuid: UUID): T = {
val url = serverUri + "/broadcast-" + uuid
def read[T](id: Long): T = {
val url = serverUri + "/broadcast-" + id
var in = if (compress) {
new LZFInputStream(new URL(url).openStream()) // Does its own buffering
} else {

View file

@ -2,8 +2,7 @@ package spark.broadcast
import java.io._
import java.net._
import java.util.{UUID, Random}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import java.util.Random
import scala.collection.mutable.Map
@ -18,7 +17,7 @@ extends Logging {
val FIND_BROADCAST_TRACKER = 2
// Map to keep track of guides of ongoing broadcasts
var valueToGuideMap = Map[UUID, SourceInfo]()
var valueToGuideMap = Map[Long, SourceInfo]()
// Random number generator
var ranGen = new Random
@ -154,44 +153,44 @@ extends Logging {
val messageType = ois.readObject.asInstanceOf[Int]
if (messageType == REGISTER_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Receive hostAddress and listenPort
val gInfo = ois.readObject.asInstanceOf[SourceInfo]
// Add to the map
valueToGuideMap.synchronized {
valueToGuideMap += (uuid -> gInfo)
valueToGuideMap += (id -> gInfo)
}
logInfo ("New broadcast " + uuid + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Remove from the map
valueToGuideMap.synchronized {
valueToGuideMap(uuid) = SourceInfo("", SourceInfo.TxOverGoToDefault)
valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
}
logInfo ("Broadcast " + uuid + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == FIND_BROADCAST_TRACKER) {
// Receive UUID
val uuid = ois.readObject.asInstanceOf[UUID]
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
var gInfo =
if (valueToGuideMap.contains(uuid)) valueToGuideMap(uuid)
if (valueToGuideMap.contains(id)) valueToGuideMap(id)
else SourceInfo("", SourceInfo.TxNotStartedRetry)
logDebug("Got new request: " + clientSocket + " for " + uuid + " : " + gInfo.listenPort)
logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
// Send reply back
oos.writeObject(gInfo)
@ -224,12 +223,12 @@ extends Logging {
}
}
def getGuideInfo(variableUUID: UUID): SourceInfo = {
def getGuideInfo(variableLong: Long): SourceInfo = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxOverGoToDefault)
var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
var retriesLeft = MultiTracker.MaxRetryCount
do {
@ -247,8 +246,8 @@ extends Logging {
oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
oosTracker.flush()
// Send UUID and receive GuideInfo
oosTracker.writeObject(variableUUID)
// Send Long and receive GuideInfo
oosTracker.writeObject(variableLong)
oosTracker.flush()
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
} catch {
@ -276,7 +275,7 @@ extends Logging {
return gInfo
}
def registerBroadcast(uuid: UUID, gInfo: SourceInfo) {
def registerBroadcast(id: Long, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
@ -286,8 +285,8 @@ extends Logging {
oosST.writeObject(REGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send UUID of this broadcast
oosST.writeObject(uuid)
// Send Long of this broadcast
oosST.writeObject(id)
oosST.flush()
// Send this tracker's information
@ -303,7 +302,7 @@ extends Logging {
socket.close()
}
def unregisterBroadcast(uuid: UUID) {
def unregisterBroadcast(id: Long) {
val socket = new Socket(MultiTracker.MasterHostAddress, MasterTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
@ -313,8 +312,8 @@ extends Logging {
oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send UUID of this broadcast
oosST.writeObject(uuid)
// Send Long of this broadcast
oosST.writeObject(id)
oosST.flush()
// Receive ACK and throw it away
@ -383,10 +382,10 @@ extends Logging {
}
}
case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
extends Serializable
case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {

View file

@ -7,7 +7,7 @@ import spark._
/**
* Used to keep and pass around information of peers involved in a broadcast
*/
case class SourceInfo (hostAddress: String,
private[spark] case class SourceInfo (hostAddress: String,
listenPort: Int,
totalBlocks: Int = SourceInfo.UnusedParam,
totalBytes: Int = SourceInfo.UnusedParam)
@ -26,10 +26,11 @@ extends Comparable[SourceInfo] with Logging {
/**
* Helper Object of SourceInfo for its constants
*/
object SourceInfo {
// Constants for special values of listenPort
private[spark] object SourceInfo {
// Broadcast has not started yet! Should never happen.
val TxNotStartedRetry = -1
val TxOverGoToDefault = 0
// Broadcast has already finished. Try default mechanism.
val TxOverGoToDefault = -3
// Other constants
val StopBroadcast = -2
val UnusedParam = 0

View file

@ -10,14 +10,15 @@ import scala.math
import spark._
import spark.storage.StorageLevel
class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean)
extends Broadcast[T] with Logging with Serializable {
private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId = "broadcast_" + id
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@ -35,7 +36,7 @@ extends Broadcast[T] with Logging with Serializable {
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var hostAddress = Utils.localIpAddress()
@transient var listenPort = -1
@transient var guidePort = -1
@ -43,7 +44,7 @@ extends Broadcast[T] with Logging with Serializable {
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast
sendBroadcast()
}
def sendBroadcast() {
@ -84,20 +85,22 @@ extends Broadcast[T] with Logging with Serializable {
listOfSources += masterSource
// Register with the Tracker
MultiTracker.registerBroadcast(uuid,
MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(uuid.toString) match {
case Some(x) => x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + uuid)
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Master will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables
initializeWorkerVariables()
logInfo("Local host address: " + hostAddress)
@ -108,18 +111,17 @@ extends Broadcast[T] with Logging with Serializable {
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(uuid)
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(
uuid.toString, value_, StorageLevel.MEMORY_ONLY_DESER, false)
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
logError("Reading Broadcasted variable " + uuid + " failed")
logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading Broadcasted variable " + uuid + " took " + time + " s")
}
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
@ -136,14 +138,14 @@ extends Broadcast[T] with Logging with Serializable {
serveMR = null
hostAddress = Utils.localIpAddress
hostAddress = Utils.localIpAddress()
listenPort = -1
stopBroadcast = false
}
def receiveBroadcast(variableUUID: UUID): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableUUID)
def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
@ -290,15 +292,17 @@ extends Broadcast[T] with Logging with Serializable {
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
logError("GuideMultipleRequests Timeout.")
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done.
// Comparing with listOfSources.size - 1, because the Guide itself
// is included
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
listOfSources.synchronized {
setOfCompletedSources.synchronized {
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
}
}
}
}
}
@ -316,7 +320,7 @@ extends Broadcast[T] with Logging with Serializable {
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
MultiTracker.unregisterBroadcast(uuid)
MultiTracker.unregisterBroadcast(id)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
@ -490,7 +494,7 @@ extends Broadcast[T] with Logging with Serializable {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => logError("ServeMultipleRequests Timeout.")
case e: Exception => { }
}
if (clientSocket != null) {
@ -570,9 +574,12 @@ extends Broadcast[T] with Logging with Serializable {
}
}
class TreeBroadcastFactory
private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isMaster: Boolean) = MultiTracker.initialize(isMaster)
def newBroadcast[T](value_ : T, isLocal: Boolean) = new TreeBroadcast[T](value_, isLocal)
def stop() = MultiTracker.stop
def initialize(isMaster: Boolean) { MultiTracker.initialize(isMaster) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
}

View file

@ -2,7 +2,7 @@ package spark.deploy
import scala.collection.Map
case class Command(
private[spark] case class Command(
mainClass: String,
arguments: Seq[String],
environment: Map[String, String]) {

View file

@ -7,13 +7,15 @@ import scala.collection.immutable.List
import scala.collection.mutable.HashMap
sealed trait DeployMessage extends Serializable
private[spark] sealed trait DeployMessage extends Serializable
// Worker to Master
private[spark]
case class RegisterWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int)
extends DeployMessage
private[spark]
case class ExecutorStateChanged(
jobId: String,
execId: Int,
@ -23,11 +25,11 @@ case class ExecutorStateChanged(
// Master to Worker
case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
case class RegisterWorkerFailed(message: String) extends DeployMessage
case class KillExecutor(jobId: String, execId: Int) extends DeployMessage
private[spark] case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
private[spark] case class RegisterWorkerFailed(message: String) extends DeployMessage
private[spark] case class KillExecutor(jobId: String, execId: Int) extends DeployMessage
case class LaunchExecutor(
private[spark] case class LaunchExecutor(
jobId: String,
execId: Int,
jobDesc: JobDescription,
@ -38,33 +40,42 @@ case class LaunchExecutor(
// Client to Master
case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
private[spark] case class RegisterJob(jobDescription: JobDescription) extends DeployMessage
// Master to Client
private[spark]
case class RegisteredJob(jobId: String) extends DeployMessage
private[spark]
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
private[spark]
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String])
private[spark]
case class JobKilled(message: String)
// Internal message in Client
case object StopClient
private[spark] case object StopClient
// MasterWebUI To Master
case object RequestMasterState
private[spark] case object RequestMasterState
// Master to MasterWebUI
private[spark]
case class MasterState(uri : String, workers: List[WorkerInfo], activeJobs: List[JobInfo],
completedJobs: List[JobInfo])
// WorkerWebUI to Worker
case object RequestWorkerState
private[spark] case object RequestWorkerState
// Worker to WorkerWebUI
private[spark]
case class WorkerState(uri: String, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)

View file

@ -1,6 +1,6 @@
package spark.deploy
object ExecutorState
private[spark] object ExecutorState
extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") {
val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value

View file

@ -1,6 +1,6 @@
package spark.deploy
class JobDescription(
private[spark] class JobDescription(
val name: String,
val cores: Int,
val memoryPerSlave: Int,

View file

@ -0,0 +1,58 @@
package spark.deploy
import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
import spark.deploy.worker.Worker
import spark.deploy.master.Master
import spark.util.AkkaUtils
import spark.{Logging, Utils}
import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numSlaves: Int, coresPerSlave: Int, memoryPerSlave: Int) extends Logging {
val localIpAddress = Utils.localIpAddress
var masterActor : ActorRef = _
var masterActorSystem : ActorSystem = _
var masterPort : Int = _
var masterUrl : String = _
val slaveActorSystems = ArrayBuffer[ActorSystem]()
val slaveActors = ArrayBuffer[ActorRef]()
def start() : String = {
logInfo("Starting a local Spark cluster with " + numSlaves + " slaves.")
/* Start the Master */
val (actorSystem, masterPort) = AkkaUtils.createActorSystem("sparkMaster", localIpAddress, 0)
masterActorSystem = actorSystem
masterUrl = "spark://" + localIpAddress + ":" + masterPort
val actor = masterActorSystem.actorOf(
Props(new Master(localIpAddress, masterPort, 0)), name = "Master")
masterActor = actor
/* Start the Slaves */
for (slaveNum <- 1 to numSlaves) {
val (actorSystem, boundPort) =
AkkaUtils.createActorSystem("sparkWorker" + slaveNum, localIpAddress, 0)
slaveActorSystems += actorSystem
val actor = actorSystem.actorOf(
Props(new Worker(localIpAddress, boundPort, 0, coresPerSlave, memoryPerSlave, masterUrl)),
name = "Worker")
slaveActors += actor
}
return masterUrl
}
def stop() {
logInfo("Shutting down local Spark cluster.")
// Stop the slaves before the master so they don't get upset that it disconnected
slaveActorSystems.foreach(_.shutdown())
slaveActorSystems.foreach(_.awaitTermination())
masterActorSystem.shutdown()
masterActorSystem.awaitTermination()
}
}

View file

@ -4,6 +4,7 @@ import spark.deploy._
import akka.actor._
import akka.pattern.ask
import akka.util.duration._
import akka.pattern.AskTimeoutException
import spark.{SparkException, Logging}
import akka.remote.RemoteClientLifeCycleEvent
import akka.remote.RemoteClientShutdown
@ -16,7 +17,7 @@ import akka.dispatch.Await
* The main class used to talk to a Spark deploy cluster. Takes a master URL, a job description,
* and a listener for job events, and calls back the listener when various events occur.
*/
class Client(
private[spark] class Client(
actorSystem: ActorSystem,
masterUrl: String,
jobDescription: JobDescription,
@ -42,7 +43,6 @@ class Client(
val akkaUrl = "akka://spark@%s:%s/user/Master".format(masterHost, masterPort)
try {
master = context.actorFor(akkaUrl)
//master ! RegisterWorker(ip, port, cores, memory)
master ! RegisterJob(jobDescription)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
@ -101,9 +101,13 @@ class Client(
def stop() {
if (actor != null) {
val timeout = 1.seconds
val future = actor.ask(StopClient)(timeout)
Await.result(future, timeout)
try {
val timeout = 1.seconds
val future = actor.ask(StopClient)(timeout)
Await.result(future, timeout)
} catch {
case e: AskTimeoutException => // Ignore it, maybe master went away
}
actor = null
}
}

View file

@ -7,7 +7,7 @@ package spark.deploy.client
*
* Users of this API should *not* block inside the callback methods.
*/
trait ClientListener {
private[spark] trait ClientListener {
def connected(jobId: String): Unit
def disconnected(): Unit

View file

@ -4,7 +4,7 @@ import spark.util.AkkaUtils
import spark.{Logging, Utils}
import spark.deploy.{Command, JobDescription}
object TestClient {
private[spark] object TestClient {
class TestListener extends ClientListener with Logging {
def connected(id: String) {

View file

@ -1,6 +1,6 @@
package spark.deploy.client
object TestExecutor {
private[spark] object TestExecutor {
def main(args: Array[String]) {
println("Hello world!")
while (true) {

View file

@ -2,7 +2,7 @@ package spark.deploy.master
import spark.deploy.ExecutorState
class ExecutorInfo(
private[spark] class ExecutorInfo(
val id: Int,
val job: JobInfo,
val worker: WorkerInfo,

View file

@ -5,6 +5,7 @@ import java.util.Date
import akka.actor.ActorRef
import scala.collection.mutable
private[spark]
class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, val actor: ActorRef) {
var state = JobState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
@ -31,4 +32,13 @@ class JobInfo(val id: String, val desc: JobDescription, val submitDate: Date, va
}
def coresLeft: Int = desc.cores - coresGranted
private var _retryCount = 0
def retryCount = _retryCount
def incrementRetryCount = {
_retryCount += 1
_retryCount
}
}

View file

@ -1,7 +1,9 @@
package spark.deploy.master
object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
private[spark] object JobState extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED") {
type JobState = Value
val WAITING, RUNNING, FINISHED, FAILED = Value
val MAX_NUM_RETRY = 10
}

View file

@ -1,21 +1,20 @@
package spark.deploy.master
import akka.actor._
import akka.actor.Terminated
import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
import java.text.SimpleDateFormat
import java.util.Date
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import akka.actor._
import spark.{Logging, Utils}
import spark.util.AkkaUtils
import java.text.SimpleDateFormat
import java.util.Date
import akka.remote.RemoteClientLifeCycleEvent
import spark.deploy._
import akka.remote.RemoteClientShutdown
import akka.remote.RemoteClientDisconnected
import spark.deploy.RegisterWorker
import spark.deploy.RegisterWorkerFailed
import akka.actor.Terminated
import spark.{Logging, SparkException, Utils}
import spark.util.AkkaUtils
class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For job IDs
var nextJobNumber = 0
@ -81,12 +80,22 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
exec.state = state
exec.job.actor ! ExecutorUpdated(execId, state, message)
if (ExecutorState.isFinished(state)) {
val jobInfo = idToJob(jobId)
// Remove this executor from the worker and job
logInfo("Removing executor " + exec.fullId + " because it is " + state)
idToJob(jobId).removeExecutor(exec)
jobInfo.removeExecutor(exec)
exec.worker.removeExecutor(exec)
// TODO: the worker would probably want to restart the executor a few times
schedule()
// Only retry certain number of times so we don't go into an infinite loop.
if (jobInfo.incrementRetryCount <= JobState.MAX_NUM_RETRY) {
schedule()
} else {
val e = new SparkException("Job %s wth ID %s failed %d times.".format(
jobInfo.desc.name, jobInfo.id, jobInfo.retryCount))
logError(e.getMessage, e)
throw e
//System.exit(1)
}
}
}
case None =>
@ -112,7 +121,7 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
addressToWorker.get(address).foreach(removeWorker)
addressToJob.get(address).foreach(removeJob)
}
case RequestMasterState => {
sender ! MasterState(ip + ":" + port, workers.toList, jobs.toList, completedJobs.toList)
}
@ -203,7 +212,7 @@ class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
}
}
object Master {
private[spark] object Master {
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)

View file

@ -6,7 +6,7 @@ import spark.Utils
/**
* Command-line parser for the master.
*/
class MasterArguments(args: Array[String]) {
private[spark] class MasterArguments(args: Array[String]) {
var ip = Utils.localIpAddress()
var port = 7077
var webUiPort = 8080
@ -51,7 +51,7 @@ class MasterArguments(args: Array[String]) {
*/
def printUsageAndExit(exitCode: Int) {
System.err.println(
"Usage: spark-master [options]\n" +
"Usage: Master [options]\n" +
"\n" +
"Options:\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +

View file

@ -10,6 +10,7 @@ import cc.spray.directives._
import cc.spray.typeconversion.TwirlSupport._
import spark.deploy._
private[spark]
class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
@ -22,7 +23,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
completeWith {
val future = master ? RequestMasterState
future.map {
masterState => masterui.html.index.render(masterState.asInstanceOf[MasterState])
masterState => spark.deploy.master.html.index.render(masterState.asInstanceOf[MasterState])
}
}
} ~
@ -36,7 +37,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
// A bit ugly an inefficient, but we won't have a number of jobs
// so large that it will make a significant difference.
(masterState.activeJobs ::: masterState.completedJobs).find(_.id == jobId) match {
case Some(job) => masterui.html.job_details.render(job)
case Some(job) => spark.deploy.master.html.job_details.render(job)
case _ => null
}
}

View file

@ -3,7 +3,7 @@ package spark.deploy.master
import akka.actor.ActorRef
import scala.collection.mutable
class WorkerInfo(
private[spark] class WorkerInfo(
val id: String,
val host: String,
val port: Int,

View file

@ -13,7 +13,7 @@ import spark.deploy.ExecutorStateChanged
/**
* Manages the execution of one executor process.
*/
class ExecutorRunner(
private[spark] class ExecutorRunner(
val jobId: String,
val execId: Int,
val jobDesc: JobDescription,
@ -29,12 +29,25 @@ class ExecutorRunner(
val fullId = jobId + "/" + execId
var workerThread: Thread = null
var process: Process = null
var shutdownHook: Thread = null
def start() {
workerThread = new Thread("ExecutorRunner for " + fullId) {
override def run() { fetchAndRunExecutor() }
}
workerThread.start()
// Shutdown hook that kills actors on shutdown.
shutdownHook = new Thread() {
override def run() {
if (process != null) {
logInfo("Shutdown hook killing child process.")
process.destroy()
process.waitFor()
}
}
}
Runtime.getRuntime.addShutdownHook(shutdownHook)
}
/** Stop this executor runner, including killing the process it launched */
@ -45,40 +58,10 @@ class ExecutorRunner(
if (process != null) {
logInfo("Killing process!")
process.destroy()
process.waitFor()
}
worker ! ExecutorStateChanged(jobId, execId, ExecutorState.KILLED, None)
}
}
/**
* Download a file requested by the executor. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
*/
def fetchFile(url: String, targetDir: File) {
val filename = url.split("/").last
val targetFile = new File(targetDir, filename)
if (url.startsWith("http://") || url.startsWith("https://") || url.startsWith("ftp://")) {
// Use the java.net library to fetch it
logInfo("Fetching " + url + " to " + targetFile)
val in = new URL(url).openStream()
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
} else {
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(targetFile)
Utils.copyStream(in, out, true)
}
// Decompress the file if it's a .tar or .tar.gz
if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xzf", filename), targetDir)
} else if (filename.endsWith(".tar")) {
logInfo("Untarring " + filename)
Utils.execute(Seq("tar", "-xf", filename), targetDir)
Runtime.getRuntime.removeShutdownHook(shutdownHook)
}
}
@ -92,7 +75,8 @@ class ExecutorRunner(
def buildCommandSeq(): Seq[String] = {
val command = jobDesc.command
val runScript = new File(sparkHome, "run").getCanonicalPath
val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run";
val runScript = new File(sparkHome, script).getCanonicalPath
Seq(runScript, command.mainClass) ++ command.arguments.map(substituteVariables)
}
@ -101,7 +85,12 @@ class ExecutorRunner(
val out = new FileOutputStream(file)
new Thread("redirect output to " + file) {
override def run() {
Utils.copyStream(in, out, true)
try {
Utils.copyStream(in, out, true)
} catch {
case e: IOException =>
logInfo("Redirection to " + file + " closed: " + e.getMessage)
}
}
}.start()
}
@ -131,6 +120,9 @@ class ExecutorRunner(
}
env.put("SPARK_CORES", cores.toString)
env.put("SPARK_MEMORY", memory.toString)
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0")
process = builder.start()
// Redirect its stdout and stderr to files

View file

@ -16,7 +16,14 @@ import spark.deploy.RegisterWorkerFailed
import akka.actor.Terminated
import java.io.File
class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, masterUrl: String)
private[spark] class Worker(
ip: String,
port: Int,
webUiPort: Int,
cores: Int,
memory: Int,
masterUrl: String,
workDirPath: String = null)
extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
@ -37,7 +44,11 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
def memoryFree: Int = memory - memoryUsed
def createWorkDir() {
workDir = new File(sparkHome, "work")
workDir = if (workDirPath != null) {
new File(workDirPath)
} else {
new File(sparkHome, "work")
}
try {
if (!workDir.exists() && !workDir.mkdirs()) {
logError("Failed to create work directory " + workDir)
@ -153,14 +164,19 @@ class Worker(ip: String, port: Int, webUiPort: Int, cores: Int, memory: Int, mas
def generateWorkerId(): String = {
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
}
override def postStop() {
executors.values.foreach(_.kill())
}
}
object Worker {
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", args.ip, args.port)
val actor = actorSystem.actorOf(
Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory, args.master)),
Props(new Worker(args.ip, boundPort, args.webUiPort, args.cores, args.memory,
args.master, args.workDir)),
name = "Worker")
actorSystem.awaitTermination()
}

View file

@ -8,13 +8,14 @@ import java.lang.management.ManagementFactory
/**
* Command-line parser for the master.
*/
class WorkerArguments(args: Array[String]) {
private[spark] class WorkerArguments(args: Array[String]) {
var ip = Utils.localIpAddress()
var port = 0
var webUiPort = 8081
var cores = inferDefaultCores()
var memory = inferDefaultMemory()
var master: String = null
var workDir: String = null
// Check for settings in environment variables
if (System.getenv("SPARK_WORKER_PORT") != null) {
@ -29,6 +30,9 @@ class WorkerArguments(args: Array[String]) {
if (System.getenv("SPARK_WORKER_WEBUI_PORT") != null) {
webUiPort = System.getenv("SPARK_WORKER_WEBUI_PORT").toInt
}
if (System.getenv("SPARK_WORKER_DIR") != null) {
workDir = System.getenv("SPARK_WORKER_DIR")
}
parse(args.toList)
@ -49,6 +53,10 @@ class WorkerArguments(args: Array[String]) {
memory = value
parse(tail)
case ("--work-dir" | "-d") :: value :: tail =>
workDir = value
parse(tail)
case "--webui-port" :: IntParam(value) :: tail =>
webUiPort = value
parse(tail)
@ -77,13 +85,14 @@ class WorkerArguments(args: Array[String]) {
*/
def printUsageAndExit(exitCode: Int) {
System.err.println(
"Usage: spark-worker [options] <master>\n" +
"Usage: Worker [options] <master>\n" +
"\n" +
"Master must be a URL of the form spark://hostname:port\n" +
"\n" +
"Options:\n" +
" -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
" -d DIR, --work-dir DIR Directory to run jobs in (default: SPARK_HOME/work)\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)")

View file

@ -9,6 +9,7 @@ import cc.spray.Directives
import cc.spray.typeconversion.TwirlSupport._
import spark.deploy.{WorkerState, RequestWorkerState}
private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
@ -21,7 +22,7 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
completeWith{
val future = worker ? RequestWorkerState
future.map { workerState =>
workerui.html.index(workerState.asInstanceOf[WorkerState])
spark.deploy.worker.html.index(workerState.asInstanceOf[WorkerState])
}
}
} ~

View file

@ -1,10 +1,12 @@
package spark.executor
import java.io.{File, FileOutputStream}
import java.net.{URL, URLClassLoader}
import java.net.{URI, URL, URLClassLoader}
import java.util.concurrent._
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.FileUtil
import scala.collection.mutable.{ArrayBuffer, Map, HashMap}
import spark.broadcast._
import spark.scheduler._
@ -14,11 +16,16 @@ import java.nio.ByteBuffer
/**
* The Mesos executor for Spark.
*/
class Executor extends Logging {
var classLoader: ClassLoader = null
private[spark] class Executor extends Logging {
var urlClassLoader : ExecutorURLClassLoader = null
var threadPool: ExecutorService = null
var env: SparkEnv = null
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
initLogging()
@ -32,14 +39,14 @@ class Executor extends Logging {
System.setProperty(key, value)
}
// Create our ClassLoader and set it on this thread
urlClassLoader = createClassLoader()
Thread.currentThread.setContextClassLoader(urlClassLoader)
// Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(slaveHostname, 0, false, false)
SparkEnv.set(env)
// Create our ClassLoader (using spark properties) and set it on this thread
classLoader = createClassLoader()
Thread.currentThread.setContextClassLoader(classLoader)
// Start worker thread pool
threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
@ -54,15 +61,16 @@ class Executor extends Logging {
override def run() {
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
Thread.currentThread.setContextClassLoader(urlClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
try {
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(classLoader)
Accumulators.clear()
val task = ser.deserialize[Task[Any]](serializedTask, classLoader)
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
val value = task.run(taskId.toInt)
@ -96,25 +104,15 @@ class Executor extends Logging {
* Create a ClassLoader for use in tasks, adding any JARs specified by the user or any classes
* created by the interpreter to the search path
*/
private def createClassLoader(): ClassLoader = {
private def createClassLoader(): ExecutorURLClassLoader = {
var loader = this.getClass.getClassLoader
// If any JAR URIs are given through spark.jar.uris, fetch them to the
// current directory and put them all on the classpath. We assume that
// each URL has a unique file name so that no local filenames will clash
// in this process. This is guaranteed by ClusterScheduler.
val uris = System.getProperty("spark.jar.uris", "")
val localFiles = ArrayBuffer[String]()
for (uri <- uris.split(",").filter(_.size > 0)) {
val url = new URL(uri)
val filename = url.getPath.split("/").last
downloadFile(url, filename)
localFiles += filename
}
if (localFiles.size > 0) {
val urls = localFiles.map(f => new File(f).toURI.toURL).toArray
loader = new URLClassLoader(urls, loader)
}
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
val urls = currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
loader = new URLClassLoader(urls, loader)
// If the REPL is in use, add another ClassLoader that will read
// new classes defined by the REPL as the user types code
@ -133,13 +131,31 @@ class Executor extends Logging {
}
}
return loader
return new ExecutorURLClassLoader(Array(), loader)
}
// Download a file from a given URL to the local filesystem
private def downloadFile(url: URL, localPath: String) {
val in = url.openStream()
val out = new FileOutputStream(localPath)
Utils.copyStream(in, out, true)
/**
* Download any missing dependencies if we receive a new set of files and JARs from the
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File("."))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File("."))
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
val url = new File(".", localName).toURI.toURL
if (!urlClassLoader.getURLs.contains(url)) {
logInfo("Adding " + url + " to class loader")
urlClassLoader.addURL(url)
}
}
}
}

View file

@ -6,6 +6,6 @@ import spark.TaskState.TaskState
/**
* A pluggable interface used by the Executor to send updates to the cluster scheduler.
*/
trait ExecutorBackend {
private[spark] trait ExecutorBackend {
def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer)
}

View file

@ -0,0 +1,14 @@
package spark.executor
import java.net.{URLClassLoader, URL}
/**
* The addURL method in URLClassLoader is protected. We subclass it to make this accessible.
*/
private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
extends URLClassLoader(urls, parent) {
override def addURL(url: URL) {
super.addURL(url)
}
}

View file

@ -8,7 +8,7 @@ import com.google.protobuf.ByteString
import spark.{Utils, Logging}
import spark.TaskState
class MesosExecutorBackend(executor: Executor)
private[spark] class MesosExecutorBackend(executor: Executor)
extends MesosExecutor
with ExecutorBackend
with Logging {
@ -59,7 +59,7 @@ class MesosExecutorBackend(executor: Executor)
/**
* Entry point for Mesos executor.
*/
object MesosExecutorBackend {
private[spark] object MesosExecutorBackend {
def main(args: Array[String]) {
MesosNativeLibrary.load()
// Create a new Executor and start it running

View file

@ -14,7 +14,7 @@ import spark.scheduler.cluster.RegisterSlaveFailed
import spark.scheduler.cluster.RegisterSlave
class StandaloneExecutorBackend(
private[spark] class StandaloneExecutorBackend(
executor: Executor,
masterUrl: String,
slaveId: String,
@ -62,7 +62,7 @@ class StandaloneExecutorBackend(
}
}
object StandaloneExecutorBackend {
private[spark] object StandaloneExecutorBackend {
def run(masterUrl: String, slaveId: String, hostname: String, cores: Int) {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc

View file

@ -11,6 +11,7 @@ import java.nio.channels.spi._
import java.net._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector) extends Logging {
channel.configureBlocking(false)
@ -23,8 +24,8 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
var onExceptionCallback: (Connection, Exception) => Unit = null
var onKeyInterestChangeCallback: (Connection, Int) => Unit = null
lazy val remoteAddress = getRemoteAddress()
lazy val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
val remoteAddress = getRemoteAddress()
val remoteConnectionManagerId = ConnectionManagerId.fromSocketAddress(remoteAddress)
def key() = channel.keyFor(selector)
@ -39,7 +40,10 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
}
def close() {
key.cancel()
val k = key()
if (k != null) {
k.cancel()
}
channel.close()
callOnCloseCallback()
}
@ -99,7 +103,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector) ex
}
class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector)
extends Connection(SocketChannel.open, selector_) {
class Outbox(fair: Int = 0) {
@ -134,9 +138,12 @@ extends Connection(SocketChannel.open, selector_) {
if (!message.started) logDebug("Starting to send [" + message + "]")
message.started = true
return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
"] in " + message.timeTaken )
}
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
}
}
None
@ -159,10 +166,11 @@ extends Connection(SocketChannel.open, selector_) {
}
logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
return chunk
}
/*messages -= message*/
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "] in " + message.timeTaken )
} else {
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
"] in " + message.timeTaken )
}
}
}
None
@ -216,7 +224,7 @@ extends Connection(SocketChannel.open, selector_) {
while(true) {
if (currentBuffers.size == 0) {
outbox.synchronized {
outbox.getChunk match {
outbox.getChunk() match {
case Some(chunk) => {
currentBuffers ++= chunk.buffers
}
@ -252,7 +260,7 @@ extends Connection(SocketChannel.open, selector_) {
}
class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
extends Connection(channel_, selector_) {
class Inbox() {

View file

@ -18,17 +18,17 @@ import akka.dispatch.{Await, Promise, ExecutionContext, Future}
import akka.util.Duration
import akka.util.duration._
case class ConnectionManagerId(host: String, port: Int) {
private[spark] case class ConnectionManagerId(host: String, port: Int) {
def toSocketAddress() = new InetSocketAddress(host, port)
}
object ConnectionManagerId {
private[spark] object ConnectionManagerId {
def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
}
}
class ConnectionManager(port: Int) extends Logging {
private[spark] class ConnectionManager(port: Int) extends Logging {
class MessageStatus(
val message: Message,
@ -113,7 +113,7 @@ class ConnectionManager(port: Int) extends Logging {
val selectedKeysCount = selector.select()
if (selectedKeysCount == 0) {
logInfo("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
}
if (selectorThread.isInterrupted) {
logInfo("Selector thread was interrupted!")
@ -167,7 +167,6 @@ class ConnectionManager(port: Int) extends Logging {
}
def removeConnection(connection: Connection) {
/*logInfo("Removing connection")*/
connectionsByKey -= connection.key
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
@ -235,7 +234,7 @@ class ConnectionManager(port: Int) extends Logging {
def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
logInfo("Received [" + message + "] from [" + connectionManagerId + "]")
logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
@ -276,15 +275,15 @@ class ConnectionManager(port: Int) extends Logging {
logDebug("Calling back")
onReceiveCallback(bufferMessage, connectionManagerId)
} else {
logWarning("Not calling back as callback is null")
logDebug("Not calling back as callback is null")
None
}
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logWarning("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logWarning("Response to " + bufferMessage + " does not have ack id set")
logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
}
}
@ -349,7 +348,7 @@ class ConnectionManager(port: Int) extends Logging {
}
object ConnectionManager {
private[spark] object ConnectionManager {
def main(args: Array[String]) {

View file

@ -11,7 +11,7 @@ import java.net.InetAddress
import akka.dispatch.Await
import akka.util.duration._
object ConnectionManagerTest extends Logging{
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
if (args.length < 2) {
println("Usage: ConnectionManagerTest <mesos cluster> <slaves file>")

View file

@ -7,8 +7,9 @@ import scala.collection.mutable.ArrayBuffer
import java.nio.ByteBuffer
import java.net.InetAddress
import java.net.InetSocketAddress
import storage.BlockManager
class MessageChunkHeader(
private[spark] class MessageChunkHeader(
val typ: Long,
val id: Int,
val totalSize: Int,
@ -36,7 +37,7 @@ class MessageChunkHeader(
" and sizes " + totalSize + " / " + chunkSize + " bytes"
}
class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
val size = if (buffer == null) 0 else buffer.remaining
lazy val buffers = {
val ab = new ArrayBuffer[ByteBuffer]()
@ -50,7 +51,7 @@ class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
}
abstract class Message(val typ: Long, val id: Int) {
private[spark] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
var started = false
var startTime = -1L
@ -64,10 +65,10 @@ abstract class Message(val typ: Long, val id: Int) {
def timeTaken(): String = (finishTime - startTime).toString + " ms"
override def toString = "" + this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
}
class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
extends Message(Message.BUFFER_MESSAGE, id_) {
val initialSize = currentSize()
@ -97,10 +98,11 @@ extends Message(Message.BUFFER_MESSAGE, id_) {
while(!buffers.isEmpty) {
val buffer = buffers(0)
if (buffer.remaining == 0) {
BlockManager.dispose(buffer)
buffers -= buffer
} else {
val newBuffer = if (buffer.remaining <= maxChunkSize) {
buffer.duplicate
buffer.duplicate()
} else {
buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
}
@ -147,11 +149,10 @@ extends Message(Message.BUFFER_MESSAGE, id_) {
} else {
"BufferMessage(id = " + id + ", size = " + size + ")"
}
}
}
object MessageChunkHeader {
private[spark] object MessageChunkHeader {
val HEADER_SIZE = 40
def create(buffer: ByteBuffer): MessageChunkHeader = {
@ -172,7 +173,7 @@ object MessageChunkHeader {
}
}
object Message {
private[spark] object Message {
val BUFFER_MESSAGE = 1111111111L
var lastId = 1

View file

@ -3,7 +3,7 @@ package spark.network
import java.nio.ByteBuffer
import java.net.InetAddress
object ReceiverTest {
private[spark] object ReceiverTest {
def main(args: Array[String]) {
val manager = new ConnectionManager(9999)

Some files were not shown because too many files have changed in this diff Show more