Merge branch 'master' into streaming

Conflicts:
	.gitignore
This commit is contained in:
Tathagata Das 2013-06-24 23:57:47 -07:00
commit c89af0a7f9
289 changed files with 13250 additions and 3293 deletions

1
.gitignore vendored
View file

@ -37,3 +37,4 @@ dependency-reduced-pom.xml
.ensime
.ensime_lucene
checkpoint
derby.log

View file

@ -12,11 +12,16 @@ This README file only contains basic setup instructions.
## Building
Spark requires Scala 2.9.2. The project is built using Simple Build Tool (SBT),
which is packaged with it. To build Spark and its example programs, run:
Spark requires Scala 2.9.2 (Scala 2.10 is not yet supported). The project is
built using Simple Build Tool (SBT), which is packaged with it. To build
Spark and its example programs, run:
sbt/sbt package
Spark also supports building using Maven. If you would like to build using Maven,
see the [instructions for building Spark with Maven](http://spark-project.org/docs/latest/building-with-maven.html)
in the spark documentation..
To run Spark, you will need to have Scala's bin directory in your `PATH`, or
you will need to set the `SCALA_HOME` environment variable to point to where
you've installed Scala. Scala must be accessible through one of these

View file

@ -3,8 +3,8 @@
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.spark-project</groupId>
<artifactId>parent</artifactId>
<version>0.7.0-SNAPSHOT</version>
<artifactId>spark-parent</artifactId>
<version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@ -102,5 +102,42 @@
</plugins>
</build>
</profile>
<profile>
<id>hadoop2-yarn</id>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
<artifactId>spark-core</artifactId>
<version>${project.version}</version>
<classifier>hadoop2-yarn</classifier>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-yarn-api</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-yarn-common</artifactId>
<scope>provided</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<classifier>hadoop2-yarn</classifier>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>

View file

@ -4,8 +4,37 @@ import spark._
import spark.SparkContext._
import scala.collection.mutable.ArrayBuffer
import storage.StorageLevel
object Bagel extends Logging {
val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK
/**
* Runs a Bagel program.
* @param sc [[spark.SparkContext]] to use for the program.
* @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the Key will be
* the vertex id.
* @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often this will be an
* empty array, i.e. sc.parallelize(Array[K, Message]()).
* @param combiner [[spark.bagel.Combiner]] combines multiple individual messages to a given vertex into one
* message before sending (which often involves network I/O).
* @param aggregator [[spark.bagel.Aggregator]] performs a reduce across all vertices after each superstep,
* and provides the result to each vertex in the next superstep.
* @param partitioner [[spark.Partitioner]] partitions values by key
* @param numPartitions number of partitions across which to split the graph.
* Default is the default parallelism of the SparkContext
* @param storageLevel [[spark.storage.StorageLevel]] to use for caching of intermediate RDDs in each superstep.
* Defaults to caching in memory.
* @param compute function that takes a Vertex, optional set of (possibly combined) messages to the Vertex,
* optional Aggregator and the current superstep,
* and returns a set of (Vertex, outgoing Messages) pairs
* @tparam K key
* @tparam V vertex type
* @tparam M message type
* @tparam C combiner
* @tparam A aggregator
* @return an RDD of (K, V) pairs representing the graph after completion of the program
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
C: Manifest, A: Manifest](
sc: SparkContext,
@ -14,7 +43,8 @@ object Bagel extends Logging {
combiner: Combiner[M, C],
aggregator: Option[Aggregator[V, A]],
partitioner: Partitioner,
numPartitions: Int
numPartitions: Int,
storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL
)(
compute: (V, Option[C], Option[A], Int) => (V, Array[M])
): RDD[(K, V)] = {
@ -32,8 +62,9 @@ object Bagel extends Logging {
val combinedMsgs = msgs.combineByKey(
combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner)
val grouped = combinedMsgs.groupWith(verts)
val superstep_ = superstep // Create a read-only copy of superstep for capture in closure
val (processed, numMsgs, numActiveVerts) =
comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
@ -50,6 +81,7 @@ object Bagel extends Logging {
verts
}
/** Runs a Bagel program with no [[spark.bagel.Aggregator]] and the default storage level */
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
@ -59,12 +91,29 @@ object Bagel extends Logging {
numPartitions: Int
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
/** Runs a Bagel program with no [[spark.bagel.Aggregator]] */
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
partitioner: Partitioner,
numPartitions: Int,
storageLevel: StorageLevel
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = {
run[K, V, M, C, Nothing](
sc, vertices, messages, combiner, None, partitioner, numPartitions)(
sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)(
addAggregatorArg[K, V, M, C](compute))
}
/**
* Runs a Bagel program with no [[spark.bagel.Aggregator]], default [[spark.HashPartitioner]]
* and default storage level
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
@ -73,13 +122,29 @@ object Bagel extends Logging {
numPartitions: Int
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
/** Runs a Bagel program with no [[spark.bagel.Aggregator]] and the default [[spark.HashPartitioner]]*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
numPartitions: Int,
storageLevel: StorageLevel
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = {
val part = new HashPartitioner(numPartitions)
run[K, V, M, C, Nothing](
sc, vertices, messages, combiner, None, part, numPartitions)(
sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)(
addAggregatorArg[K, V, M, C](compute))
}
/**
* Runs a Bagel program with no [[spark.bagel.Aggregator]], default [[spark.HashPartitioner]],
* [[spark.bagel.DefaultCombiner]] and the default storage level
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
@ -87,10 +152,24 @@ object Bagel extends Logging {
numPartitions: Int
)(
compute: (V, Option[Array[M]], Int) => (V, Array[M])
): RDD[(K, V)] = {
): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
/**
* Runs a Bagel program with no [[spark.bagel.Aggregator]], the default [[spark.HashPartitioner]]
* and [[spark.bagel.DefaultCombiner]]
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
numPartitions: Int,
storageLevel: StorageLevel
)(
compute: (V, Option[Array[M]], Int) => (V, Array[M])
): RDD[(K, V)] = {
val part = new HashPartitioner(numPartitions)
run[K, V, M, Array[M], Nothing](
sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions)(
sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)(
addAggregatorArg[K, V, M, Array[M]](compute))
}
@ -117,7 +196,8 @@ object Bagel extends Logging {
private def comp[K: Manifest, V <: Vertex, M <: Message[K], C](
sc: SparkContext,
grouped: RDD[(K, (Seq[C], Seq[V]))],
compute: (V, Option[C]) => (V, Array[M])
compute: (V, Option[C]) => (V, Array[M]),
storageLevel: StorageLevel
): (RDD[(K, (V, Array[M]))], Int, Int) = {
var numMsgs = sc.accumulator(0)
var numActiveVerts = sc.accumulator(0)
@ -135,7 +215,7 @@ object Bagel extends Logging {
numActiveVerts += 1
Some((newVert, newMsgs))
}.cache
}.persist(storageLevel)
// Force evaluation of processed RDD for accurate performance measurements
processed.foreach(x => {})
@ -166,6 +246,7 @@ trait Aggregator[V, A] {
def mergeAggregators(a: A, b: A): A
}
/** Default combiner that simply appends messages together (i.e. performs no aggregation) */
class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable {
def createCombiner(msg: M): Array[M] =
Array(msg)

View file

@ -7,6 +7,7 @@ import org.scalatest.time.SpanSugar._
import scala.collection.mutable.ArrayBuffer
import spark._
import storage.StorageLevel
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable
@ -22,6 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
}
test("halting by voting") {
@ -79,4 +81,21 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
}
}
}
test("using non-default persistence level") {
failAfter(10 seconds) {
sc = new SparkContext("local", "test")
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 50
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
}
}
}
}

View file

@ -30,7 +30,7 @@
# SPARK_NICENESS The scheduling priority for daemons. Defaults to 0.
##
usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <args...>"
usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <spark-instance-number> <args...>"
# if no args specified, show usage
if [ $# -le 1 ]; then
@ -48,6 +48,8 @@ startStop=$1
shift
command=$1
shift
instance=$1
shift
spark_rotate_log ()
{
@ -92,10 +94,10 @@ if [ "$SPARK_PID_DIR" = "" ]; then
fi
# some variables
export SPARK_LOGFILE=spark-$SPARK_IDENT_STRING-$command-$HOSTNAME.log
export SPARK_LOGFILE=spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.log
export SPARK_ROOT_LOGGER="INFO,DRFA"
log=$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$HOSTNAME.out
pid=$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command.pid
log=$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.out
pid=$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid
# Set default scheduling priority
if [ "$SPARK_NICENESS" = "" ]; then

View file

@ -2,7 +2,7 @@
# Run a Spark command on all slave hosts.
usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command args..."
usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command instance-number args..."
# if no args specified, show usage
if [ $# -le 1 ]; then

View file

@ -32,4 +32,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then
fi
fi
"$bin"/spark-daemon.sh start spark.deploy.master.Master --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT
"$bin"/spark-daemon.sh start spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT

View file

@ -6,9 +6,10 @@ bin=`cd "$bin"; pwd`
# Set SPARK_PUBLIC_DNS so slaves can be linked in master web UI
if [ "$SPARK_PUBLIC_DNS" = "" ]; then
# If we appear to be running on EC2, use the public address by default:
if [[ `hostname` == *ec2.internal ]]; then
# NOTE: ec2-metadata is installed on Amazon Linux AMI. Check based on that and hostname
if command -v ec2-metadata > /dev/null || [[ `hostname` == *ec2.internal ]]; then
export SPARK_PUBLIC_DNS=`wget -q -O - http://instance-data.ec2.internal/latest/meta-data/public-hostname`
fi
fi
"$bin"/spark-daemon.sh start spark.deploy.worker.Worker $1
"$bin"/spark-daemon.sh start spark.deploy.worker.Worker "$@"

View file

@ -21,4 +21,13 @@ fi
echo "Master IP: $SPARK_MASTER_IP"
# Launch the slaves
exec "$bin/slaves.sh" cd "$SPARK_HOME" \; "$bin/start-slave.sh" spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT
if [ "$SPARK_WORKER_INSTANCES" = "" ]; then
exec "$bin/slaves.sh" cd "$SPARK_HOME" \; "$bin/start-slave.sh" 1 spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT
else
if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then
SPARK_WORKER_WEBUI_PORT=8081
fi
for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do
"$bin/slaves.sh" cd "$SPARK_HOME" \; "$bin/start-slave.sh" $(( $i + 1 )) spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i ))
done
fi

View file

@ -7,4 +7,4 @@ bin=`cd "$bin"; pwd`
. "$bin/spark-config.sh"
"$bin"/spark-daemon.sh stop spark.deploy.master.Master
"$bin"/spark-daemon.sh stop spark.deploy.master.Master 1

View file

@ -7,4 +7,14 @@ bin=`cd "$bin"; pwd`
. "$bin/spark-config.sh"
"$bin"/spark-daemons.sh stop spark.deploy.worker.Worker
if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then
. "${SPARK_CONF_DIR}/spark-env.sh"
fi
if [ "$SPARK_WORKER_INSTANCES" = "" ]; then
"$bin"/spark-daemons.sh stop spark.deploy.worker.Worker 1
else
for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do
"$bin"/spark-daemons.sh stop spark.deploy.worker.Worker $(( $i + 1 ))
done
fi

View file

@ -12,6 +12,7 @@
# - SPARK_WORKER_CORES, to set the number of cores to use on this machine
# - SPARK_WORKER_MEMORY, to set how much memory to use (e.g. 1000m, 2g)
# - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT
# - SPARK_WORKER_INSTANCES, to set the number of worker instances/processes to be spawned on every slave machine
#
# Finally, Spark also relies on the following variables, but these can be set
# on just the *master* (i.e. in your driver program), and will automatically

View file

@ -3,8 +3,8 @@
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.spark-project</groupId>
<artifactId>parent</artifactId>
<version>0.7.0-SNAPSHOT</version>
<artifactId>spark-parent</artifactId>
<version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
@ -32,8 +32,8 @@
<artifactId>compress-lzf</artifactId>
</dependency>
<dependency>
<groupId>asm</groupId>
<artifactId>asm-all</artifactId>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
@ -73,7 +73,7 @@
</dependency>
<dependency>
<groupId>cc.spray</groupId>
<artifactId>spray-json_${scala.version}</artifactId>
<artifactId>spray-json_2.9.2</artifactId>
</dependency>
<dependency>
<groupId>org.tomdz.twirl</groupId>
@ -81,13 +81,26 @@
</dependency>
<dependency>
<groupId>com.github.scala-incubator.io</groupId>
<artifactId>scala-io-file_${scala.version}</artifactId>
<artifactId>scala-io-file_2.9.2</artifactId>
</dependency>
<dependency>
<groupId>org.apache.mesos</groupId>
<artifactId>mesos</artifactId>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
</dependency>
<dependency>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
</dependency>
<dependency>
<groupId>org.apache.derby</groupId>
<artifactId>derby</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalatest</groupId>
<artifactId>scalatest_${scala.version}</artifactId>
@ -275,5 +288,72 @@
</plugins>
</build>
</profile>
<profile>
<id>hadoop2-yarn</id>
<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-yarn-api</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-yarn-common</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-yarn-client</artifactId>
<scope>provided</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>build-helper-maven-plugin</artifactId>
<executions>
<execution>
<id>add-source</id>
<phase>generate-sources</phase>
<goals>
<goal>add-source</goal>
</goals>
<configuration>
<sources>
<source>src/main/scala</source>
<source>src/hadoop2-yarn/scala</source>
</sources>
</configuration>
</execution>
<execution>
<id>add-scala-test-sources</id>
<phase>generate-test-sources</phase>
<goals>
<goal>add-test-source</goal>
</goals>
<configuration>
<sources>
<source>src/test/scala</source>
</sources>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-jar-plugin</artifactId>
<configuration>
<classifier>hadoop2-yarn</classifier>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>

View file

@ -4,4 +4,7 @@ trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
jobId, isMap, taskId, attemptId)
}

View file

@ -6,4 +6,7 @@ trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
jobId, isMap, taskId, attemptId)
}

View file

@ -0,0 +1,23 @@
package spark.deploy
import org.apache.hadoop.conf.Configuration
/**
* Contains util methods to interact with Hadoop from spark.
*/
object SparkHadoopUtil {
def getUserNameFromEnvironment(): String = {
// defaulting to -D ...
System.getProperty("user.name")
}
def runAsUser(func: (Product) => Unit, args: Product) {
// Add support, if exists - for now, simply run func !
func(args)
}
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
def newConfiguration(): Configuration = new Configuration()
}

View file

@ -0,0 +1,13 @@
package org.apache.hadoop.mapred
import org.apache.hadoop.mapreduce.TaskType
trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
}

View file

@ -0,0 +1,13 @@
package org.apache.hadoop.mapreduce
import org.apache.hadoop.conf.Configuration
import task.{TaskAttemptContextImpl, JobContextImpl}
trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
}

View file

@ -0,0 +1,63 @@
package spark.deploy
import collection.mutable.HashMap
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import java.security.PrivilegedExceptionAction
/**
* Contains util methods to interact with Hadoop from spark.
*/
object SparkHadoopUtil {
val yarnConf = newConfiguration()
def getUserNameFromEnvironment(): String = {
// defaulting to env if -D is not present ...
val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name))
// If nothing found, default to user we are running as
if (retval == null) System.getProperty("user.name") else retval
}
def runAsUser(func: (Product) => Unit, args: Product) {
runAsUser(func, args, getUserNameFromEnvironment())
}
def runAsUser(func: (Product) => Unit, args: Product, user: String) {
// println("running as user " + jobUserName)
UserGroupInformation.setConfiguration(yarnConf)
val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user)
appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
def run: AnyRef = {
func(args)
// no return value ...
null
}
})
}
// Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true.
def isYarnMode(): Boolean = {
val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))
java.lang.Boolean.valueOf(yarnMode)
}
// Set an env variable indicating we are running in YARN mode.
// Note that anything with SPARK prefix gets propagated to all (remote) processes
def setYarnMode() {
System.setProperty("SPARK_YARN_MODE", "true")
}
def setYarnMode(env: HashMap[String, String]) {
env("SPARK_YARN_MODE") = "true"
}
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
// Always create a new config, dont reuse yarnConf.
def newConfiguration(): Configuration = new YarnConfiguration(new Configuration())
}

View file

@ -0,0 +1,329 @@
package spark.deploy.yarn
import java.net.Socket
import java.util.concurrent.CopyOnWriteArrayList
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import scala.collection.JavaConversions._
import spark.{SparkContext, Logging, Utils}
import org.apache.hadoop.security.UserGroupInformation
import java.security.PrivilegedExceptionAction
class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
def this(args: ApplicationMasterArguments) = this(args, new Configuration())
private var rpc: YarnRPC = YarnRPC.create(conf)
private var resourceManager: AMRMProtocol = null
private var appAttemptId: ApplicationAttemptId = null
private var userThread: Thread = null
private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
private var yarnAllocator: YarnAllocationHandler = null
def run() {
// Initialization
val jobUserName = Utils.getUserNameFromEnvironment()
logInfo("running as user " + jobUserName)
// run as user ...
UserGroupInformation.setConfiguration(yarnConf)
val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName)
appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
def run: AnyRef = {
runImpl()
return null
}
})
}
private def runImpl() {
appAttemptId = getApplicationAttemptId()
resourceManager = registerWithResourceManager()
val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
// Compute number of threads for akka
val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
if (minimumMemory > 0) {
val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
if (numCore > 0) {
// do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
// TODO: Uncomment when hadoop is on a version which has this fixed.
// args.workerCores = numCore
}
}
// Workaround until hadoop moves to something which has
// https://issues.apache.org/jira/browse/HADOOP-8406
// ignore result
// This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
// Hence args.workerCores = numCore disabled above. Any better option ?
// org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
ApplicationMaster.register(this)
// Start the user's JAR
userThread = startUserClass()
// This a bit hacky, but we need to wait until the spark.driver.port property has
// been set by the Thread executing the user class.
waitForSparkMaster()
// Allocate all containers
allocateWorkers()
// Wait for the user class to Finish
userThread.join()
// Finish the ApplicationMaster
finishApplicationMaster()
// TODO: Exit based on success/failure
System.exit(0)
}
private def getApplicationAttemptId(): ApplicationAttemptId = {
val envs = System.getenv()
val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
val containerId = ConverterUtils.toContainerId(containerIdString)
val appAttemptId = containerId.getApplicationAttemptId()
logInfo("ApplicationAttemptId: " + appAttemptId)
return appAttemptId
}
private def registerWithResourceManager(): AMRMProtocol = {
val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
YarnConfiguration.RM_SCHEDULER_ADDRESS,
YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
logInfo("Connecting to ResourceManager at " + rmAddress)
return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
}
private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
logInfo("Registering the ApplicationMaster")
val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
.asInstanceOf[RegisterApplicationMasterRequest]
appMasterRequest.setApplicationAttemptId(appAttemptId)
// Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
// Users can then monitor stderr/stdout on that node if required.
appMasterRequest.setHost(Utils.localHostName())
appMasterRequest.setRpcPort(0)
// What do we provide here ? Might make sense to expose something sensible later ?
appMasterRequest.setTrackingUrl("")
return resourceManager.registerApplicationMaster(appMasterRequest)
}
private def waitForSparkMaster() {
logInfo("Waiting for spark driver to be reachable.")
var driverUp = false
while(!driverUp) {
val driverHost = System.getProperty("spark.driver.host")
val driverPort = System.getProperty("spark.driver.port")
try {
val socket = new Socket(driverHost, driverPort.toInt)
socket.close()
logInfo("Master now available: " + driverHost + ":" + driverPort)
driverUp = true
} catch {
case e: Exception =>
logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
Thread.sleep(100)
}
}
}
private def startUserClass(): Thread = {
logInfo("Starting the user JAR in a separate Thread")
val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader)
.getMethod("main", classOf[Array[String]])
val t = new Thread {
override def run() {
// Copy
var mainArgs: Array[String] = new Array[String](args.userArgs.size())
args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size())
mainMethod.invoke(null, mainArgs)
}
}
t.start()
return t
}
private def allocateWorkers() {
logInfo("Waiting for spark context initialization")
try {
var sparkContext: SparkContext = null
ApplicationMaster.sparkContextRef.synchronized {
var count = 0
while (ApplicationMaster.sparkContextRef.get() == null) {
logInfo("Waiting for spark context initialization ... " + count)
count = count + 1
ApplicationMaster.sparkContextRef.wait(10000L)
}
sparkContext = ApplicationMaster.sparkContextRef.get()
assert(sparkContext != null)
this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData)
}
logInfo("Allocating " + args.numWorkers + " workers.")
// Wait until all containers have finished
// TODO: This is a bit ugly. Can we make it nicer?
// TODO: Handle container failure
while(yarnAllocator.getNumWorkersRunning < args.numWorkers &&
// If user thread exists, then quit !
userThread.isAlive) {
this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
ApplicationMaster.incrementAllocatorLoop(1)
Thread.sleep(100)
}
} finally {
// in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT :
// so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
}
logInfo("All workers have launched.")
// Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
if (userThread.isAlive){
// ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
// must be <= timeoutInterval/ 2.
// On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
// so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
launchReporterThread(interval)
}
}
// TODO: We might want to extend this to allocate more containers in case they die !
private def launchReporterThread(_sleepTime: Long): Thread = {
val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
val t = new Thread {
override def run() {
while (userThread.isAlive){
val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
if (missingWorkerCount > 0) {
logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
yarnAllocator.allocateContainers(missingWorkerCount)
}
else sendProgress()
Thread.sleep(sleepTime)
}
}
}
// setting to daemon status, though this is usually not a good idea.
t.setDaemon(true)
t.start()
logInfo("Started progress reporter thread - sleep time : " + sleepTime)
return t
}
private def sendProgress() {
logDebug("Sending progress")
// simulated with an allocate request with no nodes requested ...
yarnAllocator.allocateContainers(0)
}
/*
def printContainers(containers: List[Container]) = {
for (container <- containers) {
logInfo("Launching shell command on a new container."
+ ", containerId=" + container.getId()
+ ", containerNode=" + container.getNodeId().getHost()
+ ":" + container.getNodeId().getPort()
+ ", containerNodeURI=" + container.getNodeHttpAddress()
+ ", containerState" + container.getState()
+ ", containerResourceMemory"
+ container.getResource().getMemory())
}
}
*/
def finishApplicationMaster() {
val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
.asInstanceOf[FinishApplicationMasterRequest]
finishReq.setAppAttemptId(appAttemptId)
// TODO: Check if the application has failed or succeeded
finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED)
resourceManager.finishApplicationMaster(finishReq)
}
}
object ApplicationMaster {
// number of times to wait for the allocator loop to complete.
// each loop iteration waits for 100ms, so maximum of 3 seconds.
// This is to ensure that we have reasonable number of containers before we start
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more
// containers are available. Might need to handle this better.
private val ALLOCATOR_LOOP_WAIT_COUNT = 30
def incrementAllocatorLoop(by: Int) {
val count = yarnAllocatorLoop.getAndAdd(by)
if (count >= ALLOCATOR_LOOP_WAIT_COUNT){
yarnAllocatorLoop.synchronized {
// to wake threads off wait ...
yarnAllocatorLoop.notifyAll()
}
}
}
private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
def register(master: ApplicationMaster) {
applicationMasters.add(master)
}
val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null)
val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
def sparkContextInitialized(sc: SparkContext): Boolean = {
var modified = false
sparkContextRef.synchronized {
modified = sparkContextRef.compareAndSet(null, sc)
sparkContextRef.notifyAll()
}
// Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
// Should not really have to do this, but it helps yarn to evict resources earlier.
// not to mention, prevent Client declaring failure even though we exit'ed properly.
if (modified) {
Runtime.getRuntime().addShutdownHook(new Thread with Logging {
// This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
logInfo("Adding shutdown hook for context " + sc)
override def run() {
logInfo("Invoking sc stop from shutdown hook")
sc.stop()
// best case ...
for (master <- applicationMasters) master.finishApplicationMaster
}
} )
}
// Wait for initialization to complete and atleast 'some' nodes can get allocated
yarnAllocatorLoop.synchronized {
while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){
yarnAllocatorLoop.wait(1000L)
}
}
modified
}
def main(argStrings: Array[String]) {
val args = new ApplicationMasterArguments(argStrings)
new ApplicationMaster(args).run()
}
}

View file

@ -0,0 +1,77 @@
package spark.deploy.yarn
import spark.util.IntParam
import collection.mutable.ArrayBuffer
class ApplicationMasterArguments(val args: Array[String]) {
var userJar: String = null
var userClass: String = null
var userArgs: Seq[String] = Seq[String]()
var workerMemory = 1024
var workerCores = 1
var numWorkers = 2
parseArgs(args.toList)
private def parseArgs(inputArgs: List[String]): Unit = {
val userArgsBuffer = new ArrayBuffer[String]()
var args = inputArgs
while (! args.isEmpty) {
args match {
case ("--jar") :: value :: tail =>
userJar = value
args = tail
case ("--class") :: value :: tail =>
userClass = value
args = tail
case ("--args") :: value :: tail =>
userArgsBuffer += value
args = tail
case ("--num-workers") :: IntParam(value) :: tail =>
numWorkers = value
args = tail
case ("--worker-memory") :: IntParam(value) :: tail =>
workerMemory = value
args = tail
case ("--worker-cores") :: IntParam(value) :: tail =>
workerCores = value
args = tail
case Nil =>
if (userJar == null || userClass == null) {
printUsageAndExit(1)
}
case _ =>
printUsageAndExit(1, args)
}
}
userArgs = userArgsBuffer.readOnly
}
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
if (unknownParam != null) {
System.err.println("Unknown/unsupported param " + unknownParam)
}
System.err.println(
"Usage: spark.deploy.yarn.ApplicationMaster [options] \n" +
"Options:\n" +
" --jar JAR_PATH Path to your application's JAR file (required)\n" +
" --class CLASS_NAME Name of your application's main class (required)\n" +
" --args ARGS Arguments to be passed to your application's main class.\n" +
" Mutliple invocations are possible, each will be passed in order.\n" +
" --num-workers NUM Number of workers to start (Default: 2)\n" +
" --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n")
System.exit(exitCode)
}
}

View file

@ -0,0 +1,272 @@
package spark.deploy.yarn
import java.net.{InetSocketAddress, URI}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.client.YarnClientImpl
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
import spark.{Logging, Utils}
import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils}
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import spark.deploy.SparkHadoopUtil
class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
def this(args: ClientArguments) = this(new Configuration(), args)
var rpc: YarnRPC = YarnRPC.create(conf)
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
def run() {
init(yarnConf)
start()
logClusterResourceDetails()
val newApp = super.getNewApplication()
val appId = newApp.getApplicationId()
verifyClusterResources(newApp)
val appContext = createApplicationSubmissionContext(appId)
val localResources = prepareLocalResources(appId, "spark")
val env = setupLaunchEnv(localResources)
val amContainer = createContainerLaunchContext(newApp, localResources, env)
appContext.setQueue(args.amQueue)
appContext.setAMContainerSpec(amContainer)
appContext.setUser(args.amUser)
submitApp(appContext)
monitorApplication(appId)
System.exit(0)
}
def logClusterResourceDetails() {
val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers)
val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity +
", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size +
", queueChildQueueCount=" + queueInfo.getChildQueues.size)
}
def verifyClusterResources(app: GetNewApplicationResponse) = {
val maxMem = app.getMaximumResourceCapability().getMemory()
logInfo("Max mem capabililty of resources in this cluster " + maxMem)
// If the cluster does not have enough memory resources, exit.
val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory
if (requestedMem > maxMem) {
logError("Cluster cannot satisfy memory resource request of " + requestedMem)
System.exit(1)
}
}
def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = {
logInfo("Setting up application submission context for ASM")
val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
appContext.setApplicationId(appId)
appContext.setApplicationName("Spark")
return appContext
}
def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
val locaResources = HashMap[String, LocalResource]()
// Upload Spark and the application JAR to the remote file system
// Add them as local resources to the AM
val fs = FileSystem.get(conf)
Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
.foreach { case(destName, _localPath) =>
val localPath: String = if (_localPath != null) _localPath.trim() else ""
if (! localPath.isEmpty()) {
val src = new Path(localPath)
val pathSuffix = appName + "/" + appId.getId() + destName
val dst = new Path(fs.getHomeDirectory(), pathSuffix)
logInfo("Uploading " + src + " to " + dst)
fs.copyFromLocalFile(false, true, src, dst)
val destStatus = fs.getFileStatus(dst)
val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
amJarRsrc.setType(LocalResourceType.FILE)
amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst))
amJarRsrc.setTimestamp(destStatus.getModificationTime())
amJarRsrc.setSize(destStatus.getLen())
locaResources(destName) = amJarRsrc
}
}
return locaResources
}
def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
logInfo("Setting up the launch environment")
val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null)
val env = new HashMap[String, String]()
Apps.addToEnvironment(env, Environment.USER.name, args.amUser)
// If log4j present, ensure ours overrides all others
if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
Client.populateHadoopClasspath(yarnConf, env)
SparkHadoopUtil.setYarnMode(env)
env("SPARK_YARN_JAR_PATH") =
localResources("spark.jar").getResource().getScheme.toString() + "://" +
localResources("spark.jar").getResource().getFile().toString()
env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString()
env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString()
env("SPARK_YARN_USERJAR_PATH") =
localResources("app.jar").getResource().getScheme.toString() + "://" +
localResources("app.jar").getResource().getFile().toString()
env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString()
env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString()
if (log4jConfLocalRes != null) {
env("SPARK_YARN_LOG4J_PATH") =
log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString()
env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString()
env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
}
// Add each SPARK-* key to the environment
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
return env
}
def userArgsToString(clientArgs: ClientArguments): String = {
val prefix = " --args "
val args = clientArgs.userArgs
val retval = new StringBuilder()
for (arg <- args){
retval.append(prefix).append(" '").append(arg).append("' ")
}
retval.toString
}
def createContainerLaunchContext(newApp: GetNewApplicationResponse,
localResources: HashMap[String, LocalResource],
env: HashMap[String, String]): ContainerLaunchContext = {
logInfo("Setting up container launch context")
val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
amContainer.setLocalResources(localResources)
amContainer.setEnvironment(env)
val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
(if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD
// Extra options for the JVM
var JAVA_OPTS = ""
// Add Xmx for am memory
JAVA_OPTS += "-Xmx" + amMemory + "m "
// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
// The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
// node, spark gc effects all other containers performance (which can also be other spark containers)
// Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
// limited to subset of cores on a node.
if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) {
// In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
JAVA_OPTS += " -XX:+CMSIncrementalMode "
JAVA_OPTS += " -XX:+CMSIncrementalPacing "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
}
if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
}
// Command for the ApplicationMaster
val commands = List[String]("java " +
" -server " +
JAVA_OPTS +
" spark.deploy.yarn.ApplicationMaster" +
" --class " + args.userClass +
" --jar " + args.userJar +
userArgsToString(args) +
" --worker-memory " + args.workerMemory +
" --worker-cores " + args.workerCores +
" --num-workers " + args.numWorkers +
" 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
" 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
logInfo("Command for the ApplicationMaster: " + commands(0))
amContainer.setCommands(commands)
val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
// Memory for the ApplicationMaster
capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
amContainer.setResource(capability)
return amContainer
}
def submitApp(appContext: ApplicationSubmissionContext) = {
// Submit the application to the applications manager
logInfo("Submitting application to ASM")
super.submitApplication(appContext)
}
def monitorApplication(appId: ApplicationId): Boolean = {
while(true) {
Thread.sleep(1000)
val report = super.getApplicationReport(appId)
logInfo("Application report from ASM: \n" +
"\t application identifier: " + appId.toString() + "\n" +
"\t appId: " + appId.getId() + "\n" +
"\t clientToken: " + report.getClientToken() + "\n" +
"\t appDiagnostics: " + report.getDiagnostics() + "\n" +
"\t appMasterHost: " + report.getHost() + "\n" +
"\t appQueue: " + report.getQueue() + "\n" +
"\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
"\t appStartTime: " + report.getStartTime() + "\n" +
"\t yarnAppState: " + report.getYarnApplicationState() + "\n" +
"\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" +
"\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
"\t appUser: " + report.getUser()
)
val state = report.getYarnApplicationState()
val dsStatus = report.getFinalApplicationStatus()
if (state == YarnApplicationState.FINISHED ||
state == YarnApplicationState.FAILED ||
state == YarnApplicationState.KILLED) {
return true
}
}
return true
}
}
object Client {
def main(argStrings: Array[String]) {
val args = new ClientArguments(argStrings)
SparkHadoopUtil.setYarnMode()
new Client(args).run
}
// Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
}
}
}

View file

@ -0,0 +1,105 @@
package spark.deploy.yarn
import spark.util.MemoryParam
import spark.util.IntParam
import collection.mutable.{ArrayBuffer, HashMap}
import spark.scheduler.{InputFormatInfo, SplitInfo}
// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
class ClientArguments(val args: Array[String]) {
var userJar: String = null
var userClass: String = null
var userArgs: Seq[String] = Seq[String]()
var workerMemory = 1024
var workerCores = 1
var numWorkers = 2
var amUser = System.getProperty("user.name")
var amQueue = System.getProperty("QUEUE", "default")
var amMemory: Int = 512
// TODO
var inputFormatInfo: List[InputFormatInfo] = null
parseArgs(args.toList)
private def parseArgs(inputArgs: List[String]): Unit = {
val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]()
val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]()
var args = inputArgs
while (! args.isEmpty) {
args match {
case ("--jar") :: value :: tail =>
userJar = value
args = tail
case ("--class") :: value :: tail =>
userClass = value
args = tail
case ("--args") :: value :: tail =>
userArgsBuffer += value
args = tail
case ("--master-memory") :: MemoryParam(value) :: tail =>
amMemory = value
args = tail
case ("--num-workers") :: IntParam(value) :: tail =>
numWorkers = value
args = tail
case ("--worker-memory") :: MemoryParam(value) :: tail =>
workerMemory = value
args = tail
case ("--worker-cores") :: IntParam(value) :: tail =>
workerCores = value
args = tail
case ("--user") :: value :: tail =>
amUser = value
args = tail
case ("--queue") :: value :: tail =>
amQueue = value
args = tail
case Nil =>
if (userJar == null || userClass == null) {
printUsageAndExit(1)
}
case _ =>
printUsageAndExit(1, args)
}
}
userArgs = userArgsBuffer.readOnly
inputFormatInfo = inputFormatMap.values.toList
}
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
if (unknownParam != null) {
System.err.println("Unknown/unsupported param " + unknownParam)
}
System.err.println(
"Usage: spark.deploy.yarn.Client [options] \n" +
"Options:\n" +
" --jar JAR_PATH Path to your application's JAR file (required)\n" +
" --class CLASS_NAME Name of your application's main class (required)\n" +
" --args ARGS Arguments to be passed to your application's main class.\n" +
" Mutliple invocations are possible, each will be passed in order.\n" +
" --num-workers NUM Number of workers to start (Default: 2)\n" +
" --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
" --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
" --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
" --user USERNAME Run the ApplicationMaster (and slaves) as a different user\n"
)
System.exit(exitCode)
}
}

View file

@ -0,0 +1,171 @@
package spark.deploy.yarn
import java.net.URI
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.security.UserGroupInformation
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import spark.{Logging, Utils}
class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String,
slaveId: String, hostname: String, workerMemory: Int, workerCores: Int)
extends Runnable with Logging {
var rpc: YarnRPC = YarnRPC.create(conf)
var cm: ContainerManager = null
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
def run = {
logInfo("Starting Worker Container")
cm = connectToCM
startContainer
}
def startContainer = {
logInfo("Setting up ContainerLaunchContext")
val ctx = Records.newRecord(classOf[ContainerLaunchContext])
.asInstanceOf[ContainerLaunchContext]
ctx.setContainerId(container.getId())
ctx.setResource(container.getResource())
val localResources = prepareLocalResources
ctx.setLocalResources(localResources)
val env = prepareEnvironment
ctx.setEnvironment(env)
// Extra options for the JVM
var JAVA_OPTS = ""
// Set the JVM memory
val workerMemoryString = workerMemory + "m"
JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " "
if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
}
// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
// The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
// node, spark gc effects all other containers performance (which can also be other spark containers)
// Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
// limited to subset of cores on a node.
/*
else {
// If no java_opts specified, default to using -XX:+CMSIncrementalMode
// It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it.
// In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines
// The options are based on
// http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
JAVA_OPTS += " -XX:+CMSIncrementalMode "
JAVA_OPTS += " -XX:+CMSIncrementalPacing "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
}
*/
ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
val commands = List[String]("java " +
" -server " +
// Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
// Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state.
// TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
" -XX:OnOutOfMemoryError='kill %p' " +
JAVA_OPTS +
" spark.executor.StandaloneExecutorBackend " +
masterAddress + " " +
slaveId + " " +
hostname + " " +
workerCores +
" 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
" 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
logInfo("Setting up worker with commands: " + commands)
ctx.setCommands(commands)
// Send the start request to the ContainerManager
val startReq = Records.newRecord(classOf[StartContainerRequest])
.asInstanceOf[StartContainerRequest]
startReq.setContainerLaunchContext(ctx)
cm.startContainer(startReq)
}
def prepareLocalResources: HashMap[String, LocalResource] = {
logInfo("Preparing Local resources")
val locaResources = HashMap[String, LocalResource]()
// Spark JAR
val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
sparkJarResource.setType(LocalResourceType.FILE)
sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
locaResources("spark.jar") = sparkJarResource
// User JAR
val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
userJarResource.setType(LocalResourceType.FILE)
userJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
userJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
locaResources("app.jar") = userJarResource
// Log4j conf - if available
if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
log4jConfResource.setType(LocalResourceType.FILE)
log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION)
log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI(
new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
locaResources("log4j.properties") = log4jConfResource
}
logInfo("Prepared Local resources " + locaResources)
return locaResources
}
def prepareEnvironment: HashMap[String, String] = {
val env = new HashMap[String, String]()
// should we add this ?
Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment())
// If log4j present, ensure ours overrides all others
if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
// Which is correct ?
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties")
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
}
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
Client.populateHadoopClasspath(yarnConf, env)
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
return env
}
def connectToCM: ContainerManager = {
val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort()
val cmAddress = NetUtils.createSocketAddr(cmHostPortStr)
logInfo("Connecting to ContainerManager at " + cmHostPortStr)
return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager]
}
}

View file

@ -0,0 +1,547 @@
package spark.deploy.yarn
import spark.{Logging, Utils}
import spark.scheduler.SplitInfo
import scala.collection
import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
import org.apache.hadoop.yarn.util.{RackResolver, Records}
import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
import java.util.concurrent.atomic.AtomicInteger
import org.apache.hadoop.yarn.api.AMRMProtocol
import collection.JavaConversions._
import collection.mutable.{ArrayBuffer, HashMap, HashSet}
import org.apache.hadoop.conf.Configuration
import java.util.{Collections, Set => JSet}
import java.lang.{Boolean => JBoolean}
object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
type AllocationType = Value
val HOST, RACK, ANY = Value
}
// too many params ? refactor it 'somehow' ?
// needs to be mt-safe
// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it
// more proactive and decoupled.
// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info
// on how we are requesting for containers.
private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol,
val appAttemptId: ApplicationAttemptId,
val maxWorkers: Int, val workerMemory: Int, val workerCores: Int,
val preferredHostToCount: Map[String, Int],
val preferredRackToCount: Map[String, Int])
extends Logging {
// These three are locked on allocatedHostToContainersMap. Complementary data structures
// allocatedHostToContainersMap : containers which are running : host, Set<containerid>
// allocatedContainerToHostMap: container to host mapping
private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]()
private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
// allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node)
// As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap
private val allocatedRackCount = new HashMap[String, Int]()
// containers which have been released.
private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
// containers to be released in next request to RM
private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
private val numWorkersRunning = new AtomicInteger()
// Used to generate a unique id per worker
private val workerIdCounter = new AtomicInteger()
private val lastResponseId = new AtomicInteger()
def getNumWorkersRunning: Int = numWorkersRunning.intValue
def isResourceConstraintSatisfied(container: Container): Boolean = {
container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
}
def allocateContainers(workersToRequest: Int) {
// We need to send the request only once from what I understand ... but for now, not modifying this much.
// Keep polling the Resource Manager for containers
val amResp = allocateWorkerResources(workersToRequest).getAMResponse
val _allocatedContainers = amResp.getAllocatedContainers()
if (_allocatedContainers.size > 0) {
logDebug("Allocated " + _allocatedContainers.size + " containers, current count " +
numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
", pendingReleaseContainers : " + pendingReleaseContainers)
logDebug("Cluster Resources: " + amResp.getAvailableResources)
val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
// ignore if not satisfying constraints {
for (container <- _allocatedContainers) {
if (isResourceConstraintSatisfied(container)) {
// allocatedContainers += container
val host = container.getNodeId.getHost
val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]())
containers += container
}
// Add all ignored containers to released list
else releasedContainerList.add(container.getId())
}
// Find the appropriate containers to use
// Slightly non trivial groupBy I guess ...
val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
for (candidateHost <- hostToContainers.keySet)
{
val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null)
assert(remainingContainers != null)
if (requiredHostCount >= remainingContainers.size){
// Since we got <= required containers, add all to dataLocalContainers
dataLocalContainers.put(candidateHost, remainingContainers)
// all consumed
remainingContainers = null
}
else if (requiredHostCount > 0) {
// container list has more containers than we need for data locality.
// Split into two : data local container count of (remainingContainers.size - requiredHostCount)
// and rest as remainingContainer
val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount)
dataLocalContainers.put(candidateHost, dataLocal)
// remainingContainers = remaining
// yarn has nasty habit of allocating a tonne of containers on a host - discourage this :
// add remaining to release list. If we have insufficient containers, next allocation cycle
// will reallocate (but wont treat it as data local)
for (container <- remaining) releasedContainerList.add(container.getId())
remainingContainers = null
}
// now rack local
if (remainingContainers != null){
val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
if (rack != null){
val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
rackLocalContainers.get(rack).getOrElse(List()).size
if (requiredRackCount >= remainingContainers.size){
// Add all to dataLocalContainers
dataLocalContainers.put(rack, remainingContainers)
// all consumed
remainingContainers = null
}
else if (requiredRackCount > 0) {
// container list has more containers than we need for data locality.
// Split into two : data local container count of (remainingContainers.size - requiredRackCount)
// and rest as remainingContainer
val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount)
val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]())
existingRackLocal ++= rackLocal
remainingContainers = remaining
}
}
}
// If still not consumed, then it is off rack host - add to that list.
if (remainingContainers != null){
offRackContainers.put(candidateHost, remainingContainers)
}
}
// Now that we have split the containers into various groups, go through them in order :
// first host local, then rack local and then off rack (everything else).
// Note that the list we create below tries to ensure that not all containers end up within a host
// if there are sufficiently large number of hosts/containers.
val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
// Run each of the allocated containers
for (container <- allocatedContainers) {
val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
val workerHostname = container.getNodeId.getHost
val containerId = container.getId
assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
if (numWorkersRunningNow > maxWorkers) {
logInfo("Ignoring container " + containerId + " at host " + workerHostname +
" .. we already have required number of containers")
releasedContainerList.add(containerId)
// reset counter back to old value.
numWorkersRunning.decrementAndGet()
}
else {
// deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter)
val workerId = workerIdCounter.incrementAndGet().toString
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
StandaloneSchedulerBackend.ACTOR_NAME)
logInfo("launching container on " + containerId + " host " + workerHostname)
// just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..
pendingReleaseContainers.remove(containerId)
val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
allocatedHostToContainersMap.synchronized {
val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]())
containerSet += containerId
allocatedContainerToHostMap.put(containerId, workerHostname)
if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
}
new Thread(
new WorkerRunnable(container, conf, driverUrl, workerId,
workerHostname, workerMemory, workerCores)
).start()
}
}
logDebug("After allocated " + allocatedContainers.size + " containers (orig : " +
_allocatedContainers.size + "), current count " + numWorkersRunning.get() +
", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
}
val completedContainers = amResp.getCompletedContainersStatuses()
if (completedContainers.size > 0){
logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() +
", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
for (completedContainer <- completedContainers){
val containerId = completedContainer.getContainerId
// Was this released by us ? If yes, then simply remove from containerSet and move on.
if (pendingReleaseContainers.containsKey(containerId)) {
pendingReleaseContainers.remove(containerId)
}
else {
// simply decrement count - next iteration of ReporterThread will take care of allocating !
numWorkersRunning.decrementAndGet()
logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState +
" httpaddress: " + completedContainer.getDiagnostics)
}
allocatedHostToContainersMap.synchronized {
if (allocatedContainerToHostMap.containsKey(containerId)) {
val host = allocatedContainerToHostMap.get(containerId).getOrElse(null)
assert (host != null)
val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null)
assert (containerSet != null)
containerSet -= containerId
if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host)
else allocatedHostToContainersMap.update(host, containerSet)
allocatedContainerToHostMap -= containerId
// doing this within locked context, sigh ... move to outside ?
val rack = YarnAllocationHandler.lookupRack(conf, host)
if (rack != null) {
val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
if (rackCount > 0) allocatedRackCount.put(rack, rackCount)
else allocatedRackCount.remove(rack)
}
}
}
}
logDebug("After completed " + completedContainers.size + " containers, current count " +
numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
", pendingReleaseContainers : " + pendingReleaseContainers)
}
}
def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = {
// First generate modified racks and new set of hosts under it : then issue requests
val rackToCounts = new HashMap[String, Int]()
// Within this lock - used to read/write to the rack related maps too.
for (container <- hostContainers) {
val candidateHost = container.getHostName
val candidateNumContainers = container.getNumContainers
assert(YarnAllocationHandler.ANY_HOST != candidateHost)
val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
if (rack != null) {
var count = rackToCounts.getOrElse(rack, 0)
count += candidateNumContainers
rackToCounts.put(rack, count)
}
}
val requestedContainers: ArrayBuffer[ResourceRequest] =
new ArrayBuffer[ResourceRequest](rackToCounts.size)
for ((rack, count) <- rackToCounts){
requestedContainers +=
createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY)
}
requestedContainers.toList
}
def allocatedContainersOnHost(host: String): Int = {
var retval = 0
allocatedHostToContainersMap.synchronized {
retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
}
retval
}
def allocatedContainersOnRack(rack: String): Int = {
var retval = 0
allocatedHostToContainersMap.synchronized {
retval = allocatedRackCount.getOrElse(rack, 0)
}
retval
}
private def allocateWorkerResources(numWorkers: Int): AllocateResponse = {
var resourceRequests: List[ResourceRequest] = null
// default.
if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty)
resourceRequests = List(
createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY))
}
else {
// request for all hosts in preferred nodes and for numWorkers -
// candidates.size, request by default allocation policy.
val hostContainerRequests: ArrayBuffer[ResourceRequest] =
new ArrayBuffer[ResourceRequest](preferredHostToCount.size)
for ((candidateHost, candidateCount) <- preferredHostToCount) {
val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
if (requiredCount > 0) {
hostContainerRequests +=
createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY)
}
}
val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList)
val anyContainerRequests: ResourceRequest =
createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)
val containerRequests: ArrayBuffer[ResourceRequest] =
new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1)
containerRequests ++= hostContainerRequests
containerRequests ++= rackContainerRequests
containerRequests += anyContainerRequests
resourceRequests = containerRequests.toList
}
val req = Records.newRecord(classOf[AllocateRequest])
req.setResponseId(lastResponseId.incrementAndGet)
req.setApplicationAttemptId(appAttemptId)
req.addAllAsks(resourceRequests)
val releasedContainerList = createReleasedContainerList()
req.addAllReleases(releasedContainerList)
if (numWorkers > 0) {
logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.")
}
else {
logDebug("Empty allocation req .. release : " + releasedContainerList)
}
for (req <- resourceRequests) {
logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers +
", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability)
}
resourceManager.allocate(req)
}
private def createResourceRequest(requestType: AllocationType.AllocationType,
resource:String, numWorkers: Int, priority: Int): ResourceRequest = {
// If hostname specified, we need atleast two requests - node local and rack local.
// There must be a third request - which is ANY : that will be specially handled.
requestType match {
case AllocationType.HOST => {
assert (YarnAllocationHandler.ANY_HOST != resource)
val hostname = resource
val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority)
// add to host->rack mapping
YarnAllocationHandler.populateRackInfo(conf, hostname)
nodeLocal
}
case AllocationType.RACK => {
val rack = resource
createResourceRequestImpl(rack, numWorkers, priority)
}
case AllocationType.ANY => {
createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority)
}
case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType)
}
}
private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = {
val rsrcRequest = Records.newRecord(classOf[ResourceRequest])
val memCapability = Records.newRecord(classOf[Resource])
// There probably is some overhead here, let's reserve a bit more memory.
memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
rsrcRequest.setCapability(memCapability)
val pri = Records.newRecord(classOf[Priority])
pri.setPriority(priority)
rsrcRequest.setPriority(pri)
rsrcRequest.setHostName(hostname)
rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0))
rsrcRequest
}
def createReleasedContainerList(): ArrayBuffer[ContainerId] = {
val retval = new ArrayBuffer[ContainerId](1)
// iterator on COW list ...
for (container <- releasedContainerList.iterator()){
retval += container
}
// remove from the original list.
if (! retval.isEmpty) {
releasedContainerList.removeAll(retval)
for (v <- retval) pendingReleaseContainers.put(v, true)
logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " +
pendingReleaseContainers)
}
retval
}
}
object YarnAllocationHandler {
val ANY_HOST = "*"
// all requests are issued with same priority : we do not (yet) have any distinction between
// request types (like map/reduce in hadoop for example)
val PRIORITY = 1
// Additional memory overhead - in mb
val MEMORY_OVERHEAD = 384
// host to rack map - saved from allocation requests
// We are expecting this not to change.
// Note that it is possible for this to change : and RM will indicate that to us via update
// response to allocate. But we are punting on handling that for now.
private val hostToRack = new ConcurrentHashMap[String, String]()
private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
def newAllocator(conf: Configuration,
resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
args: ApplicationMasterArguments,
map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers,
args.workerMemory, args.workerCores, hostToCount, rackToCount)
}
def newAllocator(conf: Configuration,
resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
maxWorkers: Int, workerMemory: Int, workerCores: Int,
map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers,
workerMemory, workerCores, hostToCount, rackToCount)
}
// A simple method to copy the split info map.
private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) :
// host to count, rack to count
(Map[String, Int], Map[String, Int]) = {
if (input == null) return (Map[String, Int](), Map[String, Int]())
val hostToCount = new HashMap[String, Int]
val rackToCount = new HashMap[String, Int]
for ((host, splits) <- input) {
val hostCount = hostToCount.getOrElse(host, 0)
hostToCount.put(host, hostCount + splits.size)
val rack = lookupRack(conf, host)
if (rack != null){
val rackCount = rackToCount.getOrElse(host, 0)
rackToCount.put(host, rackCount + splits.size)
}
}
(hostToCount.toMap, rackToCount.toMap)
}
def lookupRack(conf: Configuration, host: String): String = {
if (! hostToRack.contains(host)) populateRackInfo(conf, host)
hostToRack.get(host)
}
def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
val set = rackToHostSet.get(rack)
if (set == null) return None
// No better way to get a Set[String] from JSet ?
val convertedSet: collection.mutable.Set[String] = set
Some(convertedSet.toSet)
}
def populateRackInfo(conf: Configuration, hostname: String) {
Utils.checkHost(hostname)
if (!hostToRack.containsKey(hostname)) {
// If there are repeated failures to resolve, all to an ignore list ?
val rackInfo = RackResolver.resolve(conf, hostname)
if (rackInfo != null && rackInfo.getNetworkLocation != null) {
val rack = rackInfo.getNetworkLocation
hostToRack.put(hostname, rack)
if (! rackToHostSet.containsKey(rack)) {
rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
}
rackToHostSet.get(rack).add(hostname)
// Since RackResolver caches, we are disabling this for now ...
} /* else {
// right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
hostToRack.put(hostname, null)
} */
}
}
}

View file

@ -0,0 +1,42 @@
package spark.scheduler.cluster
import spark._
import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
import org.apache.hadoop.conf.Configuration
/**
*
* This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
*/
private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
def this(sc: SparkContext) = this(sc, new Configuration())
// Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
// Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
// Subsequent creations are ignored - since nodes are already allocated by then.
// By default, rack is unknown
override def getRackForHost(hostPort: String): Option[String] = {
val host = Utils.parseHostPort(hostPort)._1
val retval = YarnAllocationHandler.lookupRack(conf, host)
if (retval != null) Some(retval) else None
}
// By default, if rack is unknown, return nothing
override def getCachedHostsForRack(rack: String): Option[Set[String]] = {
if (rack == None || rack == null) return None
YarnAllocationHandler.fetchCachedHostsForRack(rack)
}
override def postStartHook() {
val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
if (sparkContextInitialized){
// Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
Thread.sleep(3000L)
}
logInfo("YarnClusterScheduler.postStartHook done")
}
}

View file

@ -4,4 +4,7 @@ trait HadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
jobId, isMap, taskId, attemptId)
}

View file

@ -7,4 +7,7 @@ trait HadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
jobId, isMap, taskId, attemptId)
}

View file

@ -0,0 +1,23 @@
package spark.deploy
import org.apache.hadoop.conf.Configuration
/**
* Contains util methods to interact with Hadoop from spark.
*/
object SparkHadoopUtil {
def getUserNameFromEnvironment(): String = {
// defaulting to -D ...
System.getProperty("user.name")
}
def runAsUser(func: (Product) => Unit, args: Product) {
// Add support, if exists - for now, simply run func !
func(args)
}
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
def newConfiguration(): Configuration = new Configuration()
}

View file

@ -0,0 +1,72 @@
package spark.network.netty;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class FileClient {
private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
private FileClientHandler handler = null;
private Channel channel = null;
private Bootstrap bootstrap = null;
private int connectTimeout = 60*1000; // 1 min
public FileClient(FileClientHandler handler, int connectTimeout) {
this.handler = handler;
this.connectTimeout = connectTimeout;
}
public void init() {
bootstrap = new Bootstrap();
bootstrap.group(new OioEventLoopGroup())
.channel(OioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
.handler(new FileClientChannelInitializer(handler));
}
public void connect(String host, int port) {
try {
// Start the connection attempt.
channel = bootstrap.connect(host, port).sync().channel();
// ChannelFuture cf = channel.closeFuture();
//cf.addListener(new ChannelCloseListener(this));
} catch (InterruptedException e) {
close();
}
}
public void waitForClose() {
try {
channel.closeFuture().sync();
} catch (InterruptedException e) {
LOG.warn("FileClient interrupted", e);
}
}
public void sendRequest(String file) {
//assert(file == null);
//assert(channel == null);
channel.write(file + "\r\n");
}
public void close() {
if(channel != null) {
channel.close();
channel = null;
}
if ( bootstrap!=null) {
bootstrap.shutdown();
bootstrap = null;
}
}
}

View file

@ -0,0 +1,24 @@
package spark.network.netty;
import io.netty.buffer.BufType;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.string.StringEncoder;
class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> {
private FileClientHandler fhandler;
public FileClientChannelInitializer(FileClientHandler handler) {
fhandler = handler;
}
@Override
public void initChannel(SocketChannel channel) {
// file no more than 2G
channel.pipeline()
.addLast("encoder", new StringEncoder(BufType.BYTE))
.addLast("handler", fhandler);
}
}

View file

@ -0,0 +1,43 @@
package spark.network.netty;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter;
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
private FileHeader currentHeader = null;
private volatile boolean handlerCalled = false;
public boolean isComplete() {
return handlerCalled;
}
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
public abstract void handleError(String blockId);
@Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
// Use direct buffer if possible.
return ctx.alloc().ioBuffer();
}
@Override
public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) {
// get header
if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) {
currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE()));
}
// get file
if(in.readableBytes() >= currentHeader.fileLen()) {
handle(ctx, in, currentHeader);
handlerCalled = true;
currentHeader = null;
ctx.close();
}
}
}

View file

@ -0,0 +1,86 @@
package spark.network.netty;
import java.net.InetSocketAddress;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioServerSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Server that accept the path of a file an echo back its content.
*/
class FileServer {
private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
private ServerBootstrap bootstrap = null;
private ChannelFuture channelFuture = null;
private int port = 0;
private Thread blockingThread = null;
public FileServer(PathResolver pResolver, int port) {
InetSocketAddress addr = new InetSocketAddress(port);
// Configure the server.
bootstrap = new ServerBootstrap();
bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup())
.channel(OioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 100)
.option(ChannelOption.SO_RCVBUF, 1500)
.childHandler(new FileServerChannelInitializer(pResolver));
// Start the server.
channelFuture = bootstrap.bind(addr);
try {
// Get the address we bound to.
InetSocketAddress boundAddress =
((InetSocketAddress) channelFuture.sync().channel().localAddress());
this.port = boundAddress.getPort();
} catch (InterruptedException ie) {
this.port = 0;
}
}
/**
* Start the file server asynchronously in a new thread.
*/
public void start() {
blockingThread = new Thread() {
public void run() {
try {
channelFuture.channel().closeFuture().sync();
LOG.info("FileServer exiting");
} catch (InterruptedException e) {
LOG.error("File server start got interrupted", e);
}
// NOTE: bootstrap is shutdown in stop()
}
};
blockingThread.setDaemon(true);
blockingThread.start();
}
public int getPort() {
return port;
}
public void stop() {
// Close the bound channel.
if (channelFuture != null) {
channelFuture.channel().close();
channelFuture = null;
}
// Shutdown bootstrap.
if (bootstrap != null) {
bootstrap.shutdown();
bootstrap = null;
}
// TODO: Shutdown all accepted channels as well ?
}
}

View file

@ -0,0 +1,25 @@
package spark.network.netty;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.DelimiterBasedFrameDecoder;
import io.netty.handler.codec.Delimiters;
import io.netty.handler.codec.string.StringDecoder;
class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> {
PathResolver pResolver;
public FileServerChannelInitializer(PathResolver pResolver) {
this.pResolver = pResolver;
}
@Override
public void initChannel(SocketChannel channel) {
channel.pipeline()
.addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter()))
.addLast("strDecoder", new StringDecoder())
.addLast("handler", new FileServerHandler(pResolver));
}
}

View file

@ -0,0 +1,65 @@
package spark.network.netty;
import java.io.File;
import java.io.FileInputStream;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
PathResolver pResolver;
public FileServerHandler(PathResolver pResolver){
this.pResolver = pResolver;
}
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockId) {
String path = pResolver.getAbsolutePath(blockId);
// if getFilePath returns null, close the channel
if (path == null) {
//ctx.close();
return;
}
File file = new File(path);
if (file.exists()) {
if (!file.isFile()) {
//logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
long length = file.length();
if (length > Integer.MAX_VALUE || length <= 0) {
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
//logger.info("Sending block "+blockId+" filelen = "+len);
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
.getChannel(), 0, file.length()));
} catch (Exception e) {
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
//logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
cause.printStackTrace();
ctx.close();
}
}

View file

@ -0,0 +1,12 @@
package spark.network.netty;
public interface PathResolver {
/**
* Get the absolute path of the file
*
* @param fileId
* @return the absolute path of file
*/
public String getAbsolutePath(String fileId);
}

View file

@ -3,18 +3,25 @@ package spark
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import spark.executor.{ShuffleReadMetrics, TaskMetrics}
import spark.serializer.Serializer
import spark.storage.BlockManagerId
import spark.util.CompletionIterator
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
override def fetch[K, V](shuffleId: Int, reduceId: Int) = {
override def fetch[K, V](
shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = {
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
@ -45,6 +52,20 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
}
}
}
blockManager.getMultiple(blocksByAddress).flatMap(unpackBlock)
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
metrics.shuffleReadMetrics = Some(shuffleMetrics)
})
}
}

View file

@ -27,7 +27,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (loading.contains(key)) {
logInfo("Loading contains " + key + ", waiting...")
while (loading.contains(key)) {
try {loading.wait()} catch {case _ =>}
try {loading.wait()} catch {case _ : Throwable =>}
}
logInfo("Loading no longer contains " + key + ", so returning cached result")
// See whether someone else has successfully loaded it. The main way this would fail

View file

@ -5,15 +5,22 @@ import java.lang.reflect.Field
import scala.collection.mutable.Map
import scala.collection.mutable.Set
import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import org.objectweb.asm.Opcodes._
import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
private[spark] object ClosureCleaner extends Logging {
// Get an ASM class reader for a given class from the JAR that loaded it
private def getClassReader(cls: Class[_]): ClassReader = {
new ClassReader(cls.getResourceAsStream(
cls.getName.replaceFirst("^.*\\.", "") + ".class"))
// Copy data over, before delegating to ClassReader - else we can run out of open file handles.
val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
val resourceStream = cls.getResourceAsStream(className)
// todo: Fixme - continuing with earlier behavior ...
if (resourceStream == null) return new ClassReader(resourceStream)
val baos = new ByteArrayOutputStream(128)
Utils.copyStream(resourceStream, baos, true)
new ClassReader(new ByteArrayInputStream(baos.toByteArray))
}
// Check whether a class represents a Scala closure
@ -154,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging {
}
}
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
return new EmptyVisitor {
return new MethodVisitor(ASM4) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
@ -180,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten
}
}
private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String,
@ -190,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
return new EmptyVisitor {
return new MethodVisitor(ASM4) {
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
val argTypes = Type.getArgumentTypes(desc)

View file

@ -25,10 +25,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* @param shuffleId the shuffle id
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializerClass class name of the serializer to use
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[(K, V)],
val partitioner: Partitioner)
val partitioner: Partitioner,
val serializerClass: String = null)
extends Dependency(rdd) {
val shuffleId: Int = rdd.context.newShuffleId()

View file

@ -3,18 +3,25 @@ package spark
import spark.storage.BlockManagerId
private[spark] class FetchFailedException(
val bmAddress: BlockManagerId,
val shuffleId: Int,
val mapId: Int,
val reduceId: Int,
taskEndReason: TaskEndReason,
message: String,
cause: Throwable)
extends Exception {
override def getMessage(): String =
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) =
this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
cause)
def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
this(FetchFailed(null, shuffleId, -1, reduceId),
"Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)
override def getMessage(): String = message
override def getCause(): Throwable = cause
def toTaskEndReason: TaskEndReason =
FetchFailed(bmAddress, shuffleId, mapId, reduceId)
def toTaskEndReason: TaskEndReason = taskEndReason
}

View file

@ -2,14 +2,10 @@ package org.apache.hadoop.mapred
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import org.apache.hadoop.util.ReflectionUtils
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import java.text.SimpleDateFormat
import java.text.NumberFormat
import java.io.IOException
import java.net.URI
import java.util.Date
import spark.Logging
@ -24,7 +20,7 @@ import spark.SerializableWritable
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable {
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@ -106,6 +102,12 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe
}
}
def commitJob() {
// always ? Or if cmtr.needsTaskCommit ?
val cmtr = getOutputCommitter()
cmtr.commitJob(getJobContext())
}
def cleanup() {
getOutputCommitter().cleanupJob(getJobContext())
}

View file

@ -157,27 +157,34 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
// Register maps with a special serializer since they have complex internal structure
class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any])
extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {
extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {
//hack, look at https://groups.google.com/forum/#!msg/kryo-users/Eu5V4bxCfws/k-8UQ22y59AJ
private final val FAKE_REFERENCE = new Object()
override def write(
kryo: Kryo,
output: KryoOutput,
obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
kryo: Kryo,
output: KryoOutput,
obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
val map = obj.asInstanceOf[scala.collection.Map[Any, Any]]
kryo.writeObject(output, map.size.asInstanceOf[java.lang.Integer])
output.writeInt(map.size)
for ((k, v) <- map) {
kryo.writeClassAndObject(output, k)
kryo.writeClassAndObject(output, v)
}
}
override def read (
kryo: Kryo,
input: KryoInput,
cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
kryo: Kryo,
input: KryoInput,
cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
: Array[(Any, Any)] => scala.collection.Map[Any, Any] = {
val size = kryo.readObject(input, classOf[java.lang.Integer]).intValue
kryo.reference(FAKE_REFERENCE)
val size = input.readInt()
val elems = new Array[(Any, Any)](size)
for (i <- 0 until size)
elems(i) = (kryo.readClassAndObject(input), kryo.readClassAndObject(input))
for (i <- 0 until size) {
val k = kryo.readClassAndObject(input)
val v = kryo.readClassAndObject(input)
elems(i)=(k,v)
}
buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]]
}
}

View file

@ -68,6 +68,10 @@ trait Logging {
if (log.isErrorEnabled) log.error(msg, throwable)
}
protected def isTraceEnabled(): Boolean = {
log.isTraceEnabled
}
// Method for ensuring that logging is initialized, to avoid having multiple
// threads do it concurrently (as SLF4J initialization is not thread safe).
protected def initLogging() { log }

View file

@ -1,7 +1,6 @@
package spark
import java.io._
import java.util.concurrent.ConcurrentHashMap
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
@ -12,8 +11,7 @@ import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
import akka.util.Timeout
import akka.util.duration._
import spark.scheduler.MapStatus
import spark.storage.BlockManagerId
@ -38,11 +36,14 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
}
}
private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging {
private[spark] class MapOutputTracker extends Logging {
val timeout = 10.seconds
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
@ -51,19 +52,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheGeneration = generation
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val actorName: String = "MapOutputTracker"
var trackerActor: ActorRef = if (isDriver) {
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
logInfo("Registered MapOutputTrackerActor actor")
actor
} else {
val ip = System.getProperty("spark.driver.host", "localhost")
val port = System.getProperty("spark.driver.port", "7077").toInt
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
actorSystem.actorFor(url)
}
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
@ -87,10 +76,9 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.get(shuffleId) != None) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
@ -111,8 +99,9 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var array = mapStatuses(shuffleId)
if (array != null) {
var arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
var array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
@ -125,13 +114,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
}
// Remembers which map output locations are currently being fetched on a worker
val fetching = new HashSet[Int]
private val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
@ -142,31 +132,48 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
case e: InterruptedException =>
}
}
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
} else {
}
// Either while we waited the fetch happened successfully, or
// someone fetched it in between the get and the fetching.synchronized.
fetchedStatuses = mapStatuses.get(shuffleId).orNull
if (fetchedStatuses == null) {
// We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val host = System.getProperty("spark.hostname", Utils.localHostName)
// This try-finally prevents hangs due to timeouts:
var fetchedStatuses: Array[MapStatus] = null
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val hostPort = Utils.localHostPort()
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
if (fetchedStatuses != null) {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
}
else{
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
} else {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
}
@ -204,7 +211,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
generationLock.synchronized {
if (newGen > generation) {
logInfo("Updating generation to " + newGen + " and clearing cache")
mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
mapStatuses.clear()
generation = newGen
}
}
@ -242,10 +250,13 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
objOut.writeObject(statuses)
// Since statuses can be modified in parallel, sync on it
statuses.synchronized {
objOut.writeObject(statuses)
}
objOut.close()
out.toByteArray
}
@ -253,7 +264,10 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
objIn.readObject().asInstanceOf[Array[MapStatus]]
objIn.readObject().
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
// comment this out - nulls could be due to missing location ?
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
}
}
@ -263,14 +277,11 @@ private[spark] object MapOutputTracker {
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
def convertMapStatuses(
private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
if (statuses == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
assert (statuses != null)
statuses.map {
status =>
if (status == null) {

View file

@ -1,5 +1,6 @@
package spark
import java.nio.ByteBuffer
import java.util.{Date, HashMap => JHashMap}
import java.text.SimpleDateFormat
@ -10,6 +11,8 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.mapred.FileOutputFormat
import org.apache.hadoop.mapred.HadoopWriter
@ -17,12 +20,13 @@ import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil}
import spark.partial.BoundedDouble
import spark.partial.PartialResult
import spark.rdd._
import spark.SparkContext._
import spark.Partitioner._
/**
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
@ -51,7 +55,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C,
partitioner: Partitioner,
mapSideCombine: Boolean = true): RDD[(K, C)] = {
mapSideCombine: Boolean = true,
serializerClass: String = null): RDD[(K, C)] = {
if (getKeyClass().isArray) {
if (mapSideCombine) {
throw new SparkException("Cannot use map-side combining with array keys.")
@ -60,19 +65,18 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V](self, partitioner)
val values = new ShuffledRDD[K, V](self, partitioner, serializerClass)
values.mapPartitions(aggregator.combineValuesByKey(_), true)
}
}
@ -87,6 +91,42 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
}
/**
* Merge the values for each key using an associative function and a neutral "zero value" which may
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)
// When deserializing, use a lazy val to create just one instance of the serializer per task
lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
}
/**
* Merge the values for each key using an associative function and a neutral "zero value" which may
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = {
foldByKey(zeroValue, new HashPartitioner(numPartitions))(func)
}
/**
* Merge the values for each key using an associative function and a neutral "zero value" which may
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = {
foldByKey(zeroValue, defaultPartitioner(self))(func)
}
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
@ -156,11 +196,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
*/
def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
// groupByKey shouldn't use map side combine because map side combine does not
// reduce the amount of data shuffled and requires all map side data be inserted
// into a hash table, leading to more objects in the old gen.
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
bufs.asInstanceOf[RDD[(K, Seq[V])]]
}
@ -248,8 +290,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the default
* parallelism level.
* Simplified version of combineByKey that hash-partitions the resulting RDD using the
* existing partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C)
: RDD[(K, C)] = {
@ -259,7 +301,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level.
* "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
* parallelism level.
*/
def reduceByKey(func: (V, V) => V): RDD[(K, V)] = {
reduceByKey(defaultPartitioner(self), func)
@ -267,7 +310,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with the default parallelism level.
* resulting RDD with the existing partitioner/parallelism level.
*/
def groupByKey(): RDD[(K, Seq[V])] = {
groupByKey(defaultPartitioner(self))
@ -295,7 +338,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
* using the default level of parallelism.
* using the existing partitioner/parallelism level.
*/
def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = {
leftOuterJoin(other, defaultPartitioner(self, other))
@ -315,7 +358,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD using the default parallelism level.
* RDD using the existing partitioner/parallelism level.
*/
def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = {
rightOuterJoin(other, defaultPartitioner(self, other))
@ -439,15 +482,21 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
/**
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs. If any of
* the RDDs already has a partitioner, choose that one, otherwise use a default HashPartitioner.
* Return an RDD with the pairs from `this` whose keys are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
def defaultPartitioner(rdds: RDD[_]*): Partitioner = {
for (r <- rdds if r.partitioner != None) {
return r.partitioner.get
}
return new HashPartitioner(self.context.defaultParallelism)
}
def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
subtractByKey(other, new HashPartitioner(numPartitions))
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
new SubtractedRDD[K, V, W](self, other, p)
/**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
@ -479,6 +528,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
}
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD. Compress the result with the
* supplied codec.
*/
def saveAsHadoopFile[F <: OutputFormat[K, V]](
path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
}
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
@ -510,8 +569,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
val attemptId = new TaskAttemptID(jobtrackerID,
stageId, false, context.splitId, attemptNumber)
val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@ -530,14 +588,29 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* however we're only going to use this local OutputCommitter for
* setupJob/commitJob, so we just use a dummy "map" task.
*/
val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0)
val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0)
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
val count = self.context.runJob(self, writeShard _).sum
jobCommitter.commitJob(jobTaskContext)
jobCommitter.cleanupJob(jobTaskContext)
}
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD. Compress with the supplied codec.
*/
def saveAsHadoopFile(
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
codec: Class[_ <: CompressionCodec]) {
saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass,
new JobConf(self.context.hadoopConfiguration), Some(codec))
}
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
@ -547,11 +620,19 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
conf: JobConf = new JobConf(self.context.hadoopConfiguration),
codec: Option[Class[_ <: CompressionCodec]] = None) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
conf.set("mapred.output.format.class", outputFormatClass.getName)
for (c <- codec) {
conf.setCompressMapOutput(true)
conf.set("mapred.output.compress", "true")
conf.setMapOutputCompressorClass(c)
conf.set("mapred.output.compression.codec", c.getCanonicalName)
conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString)
}
conf.setOutputCommitter(classOf[FileOutputCommitter])
FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf))
saveAsHadoopDataset(conf)
@ -602,6 +683,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
}
self.context.runJob(self, writeToFile _)
writer.commitJob()
writer.cleanup()
}
@ -609,7 +691,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* Return an RDD with the keys of each tuple.
*/
def keys: RDD[K] = self.map(_._1)
/**
* Return an RDD with the values of each tuple.
*/

View file

@ -9,6 +9,35 @@ abstract class Partitioner extends Serializable {
def getPartition(key: Any): Int
}
object Partitioner {
/**
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
*
* If any of the RDDs already has a partitioner, choose that one.
*
* Otherwise, we use a default HashPartitioner. For the number of partitions, if
* spark.default.parallelism is set, then we'll use the value from SparkContext
* defaultParallelism, otherwise we'll use the max number of upstream partitions.
*
* Unless spark.default.parallelism is set, He number of partitions will be the
* same as the number of partitions in the largest upstream RDD, as this should
* be least likely to cause out-of-memory errors.
*
* We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
*/
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
for (r <- bySize if r.partitioner != None) {
return r.partitioner.get
}
if (System.getProperty("spark.default.parallelism") != null) {
return new HashPartitioner(rdd.context.defaultParallelism)
} else {
return new HashPartitioner(bySize.head.partitions.size)
}
}
}
/**
* A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
*

View file

@ -1,21 +1,21 @@
package spark
import java.net.URL
import java.util.{Date, Random}
import java.util.{HashMap => JHashMap}
import java.util.Random
import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
import spark.broadcast.Broadcast
import spark.Partitioner._
import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
import spark.partial.GroupedCountEvaluator
@ -30,10 +30,14 @@ import spark.rdd.MapPartitionsRDD
import spark.rdd.MapPartitionsWithIndexRDD
import spark.rdd.PipedRDD
import spark.rdd.SampledRDD
import spark.rdd.SubtractedRDD
import spark.rdd.ShuffledRDD
import spark.rdd.UnionRDD
import spark.rdd.ZippedRDD
import spark.rdd.ZippedPartitionsRDD2
import spark.rdd.ZippedPartitionsRDD3
import spark.rdd.ZippedPartitionsRDD4
import spark.storage.StorageLevel
import spark.util.BoundedPriorityQueue
import SparkContext._
@ -102,7 +106,7 @@ abstract class RDD[T: ClassManifest](
// =======================================================================
/** A unique ID for this RDD (within its SparkContext). */
val id = sc.newRddId()
val id: Int = sc.newRddId()
/** A friendly name for this RDD */
var name: String = null
@ -113,9 +117,18 @@ abstract class RDD[T: ClassManifest](
this
}
/** User-defined generator of this RDD*/
var generator = Utils.getCallSiteInfo.firstUserClass
/** Reset generator*/
def setGenerator(_generator: String) = {
generator = _generator
}
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
* it is computed. This can only be used to assign a new storage level if the RDD does not
* have a storage level set yet..
*/
def persist(newLevel: StorageLevel): RDD[T] = {
// TODO: Handle changes of StorageLevel
@ -135,6 +148,20 @@ abstract class RDD[T: ClassManifest](
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): RDD[T] = persist()
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*
* @param blocking Whether to block until all blocks are deleted.
* @return This RDD.
*/
def unpersist(blocking: Boolean = true): RDD[T] = {
logInfo("Removing RDD " + id + " from persistence list")
sc.env.blockManager.master.removeRdd(id, blocking)
sc.persistentRdds.remove(id)
storageLevel = StorageLevel.NONE
this
}
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel = storageLevel
@ -236,7 +263,14 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int): RDD[T] = new CoalescedRDD(this, numPartitions)
def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T] = {
if (shuffle) {
// include a shuffle step so that our upstream tasks are still distributed
new CoalescedRDD(new ShuffledRDD(map(x => (x, null)), new HashPartitioner(numPartitions)), numPartitions).keys
} else {
new CoalescedRDD(this, numPartitions)
}
}
/**
* Return a sampled subset of this RDD.
@ -247,8 +281,8 @@ abstract class RDD[T: ClassManifest](
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
var fraction = 0.0
var total = 0
var multiplier = 3.0
var initialCount = count()
val multiplier = 3.0
val initialCount = count()
var maxSelected = 0
if (initialCount > Integer.MAX_VALUE - 1) {
@ -300,19 +334,26 @@ abstract class RDD[T: ClassManifest](
*/
def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
/**
* Return an RDD of grouped items.
*/
def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] =
groupBy[K](f, defaultPartitioner(this))
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = {
val cleanF = sc.clean(f)
this.map(t => (cleanF(t), t)).groupByKey(numPartitions)
}
def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
groupBy(f, new HashPartitioner(numPartitions))
/**
* Return an RDD of grouped items.
*/
def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = groupBy[K](f, sc.defaultParallelism)
def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
val cleanF = sc.clean(f)
this.map(t => (cleanF(t), t)).groupByKey(p)
}
/**
* Return an RDD created by piping elements to a forked external process.
@ -322,13 +363,36 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
def pipe(command: String, env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
/**
* Return an RDD created by piping elements to a forked external process.
* The print behavior can be customized by providing two functions.
*
* @param command command to run in forked process.
* @param env environment variables to set.
* @param printPipeContext Before piping elements, this function is called as an oppotunity
* to pipe context data. Print line function (like out.println) will be
* passed as printPipeContext's parameter.
* @param printRDDElement Use this function to customize how to pipe elements. This function
* will be called with each RDD element as the 1st parameter, and the
* print line function (like out.println()) as the 2nd parameter.
* An example of pipe the RDD data of groupBy() in a streaming way,
* instead of constructing a huge String to concat all the elements:
* def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
* for (e <- record._2){f(e)}
* @return the result RDD
*/
def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
def pipe(
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
/**
* Return a new RDD by applying a function to each partition of this RDD.
@ -350,12 +414,68 @@ abstract class RDD[T: ClassManifest](
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
@deprecated("use mapPartitionsWithIndex")
@deprecated("use mapPartitionsWithIndex", "0.7.0")
def mapPartitionsWithSplit[U: ClassManifest](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
/**
* Maps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
(f:(T, A) => U): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val a = constructA(index)
iter.map(t => f(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
* FlatMaps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
(f:(T, A) => Seq[U]): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val a = constructA(index)
iter.flatMap(t => f(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
* Applies f to each element of this RDD, where f takes an additional parameter of type A.
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def foreachWith[A: ClassManifest](constructA: Int => A)
(f:(T, A) => Unit) {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val a = constructA(index)
iter.map(t => {f(t, a); t})
}
(new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
}
/**
* Filters this RDD with p, where p takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def filterWith[A: ClassManifest](constructA: Int => A)
(p:(T, A) => Boolean): RDD[T] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val a = constructA(index)
iter.filter(t => p(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
}
/**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of
@ -364,6 +484,31 @@ abstract class RDD[T: ClassManifest](
*/
def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
* applying a function to the zipped partitions. Assumes that all the RDDs have the
* *same number of partitions*, but does *not* require them to have the same number
* of elements in each partition.
*/
def zipPartitions[B: ClassManifest, V: ClassManifest](
f: (Iterator[T], Iterator[B]) => Iterator[V],
rdd2: RDD[B]): RDD[V] =
new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2)
def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest](
f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V],
rdd2: RDD[B],
rdd3: RDD[C]): RDD[V] =
new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3)
def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest](
f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
rdd2: RDD[B],
rdd3: RDD[C],
rdd4: RDD[D]): RDD[V] =
new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4)
// Actions (launch a job to return a value to the user program)
/**
@ -374,6 +519,14 @@ abstract class RDD[T: ClassManifest](
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
}
/**
* Return an array that contains all of the elements in this RDD.
*/
@ -396,7 +549,7 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD with the elements from `this` that are not in `other`.
*
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@ -412,7 +565,23 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p)
def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
if (partitioner == Some(p)) {
// Our partitioner knows how to handle T (which, since we have a partitioner, is
// really (K, V)) so make a new Partitioner that will de-tuple our fake tuples
val p2 = new Partitioner() {
override def numPartitions = p.numPartitions
override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
}
// Unfortunately, since we're making a new p2, we'll get ShuffleDependencies
// anyway, and when calling .keys, will not have a partitioner set, even though
// the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be
// partitioned by the right/real keys (e.g. p).
this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys
} else {
this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys
}
}
/**
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
@ -587,6 +756,24 @@ abstract class RDD[T: ClassManifest](
case _ => throw new UnsupportedOperationException("empty collection")
}
/**
* Returns the top K elements from this RDD as defined by
* the specified implicit Ordering[T].
* @param num the number of top elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
mapPartitions { items =>
val queue = new BoundedPriorityQueue[T](num)
queue ++= items
Iterator.single(queue)
}.reduce { (queue1, queue2) =>
queue1 ++= queue2
queue1
}.toArray
}
/**
* Save this RDD as a text file, using string representations of elements.
*/
@ -595,6 +782,14 @@ abstract class RDD[T: ClassManifest](
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
}
/**
* Save this RDD as a compressed text file, using string representations of elements.
*/
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
this.map(x => (NullWritable.get(), new Text(x.toString)))
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
}
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
@ -653,7 +848,7 @@ abstract class RDD[T: ClassManifest](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
private[spark] val origin = Utils.getSparkCallSite
private[spark] val origin = Utils.formatSparkCallSite
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]

View file

@ -1,6 +1,7 @@
package spark
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
import rdd.{CheckpointRDD, CoalescedRDD}
import scheduler.{ResultTask, ShuffleMapTask}
@ -62,14 +63,20 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
}
}
// Create the output path for the checkpoint
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
val fs = path.getFileSystem(new Configuration())
if (!fs.mkdirs(path)) {
throw new SparkException("Failed to create checkpoint path " + path)
}
// Save to file, and reload it as an RDD
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString
rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
val newRDD = new CheckpointRDD[T](rdd.context, path)
rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _)
val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
// Change the dependencies and partitions of the RDD
RDDCheckpointData.synchronized {
cpFile = Some(path)
cpFile = Some(path.toString)
cpRDD = Some(newRDD)
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
cpState = Checkpointed

View file

@ -18,6 +18,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.OutputCommitter
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.Writable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.BytesWritable
@ -36,17 +37,17 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
self: RDD[(K, V)])
extends Logging
with Serializable {
private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = {
val c = {
if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
classManifest[T].erasure
} else {
// We get the type of the Writable class by looking at the apply method which converts
// from T to Writable. Since we have two apply methods we filter out the one which
// is of the form "java.lang.Object apply(java.lang.Object)"
// is not of the form "java.lang.Object apply(java.lang.Object)"
implicitly[T => Writable].getClass.getDeclaredMethods().filter(
m => m.getReturnType().toString != "java.lang.Object" &&
m => m.getReturnType().toString != "class java.lang.Object" &&
m.getName() == "apply")(0).getReturnType
}
@ -62,24 +63,28 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
* byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
* file system.
*/
def saveAsSequenceFile(path: String) {
def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
def anyToWritable[U <% Writable](u: U): Writable = u
val keyClass = getWritableClass[K]
val valueClass = getWritableClass[V]
val convertKey = !classOf[Writable].isAssignableFrom(self.getKeyClass)
val convertValue = !classOf[Writable].isAssignableFrom(self.getValueClass)
logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" )
logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" )
val format = classOf[SequenceFileOutputFormat[Writable, Writable]]
val jobConf = new JobConf(self.context.hadoopConfiguration)
if (!convertKey && !convertValue) {
self.saveAsHadoopFile(path, keyClass, valueClass, format)
self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec)
} else if (!convertKey && convertValue) {
self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(
path, keyClass, valueClass, format, jobConf, codec)
} else if (convertKey && !convertValue) {
self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format)
self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(
path, keyClass, valueClass, format, jobConf, codec)
} else if (convertKey && convertValue) {
self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
}
self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(
path, keyClass, valueClass, format, jobConf, codec)
}
}
}

View file

@ -1,11 +1,16 @@
package spark
import spark.executor.TaskMetrics
import spark.serializer.Serializer
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
def fetch[K, V](shuffleId: Int, reduceId: Int) : Iterator[(K, V)]
def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)]
/** Stop the fetcher */
def stop() {}

View file

@ -198,7 +198,7 @@ private[spark] object SizeEstimator extends Logging {
val elem = JArray.get(array, index)
size += SizeEstimator.estimate(elem, state.visited)
}
state.size += ((length / 100.0) * size).toLong
state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong
}
}
}

View file

@ -1,52 +1,53 @@
package spark
import java.io._
import java.net.URI
import java.util.Properties
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
import java.net.{URI, URLClassLoader}
import java.lang.ref.WeakReference
import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
import scala.util.DynamicVariable
import scala.collection.mutable.{ConcurrentMap, HashMap}
import akka.actor.Actor
import akka.actor.Actor._
import org.apache.hadoop.fs.{FileUtil, Path}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.SequenceFileInputFormat
import org.apache.hadoop.io.Writable
import org.apache.hadoop.io.IntWritable
import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.FloatWritable
import org.apache.hadoop.io.DoubleWritable
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.ArrayWritable
import org.apache.hadoop.io.BooleanWritable
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.ArrayWritable
import org.apache.hadoop.io.DoubleWritable
import org.apache.hadoop.io.FloatWritable
import org.apache.hadoop.io.IntWritable
import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.SequenceFileInputFormat
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
import org.apache.mesos.{Scheduler, MesosNativeLibrary}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import spark.broadcast._
import spark.deploy.LocalSparkCluster
import spark.partial.ApproximateEvaluator
import spark.partial.PartialResult
import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
import org.apache.mesos.MesosNativeLibrary
import spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import spark.partial.{ApproximateEvaluator, PartialResult}
import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
import spark.scheduler.{DAGScheduler, ResultTask, ShuffleMapTask, SparkListener, SplitInfo, Stage, StageInfo, TaskScheduler}
import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, ClusterScheduler}
import spark.scheduler.local.LocalScheduler
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import storage.BlockManagerUI
import util.{MetadataCleaner, TimeStampedHashMap}
import storage.{StorageStatus, StorageUtils, RDDInfo}
import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@ -64,7 +65,10 @@ class SparkContext(
val appName: String,
val sparkHome: String = null,
val jars: Seq[String] = Nil,
environment: Map[String, String] = Map())
val environment: Map[String, String] = Map(),
// This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too.
// This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host
val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map())
extends Logging {
// Ensure logging is initialized before we spawn any threads
@ -72,7 +76,7 @@ class SparkContext(
// Set Spark driver host and port system properties
if (System.getProperty("spark.driver.host") == null) {
System.setProperty("spark.driver.host", Utils.localIpAddress)
System.setProperty("spark.driver.host", Utils.localHostName())
}
if (System.getProperty("spark.driver.port") == null) {
System.setProperty("spark.driver.port", "0")
@ -99,12 +103,14 @@ class SparkContext(
private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]()
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
// Add each JAR given through the constructor
jars.foreach { addJar(_) }
if (jars != null) {
jars.foreach { addJar(_) }
}
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
@ -116,7 +122,9 @@ class SparkContext(
executorEnvs(key) = value
}
}
executorEnvs ++= environment
if (environment != null) {
executorEnvs ++= environment
}
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
@ -169,6 +177,22 @@ class SparkContext(
}
scheduler
case "yarn-standalone" =>
val scheduler = try {
val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(this).asInstanceOf[ClusterScheduler]
} catch {
// TODO: Enumerate the exact reasons why it can fail
// But irrespective of it, it means we cannot proceed !
case th: Throwable => {
throw new SparkException("YARN mode not available ?", th)
}
}
val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
scheduler.initialize(backend)
scheduler
case _ =>
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
@ -188,12 +212,12 @@ class SparkContext(
}
taskScheduler.start()
private var dagScheduler = new DAGScheduler(taskScheduler)
@volatile private var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
val conf = new Configuration()
val conf = SparkHadoopUtil.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
@ -212,6 +236,22 @@ class SparkContext(
private[spark] var checkpointDir: Option[String] = None
// Thread Local variable that can be used by users to pass information down the stack
private val localProperties = new DynamicVariable[Properties](null)
def initLocalProperties() {
localProperties.value = new Properties()
}
def addLocalProperties(key: String, value: String) {
if(localProperties.value == null) {
localProperties.value = new Properties()
}
localProperties.value.setProperty(key,value)
}
// Post init
taskScheduler.postStartHook()
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
@ -466,13 +506,17 @@ class SparkContext(
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
def addSparkListener(listener: SparkListener) {
dagScheduler.sparkListeners += listener
}
/**
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.ip + ":" + blockManagerId.port, mem)
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
}
@ -480,14 +524,18 @@ class SparkContext(
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
* they take, etc.
*/
def getRDDStorageInfo : Array[RDDInfo] = {
def getRDDStorageInfo: Array[RDDInfo] = {
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
def getStageInfo: Map[Stage,StageInfo] = {
dagScheduler.stageToInfos
}
/**
* Return information about blocks stored in all of the slaves
*/
def getExecutorStorageStatus : Array[StorageStatus] = {
def getExecutorStorageStatus: Array[StorageStatus] = {
env.blockManager.master.getStorageStatus
}
@ -505,13 +553,18 @@ class SparkContext(
* filesystems), or an HTTP, HTTPS or FTP URI.
*/
def addJar(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
case _ => path
if (null == path) {
logWarning("null specified as parameter to addJar",
new SparkException("null specified as parameter to addJar"))
} else {
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
case _ => path
}
addedJars(key) = System.currentTimeMillis
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
addedJars(key) = System.currentTimeMillis
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
/**
@ -524,10 +577,13 @@ class SparkContext(
/** Shut down the SparkContext. */
def stop() {
if (dagScheduler != null) {
// Do this only if not stopped already - best case effort.
// prevent NPE if stopped more than once.
val dagSchedulerCopy = dagScheduler
dagScheduler = null
if (dagSchedulerCopy != null) {
metadataCleaner.cancel()
dagScheduler.stop()
dagScheduler = null
dagSchedulerCopy.stop()
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@ -543,6 +599,7 @@ class SparkContext(
}
}
/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable
@ -572,10 +629,10 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.getSparkCallSite
val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
@ -654,12 +711,11 @@ class SparkContext(
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long
): PartialResult[R] = {
val callSite = Utils.getSparkCallSite
timeout: Long): PartialResult[R] = {
val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout)
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result
}
@ -682,7 +738,7 @@ class SparkContext(
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
val path = new Path(dir)
val fs = path.getFileSystem(new Configuration())
val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
if (!useExisting) {
if (fs.exists(path)) {
throw new Exception("Checkpoint directory '" + path + "' already exists.")
@ -693,7 +749,7 @@ class SparkContext(
checkpointDir = Some(dir)
}
/** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
def defaultParallelism: Int = taskScheduler.defaultParallelism
/** Default min number of partitions for Hadoop RDDs when not given by user */

View file

@ -1,15 +1,19 @@
package spark
import akka.actor.ActorSystem
import akka.actor.ActorSystemImpl
import collection.mutable
import serializer.Serializer
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
import serializer.Serializer
import spark.broadcast.BroadcastManager
import spark.storage.BlockManager
import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
import spark.serializer.{Serializer, SerializerManager}
import spark.util.AkkaUtils
import spark.api.python.PythonWorkerFactory
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
@ -21,6 +25,7 @@ import spark.util.AkkaUtils
class SparkEnv (
val executorId: String,
val actorSystem: ActorSystem,
val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
@ -30,10 +35,16 @@ class SparkEnv (
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer,
val sparkFilesDir: String
) {
val sparkFilesDir: String,
// To be set only as part of initialization of SparkContext.
// (executorId, defaultHostPort) => executorHostPort
// If executorId is NOT found, return defaultHostPort
var executorIdToHostPort: Option[(String, String) => String]) {
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
def stop() {
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
@ -45,6 +56,22 @@ class SparkEnv (
// down, but let's call it anyway in case it gets fixed in a later release
actorSystem.awaitTermination()
}
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
synchronized {
pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create()
}
}
def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
val env = SparkEnv.get
if (env.executorIdToHostPort.isEmpty) {
// default to using host, not host port. Relevant to non cluster modes.
return defaultHostPort
}
env.executorIdToHostPort.get(executorId, defaultHostPort)
}
}
object SparkEnv extends Logging {
@ -73,6 +100,16 @@ object SparkEnv extends Logging {
System.setProperty("spark.driver.port", boundPort.toString)
}
// set only if unset until now.
if (System.getProperty("spark.hostPort", null) == null) {
if (!isDriver){
// unexpected
Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
}
Utils.checkHost(hostname)
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
}
val classLoader = Thread.currentThread.getContextClassLoader
// Create an instance of the class named by the given Java system property, or by
@ -82,24 +119,45 @@ object SparkEnv extends Logging {
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
val serializerManager = new SerializerManager
val driverIp: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
val blockManagerMaster = new BlockManagerMaster(
actorSystem, isDriver, isLocal, driverIp, driverPort)
val serializer = serializerManager.setDefault(
System.getProperty("spark.serializer", "spark.JavaSerializer"))
val closureSerializer = serializerManager.get(
System.getProperty("spark.closure.serializer", "spark.JavaSerializer"))
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
val driverHost: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
Utils.checkHost(driverHost, "Expected hostname")
val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
logInfo("Connecting to " + name + ": " + url)
actorSystem.actorFor(url)
}
}
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new spark.storage.BlockManagerMasterActor(isLocal)))
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager
val broadcastManager = new BroadcastManager(isDriver)
val closureSerializer = instantiateClass[Serializer](
"spark.closure.serializer", "spark.JavaSerializer")
val cacheManager = new CacheManager(blockManager)
val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver)
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
val mapOutputTracker = new MapOutputTracker()
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerActor(mapOutputTracker))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
@ -126,6 +184,7 @@ object SparkEnv extends Logging {
new SparkEnv(
executorId,
actorSystem,
serializerManager,
serializer,
closureSerializer,
cacheManager,
@ -135,6 +194,7 @@ object SparkEnv extends Logging {
blockManager,
connectionManager,
httpFileServer,
sparkFilesDir)
sparkFilesDir,
None)
}
}

View file

@ -1,9 +1,14 @@
package spark
import executor.TaskMetrics
import scala.collection.mutable.ArrayBuffer
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
class TaskContext(
val stageId: Int,
val splitId: Int,
val attemptId: Long,
val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable {
@transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]

View file

@ -14,9 +14,19 @@ private[spark] case object Success extends TaskEndReason
private[spark]
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
private[spark]
case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
private[spark] case class FetchFailed(
bmAddress: BlockManagerId,
shuffleId: Int,
mapId: Int,
reduceId: Int)
extends TaskEndReason
private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason
private[spark] case class ExceptionFailure(
className: String,
description: String,
stackTrace: Array[StackTraceElement])
extends TaskEndReason
private[spark] case class OtherFailure(message: String) extends TaskEndReason
private[spark] case class TaskResultTooBigFailure() extends TaskEndReason

View file

@ -1,23 +1,29 @@
package spark
import java.io._
import java.net._
import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import scala.collection.mutable.ArrayBuffer
import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
import java.util.regex.Pattern
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.collection.JavaConversions._
import scala.io.Source
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
import scala.Some
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
import spark.serializer.SerializerInstance
import spark.deploy.SparkHadoopUtil
/**
* Various utility methods used by Spark.
*/
private object Utils extends Logging {
/** Serialize an object using Java serialization */
def serialize[T](o: T): Array[Byte] = {
val bos = new ByteArrayOutputStream()
@ -68,6 +74,40 @@ private object Utils extends Logging {
return buf
}
private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
// Register the path to be deleted via shutdown hook
def registerShutdownDeleteDir(file: File) {
val absolutePath = file.getAbsolutePath()
shutdownDeletePaths.synchronized {
shutdownDeletePaths += absolutePath
}
}
// Is the path already registered to be deleted via a shutdown hook ?
def hasShutdownDeleteDir(file: File): Boolean = {
val absolutePath = file.getAbsolutePath()
shutdownDeletePaths.synchronized {
shutdownDeletePaths.contains(absolutePath)
}
}
// Note: if file is child of some registered path, while not equal to it, then return true;
// else false. This is to ensure that two shutdown hooks do not try to delete each others
// paths - resulting in IOException and incomplete cleanup.
def hasRootAsShutdownDeleteDir(file: File): Boolean = {
val absolutePath = file.getAbsolutePath()
val retval = shutdownDeletePaths.synchronized {
shutdownDeletePaths.find { path =>
!absolutePath.equals(path) && absolutePath.startsWith(path)
}.isDefined
}
if (retval) {
logInfo("path = " + file + ", already present as root for deletion.")
}
retval
}
/** Create a temporary directory inside the given parent directory */
def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
var attempts = 0
@ -76,8 +116,8 @@ private object Utils extends Logging {
while (dir == null) {
attempts += 1
if (attempts > maxAttempts) {
throw new IOException("Failed to create a temp directory after " + maxAttempts +
" attempts!")
throw new IOException("Failed to create a temp directory (under " + root + ") after " +
maxAttempts + " attempts!")
}
try {
dir = new File(root, "spark-" + UUID.randomUUID.toString)
@ -86,13 +126,17 @@ private object Utils extends Logging {
}
} catch { case e: IOException => ; }
}
registerShutdownDeleteDir(dir)
// Add a shutdown hook to delete the temp dir when the JVM exits
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
override def run() {
Utils.deleteRecursively(dir)
// Attempt to delete if some patch which is parent of this is not already registered.
if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
}
})
return dir
dir
}
/** Copy all data from an InputStream to an OutputStream */
@ -135,40 +179,35 @@ private object Utils extends Logging {
Utils.copyStream(in, out, true)
if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
tempFile.delete()
throw new SparkException("File " + targetFile + " exists and does not match contents of" +
" " + url)
throw new SparkException(
"File " + targetFile + " exists and does not match contents of" + " " + url)
} else {
Files.move(tempFile, targetFile)
}
case "file" | null =>
val sourceFile = if (uri.isAbsolute) {
new File(uri)
} else {
new File(url)
}
if (targetFile.exists && !Files.equal(sourceFile, targetFile)) {
throw new SparkException("File " + targetFile + " exists and does not match contents of" +
" " + url)
} else {
// Remove the file if it already exists
targetFile.delete()
// Symlink the file locally.
if (uri.isAbsolute) {
// url is absolute, i.e. it starts with "file:///". Extract the source
// file's absolute path from the url.
val sourceFile = new File(uri)
logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
// In the case of a local file, copy the local file to the target directory.
// Note the difference between uri vs url.
val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url)
if (targetFile.exists) {
// If the target file already exists, warn the user if
if (!Files.equal(sourceFile, targetFile)) {
throw new SparkException(
"File " + targetFile + " exists and does not match contents of" + " " + url)
} else {
// url is not absolute, i.e. itself is the path to the source file.
logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
FileUtil.symLink(url, targetFile.getAbsolutePath)
// Do nothing if the file contents are the same, i.e. this file has been copied
// previously.
logInfo(sourceFile.getAbsolutePath + " has been previously copied to "
+ targetFile.getAbsolutePath)
}
} else {
// The file does not exist in the target directory. Copy it there.
logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
Files.copy(sourceFile, targetFile)
}
case _ =>
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
val uri = new URI(url)
val conf = new Configuration()
val conf = SparkHadoopUtil.newConfiguration()
val fs = FileSystem.get(uri, conf)
val in = fs.open(new Path(uri))
val out = new FileOutputStream(tempFile)
@ -227,8 +266,10 @@ private object Utils extends Logging {
/**
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
* Note, this is typically not used from within core spark.
*/
lazy val localIpAddress: String = findLocalIpAddress()
lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress)
private def findLocalIpAddress(): String = {
val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
@ -266,6 +307,8 @@ private object Utils extends Logging {
* hostname it reports to the master.
*/
def setCustomHostname(hostname: String) {
// DEBUG code
Utils.checkHost(hostname)
customHostname = Some(hostname)
}
@ -273,7 +316,91 @@ private object Utils extends Logging {
* Get the local machine's hostname.
*/
def localHostName(): String = {
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
customHostname.getOrElse(localIpAddressHostname)
}
def getAddressHostName(address: String): String = {
InetAddress.getByName(address).getHostName
}
def localHostPort(): String = {
val retval = System.getProperty("spark.hostPort", null)
if (retval == null) {
logErrorWithStack("spark.hostPort not set but invoking localHostPort")
return localHostName()
}
retval
}
/*
// Used by DEBUG code : remove when all testing done
private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$")
def checkHost(host: String, message: String = "") {
// Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous !
// if (host.matches("^[0-9]+(\\.[0-9]+)*$")) {
if (ipPattern.matcher(host).matches()) {
Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message)
}
if (Utils.parseHostPort(host)._2 != 0){
Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message)
}
}
// Used by DEBUG code : remove when all testing done
def checkHostPort(hostPort: String, message: String = "") {
val (host, port) = Utils.parseHostPort(hostPort)
checkHost(host)
if (port <= 0){
Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message)
}
}
// Used by DEBUG code : remove when all testing done
def logErrorWithStack(msg: String) {
try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
// temp code for debug
System.exit(-1)
}
*/
// Once testing is complete in various modes, replace with this ?
def checkHost(host: String, message: String = "") {}
def checkHostPort(hostPort: String, message: String = "") {}
// Used by DEBUG code : remove when all testing done
def logErrorWithStack(msg: String) {
try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
}
def getUserNameFromEnvironment(): String = {
SparkHadoopUtil.getUserNameFromEnvironment
}
// Typically, this will be of order of number of nodes in cluster
// If not, we should change it to LRUCache or something.
private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
def parseHostPort(hostPort: String): (String, Int) = {
{
// Check cache first.
var cached = hostPortParseResults.get(hostPort)
if (cached != null) return cached
}
val indx: Int = hostPort.lastIndexOf(':')
// This is potentially broken - when dealing with ipv6 addresses for example, sigh ...
// but then hadoop does not support ipv6 right now.
// For now, we assume that if port exists, then it is valid - not check if it is an int > 0
if (-1 == indx) {
val retval = (hostPort, 0)
hostPortParseResults.put(hostPort, retval)
return retval
}
val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt)
hostPortParseResults.putIfAbsent(hostPort, retval)
hostPortParseResults.get(hostPort)
}
private[spark] val daemonThreadFactory: ThreadFactory =
@ -395,13 +522,14 @@ private object Utils extends Logging {
execute(command, new File("."))
}
private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
val firstUserLine: Int, val firstUserClass: String)
/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*/
def getSparkCallSite: String = {
def getCallSiteInfo: CallSiteInfo = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
@ -413,6 +541,7 @@ private object Utils extends Logging {
var firstUserFile = "<unknown>"
var firstUserLine = 0
var finished = false
var firstUserClass = "<unknown>"
for (el <- trace) {
if (!finished) {
@ -427,13 +556,19 @@ private object Utils extends Logging {
else {
firstUserLine = el.getLineNumber
firstUserFile = el.getFileName
firstUserClass = el.getClassName
finished = true
}
}
}
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}
def formatSparkCallSite = {
val callSiteInfo = getCallSiteInfo
"%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
callSiteInfo.firstUserLine)
}
/**
* Try to find a free port to bind to on the local host. This should ideally never be needed,
* except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)

View file

@ -57,6 +57,12 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
*/
def coalesce(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD =
fromRDD(srdd.coalesce(numPartitions, shuffle))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*

View file

@ -6,6 +6,7 @@ import java.util.Comparator
import scala.Tuple2
import scala.collection.JavaConversions._
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
@ -19,6 +20,7 @@ import spark.OrderedRDDFunctions
import spark.storage.StorageLevel
import spark.HashPartitioner
import spark.Partitioner
import spark.Partitioner._
import spark.RDD
import spark.SparkContext.rddToPairRDDFunctions
@ -65,7 +67,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.coalesce(numPartitions))
def coalesce(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.coalesce(numPartitions))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] =
fromRDD(rdd.coalesce(numPartitions, shuffle))
/**
* Return a sampled subset of this RDD.
@ -159,6 +167,30 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
: PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
/**
* Merge the values for each key using an associative function and a neutral "zero value" which may
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
fromRDD(rdd.foldByKey(zeroValue, partitioner)(func))
/**
* Merge the values for each key using an associative function and a neutral "zero value" which may
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, numPartitions: Int, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
fromRDD(rdd.foldByKey(zeroValue, numPartitions)(func))
/**
* Merge the values for each key using an associative function and a neutral "zero value" which may
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
fromRDD(rdd.foldByKey(zeroValue)(func))
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
@ -241,30 +273,30 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(rdd.rightOuterJoin(other, partitioner))
/**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the default
* parallelism level.
* Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
* partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
implicit val cm: ClassManifest[C] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners))
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd)))
}
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level.
* "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
* parallelism level.
*/
def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = {
val partitioner = rdd.defaultPartitioner(rdd)
fromRDD(reduceByKey(partitioner, func))
fromRDD(reduceByKey(defaultPartitioner(rdd), func))
}
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with the default parallelism level.
* resulting RDD with the existing partitioner/parallelism level.
*/
def groupByKey(): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey()))
@ -289,7 +321,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
* using the default level of parallelism.
* using the existing partitioner/parallelism level.
*/
def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Option[W])] =
fromRDD(rdd.leftOuterJoin(other))
@ -307,7 +339,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD using the default parallelism level.
* RDD using the existing partitioner/parallelism level.
*/
def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Option[V], W)] =
fromRDD(rdd.rightOuterJoin(other))
@ -428,6 +460,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
/** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F],
codec: Class[_ <: CompressionCodec]) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,

View file

@ -14,12 +14,18 @@ JavaRDDLike[T, JavaRDD[T]] {
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
/**
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
* it is computed. This can only be used to assign a new storage level if the RDD does not
* have a storage level set yet..
*/
def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*/
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
// Transformations (return a new RDD)
/**
@ -31,7 +37,7 @@ JavaRDDLike[T, JavaRDD[T]] {
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
@ -43,12 +49,18 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def coalesce(numPartitions: Int): JavaRDD[T] = rdd.coalesce(numPartitions)
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] =
rdd.coalesce(numPartitions, shuffle)
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
@ -57,7 +69,7 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Return an RDD with the elements from `this` that are not in `other`.
*
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@ -74,7 +86,6 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
wrapRDD(rdd.subtract(other, p))
}
object JavaRDD {

View file

@ -1,9 +1,10 @@
package spark.api.java
import java.util.{List => JList}
import java.util.{List => JList, Comparator}
import scala.Tuple2
import scala.collection.JavaConversions._
import org.apache.hadoop.io.compress.CompressionCodec
import spark.{SparkContext, Partition, RDD, TaskContext}
import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
@ -182,6 +183,21 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest)
}
/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
* applying a function to the zipped partitions. Assumes that all the RDDs have the
* *same number of partitions*, but does *not* require them to have the same number
* of elements in each partition.
*/
def zipPartitions[U, V](
f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V],
other: JavaRDDLike[U, _]): JavaRDD[V] = {
def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
JavaRDD.fromRDD(
rdd.zipPartitions(fn, other.rdd)(other.classManifest, f.elementType()))(f.elementType())
}
// Actions (launch a job to return a value to the user program)
/**
@ -295,6 +311,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
/**
* Save this RDD as a compressed text file, using string representations of elements.
*/
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
rdd.saveAsTextFile(path, codec)
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
@ -336,4 +359,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def toDebugString(): String = {
rdd.toDebugString
}
/**
* Returns the top K elements from this RDD as defined by
* the specified Comparator[T].
* @param num the number of top elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
def top(num: Int, comp: Comparator[T]): JList[T] = {
import scala.collection.JavaConversions._
val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp))
val arr: java.util.Collection[T] = topElems.toSeq
new java.util.ArrayList(arr)
}
/**
* Returns the top K elements from this RDD using the
* natural ordering for T.
* @param num the number of top elements to return
* @return an array of top elements
*/
def top(num: Int): JList[T] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
top(num, comp)
}
}

View file

@ -31,8 +31,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
* @param jarFile JAR file to send to the cluster. This can be a path on the local file system
* or an HDFS, HTTP, HTTPS, or FTP URL.
*/
def this(master: String, appName: String, sparkHome: String, jarFile: String) =
this(new SparkContext(master, appName, sparkHome, Seq(jarFile)))

View file

@ -0,0 +1,11 @@
package spark.api.java.function
/**
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
@throws(classOf[Exception])
def call(a: A, b:B) : java.lang.Iterable[C]
def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
}

View file

@ -2,10 +2,9 @@ package spark.api.python
import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Collections}
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
import scala.io.Source
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
@ -16,7 +15,7 @@ import spark.rdd.PipedRDD
private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T],
command: Seq[String],
envVars: java.util.Map[String, String],
envVars: JMap[String, String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest](
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
@ -36,35 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest](
override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
for ((variable, value) <- envVars) {
currentEnvVars.put(variable, value)
}
val proc = pb.start()
val startTime = System.currentTimeMillis
val env = SparkEnv.get
// Start a thread to print the process's stderr to ours
new Thread("stderr reader for " + command) {
override def run() {
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
System.err.println(line)
}
}
}.start()
val worker = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + command) {
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
val dOut = new DataOutputStream(proc.getOutputStream)
val out = new PrintWriter(worker.getOutputStream)
val dOut = new DataOutputStream(worker.getOutputStream)
// Partition index
dOut.writeInt(split.index)
// sparkFilesDir
@ -88,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest](
}
dOut.flush()
out.flush()
proc.getOutputStream.close()
worker.shutdownOutput()
}
}.start()
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream)
val stream = new DataInputStream(worker.getInputStream)
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
_nextObj = read()
if (hasNext) {
// FIXME: can deadlock if worker is waiting for us to
// respond to current message (currently irrelevant because
// output is shutdown before we read any input)
_nextObj = read()
}
obj
}
@ -108,6 +95,17 @@ private[spark] class PythonRDD[T: ClassManifest](
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case -3 =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
read
case -2 =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
@ -115,23 +113,21 @@ private[spark] class PythonRDD[T: ClassManifest](
stream.readFully(obj)
throw new PythonException(new String(obj))
case -1 =>
// We've finished the data section of the output, but we can still read some
// accumulator updates; let's do that, breaking when we get EOFException
while (true) {
val len2 = stream.readInt()
// We've finished the data section of the output, but we can still
// read some accumulator updates; let's do that, breaking when we
// get a negative length record.
var len2 = stream.readInt()
while (len2 >= 0) {
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
len2 = stream.readInt()
}
new Array[Byte](0)
}
} catch {
case eof: EOFException => {
val exitStatus = proc.waitFor()
if (exitStatus != 0) {
throw new Exception("Subprocess exited with status " + exitStatus)
}
new Array[Byte](0)
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
case e => throw e
}
@ -159,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
@ -215,7 +211,7 @@ private[spark] object PythonRDD {
dOut.write(s)
dOut.writeByte(Pickle.STOP)
} else {
throw new Exception("Unexpected RDD type")
throw new SparkException("Unexpected RDD type")
}
}
@ -277,6 +273,8 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte]
*/
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList

View file

@ -0,0 +1,95 @@
package spark.api.python
import java.io.{DataInputStream, IOException}
import java.net.{Socket, SocketException, InetAddress}
import scala.collection.JavaConversions._
import spark._
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
extends Logging {
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
def create(): Socket = {
synchronized {
// Start the daemon if it hasn't been started
startDaemon()
// Attempt to connect, restart and retry once if it fails
try {
new Socket(daemonHost, daemonPort)
} catch {
case exc: SocketException => {
logWarning("Python daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
new Socket(daemonHost, daemonPort)
}
case e => throw e
}
}
}
def stop() {
stopDaemon()
}
private def startDaemon() {
synchronized {
// Is it already running?
if (daemon != null) {
return
}
try {
// Create and start the daemon
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
daemon = pb.start()
daemonPort = new DataInputStream(daemon.getInputStream).readInt()
// Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) {
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME HACK: We copy the stream on the level of bytes to
// attempt to dodge encoding problems.
val in = daemon.getErrorStream
var buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
} catch {
case e => {
stopDaemon()
throw e
}
}
// Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
// detect our disappearance.
}
}
private def stopDaemon() {
synchronized {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}
daemon = null
daemonPort = 0
}
}
}

View file

@ -2,10 +2,11 @@ package spark.deploy
private[spark] class ApplicationDescription(
val name: String,
val cores: Int,
val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */
val memoryPerSlave: Int,
val command: Command,
val sparkHome: String)
val sparkHome: String,
val appUiUrl: String)
extends Serializable {
val user = System.getProperty("user.name", "<unknown>")

View file

@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState
import spark.deploy.master.{WorkerInfo, ApplicationInfo}
import spark.deploy.worker.ExecutorRunner
import scala.collection.immutable.List
import spark.Utils
private[spark] sealed trait DeployMessage extends Serializable
@ -19,7 +20,10 @@ case class RegisterWorker(
memory: Int,
webUiPort: Int,
publicAddress: String)
extends DeployMessage
extends DeployMessage {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
}
private[spark]
case class ExecutorStateChanged(
@ -58,14 +62,16 @@ private[spark]
case class RegisteredApplication(appId: String) extends DeployMessage
private[spark]
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
Utils.checkHostPort(hostPort, "Required hostport")
}
private[spark]
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String],
exitStatus: Option[Int])
private[spark]
case class appKilled(message: String)
case class ApplicationRemoved(message: String)
// Internal message in Client
@ -81,6 +87,9 @@ private[spark]
case class MasterState(host: String, port: Int, workers: Array[WorkerInfo],
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
def uri = "spark://" + host + ":" + port
}
@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState
private[spark]
case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner],
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
}

View file

@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
def write(obj: WorkerInfo) = JsObject(
"id" -> JsString(obj.id),
"host" -> JsString(obj.host),
"port" -> JsNumber(obj.port),
"webuiaddress" -> JsString(obj.webUiAddress),
"cores" -> JsNumber(obj.cores),
"coresused" -> JsNumber(obj.coresUsed),
@ -25,7 +26,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
"starttime" -> JsNumber(obj.startTime),
"id" -> JsString(obj.id),
"name" -> JsString(obj.desc.name),
"cores" -> JsNumber(obj.desc.cores),
"cores" -> JsNumber(obj.desc.maxCores),
"user" -> JsString(obj.desc.user),
"memoryperslave" -> JsNumber(obj.desc.memoryPerSlave),
"submitdate" -> JsString(obj.submitDate.toString))
@ -34,7 +35,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
implicit object AppDescriptionJsonFormat extends RootJsonWriter[ApplicationDescription] {
def write(obj: ApplicationDescription) = JsObject(
"name" -> JsString(obj.name),
"cores" -> JsNumber(obj.cores),
"cores" -> JsNumber(obj.maxCores),
"memoryperslave" -> JsNumber(obj.memoryPerSlave),
"user" -> JsString(obj.user)
)

View file

@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
private val localIpAddress = Utils.localIpAddress
private val localHostname = Utils.localHostName()
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localIpAddress + ":" + masterPort
val masterUrl = "spark://" + localHostname + ":" + masterPort
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
memoryPerWorker, masterUrl, null, Some(workerNum))
workerActorSystems += workerSystem
}

View file

@ -3,6 +3,7 @@ package spark.deploy.client
import spark.deploy._
import akka.actor._
import akka.pattern.ask
import akka.util.Duration
import akka.util.duration._
import akka.pattern.AskTimeoutException
import spark.{SparkException, Logging}
@ -54,10 +55,15 @@ private[spark] class Client(
appId = appId_
listener.connected(appId)
case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) =>
case ApplicationRemoved(message) =>
logError("Master removed our application: %s; stopping client".format(message))
markDisconnected()
context.stop(self)
case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) =>
val fullId = appId + "/" + id
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores))
listener.executorAdded(fullId, workerId, host, cores, memory)
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores))
listener.executorAdded(fullId, workerId, hostPort, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
val fullId = appId + "/" + id
@ -107,7 +113,7 @@ private[spark] class Client(
def stop() {
if (actor != null) {
try {
val timeout = 5.seconds
val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
val future = actor.ask(StopClient)(timeout)
Await.result(future, timeout)
} catch {

View file

@ -12,7 +12,7 @@ private[spark] trait ClientListener {
def disconnected(): Unit
def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
}

View file

@ -16,7 +16,7 @@ private[spark] object TestClient {
System.exit(0)
}
def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {}
def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
}
@ -25,7 +25,7 @@ private[spark] object TestClient {
val url = args(0)
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
val desc = new ApplicationDescription(
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
client.start()

View file

@ -10,7 +10,8 @@ private[spark] class ApplicationInfo(
val id: String,
val desc: ApplicationDescription,
val submitDate: Date,
val driver: ActorRef)
val driver: ActorRef,
val appUiUrl: String)
{
var state = ApplicationState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
@ -37,7 +38,7 @@ private[spark] class ApplicationInfo(
coresGranted -= exec.cores
}
def coresLeft: Int = desc.cores - coresGranted
def coresLeft: Int = desc.maxCores - coresGranted
private var _retryCount = 0
@ -60,4 +61,5 @@ private[spark] class ApplicationInfo(
System.currentTimeMillis() - startTime
}
}
}

View file

@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils}
import spark.util.AkkaUtils
private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
@ -35,18 +35,20 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
var firstApp: Option[ApplicationInfo] = None
Utils.checkHost(host, "Expected hostname")
val masterPublicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else ip
if (envVar != null) envVar else host
}
// As a temporary workaround before better ways of configuring memory, we allow users to set
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
// among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "false").toBoolean
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
override def preStart() {
logInfo("Starting Spark master at spark://" + ip + ":" + port)
logInfo("Starting Spark master at spark://" + host + ":" + port)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
startWebUi()
@ -107,7 +109,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
} else {
logError("Application %s with ID %s failed %d times, removing it".format(
appInfo.desc.name, appInfo.id, appInfo.retryCount))
removeApplication(appInfo)
removeApplication(appInfo, ApplicationState.FAILED)
}
}
}
@ -129,23 +131,23 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
// The disconnected actor could've been either a worker or an app; remove whichever of
// those we have an entry for in the corresponding actor hashmap
actorToWorker.get(actor).foreach(removeWorker)
actorToApp.get(actor).foreach(removeApplication)
actorToApp.get(actor).foreach(finishApplication)
}
case RemoteClientDisconnected(transport, address) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(removeApplication)
addressToApp.get(address).foreach(finishApplication)
}
case RemoteClientShutdown(transport, address) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(removeApplication)
addressToApp.get(address).foreach(finishApplication)
}
case RequestMasterState => {
sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray)
sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray)
}
}
@ -211,13 +213,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
publicAddress: String): WorkerInfo = {
// There may be one or more refs to dead workers on this same node (w/ different ID's), remove them.
workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
workers += worker
idToWorker(worker.id) = worker
@ -242,7 +244,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver)
val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
apps += app
idToApp(app.id) = app
actorToApp(driver) = app
@ -257,20 +259,26 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
return app
}
def removeApplication(app: ApplicationInfo) {
def finishApplication(app: ApplicationInfo) {
removeApplication(app, ApplicationState.FINISHED)
}
def removeApplication(app: ApplicationInfo, state: ApplicationState.Value) {
if (apps.contains(app)) {
logInfo("Removing app " + app.id)
apps -= app
idToApp -= app.id
actorToApp -= app.driver
addressToWorker -= app.driver.path.address
addressToApp -= app.driver.path.address
completedApps += app // Remember it in our history
waitingApps -= app
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
exec.state = ExecutorState.KILLED
}
app.markFinished(ApplicationState.FINISHED) // TODO: Mark it as FAILED if it failed
app.markFinished(state)
app.driver ! ApplicationRemoved(state.toString)
schedule()
}
}
@ -302,7 +310,7 @@ private[spark] object Master {
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort)
actorSystem.awaitTermination()
}

View file

@ -7,13 +7,13 @@ import spark.Utils
* Command-line parser for the master.
*/
private[spark] class MasterArguments(args: Array[String]) {
var ip = Utils.localHostName()
var host = Utils.localHostName()
var port = 7077
var webUiPort = 8080
// Check for settings in environment variables
if (System.getenv("SPARK_MASTER_IP") != null) {
ip = System.getenv("SPARK_MASTER_IP")
if (System.getenv("SPARK_MASTER_HOST") != null) {
host = System.getenv("SPARK_MASTER_HOST")
}
if (System.getenv("SPARK_MASTER_PORT") != null) {
port = System.getenv("SPARK_MASTER_PORT").toInt
@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) {
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
ip = value
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
host = value
parse(tail)
case ("--host" | "-h") :: value :: tail =>
Utils.checkHost(value, "Please use hostname " + value)
host = value
parse(tail)
case ("--port" | "-p") :: IntParam(value) :: tail =>
@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) {
"Usage: Master [options]\n" +
"\n" +
"Options:\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
" -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
" -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: 7077)\n" +
" --webui-port PORT Port for web UI (default: 8080)")
System.exit(exitCode)

View file

@ -3,7 +3,7 @@ package spark.deploy.master
import akka.actor.{ActorRef, ActorSystem}
import akka.dispatch.Await
import akka.pattern.ask
import akka.util.Timeout
import akka.util.{Duration, Timeout}
import akka.util.duration._
import cc.spray.Directives
import cc.spray.directives._
@ -22,7 +22,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
val RESOURCE_DIR = "spark/deploy/master/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
implicit val timeout = Timeout(10 seconds)
implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
val handler = {
get {

View file

@ -2,6 +2,7 @@ package spark.deploy.master
import akka.actor.ActorRef
import scala.collection.mutable
import spark.Utils
private[spark] class WorkerInfo(
val id: String,
@ -13,6 +14,9 @@ private[spark] class WorkerInfo(
val webUiPort: Int,
val publicAddress: String) {
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
var state: WorkerState.Value = WorkerState.ALIVE
var coresUsed = 0
@ -23,6 +27,11 @@ private[spark] class WorkerInfo(
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
def hostPort: String = {
assert (port > 0)
host + ":" + port
}
def addExecutor(exec: ExecutorInfo) {
executors(exec.fullId) = exec
coresUsed += exec.cores

View file

@ -21,11 +21,13 @@ private[spark] class ExecutorRunner(
val memory: Int,
val worker: ActorRef,
val workerId: String,
val hostname: String,
val hostPort: String,
val sparkHome: File,
val workDir: File)
extends Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
val fullId = appId + "/" + execId
var workerThread: Thread = null
var process: Process = null
@ -68,7 +70,7 @@ private[spark] class ExecutorRunner(
/** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
def substituteVariables(argument: String): String = argument match {
case "{{EXECUTOR_ID}}" => execId.toString
case "{{HOSTNAME}}" => hostname
case "{{HOSTNAME}}" => Utils.parseHostPort(hostPort)._1
case "{{CORES}}" => cores.toString
case other => other
}

View file

@ -16,7 +16,7 @@ import spark.deploy.master.Master
import java.io.File
private[spark] class Worker(
ip: String,
host: String,
port: Int,
webUiPort: Int,
cores: Int,
@ -25,6 +25,9 @@ private[spark] class Worker(
workDirPath: String = null)
extends Actor with Logging {
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
@ -39,7 +42,7 @@ private[spark] class Worker(
val finishedExecutors = new HashMap[String, ExecutorRunner]
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
if (envVar != null) envVar else ip
if (envVar != null) envVar else host
}
var coresUsed = 0
@ -51,10 +54,14 @@ private[spark] class Worker(
def createWorkDir() {
workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
try {
if (!workDir.exists() && !workDir.mkdirs()) {
// This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs()
// So attempting to create and then check if directory was created or not.
workDir.mkdirs()
if ( !workDir.exists() || !workDir.isDirectory) {
logError("Failed to create work directory " + workDir)
System.exit(1)
}
assert (workDir.isDirectory)
} catch {
case e: Exception =>
logError("Failed to create work directory " + workDir, e)
@ -64,7 +71,7 @@ private[spark] class Worker(
override def preStart() {
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
ip, port, cores, Utils.memoryMegabytesToString(memory)))
host, port, cores, Utils.memoryMegabytesToString(memory)))
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
logInfo("Spark home: " + sparkHome)
createWorkDir()
@ -74,20 +81,14 @@ private[spark] class Worker(
def connectToMaster() {
logInfo("Connecting to master " + masterUrl)
try {
master = context.actorFor(Master.toAkkaUrl(masterUrl))
master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
} catch {
case e: Exception =>
logError("Failed to connect to master", e)
System.exit(1)
}
master = context.actorFor(Master.toAkkaUrl(masterUrl))
master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(master) // Doesn't work with remote actors, but useful for testing
}
def startWebUi() {
val webUi = new WorkerWebUI(context.system, self)
val webUi = new WorkerWebUI(context.system, self, workDir)
try {
AkkaUtils.startSprayServer(context.system, "0.0.0.0", webUiPort, webUi.handler)
} catch {
@ -112,7 +113,7 @@ private[spark] class Worker(
case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(
appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@ -147,7 +148,7 @@ private[spark] class Worker(
masterDisconnected()
case RequestWorkerState => {
sender ! WorkerState(ip, port, workerId, executors.values.toList,
sender ! WorkerState(host, port, workerId, executors.values.toList,
finishedExecutors.values.toList, masterUrl, cores, memory,
coresUsed, memoryUsed, masterWebUiUrl)
}
@ -162,7 +163,7 @@ private[spark] class Worker(
}
def generateWorkerId(): String = {
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port)
}
override def postStop() {
@ -173,7 +174,7 @@ private[spark] class Worker(
private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.master, args.workDir)
actorSystem.awaitTermination()
}

View file

@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory
* Command-line parser for the master.
*/
private[spark] class WorkerArguments(args: Array[String]) {
var ip = Utils.localHostName()
var host = Utils.localHostName()
var port = 0
var webUiPort = 8081
var cores = inferDefaultCores()
@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) {
def parse(args: List[String]): Unit = args match {
case ("--ip" | "-i") :: value :: tail =>
ip = value
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
host = value
parse(tail)
case ("--host" | "-h") :: value :: tail =>
Utils.checkHost(value, "Please use hostname " + value)
host = value
parse(tail)
case ("--port" | "-p") :: IntParam(value) :: tail =>
@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) {
" -c CORES, --cores CORES Number of cores to use\n" +
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
" -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" +
" -i IP, --ip IP IP address or DNS name to listen on\n" +
" -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
" -h HOST, --host HOST Hostname to listen on\n" +
" -p PORT, --port PORT Port to listen on (default: random)\n" +
" --webui-port PORT Port for web UI (default: 8081)")
System.exit(exitCode)

View file

@ -3,7 +3,7 @@ package spark.deploy.worker
import akka.actor.{ActorRef, ActorSystem}
import akka.dispatch.Await
import akka.pattern.ask
import akka.util.Timeout
import akka.util.{Duration, Timeout}
import akka.util.duration._
import cc.spray.Directives
import cc.spray.typeconversion.TwirlSupport._
@ -12,16 +12,17 @@ import cc.spray.typeconversion.SprayJsonSupport._
import spark.deploy.{WorkerState, RequestWorkerState}
import spark.deploy.JsonProtocol._
import java.io.File
/**
* Web UI server for the standalone worker.
*/
private[spark]
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef, workDir: File) extends Directives {
val RESOURCE_DIR = "spark/deploy/worker/webui"
val STATIC_RESOURCE_DIR = "spark/deploy/static"
implicit val timeout = Timeout(10 seconds)
implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
val handler = {
get {
@ -43,7 +44,7 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
path("log") {
parameters("appId", "executorId", "logType") { (appId, executorId, logType) =>
respondWithMediaType(cc.spray.http.MediaTypes.`text/plain`) {
getFromFileName("work/" + appId + "/" + executorId + "/" + logType)
getFromFileName(workDir.getPath() + "/" + appId + "/" + executorId + "/" + logType)
}
}
} ~

View file

@ -16,66 +16,68 @@ import java.nio.ByteBuffer
/**
* The Mesos executor for Spark.
*/
private[spark] class Executor extends Logging {
var urlClassLoader : ExecutorURLClassLoader = null
var threadPool: ExecutorService = null
var env: SparkEnv = null
private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging {
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
private val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
initLogging()
def initialize(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) {
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)
// No ip or host:port - just hostname
Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
// must not have port specified.
assert (0 == Utils.parseHostPort(slaveHostname)._2)
// Set spark.* system properties from executor arg
for ((key, value) <- properties) {
System.setProperty(key, value)
}
// Make sure the local hostname we report matches the cluster scheduler's name for this host
Utils.setCustomHostname(slaveHostname)
// Create our ClassLoader and set it on this thread
urlClassLoader = createClassLoader()
Thread.currentThread.setContextClassLoader(urlClassLoader)
// Set spark.* system properties from executor arg
for ((key, value) <- properties) {
System.setProperty(key, value)
}
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
Thread.setDefaultUncaughtExceptionHandler(
new Thread.UncaughtExceptionHandler {
override def uncaughtException(thread: Thread, exception: Throwable) {
try {
logError("Uncaught exception in thread " + thread, exception)
// We may have been called from a shutdown hook. If so, we must not call System.exit().
// (If we do, we will deadlock.)
if (!Utils.inShutdown()) {
if (exception.isInstanceOf[OutOfMemoryError]) {
System.exit(ExecutorExitCode.OOM)
} else {
System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
}
// Create our ClassLoader and set it on this thread
private val urlClassLoader = createClassLoader()
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
Thread.setDefaultUncaughtExceptionHandler(
new Thread.UncaughtExceptionHandler {
override def uncaughtException(thread: Thread, exception: Throwable) {
try {
logError("Uncaught exception in thread " + thread, exception)
// We may have been called from a shutdown hook. If so, we must not call System.exit().
// (If we do, we will deadlock.)
if (!Utils.inShutdown()) {
if (exception.isInstanceOf[OutOfMemoryError]) {
System.exit(ExecutorExitCode.OOM)
} else {
System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
}
} catch {
case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
} catch {
case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
}
)
}
)
// Initialize Spark environment (using system properties read above)
env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
SparkEnv.set(env)
// Initialize Spark environment (using system properties read above)
val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
SparkEnv.set(env)
private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
// Start worker thread pool
threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
}
// Start worker thread pool
val threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
threadPool.execute(new TaskRunner(context, taskId, serializedTask))
@ -85,8 +87,9 @@ private[spark] class Executor extends Logging {
extends Runnable {
override def run() {
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
@ -98,11 +101,25 @@ private[spark] class Executor extends Logging {
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
logInfo("Its generation is " + task.generation)
env.mapOutputTracker.updateGeneration(task.generation)
val taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
task.metrics.foreach{ m =>
m.hostname = Utils.localHostName
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
}
//TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
// we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
// just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values
val result = new TaskResult(value, accumUpdates)
val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedResult = ser.serialize(result)
logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
if (serializedResult.limit >= (akkaFrameSize - 1024)) {
context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure()))
return
}
context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
logInfo("Finished task ID " + taskId)
} catch {
@ -112,7 +129,7 @@ private[spark] class Executor extends Logging {
}
case t: Throwable => {
val reason = ExceptionFailure(t)
val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
@ -137,26 +154,31 @@ private[spark] class Executor extends Logging {
val urls = currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
loader = new URLClassLoader(urls, loader)
new ExecutorURLClassLoader(urls, loader)
}
// If the REPL is in use, add another ClassLoader that will read
// new classes defined by the REPL as the user types code
/**
* If the REPL is in use, add another ClassLoader that will read
* new classes defined by the REPL as the user types code
*/
private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
val classUri = System.getProperty("spark.repl.class.uri")
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
loader = {
try {
val klass = Class.forName("spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
constructor.newInstance(classUri, loader)
} catch {
case _: ClassNotFoundException => loader
}
try {
val klass = Class.forName("spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
return constructor.newInstance(classUri, parent)
} catch {
case _: ClassNotFoundException =>
logError("Could not find spark.repl.ExecutorClassLoader on classpath!")
System.exit(1)
null
}
} else {
return parent
}
return new ExecutorURLClassLoader(Array(), loader)
}
/**

View file

@ -8,11 +8,12 @@ import com.google.protobuf.ByteString
import spark.{Utils, Logging}
import spark.TaskState
private[spark] class MesosExecutorBackend(executor: Executor)
private[spark] class MesosExecutorBackend
extends MesosExecutor
with ExecutorBackend
with Logging {
var executor: Executor = null
var driver: ExecutorDriver = null
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
@ -32,16 +33,19 @@ private[spark] class MesosExecutorBackend(executor: Executor)
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
executor.initialize(
executor = new Executor(
executorInfo.getExecutorId.getValue,
slaveInfo.getHostname,
properties
)
properties)
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
val taskId = taskInfo.getTaskId.getValue.toLong
executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer)
if (executor == null) {
logError("Received launchTask but executor was null")
} else {
executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer)
}
}
override def error(d: ExecutorDriver, message: String) {
@ -68,7 +72,7 @@ private[spark] object MesosExecutorBackend {
def main(args: Array[String]) {
MesosNativeLibrary.load()
// Create a new Executor and start it running
val runner = new MesosExecutorBackend(new Executor)
val runner = new MesosExecutorBackend()
new MesosExecutorDriver(runner).run()
}
}

View file

@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor
import spark.scheduler.cluster.LaunchTask
import spark.scheduler.cluster.RegisterExecutorFailed
import spark.scheduler.cluster.RegisterExecutor
import spark.Utils
import spark.deploy.SparkHadoopUtil
private[spark] class StandaloneExecutorBackend(
executor: Executor,
driverUrl: String,
executorId: String,
hostname: String,
hostPort: String,
cores: Int)
extends Actor
with ExecutorBackend
with Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
var executor: Executor = null
var driver: ActorRef = null
override def preStart() {
logInfo("Connecting to driver: " + driverUrl)
driver = context.actorFor(driverUrl)
driver ! RegisterExecutor(executorId, hostname, cores)
driver ! RegisterExecutor(executorId, hostPort, cores)
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
context.watch(driver) // Doesn't work with remote actors, but useful for testing
}
@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend(
override def receive = {
case RegisteredExecutor(sparkProperties) =>
logInfo("Successfully registered with driver")
executor.initialize(executorId, hostname, sparkProperties)
// Make this host instead of hostPort ?
executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@ -44,7 +49,12 @@ private[spark] class StandaloneExecutorBackend(
case LaunchTask(taskDesc) =>
logInfo("Got assigned task " + taskDesc.taskId)
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
if (executor == null) {
logError("Received launchTask but executor was null")
System.exit(1)
} else {
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
logError("Driver terminated or disconnected! Shutting down.")
@ -58,11 +68,30 @@ private[spark] class StandaloneExecutorBackend(
private[spark] object StandaloneExecutorBackend {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores))
}
// This will be run 'as' the user
def run0(args: Product) {
assert(4 == args.productArity)
runImpl(args.productElement(0).asInstanceOf[String],
args.productElement(1).asInstanceOf[String],
args.productElement(2).asInstanceOf[String],
args.productElement(3).asInstanceOf[Int])
}
private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Debug code
Utils.checkHost(hostname)
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
// set it
val sparkHostPort = hostname + ":" + boundPort
System.setProperty("spark.hostPort", sparkHostPort)
val actor = actorSystem.actorOf(
Props(new StandaloneExecutorBackend(new Executor, driverUrl, executorId, hostname, cores)),
Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
name = "Executor")
actorSystem.awaitTermination()
}

View file

@ -0,0 +1,83 @@
package spark.executor
class TaskMetrics extends Serializable {
/**
* Host's name the task runs on
*/
var hostname: String = _
/**
* Time taken on the executor to deserialize this task
*/
var executorDeserializeTime: Int = _
/**
* Time the executor spends actually running the task (including fetching shuffle data)
*/
var executorRunTime:Int = _
/**
* The number of bytes this task transmitted back to the driver as the TaskResult
*/
var resultSize: Long = _
/**
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here
*/
var shuffleReadMetrics: Option[ShuffleReadMetrics] = None
/**
* If this task writes to shuffle output, metrics on the written shuffle data will be collected here
*/
var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
}
object TaskMetrics {
private[spark] def empty(): TaskMetrics = new TaskMetrics
}
class ShuffleReadMetrics extends Serializable {
/**
* Time when shuffle finishs
*/
var shuffleFinishTime: Long = _
/**
* Total number of blocks fetched in a shuffle (remote or local)
*/
var totalBlocksFetched: Int = _
/**
* Number of remote blocks fetched in a shuffle
*/
var remoteBlocksFetched: Int = _
/**
* Local blocks fetched in a shuffle
*/
var localBlocksFetched: Int = _
/**
* Total time that is spent blocked waiting for shuffle to fetch data
*/
var fetchWaitTime: Long = _
/**
* The total amount of time for all the shuffle fetches. This adds up time from overlapping
* shuffles, so can be longer than task time
*/
var remoteFetchTime: Long = _
/**
* Total number of remote bytes read from a shuffle
*/
var remoteBytesRead: Long = _
}
class ShuffleWriteMetrics extends Serializable {
/**
* Number of bytes written for a shuffle
*/
var shuffleBytesWritten: Long = _
}

View file

@ -0,0 +1,94 @@
package spark.network
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import spark.storage.BlockManager
private[spark]
class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
extends Message(Message.BUFFER_MESSAGE, id_) {
val initialSize = currentSize()
var gotChunkForSendingOnce = false
def size = initialSize
def currentSize() = {
if (buffers == null || buffers.isEmpty) {
0
} else {
buffers.map(_.remaining).reduceLeft(_ + _)
}
}
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
if (maxChunkSize <= 0) {
throw new Exception("Max chunk size is " + maxChunkSize)
}
if (size == 0 && gotChunkForSendingOnce == false) {
val newChunk = new MessageChunk(
new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
while(!buffers.isEmpty) {
val buffer = buffers(0)
if (buffer.remaining == 0) {
BlockManager.dispose(buffer)
buffers -= buffer
} else {
val newBuffer = if (buffer.remaining <= maxChunkSize) {
buffer.duplicate()
} else {
buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
}
None
}
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
// STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
if (buffers.size > 1) {
throw new Exception("Attempting to get chunk from message with multiple data buffers")
}
val buffer = buffers(0)
if (buffer.remaining > 0) {
if (buffer.remaining < chunkSize) {
throw new Exception("Not enough space in data buffer for receiving chunk")
}
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
return Some(newChunk)
}
None
}
def flip() {
buffers.foreach(_.flip)
}
def hasAckId() = (ackId != 0)
def isCompletelyReceived() = !buffers(0).hasRemaining
override def toString = {
if (hasAckId) {
"BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
} else {
"BufferMessage(id = " + id + ", size = " + size + ")"
}
}
}

View file

@ -13,12 +13,13 @@ import java.net._
private[spark]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
val socketRemoteConnectionManagerId: ConnectionManagerId)
extends Logging {
def this(channel_ : SocketChannel, selector_ : Selector) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
))
ConnectionManagerId.fromSocketAddress(
channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
}
channel.configureBlocking(false)
@ -33,16 +34,47 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
val remoteAddress = getRemoteAddress()
// Read channels typically do not register for write and write does not for read
// Now, we do have write registering for read too (temporarily), but this is to detect
// channel close NOT to actually read/consume data on it !
// How does this work if/when we move to SSL ?
// What is the interest to register with selector for when we want this connection to be selected
def registerInterest()
// What is the interest to register with selector for when we want this connection to
// be de-selected
// Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
// it will be SelectionKey.OP_READ (until we fix it properly)
def unregisterInterest()
// On receiving a read event, should we change the interest for this channel or not ?
// Will be true for ReceivingConnection, false for SendingConnection.
def changeInterestForRead(): Boolean
// On receiving a write event, should we change the interest for this channel or not ?
// Will be false for ReceivingConnection, true for SendingConnection.
// Actually, for now, should not get triggered for ReceivingConnection
def changeInterestForWrite(): Boolean
def getRemoteConnectionManagerId(): ConnectionManagerId = {
socketRemoteConnectionManagerId
}
def key() = channel.keyFor(selector)
def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
def read() {
throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
// Returns whether we have to register for further reads or not.
def read(): Boolean = {
throw new UnsupportedOperationException(
"Cannot read on connection of type " + this.getClass.toString)
}
def write() {
throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
// Returns whether we have to register for further writes or not.
def write(): Boolean = {
throw new UnsupportedOperationException(
"Cannot write on connection of type " + this.getClass.toString)
}
def close() {
@ -54,26 +86,32 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
callOnCloseCallback()
}
def onClose(callback: Connection => Unit) {onCloseCallback = callback}
def onClose(callback: Connection => Unit) {
onCloseCallback = callback
}
def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
def onException(callback: (Connection, Exception) => Unit) {
onExceptionCallback = callback
}
def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
def onKeyInterestChange(callback: (Connection, Int) => Unit) {
onKeyInterestChangeCallback = callback
}
def callOnExceptionCallback(e: Exception) {
if (onExceptionCallback != null) {
onExceptionCallback(this, e)
} else {
logError("Error in connection to " + remoteConnectionManagerId +
logError("Error in connection to " + getRemoteConnectionManagerId() +
" and OnExceptionCallback not registered", e)
}
}
def callOnCloseCallback() {
if (onCloseCallback != null) {
onCloseCallback(this)
} else {
logWarning("Connection to " + remoteConnectionManagerId +
logWarning("Connection to " + getRemoteConnectionManagerId() +
" closed and OnExceptionCallback not registered")
}
@ -81,7 +119,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
def changeConnectionKeyInterest(ops: Int) {
if (onKeyInterestChangeCallback != null) {
onKeyInterestChangeCallback(this, ops)
onKeyInterestChangeCallback(this, ops)
} else {
throw new Exception("OnKeyInterestChangeCallback not registered")
}
@ -105,24 +143,25 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
print(" (" + position + ", " + length + ")")
buffer.position(curPosition)
}
}
private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
remoteId_ : ConnectionManagerId)
extends Connection(SocketChannel.open, selector_, remoteId_) {
private[spark]
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
remoteId_ : ConnectionManagerId)
extends Connection(SocketChannel.open, selector_, remoteId_) {
class Outbox(fair: Int = 0) {
val messages = new Queue[Message]()
val defaultChunkSize = 65536 //32768 //16384
val defaultChunkSize = 65536 //32768 //16384
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
messages.synchronized{
messages.synchronized{
/*messages += message*/
messages.enqueue(message)
logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
logDebug("Added [" + message + "] to outbox for sending to " +
"[" + getRemoteConnectionManagerId() + "]")
}
}
@ -147,18 +186,18 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
message.started = true
message.startTime = System.currentTimeMillis
}
return chunk
return chunk
} else {
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
/*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
"] in " + message.timeTaken )
}
}
}
None
}
private def getChunkRR(): Option[MessageChunk] = {
messages.synchronized {
while (!messages.isEmpty) {
@ -170,15 +209,17 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
messages.enqueue(message)
nextMessageToBeUsed = nextMessageToBeUsed + 1
if (!message.started) {
logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
logDebug(
"Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
message.started = true
message.startTime = System.currentTimeMillis
}
logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
return chunk
logTrace(
"Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
return chunk
} else {
message.finishTime = System.currentTimeMillis
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
"] in " + message.timeTaken )
}
}
@ -186,27 +227,40 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
None
}
}
val outbox = new Outbox(1)
private val outbox = new Outbox(1)
val currentBuffers = new ArrayBuffer[ByteBuffer]()
/*channel.socket.setSendBufferSize(256 * 1024)*/
override def getRemoteAddress() = address
override def getRemoteAddress() = address
val DEFAULT_INTEREST = SelectionKey.OP_READ
override def registerInterest() {
// Registering read too - does not really help in most cases, but for some
// it does - so let us keep it for now.
changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
}
override def unregisterInterest() {
changeConnectionKeyInterest(DEFAULT_INTEREST)
}
def send(message: Message) {
outbox.synchronized {
outbox.addMessage(message)
if (channel.isConnected) {
changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
registerInterest()
}
}
}
// MUST be called within the selector loop
def connect() {
try{
channel.connect(address)
channel.register(selector, SelectionKey.OP_CONNECT)
channel.connect(address)
logInfo("Initiating connection to [" + address + "]")
} catch {
case e: Exception => {
@ -216,36 +270,52 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
}
}
def finishConnect() {
def finishConnect(force: Boolean): Boolean = {
try {
channel.finishConnect
changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
// Typically, this should finish immediately since it was triggered by a connect
// selection - though need not necessarily always complete successfully.
val connected = channel.finishConnect
if (!force && !connected) {
logInfo(
"finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
return false
}
// Fallback to previous behavior - assume finishConnect completed
// This will happen only when finishConnect failed for some repeated number of times
// (10 or so)
// Is highly unlikely unless there was an unclean close of socket, etc
registerInterest()
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
return true
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
callOnExceptionCallback(e)
// ignore
return true
}
}
}
override def write() {
try{
while(true) {
override def write(): Boolean = {
try {
while (true) {
if (currentBuffers.size == 0) {
outbox.synchronized {
outbox.getChunk() match {
case Some(chunk) => {
currentBuffers ++= chunk.buffers
currentBuffers ++= chunk.buffers
}
case None => {
changeConnectionKeyInterest(SelectionKey.OP_READ)
return
// changeConnectionKeyInterest(0)
/*key.interestOps(0)*/
return false
}
}
}
}
if (currentBuffers.size > 0) {
val buffer = currentBuffers(0)
val remainingBytes = buffer.remaining
@ -254,69 +324,109 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
currentBuffers -= buffer
}
if (writtenBytes < remainingBytes) {
return
// re-register for write.
return true
}
}
}
} catch {
case e: Exception => {
logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
case e: Exception => {
logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
return false
}
}
// should not happen - to keep scala compiler happy
return true
}
override def read() {
// This is a hack to determine if remote socket was closed or not.
// SendingConnection DOES NOT expect to receive any data - if it does, it is an error
// For a bunch of cases, read will return -1 in case remote socket is closed : hence we
// register for reads to determine that.
override def read(): Boolean = {
// We don't expect the other side to send anything; so, we just read to detect an error or EOF.
try {
val length = channel.read(ByteBuffer.allocate(1))
if (length == -1) { // EOF
close()
} else if (length > 0) {
logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId)
logWarning(
"Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
}
} catch {
case e: Exception =>
logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e)
logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
}
false
}
override def changeInterestForRead(): Boolean = false
override def changeInterestForWrite(): Boolean = true
}
private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
extends Connection(channel_, selector_) {
// Must be created within selector loop - else deadlock
private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
extends Connection(channel_, selector_) {
class Inbox() {
val messages = new HashMap[Int, BufferMessage]()
def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
def createNewMessage: BufferMessage = {
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
newMessage.started = true
newMessage.startTime = System.currentTimeMillis
logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
logDebug(
"Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
messages += ((newMessage.id, newMessage))
newMessage
}
val message = messages.getOrElseUpdate(header.id, createNewMessage)
logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
logTrace(
"Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
message.getChunkForReceiving(header.chunkSize)
}
def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
messages.get(chunk.header.id)
messages.get(chunk.header.id)
}
def removeMessage(message: Message) {
messages -= message.id
}
}
@volatile private var inferredRemoteManagerId: ConnectionManagerId = null
override def getRemoteConnectionManagerId(): ConnectionManagerId = {
val currId = inferredRemoteManagerId
if (currId != null) currId else super.getRemoteConnectionManagerId()
}
// The reciever's remote address is the local socket on remote side : which is NOT
// the connection manager id of the receiver.
// We infer that from the messages we receive on the receiver socket.
private def processConnectionManagerId(header: MessageChunkHeader) {
val currId = inferredRemoteManagerId
if (header.address == null || currId != null) return
val managerId = ConnectionManagerId.fromSocketAddress(header.address)
if (managerId != null) {
inferredRemoteManagerId = managerId
}
}
val inbox = new Inbox()
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
var onReceiveCallback: (Connection , Message) => Unit = null
@ -324,24 +434,29 @@ extends Connection(channel_, selector_) {
channel.register(selector, SelectionKey.OP_READ)
override def read() {
override def read(): Boolean = {
try {
while (true) {
if (currentChunk == null) {
val headerBytesRead = channel.read(headerBuffer)
if (headerBytesRead == -1) {
close()
return
return false
}
if (headerBuffer.remaining > 0) {
return
// re-register for read event ...
return true
}
headerBuffer.flip
if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
throw new Exception(
"Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
}
val header = MessageChunkHeader.create(headerBuffer)
headerBuffer.clear()
processConnectionManagerId(header)
header.typ match {
case Message.BUFFER_MESSAGE => {
if (header.totalSize == 0) {
@ -349,7 +464,8 @@ extends Connection(channel_, selector_) {
onReceiveCallback(this, Message.create(header))
}
currentChunk = null
return
// re-register for read event ...
return true
} else {
currentChunk = inbox.getChunk(header).orNull
}
@ -357,26 +473,28 @@ extends Connection(channel_, selector_) {
case _ => throw new Exception("Message of unknown type received")
}
}
if (currentChunk == null) throw new Exception("No message chunk to receive data")
val bytesRead = channel.read(currentChunk.buffer)
if (bytesRead == 0) {
return
// re-register for read event ...
return true
} else if (bytesRead == -1) {
close()
return
return false
}
/*logDebug("Read " + bytesRead + " bytes for the buffer")*/
if (currentChunk.buffer.remaining == 0) {
/*println("Filled buffer at " + System.currentTimeMillis)*/
val bufferMessage = inbox.getMessageForChunk(currentChunk).get
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
bufferMessage.finishTime = System.currentTimeMillis
logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
logDebug("Finished receiving [" + bufferMessage + "] from " +
"[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
if (onReceiveCallback != null) {
onReceiveCallback(this, bufferMessage)
}
@ -386,13 +504,32 @@ extends Connection(channel_, selector_) {
}
}
} catch {
case e: Exception => {
logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
case e: Exception => {
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
callOnExceptionCallback(e)
close()
return false
}
}
// should not happen - to keep scala compiler happy
return true
}
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
override def changeInterestForRead(): Boolean = true
override def changeInterestForWrite(): Boolean = {
throw new IllegalStateException("Unexpected invocation right now")
}
override def registerInterest() {
// Registering read too - does not really help in most cases, but for some
// it does - so let us keep it for now.
changeConnectionKeyInterest(SelectionKey.OP_READ)
}
override def unregisterInterest() {
changeConnectionKeyInterest(0)
}
}

View file

@ -6,28 +6,19 @@ import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.net._
import java.util.concurrent.Executors
import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
import scala.collection.mutable.HashSet
import scala.collection.mutable.HashMap
import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
import scala.collection.mutable.Queue
import scala.collection.mutable.ArrayBuffer
import akka.dispatch.{Await, Promise, ExecutionContext, Future}
import akka.util.Duration
import akka.util.duration._
private[spark] case class ConnectionManagerId(host: String, port: Int) {
def toSocketAddress() = new InetSocketAddress(host, port)
}
private[spark] object ConnectionManagerId {
def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
}
}
private[spark] class ConnectionManager(port: Int) extends Logging {
class MessageStatus(
@ -41,73 +32,263 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
def markDone() { completionHandler(this) }
}
val selector = SelectorProvider.provider.openSelector()
val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
val serverChannel = ServerSocketChannel.open()
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val messageStatuses = new HashMap[Int, MessageStatus]
val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
val sendMessageRequests = new Queue[(Message, SendingConnection)]
private val selector = SelectorProvider.provider.openSelector()
private val handleMessageExecutor = new ThreadPoolExecutor(
System.getProperty("spark.core.connection.handler.threads.min","20").toInt,
System.getProperty("spark.core.connection.handler.threads.max","60").toInt,
System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
private val handleReadWriteExecutor = new ThreadPoolExecutor(
System.getProperty("spark.core.connection.io.threads.min","4").toInt,
System.getProperty("spark.core.connection.io.threads.max","32").toInt,
System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
// Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
private val handleConnectExecutor = new ThreadPoolExecutor(
System.getProperty("spark.core.connection.connect.threads.min","1").toInt,
System.getProperty("spark.core.connection.connect.threads.max","8").toInt,
System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable]())
private val serverChannel = ServerSocketChannel.open()
private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
private val messageStatuses = new HashMap[Int, MessageStatus]
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
serverChannel.configureBlocking(false)
serverChannel.socket.setReuseAddress(true)
serverChannel.socket.setReceiveBufferSize(256 * 1024)
serverChannel.socket.setReceiveBufferSize(256 * 1024)
serverChannel.socket.bind(new InetSocketAddress(port))
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
val selectorThread = new Thread("connection-manager-thread") {
private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
selectorThread.setDaemon(true)
selectorThread.start()
private def run() {
private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
private def triggerWrite(key: SelectionKey) {
val conn = connectionsByKey.getOrElse(key, null)
if (conn == null) return
writeRunnableStarted.synchronized {
// So that we do not trigger more write events while processing this one.
// The write method will re-register when done.
if (conn.changeInterestForWrite()) conn.unregisterInterest()
if (writeRunnableStarted.contains(key)) {
// key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
return
}
writeRunnableStarted += key
}
handleReadWriteExecutor.execute(new Runnable {
override def run() {
var register: Boolean = false
try {
register = conn.write()
} finally {
writeRunnableStarted.synchronized {
writeRunnableStarted -= key
if (register && conn.changeInterestForWrite()) {
conn.registerInterest()
}
}
}
}
} )
}
private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
private def triggerRead(key: SelectionKey) {
val conn = connectionsByKey.getOrElse(key, null)
if (conn == null) return
readRunnableStarted.synchronized {
// So that we do not trigger more read events while processing this one.
// The read method will re-register when done.
if (conn.changeInterestForRead())conn.unregisterInterest()
if (readRunnableStarted.contains(key)) {
return
}
readRunnableStarted += key
}
handleReadWriteExecutor.execute(new Runnable {
override def run() {
var register: Boolean = false
try {
register = conn.read()
} finally {
readRunnableStarted.synchronized {
readRunnableStarted -= key
if (register && conn.changeInterestForRead()) {
conn.registerInterest()
}
}
}
}
} )
}
private def triggerConnect(key: SelectionKey) {
val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
if (conn == null) return
// prevent other events from being triggered
// Since we are still trying to connect, we do not need to do the additional steps in triggerWrite
conn.changeConnectionKeyInterest(0)
handleConnectExecutor.execute(new Runnable {
override def run() {
var tries: Int = 10
while (tries >= 0) {
if (conn.finishConnect(false)) return
// Sleep ?
Thread.sleep(1)
tries -= 1
}
// fallback to previous behavior : we should not really come here since this method was
// triggered since channel became connectable : but at times, the first finishConnect need not
// succeed : hence the loop to retry a few 'times'.
conn.finishConnect(true)
}
} )
}
// MUST be called within selector loop - else deadlock.
private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
try {
key.interestOps(0)
} catch {
// ignore exceptions
case e: Exception => logDebug("Ignoring exception", e)
}
val conn = connectionsByKey.getOrElse(key, null)
if (conn == null) return
// Pushing to connect threadpool
handleConnectExecutor.execute(new Runnable {
override def run() {
try {
conn.callOnExceptionCallback(e)
} catch {
// ignore exceptions
case e: Exception => logDebug("Ignoring exception", e)
}
try {
conn.close()
} catch {
// ignore exceptions
case e: Exception => logDebug("Ignoring exception", e)
}
}
})
}
def run() {
try {
while(!selectorThread.isInterrupted) {
for ((connectionManagerId, sendingConnection) <- connectionRequests) {
sendingConnection.connect()
addConnection(sendingConnection)
connectionRequests -= connectionManagerId
}
sendMessageRequests.synchronized {
while (!sendMessageRequests.isEmpty) {
val (message, connection) = sendMessageRequests.dequeue
connection.send(message)
}
while (! registerRequests.isEmpty) {
val conn: SendingConnection = registerRequests.dequeue
addListeners(conn)
conn.connect()
addConnection(conn)
}
while (!keyInterestChangeRequests.isEmpty) {
while(!keyInterestChangeRequests.isEmpty) {
val (key, ops) = keyInterestChangeRequests.dequeue
val connection = connectionsByKey(key)
val lastOps = key.interestOps()
key.interestOps(ops)
def intToOpStr(op: Int): String = {
val opStrs = ArrayBuffer[String]()
if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
try {
if (key.isValid) {
val connection = connectionsByKey.getOrElse(key, null)
if (connection != null) {
val lastOps = key.interestOps()
key.interestOps(ops)
// hot loop - prevent materialization of string if trace not enabled.
if (isTraceEnabled()) {
def intToOpStr(op: Int): String = {
val opStrs = ArrayBuffer[String]()
if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
}
logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
"] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
}
}
} else {
logInfo("Key not valid ? " + key)
throw new CancelledKeyException()
}
} catch {
case e: CancelledKeyException => {
logInfo("key already cancelled ? " + key, e)
triggerForceCloseByException(key, e)
}
case e: Exception => {
logError("Exception processing key " + key, e)
triggerForceCloseByException(key, e)
}
}
logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
"] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
}
val selectedKeysCount = selector.select()
val selectedKeysCount =
try {
selector.select()
} catch {
// Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently.
case e: CancelledKeyException => {
// Some keys within the selectors list are invalid/closed. clear them.
val allKeys = selector.keys().iterator()
while (allKeys.hasNext()) {
val key = allKeys.next()
try {
if (! key.isValid) {
logInfo("Key not valid ? " + key)
throw new CancelledKeyException()
}
} catch {
case e: CancelledKeyException => {
logInfo("key already cancelled ? " + key, e)
triggerForceCloseByException(key, e)
}
case e: Exception => {
logError("Exception processing key " + key, e)
triggerForceCloseByException(key, e)
}
}
}
}
0
}
if (selectedKeysCount == 0) {
logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
}
@ -115,20 +296,40 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logInfo("Selector thread was interrupted!")
return
}
val selectedKeys = selector.selectedKeys().iterator()
while (selectedKeys.hasNext()) {
val key = selectedKeys.next
selectedKeys.remove()
if (key.isValid) {
if (key.isAcceptable) {
acceptConnection(key)
} else if (key.isConnectable) {
connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
} else if (key.isReadable) {
connectionsByKey(key).read()
} else if (key.isWritable) {
connectionsByKey(key).write()
if (0 != selectedKeysCount) {
val selectedKeys = selector.selectedKeys().iterator()
while (selectedKeys.hasNext()) {
val key = selectedKeys.next
selectedKeys.remove()
try {
if (key.isValid) {
if (key.isAcceptable) {
acceptConnection(key)
} else
if (key.isConnectable) {
triggerConnect(key)
} else
if (key.isReadable) {
triggerRead(key)
} else
if (key.isWritable) {
triggerWrite(key)
}
} else {
logInfo("Key not valid ? " + key)
throw new CancelledKeyException()
}
} catch {
// weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException.
case e: CancelledKeyException => {
logInfo("key already cancelled ? " + key, e)
triggerForceCloseByException(key, e)
}
case e: Exception => {
logError("Exception processing key " + key, e)
triggerForceCloseByException(key, e)
}
}
}
}
@ -137,97 +338,119 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
case e: Exception => logError("Error in select loop", e)
}
}
private def acceptConnection(key: SelectionKey) {
def acceptConnection(key: SelectionKey) {
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
val newChannel = serverChannel.accept()
val newConnection = new ReceivingConnection(newChannel, selector)
newConnection.onReceive(receiveMessage)
newConnection.onClose(removeConnection)
addConnection(newConnection)
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
var newChannel = serverChannel.accept()
// accept them all in a tight loop. non blocking accept with no processing, should be fine
while (newChannel != null) {
try {
val newConnection = new ReceivingConnection(newChannel, selector)
newConnection.onReceive(receiveMessage)
addListeners(newConnection)
addConnection(newConnection)
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
} catch {
// might happen in case of issues with registering with selector
case e: Exception => logError("Error in accept loop", e)
}
newChannel = serverChannel.accept()
}
}
private def addConnection(connection: Connection) {
connectionsByKey += ((connection.key, connection))
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
}
private def addListeners(connection: Connection) {
connection.onKeyInterestChange(changeConnectionKeyInterest)
connection.onException(handleConnectionError)
connection.onClose(removeConnection)
}
private def removeConnection(connection: Connection) {
def addConnection(connection: Connection) {
connectionsByKey += ((connection.key, connection))
}
def removeConnection(connection: Connection) {
connectionsByKey -= connection.key
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
connectionsById -= sendingConnectionManagerId
messageStatuses.synchronized {
messageStatuses
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.markDone()
}
try {
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
connectionsById -= sendingConnectionManagerId
messageStatuses.synchronized {
messageStatuses
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.markDone()
}
})
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
} else if (connection.isInstanceOf[ReceivingConnection]) {
val receivingConnection = connection.asInstanceOf[ReceivingConnection]
val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
} else if (connection.isInstanceOf[ReceivingConnection]) {
val receivingConnection = connection.asInstanceOf[ReceivingConnection]
val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
if (sendingConnectionManagerId == null) {
logError("Corresponding SendingConnectionManagerId not found")
return
}
logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
val sendingConnection = connectionsById(sendingConnectionManagerId)
sendingConnection.close()
connectionsById -= sendingConnectionManagerId
messageStatuses.synchronized {
for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
s.synchronized {
s.attempted = true
s.acked = false
s.markDone()
}
val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
if (! sendingConnectionOpt.isDefined) {
logError("Corresponding SendingConnectionManagerId not found")
return
}
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
val sendingConnection = sendingConnectionOpt.get
connectionsById -= remoteConnectionManagerId
sendingConnection.close()
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
assert (sendingConnectionManagerId == remoteConnectionManagerId)
messageStatuses.synchronized {
for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
s.synchronized {
s.attempted = true
s.acked = false
s.markDone()
}
}
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
}
} finally {
// So that the selection keys can be removed.
wakeupSelector()
}
}
private def handleConnectionError(connection: Connection, e: Exception) {
logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
def handleConnectionError(connection: Connection, e: Exception) {
logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId())
removeConnection(connection)
}
private def changeConnectionKeyInterest(connection: Connection, ops: Int) {
keyInterestChangeRequests += ((connection.key, ops))
def changeConnectionKeyInterest(connection: Connection, ops: Int) {
keyInterestChangeRequests += ((connection.key, ops))
// so that registerations happen !
wakeupSelector()
}
private def receiveMessage(connection: Connection, message: Message) {
def receiveMessage(connection: Connection, message: Message) {
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
val runnable = new Runnable() {
val creationTime = System.currentTimeMillis
def run() {
@ -247,11 +470,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
if (bufferMessage.hasAckId) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
case Some(status) => {
messageStatuses -= bufferMessage.ackId
case Some(status) => {
messageStatuses -= bufferMessage.ackId
status
}
case None => {
case None => {
throw new Exception("Could not find reference for received ack message " + message.id)
null
}
@ -271,7 +494,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logDebug("Not calling back as callback is null")
None
}
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
@ -281,7 +504,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
}
}
sendMessage(connectionManagerId, ackMessage.getOrElse {
sendMessage(connectionManagerId, ackMessage.getOrElse {
Message.createBufferMessage(bufferMessage.id)
})
}
@ -293,18 +516,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
def startNewConnection(): SendingConnection = {
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
new SendingConnection(inetSocketAddress, selector, connectionManagerId))
newConnection
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
registerRequests.enqueue(newConnection)
newConnection
}
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
val connection = connectionsById.getOrElse(lookupKey, startNewConnection())
// I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ...
// If we do re-add it, we should consistently use it everywhere I guess ?
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
message.senderAddress = id.toSocketAddress()
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
/*connection.send(message)*/
sendMessageRequests.synchronized {
sendMessageRequests += ((message, connection))
}
connection.send(message)
wakeupSelector()
}
private def wakeupSelector() {
selector.wakeup()
}
@ -337,6 +564,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
logWarning("All connections not cleaned up")
}
handleMessageExecutor.shutdown()
handleReadWriteExecutor.shutdown()
handleConnectExecutor.shutdown()
logInfo("ConnectionManager stopped")
}
}
@ -346,17 +575,17 @@ private[spark] object ConnectionManager {
def main(args: Array[String]) {
val manager = new ConnectionManager(9999)
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
println("Received [" + msg + "] from [" + id + "]")
None
})
/*testSequentialSending(manager)*/
/*System.gc()*/
/*testParallelSending(manager)*/
/*System.gc()*/
/*testParallelDecreasingSending(manager)*/
/*System.gc()*/
@ -368,9 +597,9 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Sequential Sending")
println("--------------------------")
val size = 10 * 1024 * 1024
val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
buffer.flip
@ -386,7 +615,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Parallel Sending")
println("--------------------------")
val size = 10 * 1024 * 1024
val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
@ -401,12 +630,12 @@ private[spark] object ConnectionManager {
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
val mb = size * count / 1024.0 / 1024.0
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
println("Started at " + startTime + ", finished at " + finishTime)
println("Started at " + startTime + ", finished at " + finishTime)
println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
println("--------------------------")
println()
@ -416,7 +645,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Parallel Decreasing Sending")
println("--------------------------")
val size = 10 * 1024 * 1024
val size = 10 * 1024 * 1024
val count = 10
val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
buffers.foreach(_.flip)
@ -431,7 +660,7 @@ private[spark] object ConnectionManager {
if (!g.isDefined) println("Failed")
})
val finishTime = System.currentTimeMillis
val ms = finishTime - startTime
val tput = mb * 1000.0 / ms
println("--------------------------")
@ -445,7 +674,7 @@ private[spark] object ConnectionManager {
println("--------------------------")
println("Continuous Sending")
println("--------------------------")
val size = 10 * 1024 * 1024
val size = 10 * 1024 * 1024
val count = 10
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))

View file

@ -0,0 +1,21 @@
package spark.network
import java.net.InetSocketAddress
import spark.Utils
private[spark] case class ConnectionManagerId(host: String, port: Int) {
// DEBUG code
Utils.checkHost(host)
assert (port > 0)
def toSocketAddress() = new InetSocketAddress(host, port)
}
private[spark] object ConnectionManagerId {
def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
}
}

View file

@ -1,55 +1,10 @@
package spark.network
import spark._
import java.nio.ByteBuffer
import java.net.InetSocketAddress
import scala.collection.mutable.ArrayBuffer
import java.nio.ByteBuffer
import java.net.InetAddress
import java.net.InetSocketAddress
import storage.BlockManager
private[spark] class MessageChunkHeader(
val typ: Long,
val id: Int,
val totalSize: Int,
val chunkSize: Int,
val other: Int,
val address: InetSocketAddress) {
lazy val buffer = {
val ip = address.getAddress.getAddress()
val port = address.getPort()
ByteBuffer.
allocate(MessageChunkHeader.HEADER_SIZE).
putLong(typ).
putInt(id).
putInt(totalSize).
putInt(chunkSize).
putInt(other).
putInt(ip.size).
put(ip).
putInt(port).
position(MessageChunkHeader.HEADER_SIZE).
flip.asInstanceOf[ByteBuffer]
}
override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
" and sizes " + totalSize + " / " + chunkSize + " bytes"
}
private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
val size = if (buffer == null) 0 else buffer.remaining
lazy val buffers = {
val ab = new ArrayBuffer[ByteBuffer]()
ab += header.buffer
if (buffer != null) {
ab += buffer
}
ab
}
override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
}
private[spark] abstract class Message(val typ: Long, val id: Int) {
var senderAddress: InetSocketAddress = null
@ -58,120 +13,16 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
var finishTime = -1L
def size: Int
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
def timeTaken(): String = (finishTime - startTime).toString + " ms"
override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
}
private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
extends Message(Message.BUFFER_MESSAGE, id_) {
val initialSize = currentSize()
var gotChunkForSendingOnce = false
def size = initialSize
def currentSize() = {
if (buffers == null || buffers.isEmpty) {
0
} else {
buffers.map(_.remaining).reduceLeft(_ + _)
}
}
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
if (maxChunkSize <= 0) {
throw new Exception("Max chunk size is " + maxChunkSize)
}
if (size == 0 && gotChunkForSendingOnce == false) {
val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
gotChunkForSendingOnce = true
return Some(newChunk)
}
while(!buffers.isEmpty) {
val buffer = buffers(0)
if (buffer.remaining == 0) {
BlockManager.dispose(buffer)
buffers -= buffer
} else {
val newBuffer = if (buffer.remaining <= maxChunkSize) {
buffer.duplicate()
} else {
buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
}
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
gotChunkForSendingOnce = true
return Some(newChunk)
}
}
None
}
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
// STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
if (buffers.size > 1) {
throw new Exception("Attempting to get chunk from message with multiple data buffers")
}
val buffer = buffers(0)
if (buffer.remaining > 0) {
if (buffer.remaining < chunkSize) {
throw new Exception("Not enough space in data buffer for receiving chunk")
}
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
buffer.position(buffer.position + newBuffer.remaining)
val newChunk = new MessageChunk(new MessageChunkHeader(
typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
return Some(newChunk)
}
None
}
def flip() {
buffers.foreach(_.flip)
}
def hasAckId() = (ackId != 0)
def isCompletelyReceived() = !buffers(0).hasRemaining
override def toString = {
if (hasAckId) {
"BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
} else {
"BufferMessage(id = " + id + ", size = " + size + ")"
}
}
}
private[spark] object MessageChunkHeader {
val HEADER_SIZE = 40
def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
throw new IllegalArgumentException("Cannot convert buffer data to Message")
}
val typ = buffer.getLong()
val id = buffer.getInt()
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
}
}
private[spark] object Message {
val BUFFER_MESSAGE = 1111111111L
@ -180,14 +31,16 @@ private[spark] object Message {
def getNewId() = synchronized {
lastId += 1
if (lastId == 0) lastId += 1
if (lastId == 0) {
lastId += 1
}
lastId
}
def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
if (dataBuffers == null) {
return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
}
}
if (dataBuffers.exists(_ == null)) {
throw new Exception("Attempting to create buffer message with null buffer")
}
@ -196,7 +49,7 @@ private[spark] object Message {
def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
createBufferMessage(dataBuffers, 0)
def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
if (dataBuffer == null) {
return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
@ -204,15 +57,18 @@ private[spark] object Message {
return createBufferMessage(Array(dataBuffer), ackId)
}
}
def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
createBufferMessage(dataBuffer, 0)
def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
def createBufferMessage(ackId: Int): BufferMessage = {
createBufferMessage(new Array[ByteBuffer](0), ackId)
}
def create(header: MessageChunkHeader): Message = {
val newMessage: Message = header.typ match {
case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
case BUFFER_MESSAGE => new BufferMessage(header.id,
ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
}
newMessage.senderAddress = header.address
newMessage

View file

@ -0,0 +1,25 @@
package spark.network
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
private[network]
class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
val size = if (buffer == null) 0 else buffer.remaining
lazy val buffers = {
val ab = new ArrayBuffer[ByteBuffer]()
ab += header.buffer
if (buffer != null) {
ab += buffer
}
ab
}
override def toString = {
"" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
}
}

View file

@ -0,0 +1,58 @@
package spark.network
import java.net.InetAddress
import java.net.InetSocketAddress
import java.nio.ByteBuffer
private[spark] class MessageChunkHeader(
val typ: Long,
val id: Int,
val totalSize: Int,
val chunkSize: Int,
val other: Int,
val address: InetSocketAddress) {
lazy val buffer = {
// No need to change this, at 'use' time, we do a reverse lookup of the hostname.
// Refer to network.Connection
val ip = address.getAddress.getAddress()
val port = address.getPort()
ByteBuffer.
allocate(MessageChunkHeader.HEADER_SIZE).
putLong(typ).
putInt(id).
putInt(totalSize).
putInt(chunkSize).
putInt(other).
putInt(ip.size).
put(ip).
putInt(port).
position(MessageChunkHeader.HEADER_SIZE).
flip.asInstanceOf[ByteBuffer]
}
override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
" and sizes " + totalSize + " / " + chunkSize + " bytes"
}
private[spark] object MessageChunkHeader {
val HEADER_SIZE = 40
def create(buffer: ByteBuffer): MessageChunkHeader = {
if (buffer.remaining != HEADER_SIZE) {
throw new IllegalArgumentException("Cannot convert buffer data to Message")
}
val typ = buffer.getLong()
val id = buffer.getInt()
val totalSize = buffer.getInt()
val chunkSize = buffer.getInt()
val other = buffer.getInt()
val ipSize = buffer.getInt()
val ipBytes = new Array[Byte](ipSize)
buffer.get(ipBytes)
val ip = InetAddress.getByAddress(ipBytes)
val port = buffer.getInt()
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
}
}

View file

@ -0,0 +1,57 @@
package spark.network.netty
import io.netty.buffer._
import spark.Logging
private[spark] class FileHeader (
val fileLen: Int,
val blockId: String) extends Logging {
lazy val buffer = {
val buf = Unpooled.buffer()
buf.capacity(FileHeader.HEADER_SIZE)
buf.writeInt(fileLen)
buf.writeInt(blockId.length)
blockId.foreach((x: Char) => buf.writeByte(x))
//padding the rest of header
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
} else {
throw new Exception("too long header " + buf.readableBytes)
logInfo("too long header")
}
buf
}
}
private[spark] object FileHeader {
val HEADER_SIZE = 40
def getFileLenOffset = 0
def getFileLenSize = Integer.SIZE/8
def create(buf: ByteBuf): FileHeader = {
val length = buf.readInt
val idLength = buf.readInt
val idBuilder = new StringBuilder(idLength)
for (i <- 1 to idLength) {
idBuilder += buf.readByte().asInstanceOf[Char]
}
val blockId = idBuilder.toString()
new FileHeader(length, blockId)
}
def main (args:Array[String]){
val header = new FileHeader(25,"block_0");
val buf = header.buffer;
val newheader = FileHeader.create(buf);
System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
}
}

View file

@ -0,0 +1,101 @@
package spark.network.netty
import java.util.concurrent.Executors
import io.netty.buffer.ByteBuf
import io.netty.channel.ChannelHandlerContext
import io.netty.util.CharsetUtil
import spark.Logging
import spark.network.ConnectionManagerId
import scala.collection.JavaConverters._
private[spark] class ShuffleCopier extends Logging {
def getBlock(host: String, port: Int, blockId: String,
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
val fc = new FileClient(handler, connectTimeout)
try {
fc.init()
fc.connect(host, port)
fc.sendRequest(blockId)
fc.waitForClose()
fc.close()
} catch {
// Handle any socket-related exceptions in FileClient
case e: Exception => {
logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
handler.handleError(blockId)
}
}
}
def getBlock(cmId: ConnectionManagerId, blockId: String,
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
}
def getBlocks(cmId: ConnectionManagerId,
blocks: Seq[(String, Long)],
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
for ((blockId, size) <- blocks) {
getBlock(cmId, blockId, resultCollectCallback)
}
}
}
private[spark] object ShuffleCopier extends Logging {
private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
extends FileClientHandler with Logging {
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
override def handleError(blockId: String) {
if (!isComplete) {
resultCollectCallBack(blockId, -1, null)
}
}
}
def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
if (size != -1) {
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
}
}
def main(args: Array[String]) {
if (args.length < 3) {
System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>")
System.exit(1)
}
val host = args(0)
val port = args(1).toInt
val file = args(2)
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
val tasks = (for (i <- Range(0, threads)) yield {
Executors.callable(new Runnable() {
def run() {
val copier = new ShuffleCopier()
copier.getBlock(host, port, file, echoResultCollectCallBack)
}
})
}).asJava
copiers.invokeAll(tasks)
copiers.shutdown
System.exit(0)
}
}

View file

@ -0,0 +1,53 @@
package spark.network.netty
import java.io.File
import spark.Logging
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
val server = new FileServer(pResolver, portIn)
server.start()
def stop() {
server.stop()
}
def port: Int = server.getPort()
}
/**
* An application for testing the shuffle sender as a standalone program.
*/
private[spark] object ShuffleSender {
def main(args: Array[String]) {
if (args.length < 3) {
System.err.println(
"Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>")
System.exit(1)
}
val port = args(0).toInt
val subDirsPerLocalDir = args(1).toInt
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
override def getAbsolutePath(blockId: String): String = {
if (!blockId.startsWith("shuffle_")) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = math.abs(blockId.hashCode)
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId)
return file.getAbsolutePath
}
}
val sender = new ShuffleSender(port, pResovler)
}
}

View file

@ -1,7 +1,7 @@
package spark.rdd
import scala.collection.mutable.HashMap
import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext}
import spark.storage.BlockManager
private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
val index = idx
@ -11,12 +11,7 @@ private[spark]
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = {
val blockManager = SparkEnv.get.blockManager
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
val locations = blockManager.getLocations(blockIds)
HashMap(blockIds.zip(locations):_*)
}
@transient lazy val locations_ = BlockManager.blockIdsToExecutorLocations(blockIds, SparkEnv.get)
override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]

View file

@ -8,6 +8,7 @@ import org.apache.hadoop.util.ReflectionUtils
import org.apache.hadoop.fs.Path
import java.io.{File, IOException, EOFException}
import java.text.NumberFormat
import spark.deploy.SparkHadoopUtil
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
@ -21,13 +22,20 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
override def getPartitions: Array[Partition] = {
val dirContents = fs.listStatus(new Path(checkpointPath))
val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
val numPartitions = partitionFiles.size
if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) {
throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
}
val cpath = new Path(checkpointPath)
val numPartitions =
// listStatus can throw exception if path does not exist.
if (fs.exists(cpath)) {
val dirContents = fs.listStatus(cpath)
val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
val numPart = partitionFiles.size
if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) {
throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
}
numPart
} else 0
Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
}
@ -35,7 +43,7 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri
checkpointData.get.cpFile = Some(checkpointPath)
override def getPreferredLocations(split: Partition): Seq[String] = {
val status = fs.getFileStatus(new Path(checkpointPath))
val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)))
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
}
@ -58,7 +66,7 @@ private[spark] object CheckpointRDD extends Logging {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(new Configuration())
val fs = outputDir.getFileSystem(SparkHadoopUtil.newConfiguration())
val finalOutputName = splitIdToFile(ctx.splitId)
val finalOutputPath = new Path(outputDir, finalOutputName)
@ -83,6 +91,7 @@ private[spark] object CheckpointRDD extends Logging {
if (!fs.rename(tempOutputPath, finalOutputPath)) {
if (!fs.exists(finalOutputPath)) {
logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
+ ctx.attemptId + " and final output path does not exist")
@ -95,7 +104,7 @@ private[spark] object CheckpointRDD extends Logging {
}
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
val fs = path.getFileSystem(new Configuration())
val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(path, bufferSize)
val serializer = SparkEnv.get.serializer.newInstance()
@ -117,11 +126,11 @@ private[spark] object CheckpointRDD extends Logging {
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
val fs = path.getFileSystem(new Configuration())
val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
fs.delete(path)
fs.delete(path, true)
}
}

View file

@ -2,10 +2,11 @@ package spark.rdd
import java.io.{ObjectOutputStream, IOException}
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Partition, TaskContext}
import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
@ -28,7 +29,8 @@ private[spark] case class NarrowCoGroupSplitDep(
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
private[spark]
class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Partition with Serializable {
class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
override val index: Int = idx
override def hashCode(): Int = idx
}
@ -40,7 +42,24 @@ private[spark] class CoGroupAggregator
{ (b1, b2) => b1 ++ b2 })
with Serializable
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
/**
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
*
* @param rdds parent RDDs.
* @param part partitioner used to partition the shuffle output.
* @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag
* is on, Spark does an extra pass over the data on the map side to merge
* all values belonging to the same key together. This can reduce the amount
* of data shuffled if and only if the number of distinct keys is very small,
* and the ratio of key size to value size is also very small.
*/
class CoGroupedRDD[K](
@transient var rdds: Seq[RDD[(K, _)]],
part: Partitioner,
val mapSideCombine: Boolean = false,
val serializerClass: String = null)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
private val aggr = new CoGroupAggregator
@ -52,8 +71,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
new OneToOneDependency(rdd)
} else {
logInfo("Adding shuffle dependency with " + rdd)
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
if (mapSideCombine) {
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part, serializerClass)
} else {
new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part, serializerClass)
}
}
}
}
@ -70,7 +93,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
}.toList)
}.toArray)
}
array
}
@ -82,6 +105,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
val seq = map.get(k)
if (seq != null) {
@ -92,6 +116,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
seq
}
}
val ser = SparkEnv.get.serializerManager.get(serializerClass)
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
// Read them from the parent
@ -102,8 +128,16 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) {
getSeq(k)(depNum) ++= vs
if (mapSideCombine) {
// With map side combine on, for each key, the shuffle fetcher returns a list of values.
fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, values) => getSeq(key)(depNum) ++= values
}
} else {
// With map side combine off, for each key the shuffle fetcher returns a single value.
fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach {
case (key, value) => getSeq(key)(depNum) += value
}
}
}
}

View file

@ -37,8 +37,8 @@ class CoalescedRDD[T: ClassManifest](
prevSplits.map(_.index).map{idx => new CoalescedRDDPartition(idx, prev, Array(idx)) }
} else {
(0 until maxPartitions).map { i =>
val rangeStart = (i * prevSplits.length) / maxPartitions
val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
val rangeStart = ((i.toLong * prevSplits.length) / maxPartitions).toInt
val rangeEnd = (((i.toLong + 1) * prevSplits.length) / maxPartitions).toInt
new CoalescedRDDPartition(i, prev, (rangeStart until rangeEnd).toArray)
}.toArray
}

Some files were not shown because too many files have changed in this diff Show more