Merge branch 'master' into streaming
Conflicts: .gitignore
This commit is contained in:
commit
c89af0a7f9
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -37,3 +37,4 @@ dependency-reduced-pom.xml
|
|||
.ensime
|
||||
.ensime_lucene
|
||||
checkpoint
|
||||
derby.log
|
||||
|
|
|
@ -12,11 +12,16 @@ This README file only contains basic setup instructions.
|
|||
|
||||
## Building
|
||||
|
||||
Spark requires Scala 2.9.2. The project is built using Simple Build Tool (SBT),
|
||||
which is packaged with it. To build Spark and its example programs, run:
|
||||
Spark requires Scala 2.9.2 (Scala 2.10 is not yet supported). The project is
|
||||
built using Simple Build Tool (SBT), which is packaged with it. To build
|
||||
Spark and its example programs, run:
|
||||
|
||||
sbt/sbt package
|
||||
|
||||
Spark also supports building using Maven. If you would like to build using Maven,
|
||||
see the [instructions for building Spark with Maven](http://spark-project.org/docs/latest/building-with-maven.html)
|
||||
in the spark documentation..
|
||||
|
||||
To run Spark, you will need to have Scala's bin directory in your `PATH`, or
|
||||
you will need to set the `SCALA_HOME` environment variable to point to where
|
||||
you've installed Scala. Scala must be accessible through one of these
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.spark-project</groupId>
|
||||
<artifactId>parent</artifactId>
|
||||
<version>0.7.0-SNAPSHOT</version>
|
||||
<artifactId>spark-parent</artifactId>
|
||||
<version>0.8.0-SNAPSHOT</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
@ -102,5 +102,42 @@
|
|||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>hadoop2-yarn</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.spark-project</groupId>
|
||||
<artifactId>spark-core</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<classifier>hadoop2-yarn</classifier>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-client</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-yarn-api</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-yarn-common</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<configuration>
|
||||
<classifier>hadoop2-yarn</classifier>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -4,8 +4,37 @@ import spark._
|
|||
import spark.SparkContext._
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import storage.StorageLevel
|
||||
|
||||
object Bagel extends Logging {
|
||||
val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK
|
||||
|
||||
/**
|
||||
* Runs a Bagel program.
|
||||
* @param sc [[spark.SparkContext]] to use for the program.
|
||||
* @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the Key will be
|
||||
* the vertex id.
|
||||
* @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often this will be an
|
||||
* empty array, i.e. sc.parallelize(Array[K, Message]()).
|
||||
* @param combiner [[spark.bagel.Combiner]] combines multiple individual messages to a given vertex into one
|
||||
* message before sending (which often involves network I/O).
|
||||
* @param aggregator [[spark.bagel.Aggregator]] performs a reduce across all vertices after each superstep,
|
||||
* and provides the result to each vertex in the next superstep.
|
||||
* @param partitioner [[spark.Partitioner]] partitions values by key
|
||||
* @param numPartitions number of partitions across which to split the graph.
|
||||
* Default is the default parallelism of the SparkContext
|
||||
* @param storageLevel [[spark.storage.StorageLevel]] to use for caching of intermediate RDDs in each superstep.
|
||||
* Defaults to caching in memory.
|
||||
* @param compute function that takes a Vertex, optional set of (possibly combined) messages to the Vertex,
|
||||
* optional Aggregator and the current superstep,
|
||||
* and returns a set of (Vertex, outgoing Messages) pairs
|
||||
* @tparam K key
|
||||
* @tparam V vertex type
|
||||
* @tparam M message type
|
||||
* @tparam C combiner
|
||||
* @tparam A aggregator
|
||||
* @return an RDD of (K, V) pairs representing the graph after completion of the program
|
||||
*/
|
||||
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
|
||||
C: Manifest, A: Manifest](
|
||||
sc: SparkContext,
|
||||
|
@ -14,7 +43,8 @@ object Bagel extends Logging {
|
|||
combiner: Combiner[M, C],
|
||||
aggregator: Option[Aggregator[V, A]],
|
||||
partitioner: Partitioner,
|
||||
numPartitions: Int
|
||||
numPartitions: Int,
|
||||
storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL
|
||||
)(
|
||||
compute: (V, Option[C], Option[A], Int) => (V, Array[M])
|
||||
): RDD[(K, V)] = {
|
||||
|
@ -32,8 +62,9 @@ object Bagel extends Logging {
|
|||
val combinedMsgs = msgs.combineByKey(
|
||||
combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner)
|
||||
val grouped = combinedMsgs.groupWith(verts)
|
||||
val superstep_ = superstep // Create a read-only copy of superstep for capture in closure
|
||||
val (processed, numMsgs, numActiveVerts) =
|
||||
comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep))
|
||||
comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel)
|
||||
|
||||
val timeTaken = System.currentTimeMillis - startTime
|
||||
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
|
||||
|
@ -50,6 +81,7 @@ object Bagel extends Logging {
|
|||
verts
|
||||
}
|
||||
|
||||
/** Runs a Bagel program with no [[spark.bagel.Aggregator]] and the default storage level */
|
||||
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
|
||||
sc: SparkContext,
|
||||
vertices: RDD[(K, V)],
|
||||
|
@ -59,12 +91,29 @@ object Bagel extends Logging {
|
|||
numPartitions: Int
|
||||
)(
|
||||
compute: (V, Option[C], Int) => (V, Array[M])
|
||||
): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
|
||||
|
||||
/** Runs a Bagel program with no [[spark.bagel.Aggregator]] */
|
||||
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
|
||||
sc: SparkContext,
|
||||
vertices: RDD[(K, V)],
|
||||
messages: RDD[(K, M)],
|
||||
combiner: Combiner[M, C],
|
||||
partitioner: Partitioner,
|
||||
numPartitions: Int,
|
||||
storageLevel: StorageLevel
|
||||
)(
|
||||
compute: (V, Option[C], Int) => (V, Array[M])
|
||||
): RDD[(K, V)] = {
|
||||
run[K, V, M, C, Nothing](
|
||||
sc, vertices, messages, combiner, None, partitioner, numPartitions)(
|
||||
sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)(
|
||||
addAggregatorArg[K, V, M, C](compute))
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a Bagel program with no [[spark.bagel.Aggregator]], default [[spark.HashPartitioner]]
|
||||
* and default storage level
|
||||
*/
|
||||
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
|
||||
sc: SparkContext,
|
||||
vertices: RDD[(K, V)],
|
||||
|
@ -73,13 +122,29 @@ object Bagel extends Logging {
|
|||
numPartitions: Int
|
||||
)(
|
||||
compute: (V, Option[C], Int) => (V, Array[M])
|
||||
): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
|
||||
|
||||
/** Runs a Bagel program with no [[spark.bagel.Aggregator]] and the default [[spark.HashPartitioner]]*/
|
||||
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
|
||||
sc: SparkContext,
|
||||
vertices: RDD[(K, V)],
|
||||
messages: RDD[(K, M)],
|
||||
combiner: Combiner[M, C],
|
||||
numPartitions: Int,
|
||||
storageLevel: StorageLevel
|
||||
)(
|
||||
compute: (V, Option[C], Int) => (V, Array[M])
|
||||
): RDD[(K, V)] = {
|
||||
val part = new HashPartitioner(numPartitions)
|
||||
run[K, V, M, C, Nothing](
|
||||
sc, vertices, messages, combiner, None, part, numPartitions)(
|
||||
sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)(
|
||||
addAggregatorArg[K, V, M, C](compute))
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a Bagel program with no [[spark.bagel.Aggregator]], default [[spark.HashPartitioner]],
|
||||
* [[spark.bagel.DefaultCombiner]] and the default storage level
|
||||
*/
|
||||
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
|
||||
sc: SparkContext,
|
||||
vertices: RDD[(K, V)],
|
||||
|
@ -87,10 +152,24 @@ object Bagel extends Logging {
|
|||
numPartitions: Int
|
||||
)(
|
||||
compute: (V, Option[Array[M]], Int) => (V, Array[M])
|
||||
): RDD[(K, V)] = {
|
||||
): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
|
||||
|
||||
/**
|
||||
* Runs a Bagel program with no [[spark.bagel.Aggregator]], the default [[spark.HashPartitioner]]
|
||||
* and [[spark.bagel.DefaultCombiner]]
|
||||
*/
|
||||
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
|
||||
sc: SparkContext,
|
||||
vertices: RDD[(K, V)],
|
||||
messages: RDD[(K, M)],
|
||||
numPartitions: Int,
|
||||
storageLevel: StorageLevel
|
||||
)(
|
||||
compute: (V, Option[Array[M]], Int) => (V, Array[M])
|
||||
): RDD[(K, V)] = {
|
||||
val part = new HashPartitioner(numPartitions)
|
||||
run[K, V, M, Array[M], Nothing](
|
||||
sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions)(
|
||||
sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)(
|
||||
addAggregatorArg[K, V, M, Array[M]](compute))
|
||||
}
|
||||
|
||||
|
@ -117,7 +196,8 @@ object Bagel extends Logging {
|
|||
private def comp[K: Manifest, V <: Vertex, M <: Message[K], C](
|
||||
sc: SparkContext,
|
||||
grouped: RDD[(K, (Seq[C], Seq[V]))],
|
||||
compute: (V, Option[C]) => (V, Array[M])
|
||||
compute: (V, Option[C]) => (V, Array[M]),
|
||||
storageLevel: StorageLevel
|
||||
): (RDD[(K, (V, Array[M]))], Int, Int) = {
|
||||
var numMsgs = sc.accumulator(0)
|
||||
var numActiveVerts = sc.accumulator(0)
|
||||
|
@ -135,7 +215,7 @@ object Bagel extends Logging {
|
|||
numActiveVerts += 1
|
||||
|
||||
Some((newVert, newMsgs))
|
||||
}.cache
|
||||
}.persist(storageLevel)
|
||||
|
||||
// Force evaluation of processed RDD for accurate performance measurements
|
||||
processed.foreach(x => {})
|
||||
|
@ -166,6 +246,7 @@ trait Aggregator[V, A] {
|
|||
def mergeAggregators(a: A, b: A): A
|
||||
}
|
||||
|
||||
/** Default combiner that simply appends messages together (i.e. performs no aggregation) */
|
||||
class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable {
|
||||
def createCombiner(msg: M): Array[M] =
|
||||
Array(msg)
|
||||
|
|
|
@ -7,6 +7,7 @@ import org.scalatest.time.SpanSugar._
|
|||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import spark._
|
||||
import storage.StorageLevel
|
||||
|
||||
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
|
||||
class TestMessage(val targetId: String) extends Message[String] with Serializable
|
||||
|
@ -22,6 +23,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
|
|||
}
|
||||
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
|
||||
System.clearProperty("spark.driver.port")
|
||||
System.clearProperty("spark.hostPort")
|
||||
}
|
||||
|
||||
test("halting by voting") {
|
||||
|
@ -79,4 +81,21 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("using non-default persistence level") {
|
||||
failAfter(10 seconds) {
|
||||
sc = new SparkContext("local", "test")
|
||||
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
|
||||
val msgs = sc.parallelize(Array[(String, TestMessage)]())
|
||||
val numSupersteps = 50
|
||||
val result =
|
||||
Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) {
|
||||
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
|
||||
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
|
||||
}
|
||||
for ((id, vert) <- result.collect) {
|
||||
assert(vert.age === numSupersteps)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@
|
|||
# SPARK_NICENESS The scheduling priority for daemons. Defaults to 0.
|
||||
##
|
||||
|
||||
usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <args...>"
|
||||
usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <spark-instance-number> <args...>"
|
||||
|
||||
# if no args specified, show usage
|
||||
if [ $# -le 1 ]; then
|
||||
|
@ -48,6 +48,8 @@ startStop=$1
|
|||
shift
|
||||
command=$1
|
||||
shift
|
||||
instance=$1
|
||||
shift
|
||||
|
||||
spark_rotate_log ()
|
||||
{
|
||||
|
@ -92,10 +94,10 @@ if [ "$SPARK_PID_DIR" = "" ]; then
|
|||
fi
|
||||
|
||||
# some variables
|
||||
export SPARK_LOGFILE=spark-$SPARK_IDENT_STRING-$command-$HOSTNAME.log
|
||||
export SPARK_LOGFILE=spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.log
|
||||
export SPARK_ROOT_LOGGER="INFO,DRFA"
|
||||
log=$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$HOSTNAME.out
|
||||
pid=$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command.pid
|
||||
log=$SPARK_LOG_DIR/spark-$SPARK_IDENT_STRING-$command-$instance-$HOSTNAME.out
|
||||
pid=$SPARK_PID_DIR/spark-$SPARK_IDENT_STRING-$command-$instance.pid
|
||||
|
||||
# Set default scheduling priority
|
||||
if [ "$SPARK_NICENESS" = "" ]; then
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
# Run a Spark command on all slave hosts.
|
||||
|
||||
usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command args..."
|
||||
usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command instance-number args..."
|
||||
|
||||
# if no args specified, show usage
|
||||
if [ $# -le 1 ]; then
|
||||
|
|
|
@ -32,4 +32,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then
|
|||
fi
|
||||
fi
|
||||
|
||||
"$bin"/spark-daemon.sh start spark.deploy.master.Master --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT
|
||||
"$bin"/spark-daemon.sh start spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT
|
||||
|
|
|
@ -6,9 +6,10 @@ bin=`cd "$bin"; pwd`
|
|||
# Set SPARK_PUBLIC_DNS so slaves can be linked in master web UI
|
||||
if [ "$SPARK_PUBLIC_DNS" = "" ]; then
|
||||
# If we appear to be running on EC2, use the public address by default:
|
||||
if [[ `hostname` == *ec2.internal ]]; then
|
||||
# NOTE: ec2-metadata is installed on Amazon Linux AMI. Check based on that and hostname
|
||||
if command -v ec2-metadata > /dev/null || [[ `hostname` == *ec2.internal ]]; then
|
||||
export SPARK_PUBLIC_DNS=`wget -q -O - http://instance-data.ec2.internal/latest/meta-data/public-hostname`
|
||||
fi
|
||||
fi
|
||||
|
||||
"$bin"/spark-daemon.sh start spark.deploy.worker.Worker $1
|
||||
"$bin"/spark-daemon.sh start spark.deploy.worker.Worker "$@"
|
||||
|
|
|
@ -21,4 +21,13 @@ fi
|
|||
echo "Master IP: $SPARK_MASTER_IP"
|
||||
|
||||
# Launch the slaves
|
||||
exec "$bin/slaves.sh" cd "$SPARK_HOME" \; "$bin/start-slave.sh" spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT
|
||||
if [ "$SPARK_WORKER_INSTANCES" = "" ]; then
|
||||
exec "$bin/slaves.sh" cd "$SPARK_HOME" \; "$bin/start-slave.sh" 1 spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT
|
||||
else
|
||||
if [ "$SPARK_WORKER_WEBUI_PORT" = "" ]; then
|
||||
SPARK_WORKER_WEBUI_PORT=8081
|
||||
fi
|
||||
for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do
|
||||
"$bin/slaves.sh" cd "$SPARK_HOME" \; "$bin/start-slave.sh" $(( $i + 1 )) spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT --webui-port $(( $SPARK_WORKER_WEBUI_PORT + $i ))
|
||||
done
|
||||
fi
|
||||
|
|
|
@ -7,4 +7,4 @@ bin=`cd "$bin"; pwd`
|
|||
|
||||
. "$bin/spark-config.sh"
|
||||
|
||||
"$bin"/spark-daemon.sh stop spark.deploy.master.Master
|
||||
"$bin"/spark-daemon.sh stop spark.deploy.master.Master 1
|
||||
|
|
|
@ -7,4 +7,14 @@ bin=`cd "$bin"; pwd`
|
|||
|
||||
. "$bin/spark-config.sh"
|
||||
|
||||
"$bin"/spark-daemons.sh stop spark.deploy.worker.Worker
|
||||
if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then
|
||||
. "${SPARK_CONF_DIR}/spark-env.sh"
|
||||
fi
|
||||
|
||||
if [ "$SPARK_WORKER_INSTANCES" = "" ]; then
|
||||
"$bin"/spark-daemons.sh stop spark.deploy.worker.Worker 1
|
||||
else
|
||||
for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do
|
||||
"$bin"/spark-daemons.sh stop spark.deploy.worker.Worker $(( $i + 1 ))
|
||||
done
|
||||
fi
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
# - SPARK_WORKER_CORES, to set the number of cores to use on this machine
|
||||
# - SPARK_WORKER_MEMORY, to set how much memory to use (e.g. 1000m, 2g)
|
||||
# - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT
|
||||
# - SPARK_WORKER_INSTANCES, to set the number of worker instances/processes to be spawned on every slave machine
|
||||
#
|
||||
# Finally, Spark also relies on the following variables, but these can be set
|
||||
# on just the *master* (i.e. in your driver program), and will automatically
|
||||
|
|
92
core/pom.xml
92
core/pom.xml
|
@ -3,8 +3,8 @@
|
|||
<modelVersion>4.0.0</modelVersion>
|
||||
<parent>
|
||||
<groupId>org.spark-project</groupId>
|
||||
<artifactId>parent</artifactId>
|
||||
<version>0.7.0-SNAPSHOT</version>
|
||||
<artifactId>spark-parent</artifactId>
|
||||
<version>0.8.0-SNAPSHOT</version>
|
||||
<relativePath>../pom.xml</relativePath>
|
||||
</parent>
|
||||
|
||||
|
@ -32,8 +32,8 @@
|
|||
<artifactId>compress-lzf</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>asm</groupId>
|
||||
<artifactId>asm-all</artifactId>
|
||||
<groupId>org.ow2.asm</groupId>
|
||||
<artifactId>asm</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
|
@ -73,7 +73,7 @@
|
|||
</dependency>
|
||||
<dependency>
|
||||
<groupId>cc.spray</groupId>
|
||||
<artifactId>spray-json_${scala.version}</artifactId>
|
||||
<artifactId>spray-json_2.9.2</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.tomdz.twirl</groupId>
|
||||
|
@ -81,13 +81,26 @@
|
|||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.github.scala-incubator.io</groupId>
|
||||
<artifactId>scala-io-file_${scala.version}</artifactId>
|
||||
<artifactId>scala-io-file_2.9.2</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.mesos</groupId>
|
||||
<artifactId>mesos</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.netty</groupId>
|
||||
<artifactId>netty-all</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>log4j</groupId>
|
||||
<artifactId>log4j</artifactId>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.apache.derby</groupId>
|
||||
<artifactId>derby</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalatest</groupId>
|
||||
<artifactId>scalatest_${scala.version}</artifactId>
|
||||
|
@ -275,5 +288,72 @@
|
|||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>hadoop2-yarn</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-client</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-yarn-api</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-yarn-common</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hadoop</groupId>
|
||||
<artifactId>hadoop-yarn-client</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.codehaus.mojo</groupId>
|
||||
<artifactId>build-helper-maven-plugin</artifactId>
|
||||
<executions>
|
||||
<execution>
|
||||
<id>add-source</id>
|
||||
<phase>generate-sources</phase>
|
||||
<goals>
|
||||
<goal>add-source</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<sources>
|
||||
<source>src/main/scala</source>
|
||||
<source>src/hadoop2-yarn/scala</source>
|
||||
</sources>
|
||||
</configuration>
|
||||
</execution>
|
||||
<execution>
|
||||
<id>add-scala-test-sources</id>
|
||||
<phase>generate-test-sources</phase>
|
||||
<goals>
|
||||
<goal>add-test-source</goal>
|
||||
</goals>
|
||||
<configuration>
|
||||
<sources>
|
||||
<source>src/test/scala</source>
|
||||
</sources>
|
||||
</configuration>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-jar-plugin</artifactId>
|
||||
<configuration>
|
||||
<classifier>hadoop2-yarn</classifier>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
||||
|
|
|
@ -4,4 +4,7 @@ trait HadoopMapRedUtil {
|
|||
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContext(conf, jobId)
|
||||
|
||||
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
|
||||
|
||||
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
|
||||
jobId, isMap, taskId, attemptId)
|
||||
}
|
||||
|
|
|
@ -6,4 +6,7 @@ trait HadoopMapReduceUtil {
|
|||
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContext(conf, jobId)
|
||||
|
||||
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContext(conf, attemptId)
|
||||
|
||||
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
|
||||
jobId, isMap, taskId, attemptId)
|
||||
}
|
||||
|
|
23
core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala
Normal file
23
core/src/hadoop1/scala/spark/deploy/SparkHadoopUtil.scala
Normal file
|
@ -0,0 +1,23 @@
|
|||
package spark.deploy
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
||||
|
||||
/**
|
||||
* Contains util methods to interact with Hadoop from spark.
|
||||
*/
|
||||
object SparkHadoopUtil {
|
||||
|
||||
def getUserNameFromEnvironment(): String = {
|
||||
// defaulting to -D ...
|
||||
System.getProperty("user.name")
|
||||
}
|
||||
|
||||
def runAsUser(func: (Product) => Unit, args: Product) {
|
||||
|
||||
// Add support, if exists - for now, simply run func !
|
||||
func(args)
|
||||
}
|
||||
|
||||
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
|
||||
def newConfiguration(): Configuration = new Configuration()
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
|
||||
package org.apache.hadoop.mapred
|
||||
|
||||
import org.apache.hadoop.mapreduce.TaskType
|
||||
|
||||
trait HadoopMapRedUtil {
|
||||
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
|
||||
|
||||
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
|
||||
|
||||
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
|
||||
new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
package org.apache.hadoop.mapreduce
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import task.{TaskAttemptContextImpl, JobContextImpl}
|
||||
|
||||
trait HadoopMapReduceUtil {
|
||||
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
|
||||
|
||||
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
|
||||
|
||||
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) =
|
||||
new TaskAttemptID(jtIdentifier, jobId, if (isMap) TaskType.MAP else TaskType.REDUCE, taskId, attemptId)
|
||||
}
|
|
@ -0,0 +1,63 @@
|
|||
package spark.deploy
|
||||
|
||||
import collection.mutable.HashMap
|
||||
import org.apache.hadoop.security.UserGroupInformation
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
|
||||
import java.security.PrivilegedExceptionAction
|
||||
|
||||
/**
|
||||
* Contains util methods to interact with Hadoop from spark.
|
||||
*/
|
||||
object SparkHadoopUtil {
|
||||
|
||||
val yarnConf = newConfiguration()
|
||||
|
||||
def getUserNameFromEnvironment(): String = {
|
||||
// defaulting to env if -D is not present ...
|
||||
val retval = System.getProperty(Environment.USER.name, System.getenv(Environment.USER.name))
|
||||
|
||||
// If nothing found, default to user we are running as
|
||||
if (retval == null) System.getProperty("user.name") else retval
|
||||
}
|
||||
|
||||
def runAsUser(func: (Product) => Unit, args: Product) {
|
||||
runAsUser(func, args, getUserNameFromEnvironment())
|
||||
}
|
||||
|
||||
def runAsUser(func: (Product) => Unit, args: Product, user: String) {
|
||||
|
||||
// println("running as user " + jobUserName)
|
||||
|
||||
UserGroupInformation.setConfiguration(yarnConf)
|
||||
val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(user)
|
||||
appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
|
||||
def run: AnyRef = {
|
||||
func(args)
|
||||
// no return value ...
|
||||
null
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Note that all params which start with SPARK are propagated all the way through, so if in yarn mode, this MUST be set to true.
|
||||
def isYarnMode(): Boolean = {
|
||||
val yarnMode = System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))
|
||||
java.lang.Boolean.valueOf(yarnMode)
|
||||
}
|
||||
|
||||
// Set an env variable indicating we are running in YARN mode.
|
||||
// Note that anything with SPARK prefix gets propagated to all (remote) processes
|
||||
def setYarnMode() {
|
||||
System.setProperty("SPARK_YARN_MODE", "true")
|
||||
}
|
||||
|
||||
def setYarnMode(env: HashMap[String, String]) {
|
||||
env("SPARK_YARN_MODE") = "true"
|
||||
}
|
||||
|
||||
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
|
||||
// Always create a new config, dont reuse yarnConf.
|
||||
def newConfiguration(): Configuration = new YarnConfiguration(new Configuration())
|
||||
}
|
|
@ -0,0 +1,329 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import java.net.Socket
|
||||
import java.util.concurrent.CopyOnWriteArrayList
|
||||
import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.net.NetUtils
|
||||
import org.apache.hadoop.yarn.api._
|
||||
import org.apache.hadoop.yarn.api.records._
|
||||
import org.apache.hadoop.yarn.api.protocolrecords._
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration
|
||||
import org.apache.hadoop.yarn.ipc.YarnRPC
|
||||
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
|
||||
import scala.collection.JavaConversions._
|
||||
import spark.{SparkContext, Logging, Utils}
|
||||
import org.apache.hadoop.security.UserGroupInformation
|
||||
import java.security.PrivilegedExceptionAction
|
||||
|
||||
class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
|
||||
|
||||
def this(args: ApplicationMasterArguments) = this(args, new Configuration())
|
||||
|
||||
private var rpc: YarnRPC = YarnRPC.create(conf)
|
||||
private var resourceManager: AMRMProtocol = null
|
||||
private var appAttemptId: ApplicationAttemptId = null
|
||||
private var userThread: Thread = null
|
||||
private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
|
||||
|
||||
private var yarnAllocator: YarnAllocationHandler = null
|
||||
|
||||
def run() {
|
||||
|
||||
// Initialization
|
||||
val jobUserName = Utils.getUserNameFromEnvironment()
|
||||
logInfo("running as user " + jobUserName)
|
||||
|
||||
// run as user ...
|
||||
UserGroupInformation.setConfiguration(yarnConf)
|
||||
val appMasterUgi: UserGroupInformation = UserGroupInformation.createRemoteUser(jobUserName)
|
||||
appMasterUgi.doAs(new PrivilegedExceptionAction[AnyRef] {
|
||||
def run: AnyRef = {
|
||||
runImpl()
|
||||
return null
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private def runImpl() {
|
||||
|
||||
appAttemptId = getApplicationAttemptId()
|
||||
resourceManager = registerWithResourceManager()
|
||||
val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
|
||||
|
||||
// Compute number of threads for akka
|
||||
val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
|
||||
|
||||
if (minimumMemory > 0) {
|
||||
val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
|
||||
val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
|
||||
|
||||
if (numCore > 0) {
|
||||
// do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
|
||||
// TODO: Uncomment when hadoop is on a version which has this fixed.
|
||||
// args.workerCores = numCore
|
||||
}
|
||||
}
|
||||
|
||||
// Workaround until hadoop moves to something which has
|
||||
// https://issues.apache.org/jira/browse/HADOOP-8406
|
||||
// ignore result
|
||||
// This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times
|
||||
// Hence args.workerCores = numCore disabled above. Any better option ?
|
||||
// org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf)
|
||||
|
||||
ApplicationMaster.register(this)
|
||||
// Start the user's JAR
|
||||
userThread = startUserClass()
|
||||
|
||||
// This a bit hacky, but we need to wait until the spark.driver.port property has
|
||||
// been set by the Thread executing the user class.
|
||||
waitForSparkMaster()
|
||||
|
||||
// Allocate all containers
|
||||
allocateWorkers()
|
||||
|
||||
// Wait for the user class to Finish
|
||||
userThread.join()
|
||||
|
||||
// Finish the ApplicationMaster
|
||||
finishApplicationMaster()
|
||||
// TODO: Exit based on success/failure
|
||||
System.exit(0)
|
||||
}
|
||||
|
||||
private def getApplicationAttemptId(): ApplicationAttemptId = {
|
||||
val envs = System.getenv()
|
||||
val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
|
||||
val containerId = ConverterUtils.toContainerId(containerIdString)
|
||||
val appAttemptId = containerId.getApplicationAttemptId()
|
||||
logInfo("ApplicationAttemptId: " + appAttemptId)
|
||||
return appAttemptId
|
||||
}
|
||||
|
||||
private def registerWithResourceManager(): AMRMProtocol = {
|
||||
val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
|
||||
YarnConfiguration.RM_SCHEDULER_ADDRESS,
|
||||
YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
|
||||
logInfo("Connecting to ResourceManager at " + rmAddress)
|
||||
return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
|
||||
}
|
||||
|
||||
private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
|
||||
logInfo("Registering the ApplicationMaster")
|
||||
val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
|
||||
.asInstanceOf[RegisterApplicationMasterRequest]
|
||||
appMasterRequest.setApplicationAttemptId(appAttemptId)
|
||||
// Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
|
||||
// Users can then monitor stderr/stdout on that node if required.
|
||||
appMasterRequest.setHost(Utils.localHostName())
|
||||
appMasterRequest.setRpcPort(0)
|
||||
// What do we provide here ? Might make sense to expose something sensible later ?
|
||||
appMasterRequest.setTrackingUrl("")
|
||||
return resourceManager.registerApplicationMaster(appMasterRequest)
|
||||
}
|
||||
|
||||
private def waitForSparkMaster() {
|
||||
logInfo("Waiting for spark driver to be reachable.")
|
||||
var driverUp = false
|
||||
while(!driverUp) {
|
||||
val driverHost = System.getProperty("spark.driver.host")
|
||||
val driverPort = System.getProperty("spark.driver.port")
|
||||
try {
|
||||
val socket = new Socket(driverHost, driverPort.toInt)
|
||||
socket.close()
|
||||
logInfo("Master now available: " + driverHost + ":" + driverPort)
|
||||
driverUp = true
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
|
||||
Thread.sleep(100)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def startUserClass(): Thread = {
|
||||
logInfo("Starting the user JAR in a separate Thread")
|
||||
val mainMethod = Class.forName(args.userClass, false, Thread.currentThread.getContextClassLoader)
|
||||
.getMethod("main", classOf[Array[String]])
|
||||
val t = new Thread {
|
||||
override def run() {
|
||||
// Copy
|
||||
var mainArgs: Array[String] = new Array[String](args.userArgs.size())
|
||||
args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size())
|
||||
mainMethod.invoke(null, mainArgs)
|
||||
}
|
||||
}
|
||||
t.start()
|
||||
return t
|
||||
}
|
||||
|
||||
private def allocateWorkers() {
|
||||
logInfo("Waiting for spark context initialization")
|
||||
|
||||
try {
|
||||
var sparkContext: SparkContext = null
|
||||
ApplicationMaster.sparkContextRef.synchronized {
|
||||
var count = 0
|
||||
while (ApplicationMaster.sparkContextRef.get() == null) {
|
||||
logInfo("Waiting for spark context initialization ... " + count)
|
||||
count = count + 1
|
||||
ApplicationMaster.sparkContextRef.wait(10000L)
|
||||
}
|
||||
sparkContext = ApplicationMaster.sparkContextRef.get()
|
||||
assert(sparkContext != null)
|
||||
this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, sparkContext.preferredNodeLocationData)
|
||||
}
|
||||
|
||||
|
||||
logInfo("Allocating " + args.numWorkers + " workers.")
|
||||
// Wait until all containers have finished
|
||||
// TODO: This is a bit ugly. Can we make it nicer?
|
||||
// TODO: Handle container failure
|
||||
while(yarnAllocator.getNumWorkersRunning < args.numWorkers &&
|
||||
// If user thread exists, then quit !
|
||||
userThread.isAlive) {
|
||||
|
||||
this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
|
||||
ApplicationMaster.incrementAllocatorLoop(1)
|
||||
Thread.sleep(100)
|
||||
}
|
||||
} finally {
|
||||
// in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT :
|
||||
// so that the loop (in ApplicationMaster.sparkContextInitialized) breaks
|
||||
ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
|
||||
}
|
||||
logInfo("All workers have launched.")
|
||||
|
||||
// Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
|
||||
if (userThread.isAlive){
|
||||
// ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
|
||||
|
||||
val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
|
||||
// must be <= timeoutInterval/ 2.
|
||||
// On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
|
||||
// so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
|
||||
val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
|
||||
launchReporterThread(interval)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: We might want to extend this to allocate more containers in case they die !
|
||||
private def launchReporterThread(_sleepTime: Long): Thread = {
|
||||
val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
|
||||
|
||||
val t = new Thread {
|
||||
override def run() {
|
||||
while (userThread.isAlive){
|
||||
val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
|
||||
if (missingWorkerCount > 0) {
|
||||
logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
|
||||
yarnAllocator.allocateContainers(missingWorkerCount)
|
||||
}
|
||||
else sendProgress()
|
||||
Thread.sleep(sleepTime)
|
||||
}
|
||||
}
|
||||
}
|
||||
// setting to daemon status, though this is usually not a good idea.
|
||||
t.setDaemon(true)
|
||||
t.start()
|
||||
logInfo("Started progress reporter thread - sleep time : " + sleepTime)
|
||||
return t
|
||||
}
|
||||
|
||||
private def sendProgress() {
|
||||
logDebug("Sending progress")
|
||||
// simulated with an allocate request with no nodes requested ...
|
||||
yarnAllocator.allocateContainers(0)
|
||||
}
|
||||
|
||||
/*
|
||||
def printContainers(containers: List[Container]) = {
|
||||
for (container <- containers) {
|
||||
logInfo("Launching shell command on a new container."
|
||||
+ ", containerId=" + container.getId()
|
||||
+ ", containerNode=" + container.getNodeId().getHost()
|
||||
+ ":" + container.getNodeId().getPort()
|
||||
+ ", containerNodeURI=" + container.getNodeHttpAddress()
|
||||
+ ", containerState" + container.getState()
|
||||
+ ", containerResourceMemory"
|
||||
+ container.getResource().getMemory())
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
def finishApplicationMaster() {
|
||||
val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
|
||||
.asInstanceOf[FinishApplicationMasterRequest]
|
||||
finishReq.setAppAttemptId(appAttemptId)
|
||||
// TODO: Check if the application has failed or succeeded
|
||||
finishReq.setFinishApplicationStatus(FinalApplicationStatus.SUCCEEDED)
|
||||
resourceManager.finishApplicationMaster(finishReq)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
object ApplicationMaster {
|
||||
// number of times to wait for the allocator loop to complete.
|
||||
// each loop iteration waits for 100ms, so maximum of 3 seconds.
|
||||
// This is to ensure that we have reasonable number of containers before we start
|
||||
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be optimal as more
|
||||
// containers are available. Might need to handle this better.
|
||||
private val ALLOCATOR_LOOP_WAIT_COUNT = 30
|
||||
def incrementAllocatorLoop(by: Int) {
|
||||
val count = yarnAllocatorLoop.getAndAdd(by)
|
||||
if (count >= ALLOCATOR_LOOP_WAIT_COUNT){
|
||||
yarnAllocatorLoop.synchronized {
|
||||
// to wake threads off wait ...
|
||||
yarnAllocatorLoop.notifyAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
|
||||
|
||||
def register(master: ApplicationMaster) {
|
||||
applicationMasters.add(master)
|
||||
}
|
||||
|
||||
val sparkContextRef: AtomicReference[SparkContext] = new AtomicReference[SparkContext](null)
|
||||
val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
|
||||
|
||||
def sparkContextInitialized(sc: SparkContext): Boolean = {
|
||||
var modified = false
|
||||
sparkContextRef.synchronized {
|
||||
modified = sparkContextRef.compareAndSet(null, sc)
|
||||
sparkContextRef.notifyAll()
|
||||
}
|
||||
|
||||
// Add a shutdown hook - as a best case effort in case users do not call sc.stop or do System.exit
|
||||
// Should not really have to do this, but it helps yarn to evict resources earlier.
|
||||
// not to mention, prevent Client declaring failure even though we exit'ed properly.
|
||||
if (modified) {
|
||||
Runtime.getRuntime().addShutdownHook(new Thread with Logging {
|
||||
// This is not just to log, but also to ensure that log system is initialized for this instance when we actually are 'run'
|
||||
logInfo("Adding shutdown hook for context " + sc)
|
||||
override def run() {
|
||||
logInfo("Invoking sc stop from shutdown hook")
|
||||
sc.stop()
|
||||
// best case ...
|
||||
for (master <- applicationMasters) master.finishApplicationMaster
|
||||
}
|
||||
} )
|
||||
}
|
||||
|
||||
// Wait for initialization to complete and atleast 'some' nodes can get allocated
|
||||
yarnAllocatorLoop.synchronized {
|
||||
while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT){
|
||||
yarnAllocatorLoop.wait(1000L)
|
||||
}
|
||||
}
|
||||
modified
|
||||
}
|
||||
|
||||
def main(argStrings: Array[String]) {
|
||||
val args = new ApplicationMasterArguments(argStrings)
|
||||
new ApplicationMaster(args).run()
|
||||
}
|
||||
}
|
|
@ -0,0 +1,77 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import spark.util.IntParam
|
||||
import collection.mutable.ArrayBuffer
|
||||
|
||||
class ApplicationMasterArguments(val args: Array[String]) {
|
||||
var userJar: String = null
|
||||
var userClass: String = null
|
||||
var userArgs: Seq[String] = Seq[String]()
|
||||
var workerMemory = 1024
|
||||
var workerCores = 1
|
||||
var numWorkers = 2
|
||||
|
||||
parseArgs(args.toList)
|
||||
|
||||
private def parseArgs(inputArgs: List[String]): Unit = {
|
||||
val userArgsBuffer = new ArrayBuffer[String]()
|
||||
|
||||
var args = inputArgs
|
||||
|
||||
while (! args.isEmpty) {
|
||||
|
||||
args match {
|
||||
case ("--jar") :: value :: tail =>
|
||||
userJar = value
|
||||
args = tail
|
||||
|
||||
case ("--class") :: value :: tail =>
|
||||
userClass = value
|
||||
args = tail
|
||||
|
||||
case ("--args") :: value :: tail =>
|
||||
userArgsBuffer += value
|
||||
args = tail
|
||||
|
||||
case ("--num-workers") :: IntParam(value) :: tail =>
|
||||
numWorkers = value
|
||||
args = tail
|
||||
|
||||
case ("--worker-memory") :: IntParam(value) :: tail =>
|
||||
workerMemory = value
|
||||
args = tail
|
||||
|
||||
case ("--worker-cores") :: IntParam(value) :: tail =>
|
||||
workerCores = value
|
||||
args = tail
|
||||
|
||||
case Nil =>
|
||||
if (userJar == null || userClass == null) {
|
||||
printUsageAndExit(1)
|
||||
}
|
||||
|
||||
case _ =>
|
||||
printUsageAndExit(1, args)
|
||||
}
|
||||
}
|
||||
|
||||
userArgs = userArgsBuffer.readOnly
|
||||
}
|
||||
|
||||
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
|
||||
if (unknownParam != null) {
|
||||
System.err.println("Unknown/unsupported param " + unknownParam)
|
||||
}
|
||||
System.err.println(
|
||||
"Usage: spark.deploy.yarn.ApplicationMaster [options] \n" +
|
||||
"Options:\n" +
|
||||
" --jar JAR_PATH Path to your application's JAR file (required)\n" +
|
||||
" --class CLASS_NAME Name of your application's main class (required)\n" +
|
||||
" --args ARGS Arguments to be passed to your application's main class.\n" +
|
||||
" Mutliple invocations are possible, each will be passed in order.\n" +
|
||||
" --num-workers NUM Number of workers to start (Default: 2)\n" +
|
||||
" --worker-cores NUM Number of cores for the workers (Default: 1)\n" +
|
||||
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n")
|
||||
System.exit(exitCode)
|
||||
}
|
||||
}
|
272
core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala
Normal file
272
core/src/hadoop2-yarn/scala/spark/deploy/yarn/Client.scala
Normal file
|
@ -0,0 +1,272 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import java.net.{InetSocketAddress, URI}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
|
||||
import org.apache.hadoop.net.NetUtils
|
||||
import org.apache.hadoop.yarn.api._
|
||||
import org.apache.hadoop.yarn.api.records._
|
||||
import org.apache.hadoop.yarn.api.protocolrecords._
|
||||
import org.apache.hadoop.yarn.client.YarnClientImpl
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration
|
||||
import org.apache.hadoop.yarn.ipc.YarnRPC
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.JavaConversions._
|
||||
import spark.{Logging, Utils}
|
||||
import org.apache.hadoop.yarn.util.{Apps, Records, ConverterUtils}
|
||||
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
|
||||
import spark.deploy.SparkHadoopUtil
|
||||
|
||||
class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging {
|
||||
|
||||
def this(args: ClientArguments) = this(new Configuration(), args)
|
||||
|
||||
var rpc: YarnRPC = YarnRPC.create(conf)
|
||||
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
|
||||
|
||||
def run() {
|
||||
init(yarnConf)
|
||||
start()
|
||||
logClusterResourceDetails()
|
||||
|
||||
val newApp = super.getNewApplication()
|
||||
val appId = newApp.getApplicationId()
|
||||
|
||||
verifyClusterResources(newApp)
|
||||
val appContext = createApplicationSubmissionContext(appId)
|
||||
val localResources = prepareLocalResources(appId, "spark")
|
||||
val env = setupLaunchEnv(localResources)
|
||||
val amContainer = createContainerLaunchContext(newApp, localResources, env)
|
||||
|
||||
appContext.setQueue(args.amQueue)
|
||||
appContext.setAMContainerSpec(amContainer)
|
||||
appContext.setUser(args.amUser)
|
||||
|
||||
submitApp(appContext)
|
||||
|
||||
monitorApplication(appId)
|
||||
System.exit(0)
|
||||
}
|
||||
|
||||
|
||||
def logClusterResourceDetails() {
|
||||
val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
|
||||
logInfo("Got Cluster metric info from ASM, numNodeManagers=" + clusterMetrics.getNumNodeManagers)
|
||||
|
||||
val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
|
||||
logInfo("Queue info .. queueName=" + queueInfo.getQueueName + ", queueCurrentCapacity=" + queueInfo.getCurrentCapacity +
|
||||
", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size +
|
||||
", queueChildQueueCount=" + queueInfo.getChildQueues.size)
|
||||
}
|
||||
|
||||
|
||||
def verifyClusterResources(app: GetNewApplicationResponse) = {
|
||||
val maxMem = app.getMaximumResourceCapability().getMemory()
|
||||
logInfo("Max mem capabililty of resources in this cluster " + maxMem)
|
||||
|
||||
// If the cluster does not have enough memory resources, exit.
|
||||
val requestedMem = (args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + args.numWorkers * args.workerMemory
|
||||
if (requestedMem > maxMem) {
|
||||
logError("Cluster cannot satisfy memory resource request of " + requestedMem)
|
||||
System.exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = {
|
||||
logInfo("Setting up application submission context for ASM")
|
||||
val appContext = Records.newRecord(classOf[ApplicationSubmissionContext])
|
||||
appContext.setApplicationId(appId)
|
||||
appContext.setApplicationName("Spark")
|
||||
return appContext
|
||||
}
|
||||
|
||||
def prepareLocalResources(appId: ApplicationId, appName: String): HashMap[String, LocalResource] = {
|
||||
logInfo("Preparing Local resources")
|
||||
val locaResources = HashMap[String, LocalResource]()
|
||||
// Upload Spark and the application JAR to the remote file system
|
||||
// Add them as local resources to the AM
|
||||
val fs = FileSystem.get(conf)
|
||||
Map("spark.jar" -> System.getenv("SPARK_JAR"), "app.jar" -> args.userJar, "log4j.properties" -> System.getenv("SPARK_LOG4J_CONF"))
|
||||
.foreach { case(destName, _localPath) =>
|
||||
val localPath: String = if (_localPath != null) _localPath.trim() else ""
|
||||
if (! localPath.isEmpty()) {
|
||||
val src = new Path(localPath)
|
||||
val pathSuffix = appName + "/" + appId.getId() + destName
|
||||
val dst = new Path(fs.getHomeDirectory(), pathSuffix)
|
||||
logInfo("Uploading " + src + " to " + dst)
|
||||
fs.copyFromLocalFile(false, true, src, dst)
|
||||
val destStatus = fs.getFileStatus(dst)
|
||||
|
||||
val amJarRsrc = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
|
||||
amJarRsrc.setType(LocalResourceType.FILE)
|
||||
amJarRsrc.setVisibility(LocalResourceVisibility.APPLICATION)
|
||||
amJarRsrc.setResource(ConverterUtils.getYarnUrlFromPath(dst))
|
||||
amJarRsrc.setTimestamp(destStatus.getModificationTime())
|
||||
amJarRsrc.setSize(destStatus.getLen())
|
||||
locaResources(destName) = amJarRsrc
|
||||
}
|
||||
}
|
||||
return locaResources
|
||||
}
|
||||
|
||||
def setupLaunchEnv(localResources: HashMap[String, LocalResource]): HashMap[String, String] = {
|
||||
logInfo("Setting up the launch environment")
|
||||
val log4jConfLocalRes = localResources.getOrElse("log4j.properties", null)
|
||||
|
||||
val env = new HashMap[String, String]()
|
||||
Apps.addToEnvironment(env, Environment.USER.name, args.amUser)
|
||||
|
||||
// If log4j present, ensure ours overrides all others
|
||||
if (log4jConfLocalRes != null) Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
|
||||
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
|
||||
Client.populateHadoopClasspath(yarnConf, env)
|
||||
SparkHadoopUtil.setYarnMode(env)
|
||||
env("SPARK_YARN_JAR_PATH") =
|
||||
localResources("spark.jar").getResource().getScheme.toString() + "://" +
|
||||
localResources("spark.jar").getResource().getFile().toString()
|
||||
env("SPARK_YARN_JAR_TIMESTAMP") = localResources("spark.jar").getTimestamp().toString()
|
||||
env("SPARK_YARN_JAR_SIZE") = localResources("spark.jar").getSize().toString()
|
||||
|
||||
env("SPARK_YARN_USERJAR_PATH") =
|
||||
localResources("app.jar").getResource().getScheme.toString() + "://" +
|
||||
localResources("app.jar").getResource().getFile().toString()
|
||||
env("SPARK_YARN_USERJAR_TIMESTAMP") = localResources("app.jar").getTimestamp().toString()
|
||||
env("SPARK_YARN_USERJAR_SIZE") = localResources("app.jar").getSize().toString()
|
||||
|
||||
if (log4jConfLocalRes != null) {
|
||||
env("SPARK_YARN_LOG4J_PATH") =
|
||||
log4jConfLocalRes.getResource().getScheme.toString() + "://" + log4jConfLocalRes.getResource().getFile().toString()
|
||||
env("SPARK_YARN_LOG4J_TIMESTAMP") = log4jConfLocalRes.getTimestamp().toString()
|
||||
env("SPARK_YARN_LOG4J_SIZE") = log4jConfLocalRes.getSize().toString()
|
||||
}
|
||||
|
||||
// Add each SPARK-* key to the environment
|
||||
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
|
||||
return env
|
||||
}
|
||||
|
||||
def userArgsToString(clientArgs: ClientArguments): String = {
|
||||
val prefix = " --args "
|
||||
val args = clientArgs.userArgs
|
||||
val retval = new StringBuilder()
|
||||
for (arg <- args){
|
||||
retval.append(prefix).append(" '").append(arg).append("' ")
|
||||
}
|
||||
|
||||
retval.toString
|
||||
}
|
||||
|
||||
def createContainerLaunchContext(newApp: GetNewApplicationResponse,
|
||||
localResources: HashMap[String, LocalResource],
|
||||
env: HashMap[String, String]): ContainerLaunchContext = {
|
||||
logInfo("Setting up container launch context")
|
||||
val amContainer = Records.newRecord(classOf[ContainerLaunchContext])
|
||||
amContainer.setLocalResources(localResources)
|
||||
amContainer.setEnvironment(env)
|
||||
|
||||
val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory()
|
||||
|
||||
var amMemory = ((args.amMemory / minResMemory) * minResMemory) +
|
||||
(if (0 != (args.amMemory % minResMemory)) minResMemory else 0) - YarnAllocationHandler.MEMORY_OVERHEAD
|
||||
|
||||
// Extra options for the JVM
|
||||
var JAVA_OPTS = ""
|
||||
|
||||
// Add Xmx for am memory
|
||||
JAVA_OPTS += "-Xmx" + amMemory + "m "
|
||||
|
||||
// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
|
||||
// The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
|
||||
// node, spark gc effects all other containers performance (which can also be other spark containers)
|
||||
// Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
|
||||
// limited to subset of cores on a node.
|
||||
if (env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC"))) {
|
||||
// In our expts, using (default) throughput collector has severe perf ramnifications in multi-tenant machines
|
||||
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
|
||||
JAVA_OPTS += " -XX:+CMSIncrementalMode "
|
||||
JAVA_OPTS += " -XX:+CMSIncrementalPacing "
|
||||
JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
|
||||
JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
|
||||
}
|
||||
if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
|
||||
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
|
||||
}
|
||||
|
||||
// Command for the ApplicationMaster
|
||||
val commands = List[String]("java " +
|
||||
" -server " +
|
||||
JAVA_OPTS +
|
||||
" spark.deploy.yarn.ApplicationMaster" +
|
||||
" --class " + args.userClass +
|
||||
" --jar " + args.userJar +
|
||||
userArgsToString(args) +
|
||||
" --worker-memory " + args.workerMemory +
|
||||
" --worker-cores " + args.workerCores +
|
||||
" --num-workers " + args.numWorkers +
|
||||
" 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
|
||||
" 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
|
||||
logInfo("Command for the ApplicationMaster: " + commands(0))
|
||||
amContainer.setCommands(commands)
|
||||
|
||||
val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource]
|
||||
// Memory for the ApplicationMaster
|
||||
capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
|
||||
amContainer.setResource(capability)
|
||||
|
||||
return amContainer
|
||||
}
|
||||
|
||||
def submitApp(appContext: ApplicationSubmissionContext) = {
|
||||
// Submit the application to the applications manager
|
||||
logInfo("Submitting application to ASM")
|
||||
super.submitApplication(appContext)
|
||||
}
|
||||
|
||||
def monitorApplication(appId: ApplicationId): Boolean = {
|
||||
while(true) {
|
||||
Thread.sleep(1000)
|
||||
val report = super.getApplicationReport(appId)
|
||||
|
||||
logInfo("Application report from ASM: \n" +
|
||||
"\t application identifier: " + appId.toString() + "\n" +
|
||||
"\t appId: " + appId.getId() + "\n" +
|
||||
"\t clientToken: " + report.getClientToken() + "\n" +
|
||||
"\t appDiagnostics: " + report.getDiagnostics() + "\n" +
|
||||
"\t appMasterHost: " + report.getHost() + "\n" +
|
||||
"\t appQueue: " + report.getQueue() + "\n" +
|
||||
"\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
|
||||
"\t appStartTime: " + report.getStartTime() + "\n" +
|
||||
"\t yarnAppState: " + report.getYarnApplicationState() + "\n" +
|
||||
"\t distributedFinalState: " + report.getFinalApplicationStatus() + "\n" +
|
||||
"\t appTrackingUrl: " + report.getTrackingUrl() + "\n" +
|
||||
"\t appUser: " + report.getUser()
|
||||
)
|
||||
|
||||
val state = report.getYarnApplicationState()
|
||||
val dsStatus = report.getFinalApplicationStatus()
|
||||
if (state == YarnApplicationState.FINISHED ||
|
||||
state == YarnApplicationState.FAILED ||
|
||||
state == YarnApplicationState.KILLED) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
object Client {
|
||||
def main(argStrings: Array[String]) {
|
||||
val args = new ClientArguments(argStrings)
|
||||
SparkHadoopUtil.setYarnMode()
|
||||
new Client(args).run
|
||||
}
|
||||
|
||||
// Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
|
||||
def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
|
||||
for (c <- conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) {
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,105 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import spark.util.MemoryParam
|
||||
import spark.util.IntParam
|
||||
import collection.mutable.{ArrayBuffer, HashMap}
|
||||
import spark.scheduler.{InputFormatInfo, SplitInfo}
|
||||
|
||||
// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware !
|
||||
class ClientArguments(val args: Array[String]) {
|
||||
var userJar: String = null
|
||||
var userClass: String = null
|
||||
var userArgs: Seq[String] = Seq[String]()
|
||||
var workerMemory = 1024
|
||||
var workerCores = 1
|
||||
var numWorkers = 2
|
||||
var amUser = System.getProperty("user.name")
|
||||
var amQueue = System.getProperty("QUEUE", "default")
|
||||
var amMemory: Int = 512
|
||||
// TODO
|
||||
var inputFormatInfo: List[InputFormatInfo] = null
|
||||
|
||||
parseArgs(args.toList)
|
||||
|
||||
private def parseArgs(inputArgs: List[String]): Unit = {
|
||||
val userArgsBuffer: ArrayBuffer[String] = new ArrayBuffer[String]()
|
||||
val inputFormatMap: HashMap[String, InputFormatInfo] = new HashMap[String, InputFormatInfo]()
|
||||
|
||||
var args = inputArgs
|
||||
|
||||
while (! args.isEmpty) {
|
||||
|
||||
args match {
|
||||
case ("--jar") :: value :: tail =>
|
||||
userJar = value
|
||||
args = tail
|
||||
|
||||
case ("--class") :: value :: tail =>
|
||||
userClass = value
|
||||
args = tail
|
||||
|
||||
case ("--args") :: value :: tail =>
|
||||
userArgsBuffer += value
|
||||
args = tail
|
||||
|
||||
case ("--master-memory") :: MemoryParam(value) :: tail =>
|
||||
amMemory = value
|
||||
args = tail
|
||||
|
||||
case ("--num-workers") :: IntParam(value) :: tail =>
|
||||
numWorkers = value
|
||||
args = tail
|
||||
|
||||
case ("--worker-memory") :: MemoryParam(value) :: tail =>
|
||||
workerMemory = value
|
||||
args = tail
|
||||
|
||||
case ("--worker-cores") :: IntParam(value) :: tail =>
|
||||
workerCores = value
|
||||
args = tail
|
||||
|
||||
case ("--user") :: value :: tail =>
|
||||
amUser = value
|
||||
args = tail
|
||||
|
||||
case ("--queue") :: value :: tail =>
|
||||
amQueue = value
|
||||
args = tail
|
||||
|
||||
case Nil =>
|
||||
if (userJar == null || userClass == null) {
|
||||
printUsageAndExit(1)
|
||||
}
|
||||
|
||||
case _ =>
|
||||
printUsageAndExit(1, args)
|
||||
}
|
||||
}
|
||||
|
||||
userArgs = userArgsBuffer.readOnly
|
||||
inputFormatInfo = inputFormatMap.values.toList
|
||||
}
|
||||
|
||||
|
||||
def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
|
||||
if (unknownParam != null) {
|
||||
System.err.println("Unknown/unsupported param " + unknownParam)
|
||||
}
|
||||
System.err.println(
|
||||
"Usage: spark.deploy.yarn.Client [options] \n" +
|
||||
"Options:\n" +
|
||||
" --jar JAR_PATH Path to your application's JAR file (required)\n" +
|
||||
" --class CLASS_NAME Name of your application's main class (required)\n" +
|
||||
" --args ARGS Arguments to be passed to your application's main class.\n" +
|
||||
" Mutliple invocations are possible, each will be passed in order.\n" +
|
||||
" --num-workers NUM Number of workers to start (Default: 2)\n" +
|
||||
" --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
|
||||
" --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
|
||||
" --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
|
||||
" --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
|
||||
" --user USERNAME Run the ApplicationMaster (and slaves) as a different user\n"
|
||||
)
|
||||
System.exit(exitCode)
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,171 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import java.net.URI
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
|
||||
import org.apache.hadoop.net.NetUtils
|
||||
import org.apache.hadoop.security.UserGroupInformation
|
||||
import org.apache.hadoop.yarn.api._
|
||||
import org.apache.hadoop.yarn.api.records._
|
||||
import org.apache.hadoop.yarn.api.protocolrecords._
|
||||
import org.apache.hadoop.yarn.conf.YarnConfiguration
|
||||
import org.apache.hadoop.yarn.ipc.YarnRPC
|
||||
import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
|
||||
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
import spark.{Logging, Utils}
|
||||
|
||||
class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String,
|
||||
slaveId: String, hostname: String, workerMemory: Int, workerCores: Int)
|
||||
extends Runnable with Logging {
|
||||
|
||||
var rpc: YarnRPC = YarnRPC.create(conf)
|
||||
var cm: ContainerManager = null
|
||||
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
|
||||
|
||||
def run = {
|
||||
logInfo("Starting Worker Container")
|
||||
cm = connectToCM
|
||||
startContainer
|
||||
}
|
||||
|
||||
def startContainer = {
|
||||
logInfo("Setting up ContainerLaunchContext")
|
||||
|
||||
val ctx = Records.newRecord(classOf[ContainerLaunchContext])
|
||||
.asInstanceOf[ContainerLaunchContext]
|
||||
|
||||
ctx.setContainerId(container.getId())
|
||||
ctx.setResource(container.getResource())
|
||||
val localResources = prepareLocalResources
|
||||
ctx.setLocalResources(localResources)
|
||||
|
||||
val env = prepareEnvironment
|
||||
ctx.setEnvironment(env)
|
||||
|
||||
// Extra options for the JVM
|
||||
var JAVA_OPTS = ""
|
||||
// Set the JVM memory
|
||||
val workerMemoryString = workerMemory + "m"
|
||||
JAVA_OPTS += "-Xms" + workerMemoryString + " -Xmx" + workerMemoryString + " "
|
||||
if (env.isDefinedAt("SPARK_JAVA_OPTS")) {
|
||||
JAVA_OPTS += env("SPARK_JAVA_OPTS") + " "
|
||||
}
|
||||
// Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out.
|
||||
// The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same
|
||||
// node, spark gc effects all other containers performance (which can also be other spark containers)
|
||||
// Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in multi-tenant environments. Not sure how default java gc behaves if it is
|
||||
// limited to subset of cores on a node.
|
||||
/*
|
||||
else {
|
||||
// If no java_opts specified, default to using -XX:+CMSIncrementalMode
|
||||
// It might be possible that other modes/config is being done in SPARK_JAVA_OPTS, so we dont want to mess with it.
|
||||
// In our expts, using (default) throughput collector has severe perf ramnifications in multi-tennent machines
|
||||
// The options are based on
|
||||
// http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
|
||||
JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
|
||||
JAVA_OPTS += " -XX:+CMSIncrementalMode "
|
||||
JAVA_OPTS += " -XX:+CMSIncrementalPacing "
|
||||
JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
|
||||
JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
|
||||
}
|
||||
*/
|
||||
|
||||
ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
|
||||
val commands = List[String]("java " +
|
||||
" -server " +
|
||||
// Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling.
|
||||
// Not killing the task leaves various aspects of the worker and (to some extent) the jvm in an inconsistent state.
|
||||
// TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
|
||||
" -XX:OnOutOfMemoryError='kill %p' " +
|
||||
JAVA_OPTS +
|
||||
" spark.executor.StandaloneExecutorBackend " +
|
||||
masterAddress + " " +
|
||||
slaveId + " " +
|
||||
hostname + " " +
|
||||
workerCores +
|
||||
" 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" +
|
||||
" 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
|
||||
logInfo("Setting up worker with commands: " + commands)
|
||||
ctx.setCommands(commands)
|
||||
|
||||
// Send the start request to the ContainerManager
|
||||
val startReq = Records.newRecord(classOf[StartContainerRequest])
|
||||
.asInstanceOf[StartContainerRequest]
|
||||
startReq.setContainerLaunchContext(ctx)
|
||||
cm.startContainer(startReq)
|
||||
}
|
||||
|
||||
|
||||
def prepareLocalResources: HashMap[String, LocalResource] = {
|
||||
logInfo("Preparing Local resources")
|
||||
val locaResources = HashMap[String, LocalResource]()
|
||||
|
||||
// Spark JAR
|
||||
val sparkJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
|
||||
sparkJarResource.setType(LocalResourceType.FILE)
|
||||
sparkJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
|
||||
sparkJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
|
||||
new URI(System.getenv("SPARK_YARN_JAR_PATH"))))
|
||||
sparkJarResource.setTimestamp(System.getenv("SPARK_YARN_JAR_TIMESTAMP").toLong)
|
||||
sparkJarResource.setSize(System.getenv("SPARK_YARN_JAR_SIZE").toLong)
|
||||
locaResources("spark.jar") = sparkJarResource
|
||||
// User JAR
|
||||
val userJarResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
|
||||
userJarResource.setType(LocalResourceType.FILE)
|
||||
userJarResource.setVisibility(LocalResourceVisibility.APPLICATION)
|
||||
userJarResource.setResource(ConverterUtils.getYarnUrlFromURI(
|
||||
new URI(System.getenv("SPARK_YARN_USERJAR_PATH"))))
|
||||
userJarResource.setTimestamp(System.getenv("SPARK_YARN_USERJAR_TIMESTAMP").toLong)
|
||||
userJarResource.setSize(System.getenv("SPARK_YARN_USERJAR_SIZE").toLong)
|
||||
locaResources("app.jar") = userJarResource
|
||||
|
||||
// Log4j conf - if available
|
||||
if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
|
||||
val log4jConfResource = Records.newRecord(classOf[LocalResource]).asInstanceOf[LocalResource]
|
||||
log4jConfResource.setType(LocalResourceType.FILE)
|
||||
log4jConfResource.setVisibility(LocalResourceVisibility.APPLICATION)
|
||||
log4jConfResource.setResource(ConverterUtils.getYarnUrlFromURI(
|
||||
new URI(System.getenv("SPARK_YARN_LOG4J_PATH"))))
|
||||
log4jConfResource.setTimestamp(System.getenv("SPARK_YARN_LOG4J_TIMESTAMP").toLong)
|
||||
log4jConfResource.setSize(System.getenv("SPARK_YARN_LOG4J_SIZE").toLong)
|
||||
locaResources("log4j.properties") = log4jConfResource
|
||||
}
|
||||
|
||||
|
||||
logInfo("Prepared Local resources " + locaResources)
|
||||
return locaResources
|
||||
}
|
||||
|
||||
def prepareEnvironment: HashMap[String, String] = {
|
||||
val env = new HashMap[String, String]()
|
||||
// should we add this ?
|
||||
Apps.addToEnvironment(env, Environment.USER.name, Utils.getUserNameFromEnvironment())
|
||||
|
||||
// If log4j present, ensure ours overrides all others
|
||||
if (System.getenv("SPARK_YARN_LOG4J_PATH") != null) {
|
||||
// Which is correct ?
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./log4j.properties")
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./")
|
||||
}
|
||||
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "./*")
|
||||
Apps.addToEnvironment(env, Environment.CLASSPATH.name, "$CLASSPATH")
|
||||
Client.populateHadoopClasspath(yarnConf, env)
|
||||
|
||||
System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v }
|
||||
return env
|
||||
}
|
||||
|
||||
def connectToCM: ContainerManager = {
|
||||
val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort()
|
||||
val cmAddress = NetUtils.createSocketAddr(cmHostPortStr)
|
||||
logInfo("Connecting to ContainerManager at " + cmHostPortStr)
|
||||
return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager]
|
||||
}
|
||||
|
||||
}
|
|
@ -0,0 +1,547 @@
|
|||
package spark.deploy.yarn
|
||||
|
||||
import spark.{Logging, Utils}
|
||||
import spark.scheduler.SplitInfo
|
||||
import scala.collection
|
||||
import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId, ContainerId, Priority, Resource, ResourceRequest, ContainerStatus, Container}
|
||||
import spark.scheduler.cluster.{ClusterScheduler, StandaloneSchedulerBackend}
|
||||
import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse}
|
||||
import org.apache.hadoop.yarn.util.{RackResolver, Records}
|
||||
import java.util.concurrent.{CopyOnWriteArrayList, ConcurrentHashMap}
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import org.apache.hadoop.yarn.api.AMRMProtocol
|
||||
import collection.JavaConversions._
|
||||
import collection.mutable.{ArrayBuffer, HashMap, HashSet}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import java.util.{Collections, Set => JSet}
|
||||
import java.lang.{Boolean => JBoolean}
|
||||
|
||||
object AllocationType extends Enumeration ("HOST", "RACK", "ANY") {
|
||||
type AllocationType = Value
|
||||
val HOST, RACK, ANY = Value
|
||||
}
|
||||
|
||||
// too many params ? refactor it 'somehow' ?
|
||||
// needs to be mt-safe
|
||||
// Need to refactor this to make it 'cleaner' ... right now, all computation is reactive : should make it
|
||||
// more proactive and decoupled.
|
||||
// Note that right now, we assume all node asks as uniform in terms of capabilities and priority
|
||||
// Refer to http://developer.yahoo.com/blogs/hadoop/posts/2011/03/mapreduce-nextgen-scheduler/ for more info
|
||||
// on how we are requesting for containers.
|
||||
private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceManager: AMRMProtocol,
|
||||
val appAttemptId: ApplicationAttemptId,
|
||||
val maxWorkers: Int, val workerMemory: Int, val workerCores: Int,
|
||||
val preferredHostToCount: Map[String, Int],
|
||||
val preferredRackToCount: Map[String, Int])
|
||||
extends Logging {
|
||||
|
||||
|
||||
// These three are locked on allocatedHostToContainersMap. Complementary data structures
|
||||
// allocatedHostToContainersMap : containers which are running : host, Set<containerid>
|
||||
// allocatedContainerToHostMap: container to host mapping
|
||||
private val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]]()
|
||||
private val allocatedContainerToHostMap = new HashMap[ContainerId, String]()
|
||||
// allocatedRackCount is populated ONLY if allocation happens (or decremented if this is an allocated node)
|
||||
// As with the two data structures above, tightly coupled with them, and to be locked on allocatedHostToContainersMap
|
||||
private val allocatedRackCount = new HashMap[String, Int]()
|
||||
|
||||
// containers which have been released.
|
||||
private val releasedContainerList = new CopyOnWriteArrayList[ContainerId]()
|
||||
// containers to be released in next request to RM
|
||||
private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean]
|
||||
|
||||
private val numWorkersRunning = new AtomicInteger()
|
||||
// Used to generate a unique id per worker
|
||||
private val workerIdCounter = new AtomicInteger()
|
||||
private val lastResponseId = new AtomicInteger()
|
||||
|
||||
def getNumWorkersRunning: Int = numWorkersRunning.intValue
|
||||
|
||||
|
||||
def isResourceConstraintSatisfied(container: Container): Boolean = {
|
||||
container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
|
||||
}
|
||||
|
||||
def allocateContainers(workersToRequest: Int) {
|
||||
// We need to send the request only once from what I understand ... but for now, not modifying this much.
|
||||
|
||||
// Keep polling the Resource Manager for containers
|
||||
val amResp = allocateWorkerResources(workersToRequest).getAMResponse
|
||||
|
||||
val _allocatedContainers = amResp.getAllocatedContainers()
|
||||
if (_allocatedContainers.size > 0) {
|
||||
|
||||
|
||||
logDebug("Allocated " + _allocatedContainers.size + " containers, current count " +
|
||||
numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
|
||||
", pendingReleaseContainers : " + pendingReleaseContainers)
|
||||
logDebug("Cluster Resources: " + amResp.getAvailableResources)
|
||||
|
||||
val hostToContainers = new HashMap[String, ArrayBuffer[Container]]()
|
||||
|
||||
// ignore if not satisfying constraints {
|
||||
for (container <- _allocatedContainers) {
|
||||
if (isResourceConstraintSatisfied(container)) {
|
||||
// allocatedContainers += container
|
||||
|
||||
val host = container.getNodeId.getHost
|
||||
val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]())
|
||||
|
||||
containers += container
|
||||
}
|
||||
// Add all ignored containers to released list
|
||||
else releasedContainerList.add(container.getId())
|
||||
}
|
||||
|
||||
// Find the appropriate containers to use
|
||||
// Slightly non trivial groupBy I guess ...
|
||||
val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
|
||||
val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]()
|
||||
val offRackContainers = new HashMap[String, ArrayBuffer[Container]]()
|
||||
|
||||
for (candidateHost <- hostToContainers.keySet)
|
||||
{
|
||||
val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0)
|
||||
val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost)
|
||||
|
||||
var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null)
|
||||
assert(remainingContainers != null)
|
||||
|
||||
if (requiredHostCount >= remainingContainers.size){
|
||||
// Since we got <= required containers, add all to dataLocalContainers
|
||||
dataLocalContainers.put(candidateHost, remainingContainers)
|
||||
// all consumed
|
||||
remainingContainers = null
|
||||
}
|
||||
else if (requiredHostCount > 0) {
|
||||
// container list has more containers than we need for data locality.
|
||||
// Split into two : data local container count of (remainingContainers.size - requiredHostCount)
|
||||
// and rest as remainingContainer
|
||||
val (dataLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredHostCount)
|
||||
dataLocalContainers.put(candidateHost, dataLocal)
|
||||
// remainingContainers = remaining
|
||||
|
||||
// yarn has nasty habit of allocating a tonne of containers on a host - discourage this :
|
||||
// add remaining to release list. If we have insufficient containers, next allocation cycle
|
||||
// will reallocate (but wont treat it as data local)
|
||||
for (container <- remaining) releasedContainerList.add(container.getId())
|
||||
remainingContainers = null
|
||||
}
|
||||
|
||||
// now rack local
|
||||
if (remainingContainers != null){
|
||||
val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
|
||||
|
||||
if (rack != null){
|
||||
val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0)
|
||||
val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) -
|
||||
rackLocalContainers.get(rack).getOrElse(List()).size
|
||||
|
||||
|
||||
if (requiredRackCount >= remainingContainers.size){
|
||||
// Add all to dataLocalContainers
|
||||
dataLocalContainers.put(rack, remainingContainers)
|
||||
// all consumed
|
||||
remainingContainers = null
|
||||
}
|
||||
else if (requiredRackCount > 0) {
|
||||
// container list has more containers than we need for data locality.
|
||||
// Split into two : data local container count of (remainingContainers.size - requiredRackCount)
|
||||
// and rest as remainingContainer
|
||||
val (rackLocal, remaining) = remainingContainers.splitAt(remainingContainers.size - requiredRackCount)
|
||||
val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]())
|
||||
|
||||
existingRackLocal ++= rackLocal
|
||||
remainingContainers = remaining
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If still not consumed, then it is off rack host - add to that list.
|
||||
if (remainingContainers != null){
|
||||
offRackContainers.put(candidateHost, remainingContainers)
|
||||
}
|
||||
}
|
||||
|
||||
// Now that we have split the containers into various groups, go through them in order :
|
||||
// first host local, then rack local and then off rack (everything else).
|
||||
// Note that the list we create below tries to ensure that not all containers end up within a host
|
||||
// if there are sufficiently large number of hosts/containers.
|
||||
|
||||
val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size)
|
||||
allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers)
|
||||
allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers)
|
||||
allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers)
|
||||
|
||||
// Run each of the allocated containers
|
||||
for (container <- allocatedContainers) {
|
||||
val numWorkersRunningNow = numWorkersRunning.incrementAndGet()
|
||||
val workerHostname = container.getNodeId.getHost
|
||||
val containerId = container.getId
|
||||
|
||||
assert (container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))
|
||||
|
||||
if (numWorkersRunningNow > maxWorkers) {
|
||||
logInfo("Ignoring container " + containerId + " at host " + workerHostname +
|
||||
" .. we already have required number of containers")
|
||||
releasedContainerList.add(containerId)
|
||||
// reset counter back to old value.
|
||||
numWorkersRunning.decrementAndGet()
|
||||
}
|
||||
else {
|
||||
// deallocate + allocate can result in reusing id's wrongly - so use a different counter (workerIdCounter)
|
||||
val workerId = workerIdCounter.incrementAndGet().toString
|
||||
val driverUrl = "akka://spark@%s:%s/user/%s".format(
|
||||
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
|
||||
StandaloneSchedulerBackend.ACTOR_NAME)
|
||||
|
||||
logInfo("launching container on " + containerId + " host " + workerHostname)
|
||||
// just to be safe, simply remove it from pendingReleaseContainers. Should not be there, but ..
|
||||
pendingReleaseContainers.remove(containerId)
|
||||
|
||||
val rack = YarnAllocationHandler.lookupRack(conf, workerHostname)
|
||||
allocatedHostToContainersMap.synchronized {
|
||||
val containerSet = allocatedHostToContainersMap.getOrElseUpdate(workerHostname, new HashSet[ContainerId]())
|
||||
|
||||
containerSet += containerId
|
||||
allocatedContainerToHostMap.put(containerId, workerHostname)
|
||||
if (rack != null) allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1)
|
||||
}
|
||||
|
||||
new Thread(
|
||||
new WorkerRunnable(container, conf, driverUrl, workerId,
|
||||
workerHostname, workerMemory, workerCores)
|
||||
).start()
|
||||
}
|
||||
}
|
||||
logDebug("After allocated " + allocatedContainers.size + " containers (orig : " +
|
||||
_allocatedContainers.size + "), current count " + numWorkersRunning.get() +
|
||||
", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
|
||||
}
|
||||
|
||||
|
||||
val completedContainers = amResp.getCompletedContainersStatuses()
|
||||
if (completedContainers.size > 0){
|
||||
logDebug("Completed " + completedContainers.size + " containers, current count " + numWorkersRunning.get() +
|
||||
", to-be-released " + releasedContainerList + ", pendingReleaseContainers : " + pendingReleaseContainers)
|
||||
|
||||
for (completedContainer <- completedContainers){
|
||||
val containerId = completedContainer.getContainerId
|
||||
|
||||
// Was this released by us ? If yes, then simply remove from containerSet and move on.
|
||||
if (pendingReleaseContainers.containsKey(containerId)) {
|
||||
pendingReleaseContainers.remove(containerId)
|
||||
}
|
||||
else {
|
||||
// simply decrement count - next iteration of ReporterThread will take care of allocating !
|
||||
numWorkersRunning.decrementAndGet()
|
||||
logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState +
|
||||
" httpaddress: " + completedContainer.getDiagnostics)
|
||||
}
|
||||
|
||||
allocatedHostToContainersMap.synchronized {
|
||||
if (allocatedContainerToHostMap.containsKey(containerId)) {
|
||||
val host = allocatedContainerToHostMap.get(containerId).getOrElse(null)
|
||||
assert (host != null)
|
||||
|
||||
val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null)
|
||||
assert (containerSet != null)
|
||||
|
||||
containerSet -= containerId
|
||||
if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host)
|
||||
else allocatedHostToContainersMap.update(host, containerSet)
|
||||
|
||||
allocatedContainerToHostMap -= containerId
|
||||
|
||||
// doing this within locked context, sigh ... move to outside ?
|
||||
val rack = YarnAllocationHandler.lookupRack(conf, host)
|
||||
if (rack != null) {
|
||||
val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1
|
||||
if (rackCount > 0) allocatedRackCount.put(rack, rackCount)
|
||||
else allocatedRackCount.remove(rack)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
logDebug("After completed " + completedContainers.size + " containers, current count " +
|
||||
numWorkersRunning.get() + ", to-be-released " + releasedContainerList +
|
||||
", pendingReleaseContainers : " + pendingReleaseContainers)
|
||||
}
|
||||
}
|
||||
|
||||
def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = {
|
||||
// First generate modified racks and new set of hosts under it : then issue requests
|
||||
val rackToCounts = new HashMap[String, Int]()
|
||||
|
||||
// Within this lock - used to read/write to the rack related maps too.
|
||||
for (container <- hostContainers) {
|
||||
val candidateHost = container.getHostName
|
||||
val candidateNumContainers = container.getNumContainers
|
||||
assert(YarnAllocationHandler.ANY_HOST != candidateHost)
|
||||
|
||||
val rack = YarnAllocationHandler.lookupRack(conf, candidateHost)
|
||||
if (rack != null) {
|
||||
var count = rackToCounts.getOrElse(rack, 0)
|
||||
count += candidateNumContainers
|
||||
rackToCounts.put(rack, count)
|
||||
}
|
||||
}
|
||||
|
||||
val requestedContainers: ArrayBuffer[ResourceRequest] =
|
||||
new ArrayBuffer[ResourceRequest](rackToCounts.size)
|
||||
for ((rack, count) <- rackToCounts){
|
||||
requestedContainers +=
|
||||
createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY)
|
||||
}
|
||||
|
||||
requestedContainers.toList
|
||||
}
|
||||
|
||||
def allocatedContainersOnHost(host: String): Int = {
|
||||
var retval = 0
|
||||
allocatedHostToContainersMap.synchronized {
|
||||
retval = allocatedHostToContainersMap.getOrElse(host, Set()).size
|
||||
}
|
||||
retval
|
||||
}
|
||||
|
||||
def allocatedContainersOnRack(rack: String): Int = {
|
||||
var retval = 0
|
||||
allocatedHostToContainersMap.synchronized {
|
||||
retval = allocatedRackCount.getOrElse(rack, 0)
|
||||
}
|
||||
retval
|
||||
}
|
||||
|
||||
private def allocateWorkerResources(numWorkers: Int): AllocateResponse = {
|
||||
|
||||
var resourceRequests: List[ResourceRequest] = null
|
||||
|
||||
// default.
|
||||
if (numWorkers <= 0 || preferredHostToCount.isEmpty) {
|
||||
logDebug("numWorkers: " + numWorkers + ", host preferences ? " + preferredHostToCount.isEmpty)
|
||||
resourceRequests = List(
|
||||
createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY))
|
||||
}
|
||||
else {
|
||||
// request for all hosts in preferred nodes and for numWorkers -
|
||||
// candidates.size, request by default allocation policy.
|
||||
val hostContainerRequests: ArrayBuffer[ResourceRequest] =
|
||||
new ArrayBuffer[ResourceRequest](preferredHostToCount.size)
|
||||
for ((candidateHost, candidateCount) <- preferredHostToCount) {
|
||||
val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost)
|
||||
|
||||
if (requiredCount > 0) {
|
||||
hostContainerRequests +=
|
||||
createResourceRequest(AllocationType.HOST, candidateHost, requiredCount, YarnAllocationHandler.PRIORITY)
|
||||
}
|
||||
}
|
||||
val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests(hostContainerRequests.toList)
|
||||
|
||||
val anyContainerRequests: ResourceRequest =
|
||||
createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)
|
||||
|
||||
val containerRequests: ArrayBuffer[ResourceRequest] =
|
||||
new ArrayBuffer[ResourceRequest](hostContainerRequests.size() + rackContainerRequests.size() + 1)
|
||||
|
||||
containerRequests ++= hostContainerRequests
|
||||
containerRequests ++= rackContainerRequests
|
||||
containerRequests += anyContainerRequests
|
||||
|
||||
resourceRequests = containerRequests.toList
|
||||
}
|
||||
|
||||
val req = Records.newRecord(classOf[AllocateRequest])
|
||||
req.setResponseId(lastResponseId.incrementAndGet)
|
||||
req.setApplicationAttemptId(appAttemptId)
|
||||
|
||||
req.addAllAsks(resourceRequests)
|
||||
|
||||
val releasedContainerList = createReleasedContainerList()
|
||||
req.addAllReleases(releasedContainerList)
|
||||
|
||||
|
||||
|
||||
if (numWorkers > 0) {
|
||||
logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.")
|
||||
}
|
||||
else {
|
||||
logDebug("Empty allocation req .. release : " + releasedContainerList)
|
||||
}
|
||||
|
||||
for (req <- resourceRequests) {
|
||||
logInfo("rsrcRequest ... host : " + req.getHostName + ", numContainers : " + req.getNumContainers +
|
||||
", p = " + req.getPriority().getPriority + ", capability: " + req.getCapability)
|
||||
}
|
||||
resourceManager.allocate(req)
|
||||
}
|
||||
|
||||
|
||||
private def createResourceRequest(requestType: AllocationType.AllocationType,
|
||||
resource:String, numWorkers: Int, priority: Int): ResourceRequest = {
|
||||
|
||||
// If hostname specified, we need atleast two requests - node local and rack local.
|
||||
// There must be a third request - which is ANY : that will be specially handled.
|
||||
requestType match {
|
||||
case AllocationType.HOST => {
|
||||
assert (YarnAllocationHandler.ANY_HOST != resource)
|
||||
|
||||
val hostname = resource
|
||||
val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority)
|
||||
|
||||
// add to host->rack mapping
|
||||
YarnAllocationHandler.populateRackInfo(conf, hostname)
|
||||
|
||||
nodeLocal
|
||||
}
|
||||
|
||||
case AllocationType.RACK => {
|
||||
val rack = resource
|
||||
createResourceRequestImpl(rack, numWorkers, priority)
|
||||
}
|
||||
|
||||
case AllocationType.ANY => {
|
||||
createResourceRequestImpl(YarnAllocationHandler.ANY_HOST, numWorkers, priority)
|
||||
}
|
||||
|
||||
case _ => throw new IllegalArgumentException("Unexpected/unsupported request type .. " + requestType)
|
||||
}
|
||||
}
|
||||
|
||||
private def createResourceRequestImpl(hostname:String, numWorkers: Int, priority: Int): ResourceRequest = {
|
||||
|
||||
val rsrcRequest = Records.newRecord(classOf[ResourceRequest])
|
||||
val memCapability = Records.newRecord(classOf[Resource])
|
||||
// There probably is some overhead here, let's reserve a bit more memory.
|
||||
memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)
|
||||
rsrcRequest.setCapability(memCapability)
|
||||
|
||||
val pri = Records.newRecord(classOf[Priority])
|
||||
pri.setPriority(priority)
|
||||
rsrcRequest.setPriority(pri)
|
||||
|
||||
rsrcRequest.setHostName(hostname)
|
||||
|
||||
rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0))
|
||||
rsrcRequest
|
||||
}
|
||||
|
||||
def createReleasedContainerList(): ArrayBuffer[ContainerId] = {
|
||||
|
||||
val retval = new ArrayBuffer[ContainerId](1)
|
||||
// iterator on COW list ...
|
||||
for (container <- releasedContainerList.iterator()){
|
||||
retval += container
|
||||
}
|
||||
// remove from the original list.
|
||||
if (! retval.isEmpty) {
|
||||
releasedContainerList.removeAll(retval)
|
||||
for (v <- retval) pendingReleaseContainers.put(v, true)
|
||||
logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " +
|
||||
pendingReleaseContainers)
|
||||
}
|
||||
|
||||
retval
|
||||
}
|
||||
}
|
||||
|
||||
object YarnAllocationHandler {
|
||||
|
||||
val ANY_HOST = "*"
|
||||
// all requests are issued with same priority : we do not (yet) have any distinction between
|
||||
// request types (like map/reduce in hadoop for example)
|
||||
val PRIORITY = 1
|
||||
|
||||
// Additional memory overhead - in mb
|
||||
val MEMORY_OVERHEAD = 384
|
||||
|
||||
// host to rack map - saved from allocation requests
|
||||
// We are expecting this not to change.
|
||||
// Note that it is possible for this to change : and RM will indicate that to us via update
|
||||
// response to allocate. But we are punting on handling that for now.
|
||||
private val hostToRack = new ConcurrentHashMap[String, String]()
|
||||
private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]()
|
||||
|
||||
def newAllocator(conf: Configuration,
|
||||
resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
|
||||
args: ApplicationMasterArguments,
|
||||
map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
|
||||
|
||||
val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
|
||||
|
||||
|
||||
new YarnAllocationHandler(conf, resourceManager, appAttemptId, args.numWorkers,
|
||||
args.workerMemory, args.workerCores, hostToCount, rackToCount)
|
||||
}
|
||||
|
||||
def newAllocator(conf: Configuration,
|
||||
resourceManager: AMRMProtocol, appAttemptId: ApplicationAttemptId,
|
||||
maxWorkers: Int, workerMemory: Int, workerCores: Int,
|
||||
map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = {
|
||||
|
||||
val (hostToCount, rackToCount) = generateNodeToWeight(conf, map)
|
||||
|
||||
new YarnAllocationHandler(conf, resourceManager, appAttemptId, maxWorkers,
|
||||
workerMemory, workerCores, hostToCount, rackToCount)
|
||||
}
|
||||
|
||||
// A simple method to copy the split info map.
|
||||
private def generateNodeToWeight(conf: Configuration, input: collection.Map[String, collection.Set[SplitInfo]]) :
|
||||
// host to count, rack to count
|
||||
(Map[String, Int], Map[String, Int]) = {
|
||||
|
||||
if (input == null) return (Map[String, Int](), Map[String, Int]())
|
||||
|
||||
val hostToCount = new HashMap[String, Int]
|
||||
val rackToCount = new HashMap[String, Int]
|
||||
|
||||
for ((host, splits) <- input) {
|
||||
val hostCount = hostToCount.getOrElse(host, 0)
|
||||
hostToCount.put(host, hostCount + splits.size)
|
||||
|
||||
val rack = lookupRack(conf, host)
|
||||
if (rack != null){
|
||||
val rackCount = rackToCount.getOrElse(host, 0)
|
||||
rackToCount.put(host, rackCount + splits.size)
|
||||
}
|
||||
}
|
||||
|
||||
(hostToCount.toMap, rackToCount.toMap)
|
||||
}
|
||||
|
||||
def lookupRack(conf: Configuration, host: String): String = {
|
||||
if (! hostToRack.contains(host)) populateRackInfo(conf, host)
|
||||
hostToRack.get(host)
|
||||
}
|
||||
|
||||
def fetchCachedHostsForRack(rack: String): Option[Set[String]] = {
|
||||
val set = rackToHostSet.get(rack)
|
||||
if (set == null) return None
|
||||
|
||||
// No better way to get a Set[String] from JSet ?
|
||||
val convertedSet: collection.mutable.Set[String] = set
|
||||
Some(convertedSet.toSet)
|
||||
}
|
||||
|
||||
def populateRackInfo(conf: Configuration, hostname: String) {
|
||||
Utils.checkHost(hostname)
|
||||
|
||||
if (!hostToRack.containsKey(hostname)) {
|
||||
// If there are repeated failures to resolve, all to an ignore list ?
|
||||
val rackInfo = RackResolver.resolve(conf, hostname)
|
||||
if (rackInfo != null && rackInfo.getNetworkLocation != null) {
|
||||
val rack = rackInfo.getNetworkLocation
|
||||
hostToRack.put(hostname, rack)
|
||||
if (! rackToHostSet.containsKey(rack)) {
|
||||
rackToHostSet.putIfAbsent(rack, Collections.newSetFromMap(new ConcurrentHashMap[String, JBoolean]()))
|
||||
}
|
||||
rackToHostSet.get(rack).add(hostname)
|
||||
|
||||
// Since RackResolver caches, we are disabling this for now ...
|
||||
} /* else {
|
||||
// right ? Else we will keep calling rack resolver in case we cant resolve rack info ...
|
||||
hostToRack.put(hostname, null)
|
||||
} */
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,42 @@
|
|||
package spark.scheduler.cluster
|
||||
|
||||
import spark._
|
||||
import spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
||||
/**
|
||||
*
|
||||
* This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of ApplicationMaster, etc is done
|
||||
*/
|
||||
private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
|
||||
|
||||
def this(sc: SparkContext) = this(sc, new Configuration())
|
||||
|
||||
// Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
|
||||
// Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
|
||||
// Subsequent creations are ignored - since nodes are already allocated by then.
|
||||
|
||||
|
||||
// By default, rack is unknown
|
||||
override def getRackForHost(hostPort: String): Option[String] = {
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
val retval = YarnAllocationHandler.lookupRack(conf, host)
|
||||
if (retval != null) Some(retval) else None
|
||||
}
|
||||
|
||||
// By default, if rack is unknown, return nothing
|
||||
override def getCachedHostsForRack(rack: String): Option[Set[String]] = {
|
||||
if (rack == None || rack == null) return None
|
||||
|
||||
YarnAllocationHandler.fetchCachedHostsForRack(rack)
|
||||
}
|
||||
|
||||
override def postStartHook() {
|
||||
val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
|
||||
if (sparkContextInitialized){
|
||||
// Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
|
||||
Thread.sleep(3000L)
|
||||
}
|
||||
logInfo("YarnClusterScheduler.postStartHook done")
|
||||
}
|
||||
}
|
|
@ -4,4 +4,7 @@ trait HadoopMapRedUtil {
|
|||
def newJobContext(conf: JobConf, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
|
||||
|
||||
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
|
||||
|
||||
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
|
||||
jobId, isMap, taskId, attemptId)
|
||||
}
|
||||
|
|
|
@ -7,4 +7,7 @@ trait HadoopMapReduceUtil {
|
|||
def newJobContext(conf: Configuration, jobId: JobID): JobContext = new JobContextImpl(conf, jobId)
|
||||
|
||||
def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = new TaskAttemptContextImpl(conf, attemptId)
|
||||
|
||||
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = new TaskAttemptID(jtIdentifier,
|
||||
jobId, isMap, taskId, attemptId)
|
||||
}
|
||||
|
|
23
core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala
Normal file
23
core/src/hadoop2/scala/spark/deploy/SparkHadoopUtil.scala
Normal file
|
@ -0,0 +1,23 @@
|
|||
package spark.deploy
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
||||
|
||||
/**
|
||||
* Contains util methods to interact with Hadoop from spark.
|
||||
*/
|
||||
object SparkHadoopUtil {
|
||||
|
||||
def getUserNameFromEnvironment(): String = {
|
||||
// defaulting to -D ...
|
||||
System.getProperty("user.name")
|
||||
}
|
||||
|
||||
def runAsUser(func: (Product) => Unit, args: Product) {
|
||||
|
||||
// Add support, if exists - for now, simply run func !
|
||||
func(args)
|
||||
}
|
||||
|
||||
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
|
||||
def newConfiguration(): Configuration = new Configuration()
|
||||
}
|
72
core/src/main/java/spark/network/netty/FileClient.java
Normal file
72
core/src/main/java/spark/network/netty/FileClient.java
Normal file
|
@ -0,0 +1,72 @@
|
|||
package spark.network.netty;
|
||||
|
||||
import io.netty.bootstrap.Bootstrap;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelFutureListener;
|
||||
import io.netty.channel.ChannelOption;
|
||||
import io.netty.channel.oio.OioEventLoopGroup;
|
||||
import io.netty.channel.socket.oio.OioSocketChannel;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
class FileClient {
|
||||
|
||||
private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
|
||||
private FileClientHandler handler = null;
|
||||
private Channel channel = null;
|
||||
private Bootstrap bootstrap = null;
|
||||
private int connectTimeout = 60*1000; // 1 min
|
||||
|
||||
public FileClient(FileClientHandler handler, int connectTimeout) {
|
||||
this.handler = handler;
|
||||
this.connectTimeout = connectTimeout;
|
||||
}
|
||||
|
||||
public void init() {
|
||||
bootstrap = new Bootstrap();
|
||||
bootstrap.group(new OioEventLoopGroup())
|
||||
.channel(OioSocketChannel.class)
|
||||
.option(ChannelOption.SO_KEEPALIVE, true)
|
||||
.option(ChannelOption.TCP_NODELAY, true)
|
||||
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
|
||||
.handler(new FileClientChannelInitializer(handler));
|
||||
}
|
||||
|
||||
public void connect(String host, int port) {
|
||||
try {
|
||||
// Start the connection attempt.
|
||||
channel = bootstrap.connect(host, port).sync().channel();
|
||||
// ChannelFuture cf = channel.closeFuture();
|
||||
//cf.addListener(new ChannelCloseListener(this));
|
||||
} catch (InterruptedException e) {
|
||||
close();
|
||||
}
|
||||
}
|
||||
|
||||
public void waitForClose() {
|
||||
try {
|
||||
channel.closeFuture().sync();
|
||||
} catch (InterruptedException e) {
|
||||
LOG.warn("FileClient interrupted", e);
|
||||
}
|
||||
}
|
||||
|
||||
public void sendRequest(String file) {
|
||||
//assert(file == null);
|
||||
//assert(channel == null);
|
||||
channel.write(file + "\r\n");
|
||||
}
|
||||
|
||||
public void close() {
|
||||
if(channel != null) {
|
||||
channel.close();
|
||||
channel = null;
|
||||
}
|
||||
if ( bootstrap!=null) {
|
||||
bootstrap.shutdown();
|
||||
bootstrap = null;
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,24 @@
|
|||
package spark.network.netty;
|
||||
|
||||
import io.netty.buffer.BufType;
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.handler.codec.string.StringEncoder;
|
||||
|
||||
|
||||
class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> {
|
||||
|
||||
private FileClientHandler fhandler;
|
||||
|
||||
public FileClientChannelInitializer(FileClientHandler handler) {
|
||||
fhandler = handler;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initChannel(SocketChannel channel) {
|
||||
// file no more than 2G
|
||||
channel.pipeline()
|
||||
.addLast("encoder", new StringEncoder(BufType.BYTE))
|
||||
.addLast("handler", fhandler);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,43 @@
|
|||
package spark.network.netty;
|
||||
|
||||
import io.netty.buffer.ByteBuf;
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundByteHandlerAdapter;
|
||||
|
||||
|
||||
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
|
||||
|
||||
private FileHeader currentHeader = null;
|
||||
|
||||
private volatile boolean handlerCalled = false;
|
||||
|
||||
public boolean isComplete() {
|
||||
return handlerCalled;
|
||||
}
|
||||
|
||||
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
|
||||
public abstract void handleError(String blockId);
|
||||
|
||||
@Override
|
||||
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
|
||||
// Use direct buffer if possible.
|
||||
return ctx.alloc().ioBuffer();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) {
|
||||
// get header
|
||||
if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) {
|
||||
currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE()));
|
||||
}
|
||||
// get file
|
||||
if(in.readableBytes() >= currentHeader.fileLen()) {
|
||||
handle(ctx, in, currentHeader);
|
||||
handlerCalled = true;
|
||||
currentHeader = null;
|
||||
ctx.close();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
86
core/src/main/java/spark/network/netty/FileServer.java
Normal file
86
core/src/main/java/spark/network/netty/FileServer.java
Normal file
|
@ -0,0 +1,86 @@
|
|||
package spark.network.netty;
|
||||
|
||||
import java.net.InetSocketAddress;
|
||||
|
||||
import io.netty.bootstrap.ServerBootstrap;
|
||||
import io.netty.channel.Channel;
|
||||
import io.netty.channel.ChannelFuture;
|
||||
import io.netty.channel.ChannelOption;
|
||||
import io.netty.channel.oio.OioEventLoopGroup;
|
||||
import io.netty.channel.socket.oio.OioServerSocketChannel;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
|
||||
/**
|
||||
* Server that accept the path of a file an echo back its content.
|
||||
*/
|
||||
class FileServer {
|
||||
|
||||
private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
|
||||
|
||||
private ServerBootstrap bootstrap = null;
|
||||
private ChannelFuture channelFuture = null;
|
||||
private int port = 0;
|
||||
private Thread blockingThread = null;
|
||||
|
||||
public FileServer(PathResolver pResolver, int port) {
|
||||
InetSocketAddress addr = new InetSocketAddress(port);
|
||||
|
||||
// Configure the server.
|
||||
bootstrap = new ServerBootstrap();
|
||||
bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup())
|
||||
.channel(OioServerSocketChannel.class)
|
||||
.option(ChannelOption.SO_BACKLOG, 100)
|
||||
.option(ChannelOption.SO_RCVBUF, 1500)
|
||||
.childHandler(new FileServerChannelInitializer(pResolver));
|
||||
// Start the server.
|
||||
channelFuture = bootstrap.bind(addr);
|
||||
try {
|
||||
// Get the address we bound to.
|
||||
InetSocketAddress boundAddress =
|
||||
((InetSocketAddress) channelFuture.sync().channel().localAddress());
|
||||
this.port = boundAddress.getPort();
|
||||
} catch (InterruptedException ie) {
|
||||
this.port = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start the file server asynchronously in a new thread.
|
||||
*/
|
||||
public void start() {
|
||||
blockingThread = new Thread() {
|
||||
public void run() {
|
||||
try {
|
||||
channelFuture.channel().closeFuture().sync();
|
||||
LOG.info("FileServer exiting");
|
||||
} catch (InterruptedException e) {
|
||||
LOG.error("File server start got interrupted", e);
|
||||
}
|
||||
// NOTE: bootstrap is shutdown in stop()
|
||||
}
|
||||
};
|
||||
blockingThread.setDaemon(true);
|
||||
blockingThread.start();
|
||||
}
|
||||
|
||||
public int getPort() {
|
||||
return port;
|
||||
}
|
||||
|
||||
public void stop() {
|
||||
// Close the bound channel.
|
||||
if (channelFuture != null) {
|
||||
channelFuture.channel().close();
|
||||
channelFuture = null;
|
||||
}
|
||||
// Shutdown bootstrap.
|
||||
if (bootstrap != null) {
|
||||
bootstrap.shutdown();
|
||||
bootstrap = null;
|
||||
}
|
||||
// TODO: Shutdown all accepted channels as well ?
|
||||
}
|
||||
}
|
|
@ -0,0 +1,25 @@
|
|||
package spark.network.netty;
|
||||
|
||||
import io.netty.channel.ChannelInitializer;
|
||||
import io.netty.channel.socket.SocketChannel;
|
||||
import io.netty.handler.codec.DelimiterBasedFrameDecoder;
|
||||
import io.netty.handler.codec.Delimiters;
|
||||
import io.netty.handler.codec.string.StringDecoder;
|
||||
|
||||
|
||||
class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> {
|
||||
|
||||
PathResolver pResolver;
|
||||
|
||||
public FileServerChannelInitializer(PathResolver pResolver) {
|
||||
this.pResolver = pResolver;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initChannel(SocketChannel channel) {
|
||||
channel.pipeline()
|
||||
.addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter()))
|
||||
.addLast("strDecoder", new StringDecoder())
|
||||
.addLast("handler", new FileServerHandler(pResolver));
|
||||
}
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package spark.network.netty;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
|
||||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
|
||||
import io.netty.channel.DefaultFileRegion;
|
||||
|
||||
|
||||
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
|
||||
|
||||
PathResolver pResolver;
|
||||
|
||||
public FileServerHandler(PathResolver pResolver){
|
||||
this.pResolver = pResolver;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void messageReceived(ChannelHandlerContext ctx, String blockId) {
|
||||
String path = pResolver.getAbsolutePath(blockId);
|
||||
// if getFilePath returns null, close the channel
|
||||
if (path == null) {
|
||||
//ctx.close();
|
||||
return;
|
||||
}
|
||||
File file = new File(path);
|
||||
if (file.exists()) {
|
||||
if (!file.isFile()) {
|
||||
//logger.info("Not a file : " + file.getAbsolutePath());
|
||||
ctx.write(new FileHeader(0, blockId).buffer());
|
||||
ctx.flush();
|
||||
return;
|
||||
}
|
||||
long length = file.length();
|
||||
if (length > Integer.MAX_VALUE || length <= 0) {
|
||||
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
|
||||
ctx.write(new FileHeader(0, blockId).buffer());
|
||||
ctx.flush();
|
||||
return;
|
||||
}
|
||||
int len = new Long(length).intValue();
|
||||
//logger.info("Sending block "+blockId+" filelen = "+len);
|
||||
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
|
||||
ctx.write((new FileHeader(len, blockId)).buffer());
|
||||
try {
|
||||
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
|
||||
.getChannel(), 0, file.length()));
|
||||
} catch (Exception e) {
|
||||
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
|
||||
e.printStackTrace();
|
||||
}
|
||||
} else {
|
||||
//logger.warning("File not found: " + file.getAbsolutePath());
|
||||
ctx.write(new FileHeader(0, blockId).buffer());
|
||||
}
|
||||
ctx.flush();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
|
||||
cause.printStackTrace();
|
||||
ctx.close();
|
||||
}
|
||||
}
|
12
core/src/main/java/spark/network/netty/PathResolver.java
Executable file
12
core/src/main/java/spark/network/netty/PathResolver.java
Executable file
|
@ -0,0 +1,12 @@
|
|||
package spark.network.netty;
|
||||
|
||||
|
||||
public interface PathResolver {
|
||||
/**
|
||||
* Get the absolute path of the file
|
||||
*
|
||||
* @param fileId
|
||||
* @return the absolute path of file
|
||||
*/
|
||||
public String getAbsolutePath(String fileId);
|
||||
}
|
|
@ -3,18 +3,25 @@ package spark
|
|||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
import spark.executor.{ShuffleReadMetrics, TaskMetrics}
|
||||
import spark.serializer.Serializer
|
||||
import spark.storage.BlockManagerId
|
||||
import spark.util.CompletionIterator
|
||||
|
||||
|
||||
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
|
||||
override def fetch[K, V](shuffleId: Int, reduceId: Int) = {
|
||||
|
||||
override def fetch[K, V](
|
||||
shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer) = {
|
||||
|
||||
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
|
||||
val blockManager = SparkEnv.get.blockManager
|
||||
|
||||
|
||||
val startTime = System.currentTimeMillis
|
||||
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
|
||||
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
|
||||
shuffleId, reduceId, System.currentTimeMillis - startTime))
|
||||
|
||||
|
||||
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
|
||||
for (((address, size), index) <- statuses.zipWithIndex) {
|
||||
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
|
||||
|
@ -45,6 +52,20 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
|
|||
}
|
||||
}
|
||||
}
|
||||
blockManager.getMultiple(blocksByAddress).flatMap(unpackBlock)
|
||||
|
||||
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
|
||||
val itr = blockFetcherItr.flatMap(unpackBlock)
|
||||
|
||||
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
|
||||
val shuffleMetrics = new ShuffleReadMetrics
|
||||
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
|
||||
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
|
||||
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
|
||||
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
|
||||
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
|
||||
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
|
||||
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
|
||||
metrics.shuffleReadMetrics = Some(shuffleMetrics)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
|
|||
if (loading.contains(key)) {
|
||||
logInfo("Loading contains " + key + ", waiting...")
|
||||
while (loading.contains(key)) {
|
||||
try {loading.wait()} catch {case _ =>}
|
||||
try {loading.wait()} catch {case _ : Throwable =>}
|
||||
}
|
||||
logInfo("Loading no longer contains " + key + ", so returning cached result")
|
||||
// See whether someone else has successfully loaded it. The main way this would fail
|
||||
|
|
|
@ -5,15 +5,22 @@ import java.lang.reflect.Field
|
|||
import scala.collection.mutable.Map
|
||||
import scala.collection.mutable.Set
|
||||
|
||||
import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
|
||||
import org.objectweb.asm.commons.EmptyVisitor
|
||||
import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
|
||||
import org.objectweb.asm.Opcodes._
|
||||
import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
|
||||
|
||||
private[spark] object ClosureCleaner extends Logging {
|
||||
// Get an ASM class reader for a given class from the JAR that loaded it
|
||||
private def getClassReader(cls: Class[_]): ClassReader = {
|
||||
new ClassReader(cls.getResourceAsStream(
|
||||
cls.getName.replaceFirst("^.*\\.", "") + ".class"))
|
||||
// Copy data over, before delegating to ClassReader - else we can run out of open file handles.
|
||||
val className = cls.getName.replaceFirst("^.*\\.", "") + ".class"
|
||||
val resourceStream = cls.getResourceAsStream(className)
|
||||
// todo: Fixme - continuing with earlier behavior ...
|
||||
if (resourceStream == null) return new ClassReader(resourceStream)
|
||||
|
||||
val baos = new ByteArrayOutputStream(128)
|
||||
Utils.copyStream(resourceStream, baos, true)
|
||||
new ClassReader(new ByteArrayInputStream(baos.toByteArray))
|
||||
}
|
||||
|
||||
// Check whether a class represents a Scala closure
|
||||
|
@ -154,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
|
||||
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
|
||||
override def visitMethod(access: Int, name: String, desc: String,
|
||||
sig: String, exceptions: Array[String]): MethodVisitor = {
|
||||
return new EmptyVisitor {
|
||||
return new MethodVisitor(ASM4) {
|
||||
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
|
||||
if (op == GETFIELD) {
|
||||
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
|
||||
|
@ -180,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten
|
|||
}
|
||||
}
|
||||
|
||||
private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
|
||||
private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
|
||||
var myName: String = null
|
||||
|
||||
override def visit(version: Int, access: Int, name: String, sig: String,
|
||||
|
@ -190,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi
|
|||
|
||||
override def visitMethod(access: Int, name: String, desc: String,
|
||||
sig: String, exceptions: Array[String]): MethodVisitor = {
|
||||
return new EmptyVisitor {
|
||||
return new MethodVisitor(ASM4) {
|
||||
override def visitMethodInsn(op: Int, owner: String, name: String,
|
||||
desc: String) {
|
||||
val argTypes = Type.getArgumentTypes(desc)
|
||||
|
|
|
@ -25,10 +25,12 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
|
|||
* @param shuffleId the shuffle id
|
||||
* @param rdd the parent RDD
|
||||
* @param partitioner partitioner used to partition the shuffle output
|
||||
* @param serializerClass class name of the serializer to use
|
||||
*/
|
||||
class ShuffleDependency[K, V](
|
||||
@transient rdd: RDD[(K, V)],
|
||||
val partitioner: Partitioner)
|
||||
val partitioner: Partitioner,
|
||||
val serializerClass: String = null)
|
||||
extends Dependency(rdd) {
|
||||
|
||||
val shuffleId: Int = rdd.context.newShuffleId()
|
||||
|
|
|
@ -3,18 +3,25 @@ package spark
|
|||
import spark.storage.BlockManagerId
|
||||
|
||||
private[spark] class FetchFailedException(
|
||||
val bmAddress: BlockManagerId,
|
||||
val shuffleId: Int,
|
||||
val mapId: Int,
|
||||
val reduceId: Int,
|
||||
taskEndReason: TaskEndReason,
|
||||
message: String,
|
||||
cause: Throwable)
|
||||
extends Exception {
|
||||
|
||||
override def getMessage(): String =
|
||||
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
|
||||
|
||||
def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) =
|
||||
this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
|
||||
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
|
||||
cause)
|
||||
|
||||
def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
|
||||
this(FetchFailed(null, shuffleId, -1, reduceId),
|
||||
"Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)
|
||||
|
||||
override def getMessage(): String = message
|
||||
|
||||
|
||||
override def getCause(): Throwable = cause
|
||||
|
||||
def toTaskEndReason: TaskEndReason =
|
||||
FetchFailed(bmAddress, shuffleId, mapId, reduceId)
|
||||
def toTaskEndReason: TaskEndReason = taskEndReason
|
||||
|
||||
}
|
||||
|
|
|
@ -2,14 +2,10 @@ package org.apache.hadoop.mapred
|
|||
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.util.ReflectionUtils
|
||||
import org.apache.hadoop.io.NullWritable
|
||||
import org.apache.hadoop.io.Text
|
||||
|
||||
import java.text.SimpleDateFormat
|
||||
import java.text.NumberFormat
|
||||
import java.io.IOException
|
||||
import java.net.URI
|
||||
import java.util.Date
|
||||
|
||||
import spark.Logging
|
||||
|
@ -24,7 +20,7 @@ import spark.SerializableWritable
|
|||
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
|
||||
*/
|
||||
class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRedUtil with Serializable {
|
||||
|
||||
|
||||
private val now = new Date()
|
||||
private val conf = new SerializableWritable(jobConf)
|
||||
|
||||
|
@ -106,6 +102,12 @@ class HadoopWriter(@transient jobConf: JobConf) extends Logging with HadoopMapRe
|
|||
}
|
||||
}
|
||||
|
||||
def commitJob() {
|
||||
// always ? Or if cmtr.needsTaskCommit ?
|
||||
val cmtr = getOutputCommitter()
|
||||
cmtr.commitJob(getJobContext())
|
||||
}
|
||||
|
||||
def cleanup() {
|
||||
getOutputCommitter().cleanupJob(getJobContext())
|
||||
}
|
||||
|
|
|
@ -157,27 +157,34 @@ class KryoSerializer extends spark.serializer.Serializer with Logging {
|
|||
|
||||
// Register maps with a special serializer since they have complex internal structure
|
||||
class ScalaMapSerializer(buildMap: Array[(Any, Any)] => scala.collection.Map[Any, Any])
|
||||
extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {
|
||||
extends KSerializer[Array[(Any, Any)] => scala.collection.Map[Any, Any]] {
|
||||
|
||||
//hack, look at https://groups.google.com/forum/#!msg/kryo-users/Eu5V4bxCfws/k-8UQ22y59AJ
|
||||
private final val FAKE_REFERENCE = new Object()
|
||||
override def write(
|
||||
kryo: Kryo,
|
||||
output: KryoOutput,
|
||||
obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
|
||||
kryo: Kryo,
|
||||
output: KryoOutput,
|
||||
obj: Array[(Any, Any)] => scala.collection.Map[Any, Any]) {
|
||||
val map = obj.asInstanceOf[scala.collection.Map[Any, Any]]
|
||||
kryo.writeObject(output, map.size.asInstanceOf[java.lang.Integer])
|
||||
output.writeInt(map.size)
|
||||
for ((k, v) <- map) {
|
||||
kryo.writeClassAndObject(output, k)
|
||||
kryo.writeClassAndObject(output, v)
|
||||
}
|
||||
}
|
||||
override def read (
|
||||
kryo: Kryo,
|
||||
input: KryoInput,
|
||||
cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
|
||||
kryo: Kryo,
|
||||
input: KryoInput,
|
||||
cls: Class[Array[(Any, Any)] => scala.collection.Map[Any, Any]])
|
||||
: Array[(Any, Any)] => scala.collection.Map[Any, Any] = {
|
||||
val size = kryo.readObject(input, classOf[java.lang.Integer]).intValue
|
||||
kryo.reference(FAKE_REFERENCE)
|
||||
val size = input.readInt()
|
||||
val elems = new Array[(Any, Any)](size)
|
||||
for (i <- 0 until size)
|
||||
elems(i) = (kryo.readClassAndObject(input), kryo.readClassAndObject(input))
|
||||
for (i <- 0 until size) {
|
||||
val k = kryo.readClassAndObject(input)
|
||||
val v = kryo.readClassAndObject(input)
|
||||
elems(i)=(k,v)
|
||||
}
|
||||
buildMap(elems).asInstanceOf[Array[(Any, Any)] => scala.collection.Map[Any, Any]]
|
||||
}
|
||||
}
|
||||
|
|
|
@ -68,6 +68,10 @@ trait Logging {
|
|||
if (log.isErrorEnabled) log.error(msg, throwable)
|
||||
}
|
||||
|
||||
protected def isTraceEnabled(): Boolean = {
|
||||
log.isTraceEnabled
|
||||
}
|
||||
|
||||
// Method for ensuring that logging is initialized, to avoid having multiple
|
||||
// threads do it concurrently (as SLF4J initialization is not thread safe).
|
||||
protected def initLogging() { log }
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package spark
|
||||
|
||||
import java.io._
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
|
@ -12,8 +11,7 @@ import akka.dispatch._
|
|||
import akka.pattern.ask
|
||||
import akka.remote._
|
||||
import akka.util.Duration
|
||||
import akka.util.Timeout
|
||||
import akka.util.duration._
|
||||
|
||||
|
||||
import spark.scheduler.MapStatus
|
||||
import spark.storage.BlockManagerId
|
||||
|
@ -38,11 +36,14 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
|
|||
}
|
||||
}
|
||||
|
||||
private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging {
|
||||
private[spark] class MapOutputTracker extends Logging {
|
||||
|
||||
val timeout = 10.seconds
|
||||
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
|
||||
|
||||
// Set to the MapOutputTrackerActor living on the driver
|
||||
var trackerActor: ActorRef = _
|
||||
|
||||
var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
|
||||
// Incremented every time a fetch fails so that client nodes know to clear
|
||||
// their cache of map output locations if this happens.
|
||||
|
@ -51,19 +52,7 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
|
|||
|
||||
// Cache a serialized version of the output statuses for each shuffle to send them out faster
|
||||
var cacheGeneration = generation
|
||||
val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
|
||||
|
||||
val actorName: String = "MapOutputTracker"
|
||||
var trackerActor: ActorRef = if (isDriver) {
|
||||
val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
|
||||
logInfo("Registered MapOutputTrackerActor actor")
|
||||
actor
|
||||
} else {
|
||||
val ip = System.getProperty("spark.driver.host", "localhost")
|
||||
val port = System.getProperty("spark.driver.port", "7077").toInt
|
||||
val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
|
||||
actorSystem.actorFor(url)
|
||||
}
|
||||
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
|
||||
|
||||
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
|
||||
|
||||
|
@ -87,10 +76,9 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
|
|||
}
|
||||
|
||||
def registerShuffle(shuffleId: Int, numMaps: Int) {
|
||||
if (mapStatuses.get(shuffleId) != None) {
|
||||
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
|
||||
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
|
||||
}
|
||||
mapStatuses.put(shuffleId, new Array[MapStatus](numMaps))
|
||||
}
|
||||
|
||||
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
|
||||
|
@ -111,8 +99,9 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
|
|||
}
|
||||
|
||||
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
|
||||
var array = mapStatuses(shuffleId)
|
||||
if (array != null) {
|
||||
var arrayOpt = mapStatuses.get(shuffleId)
|
||||
if (arrayOpt.isDefined && arrayOpt.get != null) {
|
||||
var array = arrayOpt.get
|
||||
array.synchronized {
|
||||
if (array(mapId) != null && array(mapId).location == bmAddress) {
|
||||
array(mapId) = null
|
||||
|
@ -125,13 +114,14 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
|
|||
}
|
||||
|
||||
// Remembers which map output locations are currently being fetched on a worker
|
||||
val fetching = new HashSet[Int]
|
||||
private val fetching = new HashSet[Int]
|
||||
|
||||
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
|
||||
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
|
||||
val statuses = mapStatuses.get(shuffleId).orNull
|
||||
if (statuses == null) {
|
||||
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
|
||||
var fetchedStatuses: Array[MapStatus] = null
|
||||
fetching.synchronized {
|
||||
if (fetching.contains(shuffleId)) {
|
||||
// Someone else is fetching it; wait for them to be done
|
||||
|
@ -142,31 +132,48 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
|
|||
case e: InterruptedException =>
|
||||
}
|
||||
}
|
||||
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, mapStatuses(shuffleId))
|
||||
} else {
|
||||
}
|
||||
|
||||
// Either while we waited the fetch happened successfully, or
|
||||
// someone fetched it in between the get and the fetching.synchronized.
|
||||
fetchedStatuses = mapStatuses.get(shuffleId).orNull
|
||||
if (fetchedStatuses == null) {
|
||||
// We have to do the fetch, get others to wait for us.
|
||||
fetching += shuffleId
|
||||
}
|
||||
}
|
||||
// We won the race to fetch the output locs; do so
|
||||
logInfo("Doing the fetch; tracker actor = " + trackerActor)
|
||||
val host = System.getProperty("spark.hostname", Utils.localHostName)
|
||||
// This try-finally prevents hangs due to timeouts:
|
||||
var fetchedStatuses: Array[MapStatus] = null
|
||||
try {
|
||||
val fetchedBytes =
|
||||
askTracker(GetMapOutputStatuses(shuffleId, host)).asInstanceOf[Array[Byte]]
|
||||
fetchedStatuses = deserializeStatuses(fetchedBytes)
|
||||
logInfo("Got the output locations")
|
||||
mapStatuses.put(shuffleId, fetchedStatuses)
|
||||
} finally {
|
||||
fetching.synchronized {
|
||||
fetching -= shuffleId
|
||||
fetching.notifyAll()
|
||||
|
||||
if (fetchedStatuses == null) {
|
||||
// We won the race to fetch the output locs; do so
|
||||
logInfo("Doing the fetch; tracker actor = " + trackerActor)
|
||||
val hostPort = Utils.localHostPort()
|
||||
// This try-finally prevents hangs due to timeouts:
|
||||
try {
|
||||
val fetchedBytes =
|
||||
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
|
||||
fetchedStatuses = deserializeStatuses(fetchedBytes)
|
||||
logInfo("Got the output locations")
|
||||
mapStatuses.put(shuffleId, fetchedStatuses)
|
||||
} finally {
|
||||
fetching.synchronized {
|
||||
fetching -= shuffleId
|
||||
fetching.notifyAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
|
||||
if (fetchedStatuses != null) {
|
||||
fetchedStatuses.synchronized {
|
||||
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
|
||||
}
|
||||
}
|
||||
else{
|
||||
throw new FetchFailedException(null, shuffleId, -1, reduceId,
|
||||
new Exception("Missing all output locations for shuffle " + shuffleId))
|
||||
}
|
||||
} else {
|
||||
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
|
||||
statuses.synchronized {
|
||||
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -204,7 +211,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
|
|||
generationLock.synchronized {
|
||||
if (newGen > generation) {
|
||||
logInfo("Updating generation to " + newGen + " and clearing cache")
|
||||
mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
mapStatuses.clear()
|
||||
generation = newGen
|
||||
}
|
||||
}
|
||||
|
@ -242,10 +250,13 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
|
|||
// Serialize an array of map output locations into an efficient byte format so that we can send
|
||||
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
|
||||
// generally be pretty compressible because many map outputs will be on the same hostname.
|
||||
def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
|
||||
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
|
||||
val out = new ByteArrayOutputStream
|
||||
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
|
||||
objOut.writeObject(statuses)
|
||||
// Since statuses can be modified in parallel, sync on it
|
||||
statuses.synchronized {
|
||||
objOut.writeObject(statuses)
|
||||
}
|
||||
objOut.close()
|
||||
out.toByteArray
|
||||
}
|
||||
|
@ -253,7 +264,10 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
|
|||
// Opposite of serializeStatuses.
|
||||
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
|
||||
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
|
||||
objIn.readObject().asInstanceOf[Array[MapStatus]]
|
||||
objIn.readObject().
|
||||
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
|
||||
// comment this out - nulls could be due to missing location ?
|
||||
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -263,14 +277,11 @@ private[spark] object MapOutputTracker {
|
|||
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
|
||||
// any of the statuses is null (indicating a missing location due to a failed mapper),
|
||||
// throw a FetchFailedException.
|
||||
def convertMapStatuses(
|
||||
private def convertMapStatuses(
|
||||
shuffleId: Int,
|
||||
reduceId: Int,
|
||||
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
|
||||
if (statuses == null) {
|
||||
throw new FetchFailedException(null, shuffleId, -1, reduceId,
|
||||
new Exception("Missing all output locations for shuffle " + shuffleId))
|
||||
}
|
||||
assert (statuses != null)
|
||||
statuses.map {
|
||||
status =>
|
||||
if (status == null) {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package spark
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.util.{Date, HashMap => JHashMap}
|
||||
import java.text.SimpleDateFormat
|
||||
|
||||
|
@ -10,6 +11,8 @@ import scala.collection.JavaConversions._
|
|||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
import org.apache.hadoop.io.SequenceFile.CompressionType
|
||||
import org.apache.hadoop.mapred.FileOutputCommitter
|
||||
import org.apache.hadoop.mapred.FileOutputFormat
|
||||
import org.apache.hadoop.mapred.HadoopWriter
|
||||
|
@ -17,12 +20,13 @@ import org.apache.hadoop.mapred.JobConf
|
|||
import org.apache.hadoop.mapred.OutputFormat
|
||||
|
||||
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
|
||||
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext}
|
||||
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil}
|
||||
|
||||
import spark.partial.BoundedDouble
|
||||
import spark.partial.PartialResult
|
||||
import spark.rdd._
|
||||
import spark.SparkContext._
|
||||
import spark.Partitioner._
|
||||
|
||||
/**
|
||||
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
|
||||
|
@ -51,7 +55,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
mergeValue: (C, V) => C,
|
||||
mergeCombiners: (C, C) => C,
|
||||
partitioner: Partitioner,
|
||||
mapSideCombine: Boolean = true): RDD[(K, C)] = {
|
||||
mapSideCombine: Boolean = true,
|
||||
serializerClass: String = null): RDD[(K, C)] = {
|
||||
if (getKeyClass().isArray) {
|
||||
if (mapSideCombine) {
|
||||
throw new SparkException("Cannot use map-side combining with array keys.")
|
||||
|
@ -60,19 +65,18 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
throw new SparkException("Default partitioner cannot partition array keys.")
|
||||
}
|
||||
}
|
||||
val aggregator =
|
||||
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
|
||||
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
|
||||
if (self.partitioner == Some(partitioner)) {
|
||||
self.mapPartitions(aggregator.combineValuesByKey(_), true)
|
||||
} else if (mapSideCombine) {
|
||||
val mapSideCombined = self.mapPartitions(aggregator.combineValuesByKey(_), true)
|
||||
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner)
|
||||
val partitioned = new ShuffledRDD[K, C](mapSideCombined, partitioner, serializerClass)
|
||||
partitioned.mapPartitions(aggregator.combineCombinersByKey(_), true)
|
||||
} else {
|
||||
// Don't apply map-side combiner.
|
||||
// A sanity check to make sure mergeCombiners is not defined.
|
||||
assert(mergeCombiners == null)
|
||||
val values = new ShuffledRDD[K, V](self, partitioner)
|
||||
val values = new ShuffledRDD[K, V](self, partitioner, serializerClass)
|
||||
values.mapPartitions(aggregator.combineValuesByKey(_), true)
|
||||
}
|
||||
}
|
||||
|
@ -87,6 +91,42 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge the values for each key using an associative function and a neutral "zero value" which may
|
||||
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
|
||||
* list concatenation, 0 for addition, or 1 for multiplication.).
|
||||
*/
|
||||
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
|
||||
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
|
||||
val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
|
||||
val zeroArray = new Array[Byte](zeroBuffer.limit)
|
||||
zeroBuffer.get(zeroArray)
|
||||
|
||||
// When deserializing, use a lazy val to create just one instance of the serializer per task
|
||||
lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
|
||||
def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
|
||||
|
||||
combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge the values for each key using an associative function and a neutral "zero value" which may
|
||||
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
|
||||
* list concatenation, 0 for addition, or 1 for multiplication.).
|
||||
*/
|
||||
def foldByKey(zeroValue: V, numPartitions: Int)(func: (V, V) => V): RDD[(K, V)] = {
|
||||
foldByKey(zeroValue, new HashPartitioner(numPartitions))(func)
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge the values for each key using an associative function and a neutral "zero value" which may
|
||||
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
|
||||
* list concatenation, 0 for addition, or 1 for multiplication.).
|
||||
*/
|
||||
def foldByKey(zeroValue: V)(func: (V, V) => V): RDD[(K, V)] = {
|
||||
foldByKey(zeroValue, defaultPartitioner(self))(func)
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge the values for each key using an associative reduce function. This will also perform
|
||||
* the merging locally on each mapper before sending results to a reducer, similarly to a
|
||||
|
@ -156,11 +196,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
|
||||
*/
|
||||
def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
|
||||
// groupByKey shouldn't use map side combine because map side combine does not
|
||||
// reduce the amount of data shuffled and requires all map side data be inserted
|
||||
// into a hash table, leading to more objects in the old gen.
|
||||
def createCombiner(v: V) = ArrayBuffer(v)
|
||||
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
|
||||
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
|
||||
val bufs = combineByKey[ArrayBuffer[V]](
|
||||
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
|
||||
createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
|
||||
bufs.asInstanceOf[RDD[(K, Seq[V])]]
|
||||
}
|
||||
|
||||
|
@ -248,8 +290,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
}
|
||||
|
||||
/**
|
||||
* Simplified version of combineByKey that hash-partitions the resulting RDD using the default
|
||||
* parallelism level.
|
||||
* Simplified version of combineByKey that hash-partitions the resulting RDD using the
|
||||
* existing partitioner/parallelism level.
|
||||
*/
|
||||
def combineByKey[C](createCombiner: V => C, mergeValue: (C, V) => C, mergeCombiners: (C, C) => C)
|
||||
: RDD[(K, C)] = {
|
||||
|
@ -259,7 +301,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
/**
|
||||
* Merge the values for each key using an associative reduce function. This will also perform
|
||||
* the merging locally on each mapper before sending results to a reducer, similarly to a
|
||||
* "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level.
|
||||
* "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
|
||||
* parallelism level.
|
||||
*/
|
||||
def reduceByKey(func: (V, V) => V): RDD[(K, V)] = {
|
||||
reduceByKey(defaultPartitioner(self), func)
|
||||
|
@ -267,7 +310,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
|
||||
/**
|
||||
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
|
||||
* resulting RDD with the default parallelism level.
|
||||
* resulting RDD with the existing partitioner/parallelism level.
|
||||
*/
|
||||
def groupByKey(): RDD[(K, Seq[V])] = {
|
||||
groupByKey(defaultPartitioner(self))
|
||||
|
@ -295,7 +338,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
|
||||
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
|
||||
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
|
||||
* using the default level of parallelism.
|
||||
* using the existing partitioner/parallelism level.
|
||||
*/
|
||||
def leftOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (V, Option[W]))] = {
|
||||
leftOuterJoin(other, defaultPartitioner(self, other))
|
||||
|
@ -315,7 +358,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
|
||||
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
|
||||
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
|
||||
* RDD using the default parallelism level.
|
||||
* RDD using the existing partitioner/parallelism level.
|
||||
*/
|
||||
def rightOuterJoin[W](other: RDD[(K, W)]): RDD[(K, (Option[V], W))] = {
|
||||
rightOuterJoin(other, defaultPartitioner(self, other))
|
||||
|
@ -439,15 +482,21 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
}
|
||||
|
||||
/**
|
||||
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs. If any of
|
||||
* the RDDs already has a partitioner, choose that one, otherwise use a default HashPartitioner.
|
||||
* Return an RDD with the pairs from `this` whose keys are not in `other`.
|
||||
*
|
||||
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
|
||||
* RDD will be <= us.
|
||||
*/
|
||||
def defaultPartitioner(rdds: RDD[_]*): Partitioner = {
|
||||
for (r <- rdds if r.partitioner != None) {
|
||||
return r.partitioner.get
|
||||
}
|
||||
return new HashPartitioner(self.context.defaultParallelism)
|
||||
}
|
||||
def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
|
||||
subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
|
||||
|
||||
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
|
||||
def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
|
||||
subtractByKey(other, new HashPartitioner(numPartitions))
|
||||
|
||||
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
|
||||
def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
|
||||
new SubtractedRDD[K, V, W](self, other, p)
|
||||
|
||||
/**
|
||||
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
|
||||
|
@ -479,6 +528,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
|
||||
}
|
||||
|
||||
/**
|
||||
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
|
||||
* supporting the key and value types K and V in this RDD. Compress the result with the
|
||||
* supplied codec.
|
||||
*/
|
||||
def saveAsHadoopFile[F <: OutputFormat[K, V]](
|
||||
path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
|
||||
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
|
||||
}
|
||||
|
||||
/**
|
||||
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
|
||||
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
|
||||
|
@ -510,8 +569,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
// around by taking a mod. We expect that no task will be attempted 2 billion times.
|
||||
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
|
||||
/* "reduce task" <split #> <attempt # = spark task #> */
|
||||
val attemptId = new TaskAttemptID(jobtrackerID,
|
||||
stageId, false, context.splitId, attemptNumber)
|
||||
val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
|
||||
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
|
||||
val format = outputFormatClass.newInstance
|
||||
val committer = format.getOutputCommitter(hadoopContext)
|
||||
|
@ -530,14 +588,29 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
* however we're only going to use this local OutputCommitter for
|
||||
* setupJob/commitJob, so we just use a dummy "map" task.
|
||||
*/
|
||||
val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, true, 0, 0)
|
||||
val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, true, 0, 0)
|
||||
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
|
||||
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
|
||||
jobCommitter.setupJob(jobTaskContext)
|
||||
val count = self.context.runJob(self, writeShard _).sum
|
||||
jobCommitter.commitJob(jobTaskContext)
|
||||
jobCommitter.cleanupJob(jobTaskContext)
|
||||
}
|
||||
|
||||
/**
|
||||
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
|
||||
* supporting the key and value types K and V in this RDD. Compress with the supplied codec.
|
||||
*/
|
||||
def saveAsHadoopFile(
|
||||
path: String,
|
||||
keyClass: Class[_],
|
||||
valueClass: Class[_],
|
||||
outputFormatClass: Class[_ <: OutputFormat[_, _]],
|
||||
codec: Class[_ <: CompressionCodec]) {
|
||||
saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass,
|
||||
new JobConf(self.context.hadoopConfiguration), Some(codec))
|
||||
}
|
||||
|
||||
/**
|
||||
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
|
||||
* supporting the key and value types K and V in this RDD.
|
||||
|
@ -547,11 +620,19 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
keyClass: Class[_],
|
||||
valueClass: Class[_],
|
||||
outputFormatClass: Class[_ <: OutputFormat[_, _]],
|
||||
conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
|
||||
conf: JobConf = new JobConf(self.context.hadoopConfiguration),
|
||||
codec: Option[Class[_ <: CompressionCodec]] = None) {
|
||||
conf.setOutputKeyClass(keyClass)
|
||||
conf.setOutputValueClass(valueClass)
|
||||
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
|
||||
conf.set("mapred.output.format.class", outputFormatClass.getName)
|
||||
for (c <- codec) {
|
||||
conf.setCompressMapOutput(true)
|
||||
conf.set("mapred.output.compress", "true")
|
||||
conf.setMapOutputCompressorClass(c)
|
||||
conf.set("mapred.output.compression.codec", c.getCanonicalName)
|
||||
conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString)
|
||||
}
|
||||
conf.setOutputCommitter(classOf[FileOutputCommitter])
|
||||
FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf))
|
||||
saveAsHadoopDataset(conf)
|
||||
|
@ -602,6 +683,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
}
|
||||
|
||||
self.context.runJob(self, writeToFile _)
|
||||
writer.commitJob()
|
||||
writer.cleanup()
|
||||
}
|
||||
|
||||
|
@ -609,7 +691,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
* Return an RDD with the keys of each tuple.
|
||||
*/
|
||||
def keys: RDD[K] = self.map(_._1)
|
||||
|
||||
|
||||
/**
|
||||
* Return an RDD with the values of each tuple.
|
||||
*/
|
||||
|
|
|
@ -9,6 +9,35 @@ abstract class Partitioner extends Serializable {
|
|||
def getPartition(key: Any): Int
|
||||
}
|
||||
|
||||
object Partitioner {
|
||||
/**
|
||||
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
|
||||
*
|
||||
* If any of the RDDs already has a partitioner, choose that one.
|
||||
*
|
||||
* Otherwise, we use a default HashPartitioner. For the number of partitions, if
|
||||
* spark.default.parallelism is set, then we'll use the value from SparkContext
|
||||
* defaultParallelism, otherwise we'll use the max number of upstream partitions.
|
||||
*
|
||||
* Unless spark.default.parallelism is set, He number of partitions will be the
|
||||
* same as the number of partitions in the largest upstream RDD, as this should
|
||||
* be least likely to cause out-of-memory errors.
|
||||
*
|
||||
* We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
|
||||
*/
|
||||
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
|
||||
val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
|
||||
for (r <- bySize if r.partitioner != None) {
|
||||
return r.partitioner.get
|
||||
}
|
||||
if (System.getProperty("spark.default.parallelism") != null) {
|
||||
return new HashPartitioner(rdd.context.defaultParallelism)
|
||||
} else {
|
||||
return new HashPartitioner(bySize.head.partitions.size)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A [[spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
|
||||
*
|
||||
|
|
|
@ -1,21 +1,21 @@
|
|||
package spark
|
||||
|
||||
import java.net.URL
|
||||
import java.util.{Date, Random}
|
||||
import java.util.{HashMap => JHashMap}
|
||||
import java.util.Random
|
||||
|
||||
import scala.collection.Map
|
||||
import scala.collection.JavaConversions.mapAsScalaMap
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
import org.apache.hadoop.io.BytesWritable
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
import org.apache.hadoop.io.NullWritable
|
||||
import org.apache.hadoop.io.Text
|
||||
import org.apache.hadoop.mapred.TextOutputFormat
|
||||
|
||||
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
|
||||
|
||||
import spark.broadcast.Broadcast
|
||||
import spark.Partitioner._
|
||||
import spark.partial.BoundedDouble
|
||||
import spark.partial.CountEvaluator
|
||||
import spark.partial.GroupedCountEvaluator
|
||||
|
@ -30,10 +30,14 @@ import spark.rdd.MapPartitionsRDD
|
|||
import spark.rdd.MapPartitionsWithIndexRDD
|
||||
import spark.rdd.PipedRDD
|
||||
import spark.rdd.SampledRDD
|
||||
import spark.rdd.SubtractedRDD
|
||||
import spark.rdd.ShuffledRDD
|
||||
import spark.rdd.UnionRDD
|
||||
import spark.rdd.ZippedRDD
|
||||
import spark.rdd.ZippedPartitionsRDD2
|
||||
import spark.rdd.ZippedPartitionsRDD3
|
||||
import spark.rdd.ZippedPartitionsRDD4
|
||||
import spark.storage.StorageLevel
|
||||
import spark.util.BoundedPriorityQueue
|
||||
|
||||
import SparkContext._
|
||||
|
||||
|
@ -102,7 +106,7 @@ abstract class RDD[T: ClassManifest](
|
|||
// =======================================================================
|
||||
|
||||
/** A unique ID for this RDD (within its SparkContext). */
|
||||
val id = sc.newRddId()
|
||||
val id: Int = sc.newRddId()
|
||||
|
||||
/** A friendly name for this RDD */
|
||||
var name: String = null
|
||||
|
@ -113,9 +117,18 @@ abstract class RDD[T: ClassManifest](
|
|||
this
|
||||
}
|
||||
|
||||
/** User-defined generator of this RDD*/
|
||||
var generator = Utils.getCallSiteInfo.firstUserClass
|
||||
|
||||
/** Reset generator*/
|
||||
def setGenerator(_generator: String) = {
|
||||
generator = _generator
|
||||
}
|
||||
|
||||
/**
|
||||
* Set this RDD's storage level to persist its values across operations after the first time
|
||||
* it is computed. Can only be called once on each RDD.
|
||||
* it is computed. This can only be used to assign a new storage level if the RDD does not
|
||||
* have a storage level set yet..
|
||||
*/
|
||||
def persist(newLevel: StorageLevel): RDD[T] = {
|
||||
// TODO: Handle changes of StorageLevel
|
||||
|
@ -135,6 +148,20 @@ abstract class RDD[T: ClassManifest](
|
|||
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
|
||||
def cache(): RDD[T] = persist()
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
*
|
||||
* @param blocking Whether to block until all blocks are deleted.
|
||||
* @return This RDD.
|
||||
*/
|
||||
def unpersist(blocking: Boolean = true): RDD[T] = {
|
||||
logInfo("Removing RDD " + id + " from persistence list")
|
||||
sc.env.blockManager.master.removeRdd(id, blocking)
|
||||
sc.persistentRdds.remove(id)
|
||||
storageLevel = StorageLevel.NONE
|
||||
this
|
||||
}
|
||||
|
||||
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
|
||||
def getStorageLevel = storageLevel
|
||||
|
||||
|
@ -236,7 +263,14 @@ abstract class RDD[T: ClassManifest](
|
|||
/**
|
||||
* Return a new RDD that is reduced into `numPartitions` partitions.
|
||||
*/
|
||||
def coalesce(numPartitions: Int): RDD[T] = new CoalescedRDD(this, numPartitions)
|
||||
def coalesce(numPartitions: Int, shuffle: Boolean = false): RDD[T] = {
|
||||
if (shuffle) {
|
||||
// include a shuffle step so that our upstream tasks are still distributed
|
||||
new CoalescedRDD(new ShuffledRDD(map(x => (x, null)), new HashPartitioner(numPartitions)), numPartitions).keys
|
||||
} else {
|
||||
new CoalescedRDD(this, numPartitions)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a sampled subset of this RDD.
|
||||
|
@ -247,8 +281,8 @@ abstract class RDD[T: ClassManifest](
|
|||
def takeSample(withReplacement: Boolean, num: Int, seed: Int): Array[T] = {
|
||||
var fraction = 0.0
|
||||
var total = 0
|
||||
var multiplier = 3.0
|
||||
var initialCount = count()
|
||||
val multiplier = 3.0
|
||||
val initialCount = count()
|
||||
var maxSelected = 0
|
||||
|
||||
if (initialCount > Integer.MAX_VALUE - 1) {
|
||||
|
@ -300,19 +334,26 @@ abstract class RDD[T: ClassManifest](
|
|||
*/
|
||||
def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
|
||||
|
||||
/**
|
||||
* Return an RDD of grouped items.
|
||||
*/
|
||||
def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] =
|
||||
groupBy[K](f, defaultPartitioner(this))
|
||||
|
||||
/**
|
||||
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
|
||||
* mapping to that key.
|
||||
*/
|
||||
def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = {
|
||||
val cleanF = sc.clean(f)
|
||||
this.map(t => (cleanF(t), t)).groupByKey(numPartitions)
|
||||
}
|
||||
def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
|
||||
groupBy(f, new HashPartitioner(numPartitions))
|
||||
|
||||
/**
|
||||
* Return an RDD of grouped items.
|
||||
*/
|
||||
def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = groupBy[K](f, sc.defaultParallelism)
|
||||
def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
|
||||
val cleanF = sc.clean(f)
|
||||
this.map(t => (cleanF(t), t)).groupByKey(p)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return an RDD created by piping elements to a forked external process.
|
||||
|
@ -322,13 +363,36 @@ abstract class RDD[T: ClassManifest](
|
|||
/**
|
||||
* Return an RDD created by piping elements to a forked external process.
|
||||
*/
|
||||
def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
|
||||
def pipe(command: String, env: Map[String, String]): RDD[String] =
|
||||
new PipedRDD(this, command, env)
|
||||
|
||||
|
||||
/**
|
||||
* Return an RDD created by piping elements to a forked external process.
|
||||
* The print behavior can be customized by providing two functions.
|
||||
*
|
||||
* @param command command to run in forked process.
|
||||
* @param env environment variables to set.
|
||||
* @param printPipeContext Before piping elements, this function is called as an oppotunity
|
||||
* to pipe context data. Print line function (like out.println) will be
|
||||
* passed as printPipeContext's parameter.
|
||||
* @param printRDDElement Use this function to customize how to pipe elements. This function
|
||||
* will be called with each RDD element as the 1st parameter, and the
|
||||
* print line function (like out.println()) as the 2nd parameter.
|
||||
* An example of pipe the RDD data of groupBy() in a streaming way,
|
||||
* instead of constructing a huge String to concat all the elements:
|
||||
* def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
|
||||
* for (e <- record._2){f(e)}
|
||||
* @return the result RDD
|
||||
*/
|
||||
def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
|
||||
new PipedRDD(this, command, env)
|
||||
def pipe(
|
||||
command: Seq[String],
|
||||
env: Map[String, String] = Map(),
|
||||
printPipeContext: (String => Unit) => Unit = null,
|
||||
printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
|
||||
new PipedRDD(this, command, env,
|
||||
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
|
||||
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
|
||||
|
||||
/**
|
||||
* Return a new RDD by applying a function to each partition of this RDD.
|
||||
|
@ -350,12 +414,68 @@ abstract class RDD[T: ClassManifest](
|
|||
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
|
||||
* of the original partition.
|
||||
*/
|
||||
@deprecated("use mapPartitionsWithIndex")
|
||||
@deprecated("use mapPartitionsWithIndex", "0.7.0")
|
||||
def mapPartitionsWithSplit[U: ClassManifest](
|
||||
f: (Int, Iterator[T]) => Iterator[U],
|
||||
preservesPartitioning: Boolean = false): RDD[U] =
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
|
||||
|
||||
/**
|
||||
* Maps f over this RDD, where f takes an additional parameter of type A. This
|
||||
* additional parameter is produced by constructA, which is called in each
|
||||
* partition with the index of that partition.
|
||||
*/
|
||||
def mapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
|
||||
(f:(T, A) => U): RDD[U] = {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
|
||||
val a = constructA(index)
|
||||
iter.map(t => f(t, a))
|
||||
}
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
* FlatMaps f over this RDD, where f takes an additional parameter of type A. This
|
||||
* additional parameter is produced by constructA, which is called in each
|
||||
* partition with the index of that partition.
|
||||
*/
|
||||
def flatMapWith[A: ClassManifest, U: ClassManifest](constructA: Int => A, preservesPartitioning: Boolean = false)
|
||||
(f:(T, A) => Seq[U]): RDD[U] = {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
|
||||
val a = constructA(index)
|
||||
iter.flatMap(t => f(t, a))
|
||||
}
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies f to each element of this RDD, where f takes an additional parameter of type A.
|
||||
* This additional parameter is produced by constructA, which is called in each
|
||||
* partition with the index of that partition.
|
||||
*/
|
||||
def foreachWith[A: ClassManifest](constructA: Int => A)
|
||||
(f:(T, A) => Unit) {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
|
||||
val a = constructA(index)
|
||||
iter.map(t => {f(t, a); t})
|
||||
}
|
||||
(new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
|
||||
}
|
||||
|
||||
/**
|
||||
* Filters this RDD with p, where p takes an additional parameter of type A. This
|
||||
* additional parameter is produced by constructA, which is called in each
|
||||
* partition with the index of that partition.
|
||||
*/
|
||||
def filterWith[A: ClassManifest](constructA: Int => A)
|
||||
(p:(T, A) => Boolean): RDD[T] = {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
|
||||
val a = constructA(index)
|
||||
iter.filter(t => p(t, a))
|
||||
}
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
|
||||
}
|
||||
|
||||
/**
|
||||
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
|
||||
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of
|
||||
|
@ -364,6 +484,31 @@ abstract class RDD[T: ClassManifest](
|
|||
*/
|
||||
def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
|
||||
|
||||
/**
|
||||
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
|
||||
* applying a function to the zipped partitions. Assumes that all the RDDs have the
|
||||
* *same number of partitions*, but does *not* require them to have the same number
|
||||
* of elements in each partition.
|
||||
*/
|
||||
def zipPartitions[B: ClassManifest, V: ClassManifest](
|
||||
f: (Iterator[T], Iterator[B]) => Iterator[V],
|
||||
rdd2: RDD[B]): RDD[V] =
|
||||
new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2)
|
||||
|
||||
def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest](
|
||||
f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V],
|
||||
rdd2: RDD[B],
|
||||
rdd3: RDD[C]): RDD[V] =
|
||||
new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3)
|
||||
|
||||
def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest](
|
||||
f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
|
||||
rdd2: RDD[B],
|
||||
rdd3: RDD[C],
|
||||
rdd4: RDD[D]): RDD[V] =
|
||||
new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4)
|
||||
|
||||
|
||||
// Actions (launch a job to return a value to the user program)
|
||||
|
||||
/**
|
||||
|
@ -374,6 +519,14 @@ abstract class RDD[T: ClassManifest](
|
|||
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a function f to each partition of this RDD.
|
||||
*/
|
||||
def foreachPartition(f: Iterator[T] => Unit) {
|
||||
val cleanF = sc.clean(f)
|
||||
sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
|
||||
}
|
||||
|
||||
/**
|
||||
* Return an array that contains all of the elements in this RDD.
|
||||
*/
|
||||
|
@ -396,7 +549,7 @@ abstract class RDD[T: ClassManifest](
|
|||
|
||||
/**
|
||||
* Return an RDD with the elements from `this` that are not in `other`.
|
||||
*
|
||||
*
|
||||
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
|
||||
* RDD will be <= us.
|
||||
*/
|
||||
|
@ -412,7 +565,23 @@ abstract class RDD[T: ClassManifest](
|
|||
/**
|
||||
* Return an RDD with the elements from `this` that are not in `other`.
|
||||
*/
|
||||
def subtract(other: RDD[T], p: Partitioner): RDD[T] = new SubtractedRDD[T](this, other, p)
|
||||
def subtract(other: RDD[T], p: Partitioner): RDD[T] = {
|
||||
if (partitioner == Some(p)) {
|
||||
// Our partitioner knows how to handle T (which, since we have a partitioner, is
|
||||
// really (K, V)) so make a new Partitioner that will de-tuple our fake tuples
|
||||
val p2 = new Partitioner() {
|
||||
override def numPartitions = p.numPartitions
|
||||
override def getPartition(k: Any) = p.getPartition(k.asInstanceOf[(Any, _)]._1)
|
||||
}
|
||||
// Unfortunately, since we're making a new p2, we'll get ShuffleDependencies
|
||||
// anyway, and when calling .keys, will not have a partitioner set, even though
|
||||
// the SubtractedRDD will, thanks to p2's de-tupled partitioning, already be
|
||||
// partitioned by the right/real keys (e.g. p).
|
||||
this.map(x => (x, null)).subtractByKey(other.map((_, null)), p2).keys
|
||||
} else {
|
||||
this.map(x => (x, null)).subtractByKey(other.map((_, null)), p).keys
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
|
||||
|
@ -587,6 +756,24 @@ abstract class RDD[T: ClassManifest](
|
|||
case _ => throw new UnsupportedOperationException("empty collection")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the top K elements from this RDD as defined by
|
||||
* the specified implicit Ordering[T].
|
||||
* @param num the number of top elements to return
|
||||
* @param ord the implicit ordering for T
|
||||
* @return an array of top elements
|
||||
*/
|
||||
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
|
||||
mapPartitions { items =>
|
||||
val queue = new BoundedPriorityQueue[T](num)
|
||||
queue ++= items
|
||||
Iterator.single(queue)
|
||||
}.reduce { (queue1, queue2) =>
|
||||
queue1 ++= queue2
|
||||
queue1
|
||||
}.toArray
|
||||
}
|
||||
|
||||
/**
|
||||
* Save this RDD as a text file, using string representations of elements.
|
||||
*/
|
||||
|
@ -595,6 +782,14 @@ abstract class RDD[T: ClassManifest](
|
|||
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save this RDD as a compressed text file, using string representations of elements.
|
||||
*/
|
||||
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
|
||||
this.map(x => (NullWritable.get(), new Text(x.toString)))
|
||||
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save this RDD as a SequenceFile of serialized objects.
|
||||
*/
|
||||
|
@ -653,7 +848,7 @@ abstract class RDD[T: ClassManifest](
|
|||
private var storageLevel: StorageLevel = StorageLevel.NONE
|
||||
|
||||
/** Record user function generating this RDD. */
|
||||
private[spark] val origin = Utils.getSparkCallSite
|
||||
private[spark] val origin = Utils.formatSparkCallSite
|
||||
|
||||
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package spark
|
||||
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import rdd.{CheckpointRDD, CoalescedRDD}
|
||||
import scheduler.{ResultTask, ShuffleMapTask}
|
||||
|
||||
|
@ -62,14 +63,20 @@ private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
|
|||
}
|
||||
}
|
||||
|
||||
// Create the output path for the checkpoint
|
||||
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id)
|
||||
val fs = path.getFileSystem(new Configuration())
|
||||
if (!fs.mkdirs(path)) {
|
||||
throw new SparkException("Failed to create checkpoint path " + path)
|
||||
}
|
||||
|
||||
// Save to file, and reload it as an RDD
|
||||
val path = new Path(rdd.context.checkpointDir.get, "rdd-" + rdd.id).toString
|
||||
rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
|
||||
val newRDD = new CheckpointRDD[T](rdd.context, path)
|
||||
rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString) _)
|
||||
val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
|
||||
|
||||
// Change the dependencies and partitions of the RDD
|
||||
RDDCheckpointData.synchronized {
|
||||
cpFile = Some(path)
|
||||
cpFile = Some(path.toString)
|
||||
cpRDD = Some(newRDD)
|
||||
rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions
|
||||
cpState = Checkpointed
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
|
|||
import org.apache.hadoop.mapred.SequenceFileOutputFormat
|
||||
import org.apache.hadoop.mapred.OutputCommitter
|
||||
import org.apache.hadoop.mapred.FileOutputCommitter
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
import org.apache.hadoop.io.Writable
|
||||
import org.apache.hadoop.io.NullWritable
|
||||
import org.apache.hadoop.io.BytesWritable
|
||||
|
@ -36,17 +37,17 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
|
|||
self: RDD[(K, V)])
|
||||
extends Logging
|
||||
with Serializable {
|
||||
|
||||
|
||||
private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = {
|
||||
val c = {
|
||||
if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
|
||||
if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
|
||||
classManifest[T].erasure
|
||||
} else {
|
||||
// We get the type of the Writable class by looking at the apply method which converts
|
||||
// from T to Writable. Since we have two apply methods we filter out the one which
|
||||
// is of the form "java.lang.Object apply(java.lang.Object)"
|
||||
// is not of the form "java.lang.Object apply(java.lang.Object)"
|
||||
implicitly[T => Writable].getClass.getDeclaredMethods().filter(
|
||||
m => m.getReturnType().toString != "java.lang.Object" &&
|
||||
m => m.getReturnType().toString != "class java.lang.Object" &&
|
||||
m.getName() == "apply")(0).getReturnType
|
||||
|
||||
}
|
||||
|
@ -62,24 +63,28 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
|
|||
* byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
|
||||
* file system.
|
||||
*/
|
||||
def saveAsSequenceFile(path: String) {
|
||||
def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
|
||||
def anyToWritable[U <% Writable](u: U): Writable = u
|
||||
|
||||
val keyClass = getWritableClass[K]
|
||||
val valueClass = getWritableClass[V]
|
||||
val convertKey = !classOf[Writable].isAssignableFrom(self.getKeyClass)
|
||||
val convertValue = !classOf[Writable].isAssignableFrom(self.getValueClass)
|
||||
|
||||
logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" )
|
||||
|
||||
logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" )
|
||||
val format = classOf[SequenceFileOutputFormat[Writable, Writable]]
|
||||
val jobConf = new JobConf(self.context.hadoopConfiguration)
|
||||
if (!convertKey && !convertValue) {
|
||||
self.saveAsHadoopFile(path, keyClass, valueClass, format)
|
||||
self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec)
|
||||
} else if (!convertKey && convertValue) {
|
||||
self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
|
||||
self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(
|
||||
path, keyClass, valueClass, format, jobConf, codec)
|
||||
} else if (convertKey && !convertValue) {
|
||||
self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format)
|
||||
self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(
|
||||
path, keyClass, valueClass, format, jobConf, codec)
|
||||
} else if (convertKey && convertValue) {
|
||||
self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
|
||||
}
|
||||
self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(
|
||||
path, keyClass, valueClass, format, jobConf, codec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
package spark
|
||||
|
||||
import spark.executor.TaskMetrics
|
||||
import spark.serializer.Serializer
|
||||
|
||||
|
||||
private[spark] abstract class ShuffleFetcher {
|
||||
/**
|
||||
* Fetch the shuffle outputs for a given ShuffleDependency.
|
||||
* @return An iterator over the elements of the fetched shuffle outputs.
|
||||
*/
|
||||
def fetch[K, V](shuffleId: Int, reduceId: Int) : Iterator[(K, V)]
|
||||
def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
|
||||
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[(K,V)]
|
||||
|
||||
/** Stop the fetcher */
|
||||
def stop() {}
|
||||
|
|
|
@ -198,7 +198,7 @@ private[spark] object SizeEstimator extends Logging {
|
|||
val elem = JArray.get(array, index)
|
||||
size += SizeEstimator.estimate(elem, state.visited)
|
||||
}
|
||||
state.size += ((length / 100.0) * size).toLong
|
||||
state.size += ((length / (ARRAY_SAMPLE_SIZE * 1.0)) * size).toLong
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,52 +1,53 @@
|
|||
package spark
|
||||
|
||||
import java.io._
|
||||
import java.net.URI
|
||||
import java.util.Properties
|
||||
import java.util.concurrent.ConcurrentHashMap
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.net.{URI, URLClassLoader}
|
||||
import java.lang.ref.WeakReference
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.collection.Map
|
||||
import scala.collection.generic.Growable
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap}
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.util.DynamicVariable
|
||||
import scala.collection.mutable.{ConcurrentMap, HashMap}
|
||||
|
||||
import akka.actor.Actor
|
||||
import akka.actor.Actor._
|
||||
import org.apache.hadoop.fs.{FileUtil, Path}
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.mapred.InputFormat
|
||||
import org.apache.hadoop.mapred.SequenceFileInputFormat
|
||||
import org.apache.hadoop.io.Writable
|
||||
import org.apache.hadoop.io.IntWritable
|
||||
import org.apache.hadoop.io.LongWritable
|
||||
import org.apache.hadoop.io.FloatWritable
|
||||
import org.apache.hadoop.io.DoubleWritable
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.io.ArrayWritable
|
||||
import org.apache.hadoop.io.BooleanWritable
|
||||
import org.apache.hadoop.io.BytesWritable
|
||||
import org.apache.hadoop.io.ArrayWritable
|
||||
import org.apache.hadoop.io.DoubleWritable
|
||||
import org.apache.hadoop.io.FloatWritable
|
||||
import org.apache.hadoop.io.IntWritable
|
||||
import org.apache.hadoop.io.LongWritable
|
||||
import org.apache.hadoop.io.NullWritable
|
||||
import org.apache.hadoop.io.Text
|
||||
import org.apache.hadoop.io.Writable
|
||||
import org.apache.hadoop.mapred.FileInputFormat
|
||||
import org.apache.hadoop.mapred.InputFormat
|
||||
import org.apache.hadoop.mapred.JobConf
|
||||
import org.apache.hadoop.mapred.SequenceFileInputFormat
|
||||
import org.apache.hadoop.mapred.TextInputFormat
|
||||
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
|
||||
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
|
||||
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
|
||||
import org.apache.mesos.{Scheduler, MesosNativeLibrary}
|
||||
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
|
||||
|
||||
import spark.broadcast._
|
||||
import spark.deploy.LocalSparkCluster
|
||||
import spark.partial.ApproximateEvaluator
|
||||
import spark.partial.PartialResult
|
||||
import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
|
||||
import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
|
||||
import org.apache.mesos.MesosNativeLibrary
|
||||
|
||||
import spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
|
||||
import spark.partial.{ApproximateEvaluator, PartialResult}
|
||||
import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
|
||||
import spark.scheduler.{DAGScheduler, ResultTask, ShuffleMapTask, SparkListener, SplitInfo, Stage, StageInfo, TaskScheduler}
|
||||
import spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend, ClusterScheduler}
|
||||
import spark.scheduler.local.LocalScheduler
|
||||
import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
|
||||
import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
|
||||
import storage.BlockManagerUI
|
||||
import util.{MetadataCleaner, TimeStampedHashMap}
|
||||
import storage.{StorageStatus, StorageUtils, RDDInfo}
|
||||
import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo}
|
||||
import spark.util.{MetadataCleaner, TimeStampedHashMap}
|
||||
|
||||
/**
|
||||
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
|
||||
|
@ -64,7 +65,10 @@ class SparkContext(
|
|||
val appName: String,
|
||||
val sparkHome: String = null,
|
||||
val jars: Seq[String] = Nil,
|
||||
environment: Map[String, String] = Map())
|
||||
val environment: Map[String, String] = Map(),
|
||||
// This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too.
|
||||
// This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host
|
||||
val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map())
|
||||
extends Logging {
|
||||
|
||||
// Ensure logging is initialized before we spawn any threads
|
||||
|
@ -72,7 +76,7 @@ class SparkContext(
|
|||
|
||||
// Set Spark driver host and port system properties
|
||||
if (System.getProperty("spark.driver.host") == null) {
|
||||
System.setProperty("spark.driver.host", Utils.localIpAddress)
|
||||
System.setProperty("spark.driver.host", Utils.localHostName())
|
||||
}
|
||||
if (System.getProperty("spark.driver.port") == null) {
|
||||
System.setProperty("spark.driver.port", "0")
|
||||
|
@ -99,12 +103,14 @@ class SparkContext(
|
|||
private[spark] val addedJars = HashMap[String, Long]()
|
||||
|
||||
// Keeps track of all persisted RDDs
|
||||
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]()
|
||||
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
|
||||
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
|
||||
|
||||
|
||||
// Add each JAR given through the constructor
|
||||
jars.foreach { addJar(_) }
|
||||
if (jars != null) {
|
||||
jars.foreach { addJar(_) }
|
||||
}
|
||||
|
||||
// Environment variables to pass to our executors
|
||||
private[spark] val executorEnvs = HashMap[String, String]()
|
||||
|
@ -116,7 +122,9 @@ class SparkContext(
|
|||
executorEnvs(key) = value
|
||||
}
|
||||
}
|
||||
executorEnvs ++= environment
|
||||
if (environment != null) {
|
||||
executorEnvs ++= environment
|
||||
}
|
||||
|
||||
// Create and start the scheduler
|
||||
private var taskScheduler: TaskScheduler = {
|
||||
|
@ -169,6 +177,22 @@ class SparkContext(
|
|||
}
|
||||
scheduler
|
||||
|
||||
case "yarn-standalone" =>
|
||||
val scheduler = try {
|
||||
val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler")
|
||||
val cons = clazz.getConstructor(classOf[SparkContext])
|
||||
cons.newInstance(this).asInstanceOf[ClusterScheduler]
|
||||
} catch {
|
||||
// TODO: Enumerate the exact reasons why it can fail
|
||||
// But irrespective of it, it means we cannot proceed !
|
||||
case th: Throwable => {
|
||||
throw new SparkException("YARN mode not available ?", th)
|
||||
}
|
||||
}
|
||||
val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
|
||||
scheduler.initialize(backend)
|
||||
scheduler
|
||||
|
||||
case _ =>
|
||||
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
|
||||
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
|
||||
|
@ -188,12 +212,12 @@ class SparkContext(
|
|||
}
|
||||
taskScheduler.start()
|
||||
|
||||
private var dagScheduler = new DAGScheduler(taskScheduler)
|
||||
@volatile private var dagScheduler = new DAGScheduler(taskScheduler)
|
||||
dagScheduler.start()
|
||||
|
||||
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
|
||||
val hadoopConfiguration = {
|
||||
val conf = new Configuration()
|
||||
val conf = SparkHadoopUtil.newConfiguration()
|
||||
// Explicitly check for S3 environment variables
|
||||
if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
|
||||
conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
|
||||
|
@ -212,6 +236,22 @@ class SparkContext(
|
|||
|
||||
private[spark] var checkpointDir: Option[String] = None
|
||||
|
||||
// Thread Local variable that can be used by users to pass information down the stack
|
||||
private val localProperties = new DynamicVariable[Properties](null)
|
||||
|
||||
def initLocalProperties() {
|
||||
localProperties.value = new Properties()
|
||||
}
|
||||
|
||||
def addLocalProperties(key: String, value: String) {
|
||||
if(localProperties.value == null) {
|
||||
localProperties.value = new Properties()
|
||||
}
|
||||
localProperties.value.setProperty(key,value)
|
||||
}
|
||||
// Post init
|
||||
taskScheduler.postStartHook()
|
||||
|
||||
// Methods for creating RDDs
|
||||
|
||||
/** Distribute a local Scala collection to form an RDD. */
|
||||
|
@ -466,13 +506,17 @@ class SparkContext(
|
|||
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
|
||||
}
|
||||
|
||||
def addSparkListener(listener: SparkListener) {
|
||||
dagScheduler.sparkListeners += listener
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a map from the slave to the max memory available for caching and the remaining
|
||||
* memory available for caching.
|
||||
*/
|
||||
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
|
||||
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
|
||||
(blockManagerId.ip + ":" + blockManagerId.port, mem)
|
||||
(blockManagerId.host + ":" + blockManagerId.port, mem)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -480,14 +524,18 @@ class SparkContext(
|
|||
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
|
||||
* they take, etc.
|
||||
*/
|
||||
def getRDDStorageInfo : Array[RDDInfo] = {
|
||||
def getRDDStorageInfo: Array[RDDInfo] = {
|
||||
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
|
||||
}
|
||||
|
||||
def getStageInfo: Map[Stage,StageInfo] = {
|
||||
dagScheduler.stageToInfos
|
||||
}
|
||||
|
||||
/**
|
||||
* Return information about blocks stored in all of the slaves
|
||||
*/
|
||||
def getExecutorStorageStatus : Array[StorageStatus] = {
|
||||
def getExecutorStorageStatus: Array[StorageStatus] = {
|
||||
env.blockManager.master.getStorageStatus
|
||||
}
|
||||
|
||||
|
@ -505,13 +553,18 @@ class SparkContext(
|
|||
* filesystems), or an HTTP, HTTPS or FTP URI.
|
||||
*/
|
||||
def addJar(path: String) {
|
||||
val uri = new URI(path)
|
||||
val key = uri.getScheme match {
|
||||
case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
|
||||
case _ => path
|
||||
if (null == path) {
|
||||
logWarning("null specified as parameter to addJar",
|
||||
new SparkException("null specified as parameter to addJar"))
|
||||
} else {
|
||||
val uri = new URI(path)
|
||||
val key = uri.getScheme match {
|
||||
case null | "file" => env.httpFileServer.addJar(new File(uri.getPath))
|
||||
case _ => path
|
||||
}
|
||||
addedJars(key) = System.currentTimeMillis
|
||||
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
|
||||
}
|
||||
addedJars(key) = System.currentTimeMillis
|
||||
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -524,10 +577,13 @@ class SparkContext(
|
|||
|
||||
/** Shut down the SparkContext. */
|
||||
def stop() {
|
||||
if (dagScheduler != null) {
|
||||
// Do this only if not stopped already - best case effort.
|
||||
// prevent NPE if stopped more than once.
|
||||
val dagSchedulerCopy = dagScheduler
|
||||
dagScheduler = null
|
||||
if (dagSchedulerCopy != null) {
|
||||
metadataCleaner.cancel()
|
||||
dagScheduler.stop()
|
||||
dagScheduler = null
|
||||
dagSchedulerCopy.stop()
|
||||
taskScheduler = null
|
||||
// TODO: Cache.stop()?
|
||||
env.stop()
|
||||
|
@ -543,6 +599,7 @@ class SparkContext(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get Spark's home location from either a value set through the constructor,
|
||||
* or the spark.home Java property, or the SPARK_HOME environment variable
|
||||
|
@ -572,10 +629,10 @@ class SparkContext(
|
|||
partitions: Seq[Int],
|
||||
allowLocal: Boolean,
|
||||
resultHandler: (Int, U) => Unit) {
|
||||
val callSite = Utils.getSparkCallSite
|
||||
val callSite = Utils.formatSparkCallSite
|
||||
logInfo("Starting job: " + callSite)
|
||||
val start = System.nanoTime
|
||||
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler)
|
||||
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value)
|
||||
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
|
||||
rdd.doCheckpoint()
|
||||
result
|
||||
|
@ -654,12 +711,11 @@ class SparkContext(
|
|||
rdd: RDD[T],
|
||||
func: (TaskContext, Iterator[T]) => U,
|
||||
evaluator: ApproximateEvaluator[U, R],
|
||||
timeout: Long
|
||||
): PartialResult[R] = {
|
||||
val callSite = Utils.getSparkCallSite
|
||||
timeout: Long): PartialResult[R] = {
|
||||
val callSite = Utils.formatSparkCallSite
|
||||
logInfo("Starting job: " + callSite)
|
||||
val start = System.nanoTime
|
||||
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout)
|
||||
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
|
||||
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
|
||||
result
|
||||
}
|
||||
|
@ -682,7 +738,7 @@ class SparkContext(
|
|||
*/
|
||||
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
|
||||
val path = new Path(dir)
|
||||
val fs = path.getFileSystem(new Configuration())
|
||||
val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
|
||||
if (!useExisting) {
|
||||
if (fs.exists(path)) {
|
||||
throw new Exception("Checkpoint directory '" + path + "' already exists.")
|
||||
|
@ -693,7 +749,7 @@ class SparkContext(
|
|||
checkpointDir = Some(dir)
|
||||
}
|
||||
|
||||
/** Default level of parallelism to use when not given by user (e.g. for reduce tasks) */
|
||||
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
|
||||
def defaultParallelism: Int = taskScheduler.defaultParallelism
|
||||
|
||||
/** Default min number of partitions for Hadoop RDDs when not given by user */
|
||||
|
|
|
@ -1,15 +1,19 @@
|
|||
package spark
|
||||
|
||||
import akka.actor.ActorSystem
|
||||
import akka.actor.ActorSystemImpl
|
||||
import collection.mutable
|
||||
import serializer.Serializer
|
||||
|
||||
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
|
||||
import akka.remote.RemoteActorRefProvider
|
||||
|
||||
import serializer.Serializer
|
||||
import spark.broadcast.BroadcastManager
|
||||
import spark.storage.BlockManager
|
||||
import spark.storage.BlockManagerMaster
|
||||
import spark.network.ConnectionManager
|
||||
import spark.serializer.{Serializer, SerializerManager}
|
||||
import spark.util.AkkaUtils
|
||||
import spark.api.python.PythonWorkerFactory
|
||||
|
||||
|
||||
/**
|
||||
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
|
||||
|
@ -21,6 +25,7 @@ import spark.util.AkkaUtils
|
|||
class SparkEnv (
|
||||
val executorId: String,
|
||||
val actorSystem: ActorSystem,
|
||||
val serializerManager: SerializerManager,
|
||||
val serializer: Serializer,
|
||||
val closureSerializer: Serializer,
|
||||
val cacheManager: CacheManager,
|
||||
|
@ -30,10 +35,16 @@ class SparkEnv (
|
|||
val blockManager: BlockManager,
|
||||
val connectionManager: ConnectionManager,
|
||||
val httpFileServer: HttpFileServer,
|
||||
val sparkFilesDir: String
|
||||
) {
|
||||
val sparkFilesDir: String,
|
||||
// To be set only as part of initialization of SparkContext.
|
||||
// (executorId, defaultHostPort) => executorHostPort
|
||||
// If executorId is NOT found, return defaultHostPort
|
||||
var executorIdToHostPort: Option[(String, String) => String]) {
|
||||
|
||||
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
|
||||
|
||||
def stop() {
|
||||
pythonWorkers.foreach { case(key, worker) => worker.stop() }
|
||||
httpFileServer.stop()
|
||||
mapOutputTracker.stop()
|
||||
shuffleFetcher.stop()
|
||||
|
@ -45,6 +56,22 @@ class SparkEnv (
|
|||
// down, but let's call it anyway in case it gets fixed in a later release
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
||||
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
|
||||
synchronized {
|
||||
pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create()
|
||||
}
|
||||
}
|
||||
|
||||
def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
|
||||
val env = SparkEnv.get
|
||||
if (env.executorIdToHostPort.isEmpty) {
|
||||
// default to using host, not host port. Relevant to non cluster modes.
|
||||
return defaultHostPort
|
||||
}
|
||||
|
||||
env.executorIdToHostPort.get(executorId, defaultHostPort)
|
||||
}
|
||||
}
|
||||
|
||||
object SparkEnv extends Logging {
|
||||
|
@ -73,6 +100,16 @@ object SparkEnv extends Logging {
|
|||
System.setProperty("spark.driver.port", boundPort.toString)
|
||||
}
|
||||
|
||||
// set only if unset until now.
|
||||
if (System.getProperty("spark.hostPort", null) == null) {
|
||||
if (!isDriver){
|
||||
// unexpected
|
||||
Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
|
||||
}
|
||||
Utils.checkHost(hostname)
|
||||
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
|
||||
}
|
||||
|
||||
val classLoader = Thread.currentThread.getContextClassLoader
|
||||
|
||||
// Create an instance of the class named by the given Java system property, or by
|
||||
|
@ -82,24 +119,45 @@ object SparkEnv extends Logging {
|
|||
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
|
||||
}
|
||||
|
||||
val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
|
||||
val serializerManager = new SerializerManager
|
||||
|
||||
val driverIp: String = System.getProperty("spark.driver.host", "localhost")
|
||||
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
|
||||
val blockManagerMaster = new BlockManagerMaster(
|
||||
actorSystem, isDriver, isLocal, driverIp, driverPort)
|
||||
val serializer = serializerManager.setDefault(
|
||||
System.getProperty("spark.serializer", "spark.JavaSerializer"))
|
||||
|
||||
val closureSerializer = serializerManager.get(
|
||||
System.getProperty("spark.closure.serializer", "spark.JavaSerializer"))
|
||||
|
||||
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
|
||||
if (isDriver) {
|
||||
logInfo("Registering " + name)
|
||||
actorSystem.actorOf(Props(newActor), name = name)
|
||||
} else {
|
||||
val driverHost: String = System.getProperty("spark.driver.host", "localhost")
|
||||
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
|
||||
Utils.checkHost(driverHost, "Expected hostname")
|
||||
val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
|
||||
logInfo("Connecting to " + name + ": " + url)
|
||||
actorSystem.actorFor(url)
|
||||
}
|
||||
}
|
||||
|
||||
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
|
||||
"BlockManagerMaster",
|
||||
new spark.storage.BlockManagerMasterActor(isLocal)))
|
||||
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
|
||||
|
||||
val connectionManager = blockManager.connectionManager
|
||||
|
||||
val broadcastManager = new BroadcastManager(isDriver)
|
||||
|
||||
val closureSerializer = instantiateClass[Serializer](
|
||||
"spark.closure.serializer", "spark.JavaSerializer")
|
||||
|
||||
val cacheManager = new CacheManager(blockManager)
|
||||
|
||||
val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver)
|
||||
// Have to assign trackerActor after initialization as MapOutputTrackerActor
|
||||
// requires the MapOutputTracker itself
|
||||
val mapOutputTracker = new MapOutputTracker()
|
||||
mapOutputTracker.trackerActor = registerOrLookup(
|
||||
"MapOutputTracker",
|
||||
new MapOutputTrackerActor(mapOutputTracker))
|
||||
|
||||
val shuffleFetcher = instantiateClass[ShuffleFetcher](
|
||||
"spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
|
||||
|
@ -126,6 +184,7 @@ object SparkEnv extends Logging {
|
|||
new SparkEnv(
|
||||
executorId,
|
||||
actorSystem,
|
||||
serializerManager,
|
||||
serializer,
|
||||
closureSerializer,
|
||||
cacheManager,
|
||||
|
@ -135,6 +194,7 @@ object SparkEnv extends Logging {
|
|||
blockManager,
|
||||
connectionManager,
|
||||
httpFileServer,
|
||||
sparkFilesDir)
|
||||
sparkFilesDir,
|
||||
None)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,9 +1,14 @@
|
|||
package spark
|
||||
|
||||
import executor.TaskMetrics
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
|
||||
class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
|
||||
class TaskContext(
|
||||
val stageId: Int,
|
||||
val splitId: Int,
|
||||
val attemptId: Long,
|
||||
val taskMetrics: TaskMetrics = TaskMetrics.empty()
|
||||
) extends Serializable {
|
||||
|
||||
@transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
|
||||
|
||||
|
|
|
@ -14,9 +14,19 @@ private[spark] case object Success extends TaskEndReason
|
|||
private[spark]
|
||||
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
|
||||
|
||||
private[spark]
|
||||
case class FetchFailed(bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int) extends TaskEndReason
|
||||
private[spark] case class FetchFailed(
|
||||
bmAddress: BlockManagerId,
|
||||
shuffleId: Int,
|
||||
mapId: Int,
|
||||
reduceId: Int)
|
||||
extends TaskEndReason
|
||||
|
||||
private[spark] case class ExceptionFailure(exception: Throwable) extends TaskEndReason
|
||||
private[spark] case class ExceptionFailure(
|
||||
className: String,
|
||||
description: String,
|
||||
stackTrace: Array[StackTraceElement])
|
||||
extends TaskEndReason
|
||||
|
||||
private[spark] case class OtherFailure(message: String) extends TaskEndReason
|
||||
|
||||
private[spark] case class TaskResultTooBigFailure() extends TaskEndReason
|
||||
|
|
|
@ -1,23 +1,29 @@
|
|||
package spark
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address, ServerSocket}
|
||||
import java.util.{Locale, Random, UUID}
|
||||
import java.util.concurrent.{Executors, ThreadFactory, ThreadPoolExecutor}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadFactory, ThreadPoolExecutor}
|
||||
import java.util.regex.Pattern
|
||||
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap}
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.io.Source
|
||||
|
||||
import com.google.common.io.Files
|
||||
import com.google.common.util.concurrent.ThreadFactoryBuilder
|
||||
import scala.Some
|
||||
|
||||
import org.apache.hadoop.fs.{Path, FileSystem, FileUtil}
|
||||
|
||||
import spark.serializer.SerializerInstance
|
||||
import spark.deploy.SparkHadoopUtil
|
||||
|
||||
|
||||
/**
|
||||
* Various utility methods used by Spark.
|
||||
*/
|
||||
private object Utils extends Logging {
|
||||
|
||||
/** Serialize an object using Java serialization */
|
||||
def serialize[T](o: T): Array[Byte] = {
|
||||
val bos = new ByteArrayOutputStream()
|
||||
|
@ -68,6 +74,40 @@ private object Utils extends Logging {
|
|||
return buf
|
||||
}
|
||||
|
||||
private val shutdownDeletePaths = new collection.mutable.HashSet[String]()
|
||||
|
||||
// Register the path to be deleted via shutdown hook
|
||||
def registerShutdownDeleteDir(file: File) {
|
||||
val absolutePath = file.getAbsolutePath()
|
||||
shutdownDeletePaths.synchronized {
|
||||
shutdownDeletePaths += absolutePath
|
||||
}
|
||||
}
|
||||
|
||||
// Is the path already registered to be deleted via a shutdown hook ?
|
||||
def hasShutdownDeleteDir(file: File): Boolean = {
|
||||
val absolutePath = file.getAbsolutePath()
|
||||
shutdownDeletePaths.synchronized {
|
||||
shutdownDeletePaths.contains(absolutePath)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: if file is child of some registered path, while not equal to it, then return true;
|
||||
// else false. This is to ensure that two shutdown hooks do not try to delete each others
|
||||
// paths - resulting in IOException and incomplete cleanup.
|
||||
def hasRootAsShutdownDeleteDir(file: File): Boolean = {
|
||||
val absolutePath = file.getAbsolutePath()
|
||||
val retval = shutdownDeletePaths.synchronized {
|
||||
shutdownDeletePaths.find { path =>
|
||||
!absolutePath.equals(path) && absolutePath.startsWith(path)
|
||||
}.isDefined
|
||||
}
|
||||
if (retval) {
|
||||
logInfo("path = " + file + ", already present as root for deletion.")
|
||||
}
|
||||
retval
|
||||
}
|
||||
|
||||
/** Create a temporary directory inside the given parent directory */
|
||||
def createTempDir(root: String = System.getProperty("java.io.tmpdir")): File = {
|
||||
var attempts = 0
|
||||
|
@ -76,8 +116,8 @@ private object Utils extends Logging {
|
|||
while (dir == null) {
|
||||
attempts += 1
|
||||
if (attempts > maxAttempts) {
|
||||
throw new IOException("Failed to create a temp directory after " + maxAttempts +
|
||||
" attempts!")
|
||||
throw new IOException("Failed to create a temp directory (under " + root + ") after " +
|
||||
maxAttempts + " attempts!")
|
||||
}
|
||||
try {
|
||||
dir = new File(root, "spark-" + UUID.randomUUID.toString)
|
||||
|
@ -86,13 +126,17 @@ private object Utils extends Logging {
|
|||
}
|
||||
} catch { case e: IOException => ; }
|
||||
}
|
||||
|
||||
registerShutdownDeleteDir(dir)
|
||||
|
||||
// Add a shutdown hook to delete the temp dir when the JVM exits
|
||||
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark temp dir " + dir) {
|
||||
override def run() {
|
||||
Utils.deleteRecursively(dir)
|
||||
// Attempt to delete if some patch which is parent of this is not already registered.
|
||||
if (! hasRootAsShutdownDeleteDir(dir)) Utils.deleteRecursively(dir)
|
||||
}
|
||||
})
|
||||
return dir
|
||||
dir
|
||||
}
|
||||
|
||||
/** Copy all data from an InputStream to an OutputStream */
|
||||
|
@ -135,40 +179,35 @@ private object Utils extends Logging {
|
|||
Utils.copyStream(in, out, true)
|
||||
if (targetFile.exists && !Files.equal(tempFile, targetFile)) {
|
||||
tempFile.delete()
|
||||
throw new SparkException("File " + targetFile + " exists and does not match contents of" +
|
||||
" " + url)
|
||||
throw new SparkException(
|
||||
"File " + targetFile + " exists and does not match contents of" + " " + url)
|
||||
} else {
|
||||
Files.move(tempFile, targetFile)
|
||||
}
|
||||
case "file" | null =>
|
||||
val sourceFile = if (uri.isAbsolute) {
|
||||
new File(uri)
|
||||
} else {
|
||||
new File(url)
|
||||
}
|
||||
if (targetFile.exists && !Files.equal(sourceFile, targetFile)) {
|
||||
throw new SparkException("File " + targetFile + " exists and does not match contents of" +
|
||||
" " + url)
|
||||
} else {
|
||||
// Remove the file if it already exists
|
||||
targetFile.delete()
|
||||
// Symlink the file locally.
|
||||
if (uri.isAbsolute) {
|
||||
// url is absolute, i.e. it starts with "file:///". Extract the source
|
||||
// file's absolute path from the url.
|
||||
val sourceFile = new File(uri)
|
||||
logInfo("Symlinking " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
|
||||
FileUtil.symLink(sourceFile.getAbsolutePath, targetFile.getAbsolutePath)
|
||||
// In the case of a local file, copy the local file to the target directory.
|
||||
// Note the difference between uri vs url.
|
||||
val sourceFile = if (uri.isAbsolute) new File(uri) else new File(url)
|
||||
if (targetFile.exists) {
|
||||
// If the target file already exists, warn the user if
|
||||
if (!Files.equal(sourceFile, targetFile)) {
|
||||
throw new SparkException(
|
||||
"File " + targetFile + " exists and does not match contents of" + " " + url)
|
||||
} else {
|
||||
// url is not absolute, i.e. itself is the path to the source file.
|
||||
logInfo("Symlinking " + url + " to " + targetFile.getAbsolutePath)
|
||||
FileUtil.symLink(url, targetFile.getAbsolutePath)
|
||||
// Do nothing if the file contents are the same, i.e. this file has been copied
|
||||
// previously.
|
||||
logInfo(sourceFile.getAbsolutePath + " has been previously copied to "
|
||||
+ targetFile.getAbsolutePath)
|
||||
}
|
||||
} else {
|
||||
// The file does not exist in the target directory. Copy it there.
|
||||
logInfo("Copying " + sourceFile.getAbsolutePath + " to " + targetFile.getAbsolutePath)
|
||||
Files.copy(sourceFile, targetFile)
|
||||
}
|
||||
case _ =>
|
||||
// Use the Hadoop filesystem library, which supports file://, hdfs://, s3://, and others
|
||||
val uri = new URI(url)
|
||||
val conf = new Configuration()
|
||||
val conf = SparkHadoopUtil.newConfiguration()
|
||||
val fs = FileSystem.get(uri, conf)
|
||||
val in = fs.open(new Path(uri))
|
||||
val out = new FileOutputStream(tempFile)
|
||||
|
@ -227,8 +266,10 @@ private object Utils extends Logging {
|
|||
|
||||
/**
|
||||
* Get the local host's IP address in dotted-quad format (e.g. 1.2.3.4).
|
||||
* Note, this is typically not used from within core spark.
|
||||
*/
|
||||
lazy val localIpAddress: String = findLocalIpAddress()
|
||||
lazy val localIpAddressHostname: String = getAddressHostName(localIpAddress)
|
||||
|
||||
private def findLocalIpAddress(): String = {
|
||||
val defaultIpOverride = System.getenv("SPARK_LOCAL_IP")
|
||||
|
@ -266,6 +307,8 @@ private object Utils extends Logging {
|
|||
* hostname it reports to the master.
|
||||
*/
|
||||
def setCustomHostname(hostname: String) {
|
||||
// DEBUG code
|
||||
Utils.checkHost(hostname)
|
||||
customHostname = Some(hostname)
|
||||
}
|
||||
|
||||
|
@ -273,7 +316,91 @@ private object Utils extends Logging {
|
|||
* Get the local machine's hostname.
|
||||
*/
|
||||
def localHostName(): String = {
|
||||
customHostname.getOrElse(InetAddress.getLocalHost.getHostName)
|
||||
customHostname.getOrElse(localIpAddressHostname)
|
||||
}
|
||||
|
||||
def getAddressHostName(address: String): String = {
|
||||
InetAddress.getByName(address).getHostName
|
||||
}
|
||||
|
||||
def localHostPort(): String = {
|
||||
val retval = System.getProperty("spark.hostPort", null)
|
||||
if (retval == null) {
|
||||
logErrorWithStack("spark.hostPort not set but invoking localHostPort")
|
||||
return localHostName()
|
||||
}
|
||||
|
||||
retval
|
||||
}
|
||||
|
||||
/*
|
||||
// Used by DEBUG code : remove when all testing done
|
||||
private val ipPattern = Pattern.compile("^[0-9]+(\\.[0-9]+)*$")
|
||||
def checkHost(host: String, message: String = "") {
|
||||
// Currently catches only ipv4 pattern, this is just a debugging tool - not rigourous !
|
||||
// if (host.matches("^[0-9]+(\\.[0-9]+)*$")) {
|
||||
if (ipPattern.matcher(host).matches()) {
|
||||
Utils.logErrorWithStack("Unexpected to have host " + host + " which matches IP pattern. Message " + message)
|
||||
}
|
||||
if (Utils.parseHostPort(host)._2 != 0){
|
||||
Utils.logErrorWithStack("Unexpected to have host " + host + " which has port in it. Message " + message)
|
||||
}
|
||||
}
|
||||
|
||||
// Used by DEBUG code : remove when all testing done
|
||||
def checkHostPort(hostPort: String, message: String = "") {
|
||||
val (host, port) = Utils.parseHostPort(hostPort)
|
||||
checkHost(host)
|
||||
if (port <= 0){
|
||||
Utils.logErrorWithStack("Unexpected to have port " + port + " which is not valid in " + hostPort + ". Message " + message)
|
||||
}
|
||||
}
|
||||
|
||||
// Used by DEBUG code : remove when all testing done
|
||||
def logErrorWithStack(msg: String) {
|
||||
try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
|
||||
// temp code for debug
|
||||
System.exit(-1)
|
||||
}
|
||||
*/
|
||||
|
||||
// Once testing is complete in various modes, replace with this ?
|
||||
def checkHost(host: String, message: String = "") {}
|
||||
def checkHostPort(hostPort: String, message: String = "") {}
|
||||
|
||||
// Used by DEBUG code : remove when all testing done
|
||||
def logErrorWithStack(msg: String) {
|
||||
try { throw new Exception } catch { case ex: Exception => { logError(msg, ex) } }
|
||||
}
|
||||
|
||||
def getUserNameFromEnvironment(): String = {
|
||||
SparkHadoopUtil.getUserNameFromEnvironment
|
||||
}
|
||||
|
||||
// Typically, this will be of order of number of nodes in cluster
|
||||
// If not, we should change it to LRUCache or something.
|
||||
private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]()
|
||||
|
||||
def parseHostPort(hostPort: String): (String, Int) = {
|
||||
{
|
||||
// Check cache first.
|
||||
var cached = hostPortParseResults.get(hostPort)
|
||||
if (cached != null) return cached
|
||||
}
|
||||
|
||||
val indx: Int = hostPort.lastIndexOf(':')
|
||||
// This is potentially broken - when dealing with ipv6 addresses for example, sigh ...
|
||||
// but then hadoop does not support ipv6 right now.
|
||||
// For now, we assume that if port exists, then it is valid - not check if it is an int > 0
|
||||
if (-1 == indx) {
|
||||
val retval = (hostPort, 0)
|
||||
hostPortParseResults.put(hostPort, retval)
|
||||
return retval
|
||||
}
|
||||
|
||||
val retval = (hostPort.substring(0, indx).trim(), hostPort.substring(indx + 1).trim().toInt)
|
||||
hostPortParseResults.putIfAbsent(hostPort, retval)
|
||||
hostPortParseResults.get(hostPort)
|
||||
}
|
||||
|
||||
private[spark] val daemonThreadFactory: ThreadFactory =
|
||||
|
@ -395,13 +522,14 @@ private object Utils extends Logging {
|
|||
execute(command, new File("."))
|
||||
}
|
||||
|
||||
|
||||
private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
|
||||
val firstUserLine: Int, val firstUserClass: String)
|
||||
/**
|
||||
* When called inside a class in the spark package, returns the name of the user code class
|
||||
* (outside the spark package) that called into Spark, as well as which Spark method they called.
|
||||
* This is used, for example, to tell users where in their code each RDD got created.
|
||||
*/
|
||||
def getSparkCallSite: String = {
|
||||
def getCallSiteInfo: CallSiteInfo = {
|
||||
val trace = Thread.currentThread.getStackTrace().filter( el =>
|
||||
(!el.getMethodName.contains("getStackTrace")))
|
||||
|
||||
|
@ -413,6 +541,7 @@ private object Utils extends Logging {
|
|||
var firstUserFile = "<unknown>"
|
||||
var firstUserLine = 0
|
||||
var finished = false
|
||||
var firstUserClass = "<unknown>"
|
||||
|
||||
for (el <- trace) {
|
||||
if (!finished) {
|
||||
|
@ -427,13 +556,19 @@ private object Utils extends Logging {
|
|||
else {
|
||||
firstUserLine = el.getLineNumber
|
||||
firstUserFile = el.getFileName
|
||||
firstUserClass = el.getClassName
|
||||
finished = true
|
||||
}
|
||||
}
|
||||
}
|
||||
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
|
||||
new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
|
||||
}
|
||||
|
||||
def formatSparkCallSite = {
|
||||
val callSiteInfo = getCallSiteInfo
|
||||
"%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
|
||||
callSiteInfo.firstUserLine)
|
||||
}
|
||||
/**
|
||||
* Try to find a free port to bind to on the local host. This should ideally never be needed,
|
||||
* except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
|
||||
|
|
|
@ -57,6 +57,12 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
|
|||
*/
|
||||
def coalesce(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions))
|
||||
|
||||
/**
|
||||
* Return a new RDD that is reduced into `numPartitions` partitions.
|
||||
*/
|
||||
def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD =
|
||||
fromRDD(srdd.coalesce(numPartitions, shuffle))
|
||||
|
||||
/**
|
||||
* Return an RDD with the elements from `this` that are not in `other`.
|
||||
*
|
||||
|
|
|
@ -6,6 +6,7 @@ import java.util.Comparator
|
|||
import scala.Tuple2
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
import org.apache.hadoop.mapred.JobConf
|
||||
import org.apache.hadoop.mapred.OutputFormat
|
||||
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
|
||||
|
@ -19,6 +20,7 @@ import spark.OrderedRDDFunctions
|
|||
import spark.storage.StorageLevel
|
||||
import spark.HashPartitioner
|
||||
import spark.Partitioner
|
||||
import spark.Partitioner._
|
||||
import spark.RDD
|
||||
import spark.SparkContext.rddToPairRDDFunctions
|
||||
|
||||
|
@ -65,7 +67,13 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
|
|||
/**
|
||||
* Return a new RDD that is reduced into `numPartitions` partitions.
|
||||
*/
|
||||
def coalesce(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.coalesce(numPartitions))
|
||||
def coalesce(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.coalesce(numPartitions))
|
||||
|
||||
/**
|
||||
* Return a new RDD that is reduced into `numPartitions` partitions.
|
||||
*/
|
||||
def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] =
|
||||
fromRDD(rdd.coalesce(numPartitions, shuffle))
|
||||
|
||||
/**
|
||||
* Return a sampled subset of this RDD.
|
||||
|
@ -159,6 +167,30 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
|
|||
: PartialResult[java.util.Map[K, BoundedDouble]] =
|
||||
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
|
||||
|
||||
/**
|
||||
* Merge the values for each key using an associative function and a neutral "zero value" which may
|
||||
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
|
||||
* list concatenation, 0 for addition, or 1 for multiplication.).
|
||||
*/
|
||||
def foldByKey(zeroValue: V, partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
|
||||
fromRDD(rdd.foldByKey(zeroValue, partitioner)(func))
|
||||
|
||||
/**
|
||||
* Merge the values for each key using an associative function and a neutral "zero value" which may
|
||||
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
|
||||
* list concatenation, 0 for addition, or 1 for multiplication.).
|
||||
*/
|
||||
def foldByKey(zeroValue: V, numPartitions: Int, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
|
||||
fromRDD(rdd.foldByKey(zeroValue, numPartitions)(func))
|
||||
|
||||
/**
|
||||
* Merge the values for each key using an associative function and a neutral "zero value" which may
|
||||
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
|
||||
* list concatenation, 0 for addition, or 1 for multiplication.).
|
||||
*/
|
||||
def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
|
||||
fromRDD(rdd.foldByKey(zeroValue)(func))
|
||||
|
||||
/**
|
||||
* Merge the values for each key using an associative reduce function. This will also perform
|
||||
* the merging locally on each mapper before sending results to a reducer, similarly to a
|
||||
|
@ -241,30 +273,30 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
|
|||
fromRDD(rdd.rightOuterJoin(other, partitioner))
|
||||
|
||||
/**
|
||||
* Simplified version of combineByKey that hash-partitions the resulting RDD using the default
|
||||
* parallelism level.
|
||||
* Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
|
||||
* partitioner/parallelism level.
|
||||
*/
|
||||
def combineByKey[C](createCombiner: JFunction[V, C],
|
||||
mergeValue: JFunction2[C, V, C],
|
||||
mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
|
||||
implicit val cm: ClassManifest[C] =
|
||||
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
|
||||
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners))
|
||||
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd)))
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge the values for each key using an associative reduce function. This will also perform
|
||||
* the merging locally on each mapper before sending results to a reducer, similarly to a
|
||||
* "combiner" in MapReduce. Output will be hash-partitioned with the default parallelism level.
|
||||
* "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
|
||||
* parallelism level.
|
||||
*/
|
||||
def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = {
|
||||
val partitioner = rdd.defaultPartitioner(rdd)
|
||||
fromRDD(reduceByKey(partitioner, func))
|
||||
fromRDD(reduceByKey(defaultPartitioner(rdd), func))
|
||||
}
|
||||
|
||||
/**
|
||||
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
|
||||
* resulting RDD with the default parallelism level.
|
||||
* resulting RDD with the existing partitioner/parallelism level.
|
||||
*/
|
||||
def groupByKey(): JavaPairRDD[K, JList[V]] =
|
||||
fromRDD(groupByResultToJava(rdd.groupByKey()))
|
||||
|
@ -289,7 +321,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
|
|||
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
|
||||
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
|
||||
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
|
||||
* using the default level of parallelism.
|
||||
* using the existing partitioner/parallelism level.
|
||||
*/
|
||||
def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Option[W])] =
|
||||
fromRDD(rdd.leftOuterJoin(other))
|
||||
|
@ -307,7 +339,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
|
|||
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
|
||||
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
|
||||
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
|
||||
* RDD using the default parallelism level.
|
||||
* RDD using the existing partitioner/parallelism level.
|
||||
*/
|
||||
def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Option[V], W)] =
|
||||
fromRDD(rdd.rightOuterJoin(other))
|
||||
|
@ -428,6 +460,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
|
|||
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
|
||||
}
|
||||
|
||||
/** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
|
||||
def saveAsHadoopFile[F <: OutputFormat[_, _]](
|
||||
path: String,
|
||||
keyClass: Class[_],
|
||||
valueClass: Class[_],
|
||||
outputFormatClass: Class[F],
|
||||
codec: Class[_ <: CompressionCodec]) {
|
||||
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
|
||||
}
|
||||
|
||||
/** Output the RDD to any Hadoop-supported file system. */
|
||||
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
|
||||
path: String,
|
||||
|
|
|
@ -14,12 +14,18 @@ JavaRDDLike[T, JavaRDD[T]] {
|
|||
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
|
||||
def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
|
||||
|
||||
/**
|
||||
/**
|
||||
* Set this RDD's storage level to persist its values across operations after the first time
|
||||
* it is computed. Can only be called once on each RDD.
|
||||
* it is computed. This can only be used to assign a new storage level if the RDD does not
|
||||
* have a storage level set yet..
|
||||
*/
|
||||
def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
*/
|
||||
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
|
||||
|
||||
// Transformations (return a new RDD)
|
||||
|
||||
/**
|
||||
|
@ -31,7 +37,7 @@ JavaRDDLike[T, JavaRDD[T]] {
|
|||
* Return a new RDD containing the distinct elements in this RDD.
|
||||
*/
|
||||
def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions))
|
||||
|
||||
|
||||
/**
|
||||
* Return a new RDD containing only the elements that satisfy a predicate.
|
||||
*/
|
||||
|
@ -43,12 +49,18 @@ JavaRDDLike[T, JavaRDD[T]] {
|
|||
*/
|
||||
def coalesce(numPartitions: Int): JavaRDD[T] = rdd.coalesce(numPartitions)
|
||||
|
||||
/**
|
||||
* Return a new RDD that is reduced into `numPartitions` partitions.
|
||||
*/
|
||||
def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] =
|
||||
rdd.coalesce(numPartitions, shuffle)
|
||||
|
||||
/**
|
||||
* Return a sampled subset of this RDD.
|
||||
*/
|
||||
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
|
||||
wrapRDD(rdd.sample(withReplacement, fraction, seed))
|
||||
|
||||
|
||||
/**
|
||||
* Return the union of this RDD and another one. Any identical elements will appear multiple
|
||||
* times (use `.distinct()` to eliminate them).
|
||||
|
@ -57,7 +69,7 @@ JavaRDDLike[T, JavaRDD[T]] {
|
|||
|
||||
/**
|
||||
* Return an RDD with the elements from `this` that are not in `other`.
|
||||
*
|
||||
*
|
||||
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
|
||||
* RDD will be <= us.
|
||||
*/
|
||||
|
@ -74,7 +86,6 @@ JavaRDDLike[T, JavaRDD[T]] {
|
|||
*/
|
||||
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
|
||||
wrapRDD(rdd.subtract(other, p))
|
||||
|
||||
}
|
||||
|
||||
object JavaRDD {
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
package spark.api.java
|
||||
|
||||
import java.util.{List => JList}
|
||||
import java.util.{List => JList, Comparator}
|
||||
import scala.Tuple2
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
import spark.{SparkContext, Partition, RDD, TaskContext}
|
||||
import spark.api.java.JavaPairRDD._
|
||||
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
|
||||
|
@ -182,6 +183,21 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
|
|||
JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest)
|
||||
}
|
||||
|
||||
/**
|
||||
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
|
||||
* applying a function to the zipped partitions. Assumes that all the RDDs have the
|
||||
* *same number of partitions*, but does *not* require them to have the same number
|
||||
* of elements in each partition.
|
||||
*/
|
||||
def zipPartitions[U, V](
|
||||
f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V],
|
||||
other: JavaRDDLike[U, _]): JavaRDD[V] = {
|
||||
def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
|
||||
f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
|
||||
JavaRDD.fromRDD(
|
||||
rdd.zipPartitions(fn, other.rdd)(other.classManifest, f.elementType()))(f.elementType())
|
||||
}
|
||||
|
||||
// Actions (launch a job to return a value to the user program)
|
||||
|
||||
/**
|
||||
|
@ -295,6 +311,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
|
|||
*/
|
||||
def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
|
||||
|
||||
|
||||
/**
|
||||
* Save this RDD as a compressed text file, using string representations of elements.
|
||||
*/
|
||||
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
|
||||
rdd.saveAsTextFile(path, codec)
|
||||
|
||||
/**
|
||||
* Save this RDD as a SequenceFile of serialized objects.
|
||||
*/
|
||||
|
@ -336,4 +359,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
|
|||
def toDebugString(): String = {
|
||||
rdd.toDebugString
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the top K elements from this RDD as defined by
|
||||
* the specified Comparator[T].
|
||||
* @param num the number of top elements to return
|
||||
* @param comp the comparator that defines the order
|
||||
* @return an array of top elements
|
||||
*/
|
||||
def top(num: Int, comp: Comparator[T]): JList[T] = {
|
||||
import scala.collection.JavaConversions._
|
||||
val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp))
|
||||
val arr: java.util.Collection[T] = topElems.toSeq
|
||||
new java.util.ArrayList(arr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the top K elements from this RDD using the
|
||||
* natural ordering for T.
|
||||
* @param num the number of top elements to return
|
||||
* @return an array of top elements
|
||||
*/
|
||||
def top(num: Int): JList[T] = {
|
||||
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
|
||||
top(num, comp)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -31,8 +31,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
|
|||
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
|
||||
* @param appName A name for your application, to display on the cluster web UI
|
||||
* @param sparkHome The SPARK_HOME directory on the slave nodes
|
||||
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
|
||||
* system or HDFS, HTTP, HTTPS, or FTP URLs.
|
||||
* @param jarFile JAR file to send to the cluster. This can be a path on the local file system
|
||||
* or an HDFS, HTTP, HTTPS, or FTP URL.
|
||||
*/
|
||||
def this(master: String, appName: String, sparkHome: String, jarFile: String) =
|
||||
this(new SparkContext(master, appName, sparkHome, Seq(jarFile)))
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
package spark.api.java.function
|
||||
|
||||
/**
|
||||
* A function that takes two inputs and returns zero or more output records.
|
||||
*/
|
||||
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
|
||||
@throws(classOf[Exception])
|
||||
def call(a: A, b:B) : java.lang.Iterable[C]
|
||||
|
||||
def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
|
||||
}
|
|
@ -2,10 +2,9 @@ package spark.api.python
|
|||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.{List => JList, ArrayList => JArrayList, Collections}
|
||||
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.io.Source
|
||||
|
||||
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
|
||||
import spark.broadcast.Broadcast
|
||||
|
@ -16,7 +15,7 @@ import spark.rdd.PipedRDD
|
|||
private[spark] class PythonRDD[T: ClassManifest](
|
||||
parent: RDD[T],
|
||||
command: Seq[String],
|
||||
envVars: java.util.Map[String, String],
|
||||
envVars: JMap[String, String],
|
||||
preservePartitoning: Boolean,
|
||||
pythonExec: String,
|
||||
broadcastVars: JList[Broadcast[Array[Byte]]],
|
||||
|
@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
|
||||
// Similar to Runtime.exec(), if we are given a single string, split it into words
|
||||
// using a standard StringTokenizer (i.e. by spaces)
|
||||
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
|
||||
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
|
||||
preservePartitoning: Boolean, pythonExec: String,
|
||||
broadcastVars: JList[Broadcast[Array[Byte]]],
|
||||
accumulator: Accumulator[JList[Array[Byte]]]) =
|
||||
|
@ -36,35 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
|
||||
override val partitioner = if (preservePartitoning) parent.partitioner else None
|
||||
|
||||
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
|
||||
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
|
||||
|
||||
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
|
||||
// Add the environmental variables to the process.
|
||||
val currentEnvVars = pb.environment()
|
||||
|
||||
for ((variable, value) <- envVars) {
|
||||
currentEnvVars.put(variable, value)
|
||||
}
|
||||
|
||||
val proc = pb.start()
|
||||
val startTime = System.currentTimeMillis
|
||||
val env = SparkEnv.get
|
||||
|
||||
// Start a thread to print the process's stderr to ours
|
||||
new Thread("stderr reader for " + command) {
|
||||
override def run() {
|
||||
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
|
||||
System.err.println(line)
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
val worker = env.createPythonWorker(pythonExec, envVars.toMap)
|
||||
|
||||
// Start a thread to feed the process input from our parent's iterator
|
||||
new Thread("stdin writer for " + command) {
|
||||
new Thread("stdin writer for " + pythonExec) {
|
||||
override def run() {
|
||||
SparkEnv.set(env)
|
||||
val out = new PrintWriter(proc.getOutputStream)
|
||||
val dOut = new DataOutputStream(proc.getOutputStream)
|
||||
val out = new PrintWriter(worker.getOutputStream)
|
||||
val dOut = new DataOutputStream(worker.getOutputStream)
|
||||
// Partition index
|
||||
dOut.writeInt(split.index)
|
||||
// sparkFilesDir
|
||||
|
@ -88,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
}
|
||||
dOut.flush()
|
||||
out.flush()
|
||||
proc.getOutputStream.close()
|
||||
worker.shutdownOutput()
|
||||
}
|
||||
}.start()
|
||||
|
||||
// Return an iterator that read lines from the process's stdout
|
||||
val stream = new DataInputStream(proc.getInputStream)
|
||||
val stream = new DataInputStream(worker.getInputStream)
|
||||
return new Iterator[Array[Byte]] {
|
||||
def next(): Array[Byte] = {
|
||||
val obj = _nextObj
|
||||
_nextObj = read()
|
||||
if (hasNext) {
|
||||
// FIXME: can deadlock if worker is waiting for us to
|
||||
// respond to current message (currently irrelevant because
|
||||
// output is shutdown before we read any input)
|
||||
_nextObj = read()
|
||||
}
|
||||
obj
|
||||
}
|
||||
|
||||
|
@ -108,6 +95,17 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
val obj = new Array[Byte](length)
|
||||
stream.readFully(obj)
|
||||
obj
|
||||
case -3 =>
|
||||
// Timing data from worker
|
||||
val bootTime = stream.readLong()
|
||||
val initTime = stream.readLong()
|
||||
val finishTime = stream.readLong()
|
||||
val boot = bootTime - startTime
|
||||
val init = initTime - bootTime
|
||||
val finish = finishTime - initTime
|
||||
val total = finishTime - startTime
|
||||
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
|
||||
read
|
||||
case -2 =>
|
||||
// Signals that an exception has been thrown in python
|
||||
val exLength = stream.readInt()
|
||||
|
@ -115,23 +113,21 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
stream.readFully(obj)
|
||||
throw new PythonException(new String(obj))
|
||||
case -1 =>
|
||||
// We've finished the data section of the output, but we can still read some
|
||||
// accumulator updates; let's do that, breaking when we get EOFException
|
||||
while (true) {
|
||||
val len2 = stream.readInt()
|
||||
// We've finished the data section of the output, but we can still
|
||||
// read some accumulator updates; let's do that, breaking when we
|
||||
// get a negative length record.
|
||||
var len2 = stream.readInt()
|
||||
while (len2 >= 0) {
|
||||
val update = new Array[Byte](len2)
|
||||
stream.readFully(update)
|
||||
accumulator += Collections.singletonList(update)
|
||||
len2 = stream.readInt()
|
||||
}
|
||||
new Array[Byte](0)
|
||||
}
|
||||
} catch {
|
||||
case eof: EOFException => {
|
||||
val exitStatus = proc.waitFor()
|
||||
if (exitStatus != 0) {
|
||||
throw new Exception("Subprocess exited with status " + exitStatus)
|
||||
}
|
||||
new Array[Byte](0)
|
||||
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
|
||||
}
|
||||
case e => throw e
|
||||
}
|
||||
|
@ -159,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
|
|||
override def compute(split: Partition, context: TaskContext) =
|
||||
prev.iterator(split, context).grouped(2).map {
|
||||
case Seq(a, b) => (a, b)
|
||||
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
|
||||
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
|
||||
}
|
||||
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
|
||||
}
|
||||
|
@ -215,7 +211,7 @@ private[spark] object PythonRDD {
|
|||
dOut.write(s)
|
||||
dOut.writeByte(Pickle.STOP)
|
||||
} else {
|
||||
throw new Exception("Unexpected RDD type")
|
||||
throw new SparkException("Unexpected RDD type")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -277,6 +273,8 @@ private class BytesToString extends spark.api.java.function.Function[Array[Byte]
|
|||
*/
|
||||
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
|
||||
extends AccumulatorParam[JList[Array[Byte]]] {
|
||||
|
||||
Utils.checkHost(serverHost, "Expected hostname")
|
||||
|
||||
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
|
||||
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
package spark.api.python
|
||||
|
||||
import java.io.{DataInputStream, IOException}
|
||||
import java.net.{Socket, SocketException, InetAddress}
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import spark._
|
||||
|
||||
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
|
||||
extends Logging {
|
||||
var daemon: Process = null
|
||||
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
|
||||
var daemonPort: Int = 0
|
||||
|
||||
def create(): Socket = {
|
||||
synchronized {
|
||||
// Start the daemon if it hasn't been started
|
||||
startDaemon()
|
||||
|
||||
// Attempt to connect, restart and retry once if it fails
|
||||
try {
|
||||
new Socket(daemonHost, daemonPort)
|
||||
} catch {
|
||||
case exc: SocketException => {
|
||||
logWarning("Python daemon unexpectedly quit, attempting to restart")
|
||||
stopDaemon()
|
||||
startDaemon()
|
||||
new Socket(daemonHost, daemonPort)
|
||||
}
|
||||
case e => throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def stop() {
|
||||
stopDaemon()
|
||||
}
|
||||
|
||||
private def startDaemon() {
|
||||
synchronized {
|
||||
// Is it already running?
|
||||
if (daemon != null) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
// Create and start the daemon
|
||||
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
|
||||
val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
|
||||
val workerEnv = pb.environment()
|
||||
workerEnv.putAll(envVars)
|
||||
daemon = pb.start()
|
||||
daemonPort = new DataInputStream(daemon.getInputStream).readInt()
|
||||
|
||||
// Redirect the stderr to ours
|
||||
new Thread("stderr reader for " + pythonExec) {
|
||||
override def run() {
|
||||
scala.util.control.Exception.ignoring(classOf[IOException]) {
|
||||
// FIXME HACK: We copy the stream on the level of bytes to
|
||||
// attempt to dodge encoding problems.
|
||||
val in = daemon.getErrorStream
|
||||
var buf = new Array[Byte](1024)
|
||||
var len = in.read(buf)
|
||||
while (len != -1) {
|
||||
System.err.write(buf, 0, len)
|
||||
len = in.read(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
} catch {
|
||||
case e => {
|
||||
stopDaemon()
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
// Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
|
||||
// detect our disappearance.
|
||||
}
|
||||
}
|
||||
|
||||
private def stopDaemon() {
|
||||
synchronized {
|
||||
// Request shutdown of existing daemon by sending SIGTERM
|
||||
if (daemon != null) {
|
||||
daemon.destroy()
|
||||
}
|
||||
|
||||
daemon = null
|
||||
daemonPort = 0
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,10 +2,11 @@ package spark.deploy
|
|||
|
||||
private[spark] class ApplicationDescription(
|
||||
val name: String,
|
||||
val cores: Int,
|
||||
val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */
|
||||
val memoryPerSlave: Int,
|
||||
val command: Command,
|
||||
val sparkHome: String)
|
||||
val sparkHome: String,
|
||||
val appUiUrl: String)
|
||||
extends Serializable {
|
||||
|
||||
val user = System.getProperty("user.name", "<unknown>")
|
||||
|
|
|
@ -4,6 +4,7 @@ import spark.deploy.ExecutorState.ExecutorState
|
|||
import spark.deploy.master.{WorkerInfo, ApplicationInfo}
|
||||
import spark.deploy.worker.ExecutorRunner
|
||||
import scala.collection.immutable.List
|
||||
import spark.Utils
|
||||
|
||||
|
||||
private[spark] sealed trait DeployMessage extends Serializable
|
||||
|
@ -19,7 +20,10 @@ case class RegisterWorker(
|
|||
memory: Int,
|
||||
webUiPort: Int,
|
||||
publicAddress: String)
|
||||
extends DeployMessage
|
||||
extends DeployMessage {
|
||||
Utils.checkHost(host, "Required hostname")
|
||||
assert (port > 0)
|
||||
}
|
||||
|
||||
private[spark]
|
||||
case class ExecutorStateChanged(
|
||||
|
@ -58,14 +62,16 @@ private[spark]
|
|||
case class RegisteredApplication(appId: String) extends DeployMessage
|
||||
|
||||
private[spark]
|
||||
case class ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int)
|
||||
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
|
||||
Utils.checkHostPort(hostPort, "Required hostport")
|
||||
}
|
||||
|
||||
private[spark]
|
||||
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String],
|
||||
exitStatus: Option[Int])
|
||||
|
||||
private[spark]
|
||||
case class appKilled(message: String)
|
||||
case class ApplicationRemoved(message: String)
|
||||
|
||||
// Internal message in Client
|
||||
|
||||
|
@ -81,6 +87,9 @@ private[spark]
|
|||
case class MasterState(host: String, port: Int, workers: Array[WorkerInfo],
|
||||
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
|
||||
|
||||
Utils.checkHost(host, "Required hostname")
|
||||
assert (port > 0)
|
||||
|
||||
def uri = "spark://" + host + ":" + port
|
||||
}
|
||||
|
||||
|
@ -92,4 +101,8 @@ private[spark] case object RequestWorkerState
|
|||
private[spark]
|
||||
case class WorkerState(host: String, port: Int, workerId: String, executors: List[ExecutorRunner],
|
||||
finishedExecutors: List[ExecutorRunner], masterUrl: String, cores: Int, memory: Int,
|
||||
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String)
|
||||
coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) {
|
||||
|
||||
Utils.checkHost(host, "Required hostname")
|
||||
assert (port > 0)
|
||||
}
|
||||
|
|
|
@ -12,6 +12,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
|
|||
def write(obj: WorkerInfo) = JsObject(
|
||||
"id" -> JsString(obj.id),
|
||||
"host" -> JsString(obj.host),
|
||||
"port" -> JsNumber(obj.port),
|
||||
"webuiaddress" -> JsString(obj.webUiAddress),
|
||||
"cores" -> JsNumber(obj.cores),
|
||||
"coresused" -> JsNumber(obj.coresUsed),
|
||||
|
@ -25,7 +26,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
|
|||
"starttime" -> JsNumber(obj.startTime),
|
||||
"id" -> JsString(obj.id),
|
||||
"name" -> JsString(obj.desc.name),
|
||||
"cores" -> JsNumber(obj.desc.cores),
|
||||
"cores" -> JsNumber(obj.desc.maxCores),
|
||||
"user" -> JsString(obj.desc.user),
|
||||
"memoryperslave" -> JsNumber(obj.desc.memoryPerSlave),
|
||||
"submitdate" -> JsString(obj.submitDate.toString))
|
||||
|
@ -34,7 +35,7 @@ private[spark] object JsonProtocol extends DefaultJsonProtocol {
|
|||
implicit object AppDescriptionJsonFormat extends RootJsonWriter[ApplicationDescription] {
|
||||
def write(obj: ApplicationDescription) = JsObject(
|
||||
"name" -> JsString(obj.name),
|
||||
"cores" -> JsNumber(obj.cores),
|
||||
"cores" -> JsNumber(obj.maxCores),
|
||||
"memoryperslave" -> JsNumber(obj.memoryPerSlave),
|
||||
"user" -> JsString(obj.user)
|
||||
)
|
||||
|
|
|
@ -18,7 +18,7 @@ import scala.collection.mutable.ArrayBuffer
|
|||
private[spark]
|
||||
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
|
||||
|
||||
private val localIpAddress = Utils.localIpAddress
|
||||
private val localHostname = Utils.localHostName()
|
||||
private val masterActorSystems = ArrayBuffer[ActorSystem]()
|
||||
private val workerActorSystems = ArrayBuffer[ActorSystem]()
|
||||
|
||||
|
@ -26,13 +26,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
|
|||
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
|
||||
|
||||
/* Start the Master */
|
||||
val (masterSystem, masterPort) = Master.startSystemAndActor(localIpAddress, 0, 0)
|
||||
val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0)
|
||||
masterActorSystems += masterSystem
|
||||
val masterUrl = "spark://" + localIpAddress + ":" + masterPort
|
||||
val masterUrl = "spark://" + localHostname + ":" + masterPort
|
||||
|
||||
/* Start the Workers */
|
||||
for (workerNum <- 1 to numWorkers) {
|
||||
val (workerSystem, _) = Worker.startSystemAndActor(localIpAddress, 0, 0, coresPerWorker,
|
||||
val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
|
||||
memoryPerWorker, masterUrl, null, Some(workerNum))
|
||||
workerActorSystems += workerSystem
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ package spark.deploy.client
|
|||
import spark.deploy._
|
||||
import akka.actor._
|
||||
import akka.pattern.ask
|
||||
import akka.util.Duration
|
||||
import akka.util.duration._
|
||||
import akka.pattern.AskTimeoutException
|
||||
import spark.{SparkException, Logging}
|
||||
|
@ -54,10 +55,15 @@ private[spark] class Client(
|
|||
appId = appId_
|
||||
listener.connected(appId)
|
||||
|
||||
case ExecutorAdded(id: Int, workerId: String, host: String, cores: Int, memory: Int) =>
|
||||
case ApplicationRemoved(message) =>
|
||||
logError("Master removed our application: %s; stopping client".format(message))
|
||||
markDisconnected()
|
||||
context.stop(self)
|
||||
|
||||
case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) =>
|
||||
val fullId = appId + "/" + id
|
||||
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, host, cores))
|
||||
listener.executorAdded(fullId, workerId, host, cores, memory)
|
||||
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores))
|
||||
listener.executorAdded(fullId, workerId, hostPort, cores, memory)
|
||||
|
||||
case ExecutorUpdated(id, state, message, exitStatus) =>
|
||||
val fullId = appId + "/" + id
|
||||
|
@ -107,7 +113,7 @@ private[spark] class Client(
|
|||
def stop() {
|
||||
if (actor != null) {
|
||||
try {
|
||||
val timeout = 5.seconds
|
||||
val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
|
||||
val future = actor.ask(StopClient)(timeout)
|
||||
Await.result(future, timeout)
|
||||
} catch {
|
||||
|
|
|
@ -12,7 +12,7 @@ private[spark] trait ClientListener {
|
|||
|
||||
def disconnected(): Unit
|
||||
|
||||
def executorAdded(fullId: String, workerId: String, host: String, cores: Int, memory: Int): Unit
|
||||
def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
|
||||
|
||||
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ private[spark] object TestClient {
|
|||
System.exit(0)
|
||||
}
|
||||
|
||||
def executorAdded(id: String, workerId: String, host: String, cores: Int, memory: Int) {}
|
||||
def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
|
||||
|
||||
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
|
||||
}
|
||||
|
@ -25,7 +25,7 @@ private[spark] object TestClient {
|
|||
val url = args(0)
|
||||
val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0)
|
||||
val desc = new ApplicationDescription(
|
||||
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home")
|
||||
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
|
||||
val listener = new TestListener
|
||||
val client = new Client(actorSystem, url, desc, listener)
|
||||
client.start()
|
||||
|
|
|
@ -10,7 +10,8 @@ private[spark] class ApplicationInfo(
|
|||
val id: String,
|
||||
val desc: ApplicationDescription,
|
||||
val submitDate: Date,
|
||||
val driver: ActorRef)
|
||||
val driver: ActorRef,
|
||||
val appUiUrl: String)
|
||||
{
|
||||
var state = ApplicationState.WAITING
|
||||
var executors = new mutable.HashMap[Int, ExecutorInfo]
|
||||
|
@ -37,7 +38,7 @@ private[spark] class ApplicationInfo(
|
|||
coresGranted -= exec.cores
|
||||
}
|
||||
|
||||
def coresLeft: Int = desc.cores - coresGranted
|
||||
def coresLeft: Int = desc.maxCores - coresGranted
|
||||
|
||||
private var _retryCount = 0
|
||||
|
||||
|
@ -60,4 +61,5 @@ private[spark] class ApplicationInfo(
|
|||
System.currentTimeMillis() - startTime
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -15,7 +15,7 @@ import spark.{Logging, SparkException, Utils}
|
|||
import spark.util.AkkaUtils
|
||||
|
||||
|
||||
private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor with Logging {
|
||||
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
|
||||
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
|
||||
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
|
||||
|
||||
|
@ -35,18 +35,20 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
|
||||
var firstApp: Option[ApplicationInfo] = None
|
||||
|
||||
Utils.checkHost(host, "Expected hostname")
|
||||
|
||||
val masterPublicAddress = {
|
||||
val envVar = System.getenv("SPARK_PUBLIC_DNS")
|
||||
if (envVar != null) envVar else ip
|
||||
if (envVar != null) envVar else host
|
||||
}
|
||||
|
||||
// As a temporary workaround before better ways of configuring memory, we allow users to set
|
||||
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
|
||||
// among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
|
||||
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "false").toBoolean
|
||||
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
|
||||
|
||||
override def preStart() {
|
||||
logInfo("Starting Spark master at spark://" + ip + ":" + port)
|
||||
logInfo("Starting Spark master at spark://" + host + ":" + port)
|
||||
// Listen for remote client disconnection events, since they don't go through Akka's watch()
|
||||
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
|
||||
startWebUi()
|
||||
|
@ -107,7 +109,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
} else {
|
||||
logError("Application %s with ID %s failed %d times, removing it".format(
|
||||
appInfo.desc.name, appInfo.id, appInfo.retryCount))
|
||||
removeApplication(appInfo)
|
||||
removeApplication(appInfo, ApplicationState.FAILED)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -129,23 +131,23 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
// The disconnected actor could've been either a worker or an app; remove whichever of
|
||||
// those we have an entry for in the corresponding actor hashmap
|
||||
actorToWorker.get(actor).foreach(removeWorker)
|
||||
actorToApp.get(actor).foreach(removeApplication)
|
||||
actorToApp.get(actor).foreach(finishApplication)
|
||||
}
|
||||
|
||||
case RemoteClientDisconnected(transport, address) => {
|
||||
// The disconnected client could've been either a worker or an app; remove whichever it was
|
||||
addressToWorker.get(address).foreach(removeWorker)
|
||||
addressToApp.get(address).foreach(removeApplication)
|
||||
addressToApp.get(address).foreach(finishApplication)
|
||||
}
|
||||
|
||||
case RemoteClientShutdown(transport, address) => {
|
||||
// The disconnected client could've been either a worker or an app; remove whichever it was
|
||||
addressToWorker.get(address).foreach(removeWorker)
|
||||
addressToApp.get(address).foreach(removeApplication)
|
||||
addressToApp.get(address).foreach(finishApplication)
|
||||
}
|
||||
|
||||
case RequestMasterState => {
|
||||
sender ! MasterState(ip, port, workers.toArray, apps.toArray, completedApps.toArray)
|
||||
sender ! MasterState(host, port, workers.toArray, apps.toArray, completedApps.toArray)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -211,13 +213,13 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
|
||||
worker.addExecutor(exec)
|
||||
worker.actor ! LaunchExecutor(exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
|
||||
exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.host, exec.cores, exec.memory)
|
||||
exec.application.driver ! ExecutorAdded(exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
|
||||
}
|
||||
|
||||
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
|
||||
publicAddress: String): WorkerInfo = {
|
||||
// There may be one or more refs to dead workers on this same node (w/ different ID's), remove them.
|
||||
workers.filter(w => (w.host == host) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
|
||||
workers.filter(w => (w.host == host && w.port == port) && (w.state == WorkerState.DEAD)).foreach(workers -= _)
|
||||
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
|
||||
workers += worker
|
||||
idToWorker(worker.id) = worker
|
||||
|
@ -242,7 +244,7 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
|
||||
val now = System.currentTimeMillis()
|
||||
val date = new Date(now)
|
||||
val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver)
|
||||
val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
|
||||
apps += app
|
||||
idToApp(app.id) = app
|
||||
actorToApp(driver) = app
|
||||
|
@ -257,20 +259,26 @@ private[spark] class Master(ip: String, port: Int, webUiPort: Int) extends Actor
|
|||
return app
|
||||
}
|
||||
|
||||
def removeApplication(app: ApplicationInfo) {
|
||||
def finishApplication(app: ApplicationInfo) {
|
||||
removeApplication(app, ApplicationState.FINISHED)
|
||||
}
|
||||
|
||||
def removeApplication(app: ApplicationInfo, state: ApplicationState.Value) {
|
||||
if (apps.contains(app)) {
|
||||
logInfo("Removing app " + app.id)
|
||||
apps -= app
|
||||
idToApp -= app.id
|
||||
actorToApp -= app.driver
|
||||
addressToWorker -= app.driver.path.address
|
||||
addressToApp -= app.driver.path.address
|
||||
completedApps += app // Remember it in our history
|
||||
waitingApps -= app
|
||||
for (exec <- app.executors.values) {
|
||||
exec.worker.removeExecutor(exec)
|
||||
exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
|
||||
exec.state = ExecutorState.KILLED
|
||||
}
|
||||
app.markFinished(ApplicationState.FINISHED) // TODO: Mark it as FAILED if it failed
|
||||
app.markFinished(state)
|
||||
app.driver ! ApplicationRemoved(state.toString)
|
||||
schedule()
|
||||
}
|
||||
}
|
||||
|
@ -302,7 +310,7 @@ private[spark] object Master {
|
|||
|
||||
def main(argStrings: Array[String]) {
|
||||
val args = new MasterArguments(argStrings)
|
||||
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort)
|
||||
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort)
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
||||
|
|
|
@ -7,13 +7,13 @@ import spark.Utils
|
|||
* Command-line parser for the master.
|
||||
*/
|
||||
private[spark] class MasterArguments(args: Array[String]) {
|
||||
var ip = Utils.localHostName()
|
||||
var host = Utils.localHostName()
|
||||
var port = 7077
|
||||
var webUiPort = 8080
|
||||
|
||||
// Check for settings in environment variables
|
||||
if (System.getenv("SPARK_MASTER_IP") != null) {
|
||||
ip = System.getenv("SPARK_MASTER_IP")
|
||||
if (System.getenv("SPARK_MASTER_HOST") != null) {
|
||||
host = System.getenv("SPARK_MASTER_HOST")
|
||||
}
|
||||
if (System.getenv("SPARK_MASTER_PORT") != null) {
|
||||
port = System.getenv("SPARK_MASTER_PORT").toInt
|
||||
|
@ -26,7 +26,13 @@ private[spark] class MasterArguments(args: Array[String]) {
|
|||
|
||||
def parse(args: List[String]): Unit = args match {
|
||||
case ("--ip" | "-i") :: value :: tail =>
|
||||
ip = value
|
||||
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
|
||||
host = value
|
||||
parse(tail)
|
||||
|
||||
case ("--host" | "-h") :: value :: tail =>
|
||||
Utils.checkHost(value, "Please use hostname " + value)
|
||||
host = value
|
||||
parse(tail)
|
||||
|
||||
case ("--port" | "-p") :: IntParam(value) :: tail =>
|
||||
|
@ -54,7 +60,8 @@ private[spark] class MasterArguments(args: Array[String]) {
|
|||
"Usage: Master [options]\n" +
|
||||
"\n" +
|
||||
"Options:\n" +
|
||||
" -i IP, --ip IP IP address or DNS name to listen on\n" +
|
||||
" -i HOST, --ip HOST Hostname to listen on (deprecated, please use --host or -h) \n" +
|
||||
" -h HOST, --host HOST Hostname to listen on\n" +
|
||||
" -p PORT, --port PORT Port to listen on (default: 7077)\n" +
|
||||
" --webui-port PORT Port for web UI (default: 8080)")
|
||||
System.exit(exitCode)
|
||||
|
|
|
@ -3,7 +3,7 @@ package spark.deploy.master
|
|||
import akka.actor.{ActorRef, ActorSystem}
|
||||
import akka.dispatch.Await
|
||||
import akka.pattern.ask
|
||||
import akka.util.Timeout
|
||||
import akka.util.{Duration, Timeout}
|
||||
import akka.util.duration._
|
||||
import cc.spray.Directives
|
||||
import cc.spray.directives._
|
||||
|
@ -22,7 +22,7 @@ class MasterWebUI(val actorSystem: ActorSystem, master: ActorRef) extends Direct
|
|||
val RESOURCE_DIR = "spark/deploy/master/webui"
|
||||
val STATIC_RESOURCE_DIR = "spark/deploy/static"
|
||||
|
||||
implicit val timeout = Timeout(10 seconds)
|
||||
implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
|
||||
|
||||
val handler = {
|
||||
get {
|
||||
|
|
|
@ -2,6 +2,7 @@ package spark.deploy.master
|
|||
|
||||
import akka.actor.ActorRef
|
||||
import scala.collection.mutable
|
||||
import spark.Utils
|
||||
|
||||
private[spark] class WorkerInfo(
|
||||
val id: String,
|
||||
|
@ -13,6 +14,9 @@ private[spark] class WorkerInfo(
|
|||
val webUiPort: Int,
|
||||
val publicAddress: String) {
|
||||
|
||||
Utils.checkHost(host, "Expected hostname")
|
||||
assert (port > 0)
|
||||
|
||||
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
|
||||
var state: WorkerState.Value = WorkerState.ALIVE
|
||||
var coresUsed = 0
|
||||
|
@ -23,6 +27,11 @@ private[spark] class WorkerInfo(
|
|||
def coresFree: Int = cores - coresUsed
|
||||
def memoryFree: Int = memory - memoryUsed
|
||||
|
||||
def hostPort: String = {
|
||||
assert (port > 0)
|
||||
host + ":" + port
|
||||
}
|
||||
|
||||
def addExecutor(exec: ExecutorInfo) {
|
||||
executors(exec.fullId) = exec
|
||||
coresUsed += exec.cores
|
||||
|
|
|
@ -21,11 +21,13 @@ private[spark] class ExecutorRunner(
|
|||
val memory: Int,
|
||||
val worker: ActorRef,
|
||||
val workerId: String,
|
||||
val hostname: String,
|
||||
val hostPort: String,
|
||||
val sparkHome: File,
|
||||
val workDir: File)
|
||||
extends Logging {
|
||||
|
||||
Utils.checkHostPort(hostPort, "Expected hostport")
|
||||
|
||||
val fullId = appId + "/" + execId
|
||||
var workerThread: Thread = null
|
||||
var process: Process = null
|
||||
|
@ -68,7 +70,7 @@ private[spark] class ExecutorRunner(
|
|||
/** Replace variables such as {{EXECUTOR_ID}} and {{CORES}} in a command argument passed to us */
|
||||
def substituteVariables(argument: String): String = argument match {
|
||||
case "{{EXECUTOR_ID}}" => execId.toString
|
||||
case "{{HOSTNAME}}" => hostname
|
||||
case "{{HOSTNAME}}" => Utils.parseHostPort(hostPort)._1
|
||||
case "{{CORES}}" => cores.toString
|
||||
case other => other
|
||||
}
|
||||
|
|
|
@ -16,7 +16,7 @@ import spark.deploy.master.Master
|
|||
import java.io.File
|
||||
|
||||
private[spark] class Worker(
|
||||
ip: String,
|
||||
host: String,
|
||||
port: Int,
|
||||
webUiPort: Int,
|
||||
cores: Int,
|
||||
|
@ -25,6 +25,9 @@ private[spark] class Worker(
|
|||
workDirPath: String = null)
|
||||
extends Actor with Logging {
|
||||
|
||||
Utils.checkHost(host, "Expected hostname")
|
||||
assert (port > 0)
|
||||
|
||||
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs
|
||||
|
||||
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
|
||||
|
@ -39,7 +42,7 @@ private[spark] class Worker(
|
|||
val finishedExecutors = new HashMap[String, ExecutorRunner]
|
||||
val publicAddress = {
|
||||
val envVar = System.getenv("SPARK_PUBLIC_DNS")
|
||||
if (envVar != null) envVar else ip
|
||||
if (envVar != null) envVar else host
|
||||
}
|
||||
|
||||
var coresUsed = 0
|
||||
|
@ -51,10 +54,14 @@ private[spark] class Worker(
|
|||
def createWorkDir() {
|
||||
workDir = Option(workDirPath).map(new File(_)).getOrElse(new File(sparkHome, "work"))
|
||||
try {
|
||||
if (!workDir.exists() && !workDir.mkdirs()) {
|
||||
// This sporadically fails - not sure why ... !workDir.exists() && !workDir.mkdirs()
|
||||
// So attempting to create and then check if directory was created or not.
|
||||
workDir.mkdirs()
|
||||
if ( !workDir.exists() || !workDir.isDirectory) {
|
||||
logError("Failed to create work directory " + workDir)
|
||||
System.exit(1)
|
||||
}
|
||||
assert (workDir.isDirectory)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("Failed to create work directory " + workDir, e)
|
||||
|
@ -64,7 +71,7 @@ private[spark] class Worker(
|
|||
|
||||
override def preStart() {
|
||||
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
|
||||
ip, port, cores, Utils.memoryMegabytesToString(memory)))
|
||||
host, port, cores, Utils.memoryMegabytesToString(memory)))
|
||||
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
|
||||
logInfo("Spark home: " + sparkHome)
|
||||
createWorkDir()
|
||||
|
@ -74,20 +81,14 @@ private[spark] class Worker(
|
|||
|
||||
def connectToMaster() {
|
||||
logInfo("Connecting to master " + masterUrl)
|
||||
try {
|
||||
master = context.actorFor(Master.toAkkaUrl(masterUrl))
|
||||
master ! RegisterWorker(workerId, ip, port, cores, memory, webUiPort, publicAddress)
|
||||
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
|
||||
context.watch(master) // Doesn't work with remote actors, but useful for testing
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("Failed to connect to master", e)
|
||||
System.exit(1)
|
||||
}
|
||||
master = context.actorFor(Master.toAkkaUrl(masterUrl))
|
||||
master ! RegisterWorker(workerId, host, port, cores, memory, webUiPort, publicAddress)
|
||||
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
|
||||
context.watch(master) // Doesn't work with remote actors, but useful for testing
|
||||
}
|
||||
|
||||
def startWebUi() {
|
||||
val webUi = new WorkerWebUI(context.system, self)
|
||||
val webUi = new WorkerWebUI(context.system, self, workDir)
|
||||
try {
|
||||
AkkaUtils.startSprayServer(context.system, "0.0.0.0", webUiPort, webUi.handler)
|
||||
} catch {
|
||||
|
@ -112,7 +113,7 @@ private[spark] class Worker(
|
|||
case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
|
||||
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
|
||||
val manager = new ExecutorRunner(
|
||||
appId, execId, appDesc, cores_, memory_, self, workerId, ip, new File(execSparkHome_), workDir)
|
||||
appId, execId, appDesc, cores_, memory_, self, workerId, host + ":" + port, new File(execSparkHome_), workDir)
|
||||
executors(appId + "/" + execId) = manager
|
||||
manager.start()
|
||||
coresUsed += cores_
|
||||
|
@ -147,7 +148,7 @@ private[spark] class Worker(
|
|||
masterDisconnected()
|
||||
|
||||
case RequestWorkerState => {
|
||||
sender ! WorkerState(ip, port, workerId, executors.values.toList,
|
||||
sender ! WorkerState(host, port, workerId, executors.values.toList,
|
||||
finishedExecutors.values.toList, masterUrl, cores, memory,
|
||||
coresUsed, memoryUsed, masterWebUiUrl)
|
||||
}
|
||||
|
@ -162,7 +163,7 @@ private[spark] class Worker(
|
|||
}
|
||||
|
||||
def generateWorkerId(): String = {
|
||||
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), ip, port)
|
||||
"worker-%s-%s-%d".format(DATE_FORMAT.format(new Date), host, port)
|
||||
}
|
||||
|
||||
override def postStop() {
|
||||
|
@ -173,7 +174,7 @@ private[spark] class Worker(
|
|||
private[spark] object Worker {
|
||||
def main(argStrings: Array[String]) {
|
||||
val args = new WorkerArguments(argStrings)
|
||||
val (actorSystem, _) = startSystemAndActor(args.ip, args.port, args.webUiPort, args.cores,
|
||||
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
|
||||
args.memory, args.master, args.workDir)
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ import java.lang.management.ManagementFactory
|
|||
* Command-line parser for the master.
|
||||
*/
|
||||
private[spark] class WorkerArguments(args: Array[String]) {
|
||||
var ip = Utils.localHostName()
|
||||
var host = Utils.localHostName()
|
||||
var port = 0
|
||||
var webUiPort = 8081
|
||||
var cores = inferDefaultCores()
|
||||
|
@ -38,7 +38,13 @@ private[spark] class WorkerArguments(args: Array[String]) {
|
|||
|
||||
def parse(args: List[String]): Unit = args match {
|
||||
case ("--ip" | "-i") :: value :: tail =>
|
||||
ip = value
|
||||
Utils.checkHost(value, "ip no longer supported, please use hostname " + value)
|
||||
host = value
|
||||
parse(tail)
|
||||
|
||||
case ("--host" | "-h") :: value :: tail =>
|
||||
Utils.checkHost(value, "Please use hostname " + value)
|
||||
host = value
|
||||
parse(tail)
|
||||
|
||||
case ("--port" | "-p") :: IntParam(value) :: tail =>
|
||||
|
@ -93,7 +99,8 @@ private[spark] class WorkerArguments(args: Array[String]) {
|
|||
" -c CORES, --cores CORES Number of cores to use\n" +
|
||||
" -m MEM, --memory MEM Amount of memory to use (e.g. 1000M, 2G)\n" +
|
||||
" -d DIR, --work-dir DIR Directory to run apps in (default: SPARK_HOME/work)\n" +
|
||||
" -i IP, --ip IP IP address or DNS name to listen on\n" +
|
||||
" -i HOST, --ip IP Hostname to listen on (deprecated, please use --host or -h)\n" +
|
||||
" -h HOST, --host HOST Hostname to listen on\n" +
|
||||
" -p PORT, --port PORT Port to listen on (default: random)\n" +
|
||||
" --webui-port PORT Port for web UI (default: 8081)")
|
||||
System.exit(exitCode)
|
||||
|
|
|
@ -3,7 +3,7 @@ package spark.deploy.worker
|
|||
import akka.actor.{ActorRef, ActorSystem}
|
||||
import akka.dispatch.Await
|
||||
import akka.pattern.ask
|
||||
import akka.util.Timeout
|
||||
import akka.util.{Duration, Timeout}
|
||||
import akka.util.duration._
|
||||
import cc.spray.Directives
|
||||
import cc.spray.typeconversion.TwirlSupport._
|
||||
|
@ -12,16 +12,17 @@ import cc.spray.typeconversion.SprayJsonSupport._
|
|||
|
||||
import spark.deploy.{WorkerState, RequestWorkerState}
|
||||
import spark.deploy.JsonProtocol._
|
||||
import java.io.File
|
||||
|
||||
/**
|
||||
* Web UI server for the standalone worker.
|
||||
*/
|
||||
private[spark]
|
||||
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Directives {
|
||||
class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef, workDir: File) extends Directives {
|
||||
val RESOURCE_DIR = "spark/deploy/worker/webui"
|
||||
val STATIC_RESOURCE_DIR = "spark/deploy/static"
|
||||
|
||||
implicit val timeout = Timeout(10 seconds)
|
||||
implicit val timeout = Timeout(Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds"))
|
||||
|
||||
val handler = {
|
||||
get {
|
||||
|
@ -43,7 +44,7 @@ class WorkerWebUI(val actorSystem: ActorSystem, worker: ActorRef) extends Direct
|
|||
path("log") {
|
||||
parameters("appId", "executorId", "logType") { (appId, executorId, logType) =>
|
||||
respondWithMediaType(cc.spray.http.MediaTypes.`text/plain`) {
|
||||
getFromFileName("work/" + appId + "/" + executorId + "/" + logType)
|
||||
getFromFileName(workDir.getPath() + "/" + appId + "/" + executorId + "/" + logType)
|
||||
}
|
||||
}
|
||||
} ~
|
||||
|
|
|
@ -16,66 +16,68 @@ import java.nio.ByteBuffer
|
|||
/**
|
||||
* The Mesos executor for Spark.
|
||||
*/
|
||||
private[spark] class Executor extends Logging {
|
||||
var urlClassLoader : ExecutorURLClassLoader = null
|
||||
var threadPool: ExecutorService = null
|
||||
var env: SparkEnv = null
|
||||
private[spark] class Executor(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) extends Logging {
|
||||
|
||||
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
|
||||
// Each map holds the master's timestamp for the version of that file or JAR we got.
|
||||
val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
|
||||
val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
|
||||
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
|
||||
private val currentJars: HashMap[String, Long] = new HashMap[String, Long]()
|
||||
|
||||
val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
|
||||
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
|
||||
|
||||
initLogging()
|
||||
|
||||
def initialize(executorId: String, slaveHostname: String, properties: Seq[(String, String)]) {
|
||||
// Make sure the local hostname we report matches the cluster scheduler's name for this host
|
||||
Utils.setCustomHostname(slaveHostname)
|
||||
// No ip or host:port - just hostname
|
||||
Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
|
||||
// must not have port specified.
|
||||
assert (0 == Utils.parseHostPort(slaveHostname)._2)
|
||||
|
||||
// Set spark.* system properties from executor arg
|
||||
for ((key, value) <- properties) {
|
||||
System.setProperty(key, value)
|
||||
}
|
||||
// Make sure the local hostname we report matches the cluster scheduler's name for this host
|
||||
Utils.setCustomHostname(slaveHostname)
|
||||
|
||||
// Create our ClassLoader and set it on this thread
|
||||
urlClassLoader = createClassLoader()
|
||||
Thread.currentThread.setContextClassLoader(urlClassLoader)
|
||||
// Set spark.* system properties from executor arg
|
||||
for ((key, value) <- properties) {
|
||||
System.setProperty(key, value)
|
||||
}
|
||||
|
||||
// Make any thread terminations due to uncaught exceptions kill the entire
|
||||
// executor process to avoid surprising stalls.
|
||||
Thread.setDefaultUncaughtExceptionHandler(
|
||||
new Thread.UncaughtExceptionHandler {
|
||||
override def uncaughtException(thread: Thread, exception: Throwable) {
|
||||
try {
|
||||
logError("Uncaught exception in thread " + thread, exception)
|
||||
|
||||
// We may have been called from a shutdown hook. If so, we must not call System.exit().
|
||||
// (If we do, we will deadlock.)
|
||||
if (!Utils.inShutdown()) {
|
||||
if (exception.isInstanceOf[OutOfMemoryError]) {
|
||||
System.exit(ExecutorExitCode.OOM)
|
||||
} else {
|
||||
System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
|
||||
}
|
||||
// Create our ClassLoader and set it on this thread
|
||||
private val urlClassLoader = createClassLoader()
|
||||
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
|
||||
Thread.currentThread.setContextClassLoader(replClassLoader)
|
||||
|
||||
// Make any thread terminations due to uncaught exceptions kill the entire
|
||||
// executor process to avoid surprising stalls.
|
||||
Thread.setDefaultUncaughtExceptionHandler(
|
||||
new Thread.UncaughtExceptionHandler {
|
||||
override def uncaughtException(thread: Thread, exception: Throwable) {
|
||||
try {
|
||||
logError("Uncaught exception in thread " + thread, exception)
|
||||
|
||||
// We may have been called from a shutdown hook. If so, we must not call System.exit().
|
||||
// (If we do, we will deadlock.)
|
||||
if (!Utils.inShutdown()) {
|
||||
if (exception.isInstanceOf[OutOfMemoryError]) {
|
||||
System.exit(ExecutorExitCode.OOM)
|
||||
} else {
|
||||
System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
|
||||
}
|
||||
} catch {
|
||||
case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
|
||||
case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
|
||||
}
|
||||
} catch {
|
||||
case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
|
||||
case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
// Initialize Spark environment (using system properties read above)
|
||||
env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
|
||||
SparkEnv.set(env)
|
||||
// Initialize Spark environment (using system properties read above)
|
||||
val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
|
||||
SparkEnv.set(env)
|
||||
private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
|
||||
|
||||
// Start worker thread pool
|
||||
threadPool = new ThreadPoolExecutor(
|
||||
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
|
||||
}
|
||||
// Start worker thread pool
|
||||
val threadPool = new ThreadPoolExecutor(
|
||||
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
|
||||
|
||||
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
|
||||
threadPool.execute(new TaskRunner(context, taskId, serializedTask))
|
||||
|
@ -85,8 +87,9 @@ private[spark] class Executor extends Logging {
|
|||
extends Runnable {
|
||||
|
||||
override def run() {
|
||||
val startTime = System.currentTimeMillis()
|
||||
SparkEnv.set(env)
|
||||
Thread.currentThread.setContextClassLoader(urlClassLoader)
|
||||
Thread.currentThread.setContextClassLoader(replClassLoader)
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
logInfo("Running task ID " + taskId)
|
||||
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
|
||||
|
@ -98,11 +101,25 @@ private[spark] class Executor extends Logging {
|
|||
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
|
||||
logInfo("Its generation is " + task.generation)
|
||||
env.mapOutputTracker.updateGeneration(task.generation)
|
||||
val taskStart = System.currentTimeMillis()
|
||||
val value = task.run(taskId.toInt)
|
||||
val taskFinish = System.currentTimeMillis()
|
||||
task.metrics.foreach{ m =>
|
||||
m.hostname = Utils.localHostName
|
||||
m.executorDeserializeTime = (taskStart - startTime).toInt
|
||||
m.executorRunTime = (taskFinish - taskStart).toInt
|
||||
}
|
||||
//TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
|
||||
// we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
|
||||
// just change the relevants bytes in the byte buffer
|
||||
val accumUpdates = Accumulators.values
|
||||
val result = new TaskResult(value, accumUpdates)
|
||||
val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
|
||||
val serializedResult = ser.serialize(result)
|
||||
logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
|
||||
if (serializedResult.limit >= (akkaFrameSize - 1024)) {
|
||||
context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure()))
|
||||
return
|
||||
}
|
||||
context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
|
||||
logInfo("Finished task ID " + taskId)
|
||||
} catch {
|
||||
|
@ -112,7 +129,7 @@ private[spark] class Executor extends Logging {
|
|||
}
|
||||
|
||||
case t: Throwable => {
|
||||
val reason = ExceptionFailure(t)
|
||||
val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
|
||||
context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
|
||||
|
||||
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
|
||||
|
@ -137,26 +154,31 @@ private[spark] class Executor extends Logging {
|
|||
val urls = currentJars.keySet.map { uri =>
|
||||
new File(uri.split("/").last).toURI.toURL
|
||||
}.toArray
|
||||
loader = new URLClassLoader(urls, loader)
|
||||
new ExecutorURLClassLoader(urls, loader)
|
||||
}
|
||||
|
||||
// If the REPL is in use, add another ClassLoader that will read
|
||||
// new classes defined by the REPL as the user types code
|
||||
/**
|
||||
* If the REPL is in use, add another ClassLoader that will read
|
||||
* new classes defined by the REPL as the user types code
|
||||
*/
|
||||
private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
|
||||
val classUri = System.getProperty("spark.repl.class.uri")
|
||||
if (classUri != null) {
|
||||
logInfo("Using REPL class URI: " + classUri)
|
||||
loader = {
|
||||
try {
|
||||
val klass = Class.forName("spark.repl.ExecutorClassLoader")
|
||||
.asInstanceOf[Class[_ <: ClassLoader]]
|
||||
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
|
||||
constructor.newInstance(classUri, loader)
|
||||
} catch {
|
||||
case _: ClassNotFoundException => loader
|
||||
}
|
||||
try {
|
||||
val klass = Class.forName("spark.repl.ExecutorClassLoader")
|
||||
.asInstanceOf[Class[_ <: ClassLoader]]
|
||||
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
|
||||
return constructor.newInstance(classUri, parent)
|
||||
} catch {
|
||||
case _: ClassNotFoundException =>
|
||||
logError("Could not find spark.repl.ExecutorClassLoader on classpath!")
|
||||
System.exit(1)
|
||||
null
|
||||
}
|
||||
} else {
|
||||
return parent
|
||||
}
|
||||
|
||||
return new ExecutorURLClassLoader(Array(), loader)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -8,11 +8,12 @@ import com.google.protobuf.ByteString
|
|||
import spark.{Utils, Logging}
|
||||
import spark.TaskState
|
||||
|
||||
private[spark] class MesosExecutorBackend(executor: Executor)
|
||||
private[spark] class MesosExecutorBackend
|
||||
extends MesosExecutor
|
||||
with ExecutorBackend
|
||||
with Logging {
|
||||
|
||||
var executor: Executor = null
|
||||
var driver: ExecutorDriver = null
|
||||
|
||||
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
|
||||
|
@ -32,16 +33,19 @@ private[spark] class MesosExecutorBackend(executor: Executor)
|
|||
logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
|
||||
this.driver = driver
|
||||
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray)
|
||||
executor.initialize(
|
||||
executor = new Executor(
|
||||
executorInfo.getExecutorId.getValue,
|
||||
slaveInfo.getHostname,
|
||||
properties
|
||||
)
|
||||
properties)
|
||||
}
|
||||
|
||||
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
|
||||
val taskId = taskInfo.getTaskId.getValue.toLong
|
||||
executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer)
|
||||
if (executor == null) {
|
||||
logError("Received launchTask but executor was null")
|
||||
} else {
|
||||
executor.launchTask(this, taskId, taskInfo.getData.asReadOnlyByteBuffer)
|
||||
}
|
||||
}
|
||||
|
||||
override def error(d: ExecutorDriver, message: String) {
|
||||
|
@ -68,7 +72,7 @@ private[spark] object MesosExecutorBackend {
|
|||
def main(args: Array[String]) {
|
||||
MesosNativeLibrary.load()
|
||||
// Create a new Executor and start it running
|
||||
val runner = new MesosExecutorBackend(new Executor)
|
||||
val runner = new MesosExecutorBackend()
|
||||
new MesosExecutorDriver(runner).run()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,23 +12,27 @@ import spark.scheduler.cluster.RegisteredExecutor
|
|||
import spark.scheduler.cluster.LaunchTask
|
||||
import spark.scheduler.cluster.RegisterExecutorFailed
|
||||
import spark.scheduler.cluster.RegisterExecutor
|
||||
import spark.Utils
|
||||
import spark.deploy.SparkHadoopUtil
|
||||
|
||||
private[spark] class StandaloneExecutorBackend(
|
||||
executor: Executor,
|
||||
driverUrl: String,
|
||||
executorId: String,
|
||||
hostname: String,
|
||||
hostPort: String,
|
||||
cores: Int)
|
||||
extends Actor
|
||||
with ExecutorBackend
|
||||
with Logging {
|
||||
|
||||
Utils.checkHostPort(hostPort, "Expected hostport")
|
||||
|
||||
var executor: Executor = null
|
||||
var driver: ActorRef = null
|
||||
|
||||
override def preStart() {
|
||||
logInfo("Connecting to driver: " + driverUrl)
|
||||
driver = context.actorFor(driverUrl)
|
||||
driver ! RegisterExecutor(executorId, hostname, cores)
|
||||
driver ! RegisterExecutor(executorId, hostPort, cores)
|
||||
context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
|
||||
context.watch(driver) // Doesn't work with remote actors, but useful for testing
|
||||
}
|
||||
|
@ -36,7 +40,8 @@ private[spark] class StandaloneExecutorBackend(
|
|||
override def receive = {
|
||||
case RegisteredExecutor(sparkProperties) =>
|
||||
logInfo("Successfully registered with driver")
|
||||
executor.initialize(executorId, hostname, sparkProperties)
|
||||
// Make this host instead of hostPort ?
|
||||
executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties)
|
||||
|
||||
case RegisterExecutorFailed(message) =>
|
||||
logError("Slave registration failed: " + message)
|
||||
|
@ -44,7 +49,12 @@ private[spark] class StandaloneExecutorBackend(
|
|||
|
||||
case LaunchTask(taskDesc) =>
|
||||
logInfo("Got assigned task " + taskDesc.taskId)
|
||||
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
|
||||
if (executor == null) {
|
||||
logError("Received launchTask but executor was null")
|
||||
System.exit(1)
|
||||
} else {
|
||||
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
|
||||
}
|
||||
|
||||
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
|
||||
logError("Driver terminated or disconnected! Shutting down.")
|
||||
|
@ -58,11 +68,30 @@ private[spark] class StandaloneExecutorBackend(
|
|||
|
||||
private[spark] object StandaloneExecutorBackend {
|
||||
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
|
||||
SparkHadoopUtil.runAsUser(run0, Tuple4[Any, Any, Any, Any] (driverUrl, executorId, hostname, cores))
|
||||
}
|
||||
|
||||
// This will be run 'as' the user
|
||||
def run0(args: Product) {
|
||||
assert(4 == args.productArity)
|
||||
runImpl(args.productElement(0).asInstanceOf[String],
|
||||
args.productElement(1).asInstanceOf[String],
|
||||
args.productElement(2).asInstanceOf[String],
|
||||
args.productElement(3).asInstanceOf[Int])
|
||||
}
|
||||
|
||||
private def runImpl(driverUrl: String, executorId: String, hostname: String, cores: Int) {
|
||||
// Debug code
|
||||
Utils.checkHost(hostname)
|
||||
|
||||
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
|
||||
// before getting started with all our system properties, etc
|
||||
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
|
||||
// set it
|
||||
val sparkHostPort = hostname + ":" + boundPort
|
||||
System.setProperty("spark.hostPort", sparkHostPort)
|
||||
val actor = actorSystem.actorOf(
|
||||
Props(new StandaloneExecutorBackend(new Executor, driverUrl, executorId, hostname, cores)),
|
||||
Props(new StandaloneExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
|
||||
name = "Executor")
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
|
83
core/src/main/scala/spark/executor/TaskMetrics.scala
Normal file
83
core/src/main/scala/spark/executor/TaskMetrics.scala
Normal file
|
@ -0,0 +1,83 @@
|
|||
package spark.executor
|
||||
|
||||
class TaskMetrics extends Serializable {
|
||||
/**
|
||||
* Host's name the task runs on
|
||||
*/
|
||||
var hostname: String = _
|
||||
|
||||
/**
|
||||
* Time taken on the executor to deserialize this task
|
||||
*/
|
||||
var executorDeserializeTime: Int = _
|
||||
|
||||
/**
|
||||
* Time the executor spends actually running the task (including fetching shuffle data)
|
||||
*/
|
||||
var executorRunTime:Int = _
|
||||
|
||||
/**
|
||||
* The number of bytes this task transmitted back to the driver as the TaskResult
|
||||
*/
|
||||
var resultSize: Long = _
|
||||
|
||||
/**
|
||||
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here
|
||||
*/
|
||||
var shuffleReadMetrics: Option[ShuffleReadMetrics] = None
|
||||
|
||||
/**
|
||||
* If this task writes to shuffle output, metrics on the written shuffle data will be collected here
|
||||
*/
|
||||
var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
|
||||
}
|
||||
|
||||
object TaskMetrics {
|
||||
private[spark] def empty(): TaskMetrics = new TaskMetrics
|
||||
}
|
||||
|
||||
|
||||
class ShuffleReadMetrics extends Serializable {
|
||||
/**
|
||||
* Time when shuffle finishs
|
||||
*/
|
||||
var shuffleFinishTime: Long = _
|
||||
|
||||
/**
|
||||
* Total number of blocks fetched in a shuffle (remote or local)
|
||||
*/
|
||||
var totalBlocksFetched: Int = _
|
||||
|
||||
/**
|
||||
* Number of remote blocks fetched in a shuffle
|
||||
*/
|
||||
var remoteBlocksFetched: Int = _
|
||||
|
||||
/**
|
||||
* Local blocks fetched in a shuffle
|
||||
*/
|
||||
var localBlocksFetched: Int = _
|
||||
|
||||
/**
|
||||
* Total time that is spent blocked waiting for shuffle to fetch data
|
||||
*/
|
||||
var fetchWaitTime: Long = _
|
||||
|
||||
/**
|
||||
* The total amount of time for all the shuffle fetches. This adds up time from overlapping
|
||||
* shuffles, so can be longer than task time
|
||||
*/
|
||||
var remoteFetchTime: Long = _
|
||||
|
||||
/**
|
||||
* Total number of remote bytes read from a shuffle
|
||||
*/
|
||||
var remoteBytesRead: Long = _
|
||||
}
|
||||
|
||||
class ShuffleWriteMetrics extends Serializable {
|
||||
/**
|
||||
* Number of bytes written for a shuffle
|
||||
*/
|
||||
var shuffleBytesWritten: Long = _
|
||||
}
|
94
core/src/main/scala/spark/network/BufferMessage.scala
Normal file
94
core/src/main/scala/spark/network/BufferMessage.scala
Normal file
|
@ -0,0 +1,94 @@
|
|||
package spark.network
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import spark.storage.BlockManager
|
||||
|
||||
|
||||
private[spark]
|
||||
class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
|
||||
extends Message(Message.BUFFER_MESSAGE, id_) {
|
||||
|
||||
val initialSize = currentSize()
|
||||
var gotChunkForSendingOnce = false
|
||||
|
||||
def size = initialSize
|
||||
|
||||
def currentSize() = {
|
||||
if (buffers == null || buffers.isEmpty) {
|
||||
0
|
||||
} else {
|
||||
buffers.map(_.remaining).reduceLeft(_ + _)
|
||||
}
|
||||
}
|
||||
|
||||
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
|
||||
if (maxChunkSize <= 0) {
|
||||
throw new Exception("Max chunk size is " + maxChunkSize)
|
||||
}
|
||||
|
||||
if (size == 0 && gotChunkForSendingOnce == false) {
|
||||
val newChunk = new MessageChunk(
|
||||
new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
|
||||
gotChunkForSendingOnce = true
|
||||
return Some(newChunk)
|
||||
}
|
||||
|
||||
while(!buffers.isEmpty) {
|
||||
val buffer = buffers(0)
|
||||
if (buffer.remaining == 0) {
|
||||
BlockManager.dispose(buffer)
|
||||
buffers -= buffer
|
||||
} else {
|
||||
val newBuffer = if (buffer.remaining <= maxChunkSize) {
|
||||
buffer.duplicate()
|
||||
} else {
|
||||
buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
|
||||
}
|
||||
buffer.position(buffer.position + newBuffer.remaining)
|
||||
val newChunk = new MessageChunk(new MessageChunkHeader(
|
||||
typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
|
||||
gotChunkForSendingOnce = true
|
||||
return Some(newChunk)
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
|
||||
// STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
|
||||
if (buffers.size > 1) {
|
||||
throw new Exception("Attempting to get chunk from message with multiple data buffers")
|
||||
}
|
||||
val buffer = buffers(0)
|
||||
if (buffer.remaining > 0) {
|
||||
if (buffer.remaining < chunkSize) {
|
||||
throw new Exception("Not enough space in data buffer for receiving chunk")
|
||||
}
|
||||
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
|
||||
buffer.position(buffer.position + newBuffer.remaining)
|
||||
val newChunk = new MessageChunk(new MessageChunkHeader(
|
||||
typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
|
||||
return Some(newChunk)
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
def flip() {
|
||||
buffers.foreach(_.flip)
|
||||
}
|
||||
|
||||
def hasAckId() = (ackId != 0)
|
||||
|
||||
def isCompletelyReceived() = !buffers(0).hasRemaining
|
||||
|
||||
override def toString = {
|
||||
if (hasAckId) {
|
||||
"BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
|
||||
} else {
|
||||
"BufferMessage(id = " + id + ", size = " + size + ")"
|
||||
}
|
||||
}
|
||||
}
|
|
@ -13,12 +13,13 @@ import java.net._
|
|||
|
||||
private[spark]
|
||||
abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
||||
val remoteConnectionManagerId: ConnectionManagerId) extends Logging {
|
||||
val socketRemoteConnectionManagerId: ConnectionManagerId)
|
||||
extends Logging {
|
||||
|
||||
def this(channel_ : SocketChannel, selector_ : Selector) = {
|
||||
this(channel_, selector_,
|
||||
ConnectionManagerId.fromSocketAddress(
|
||||
channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
|
||||
))
|
||||
ConnectionManagerId.fromSocketAddress(
|
||||
channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]))
|
||||
}
|
||||
|
||||
channel.configureBlocking(false)
|
||||
|
@ -33,16 +34,47 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
|||
|
||||
val remoteAddress = getRemoteAddress()
|
||||
|
||||
// Read channels typically do not register for write and write does not for read
|
||||
// Now, we do have write registering for read too (temporarily), but this is to detect
|
||||
// channel close NOT to actually read/consume data on it !
|
||||
// How does this work if/when we move to SSL ?
|
||||
|
||||
// What is the interest to register with selector for when we want this connection to be selected
|
||||
def registerInterest()
|
||||
|
||||
// What is the interest to register with selector for when we want this connection to
|
||||
// be de-selected
|
||||
// Traditionally, 0 - but in our case, for example, for close-detection on SendingConnection hack,
|
||||
// it will be SelectionKey.OP_READ (until we fix it properly)
|
||||
def unregisterInterest()
|
||||
|
||||
// On receiving a read event, should we change the interest for this channel or not ?
|
||||
// Will be true for ReceivingConnection, false for SendingConnection.
|
||||
def changeInterestForRead(): Boolean
|
||||
|
||||
// On receiving a write event, should we change the interest for this channel or not ?
|
||||
// Will be false for ReceivingConnection, true for SendingConnection.
|
||||
// Actually, for now, should not get triggered for ReceivingConnection
|
||||
def changeInterestForWrite(): Boolean
|
||||
|
||||
def getRemoteConnectionManagerId(): ConnectionManagerId = {
|
||||
socketRemoteConnectionManagerId
|
||||
}
|
||||
|
||||
def key() = channel.keyFor(selector)
|
||||
|
||||
def getRemoteAddress() = channel.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]
|
||||
|
||||
def read() {
|
||||
throw new UnsupportedOperationException("Cannot read on connection of type " + this.getClass.toString)
|
||||
// Returns whether we have to register for further reads or not.
|
||||
def read(): Boolean = {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot read on connection of type " + this.getClass.toString)
|
||||
}
|
||||
|
||||
def write() {
|
||||
throw new UnsupportedOperationException("Cannot write on connection of type " + this.getClass.toString)
|
||||
|
||||
// Returns whether we have to register for further writes or not.
|
||||
def write(): Boolean = {
|
||||
throw new UnsupportedOperationException(
|
||||
"Cannot write on connection of type " + this.getClass.toString)
|
||||
}
|
||||
|
||||
def close() {
|
||||
|
@ -54,26 +86,32 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
|||
callOnCloseCallback()
|
||||
}
|
||||
|
||||
def onClose(callback: Connection => Unit) {onCloseCallback = callback}
|
||||
def onClose(callback: Connection => Unit) {
|
||||
onCloseCallback = callback
|
||||
}
|
||||
|
||||
def onException(callback: (Connection, Exception) => Unit) {onExceptionCallback = callback}
|
||||
def onException(callback: (Connection, Exception) => Unit) {
|
||||
onExceptionCallback = callback
|
||||
}
|
||||
|
||||
def onKeyInterestChange(callback: (Connection, Int) => Unit) {onKeyInterestChangeCallback = callback}
|
||||
def onKeyInterestChange(callback: (Connection, Int) => Unit) {
|
||||
onKeyInterestChangeCallback = callback
|
||||
}
|
||||
|
||||
def callOnExceptionCallback(e: Exception) {
|
||||
if (onExceptionCallback != null) {
|
||||
onExceptionCallback(this, e)
|
||||
} else {
|
||||
logError("Error in connection to " + remoteConnectionManagerId +
|
||||
logError("Error in connection to " + getRemoteConnectionManagerId() +
|
||||
" and OnExceptionCallback not registered", e)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def callOnCloseCallback() {
|
||||
if (onCloseCallback != null) {
|
||||
onCloseCallback(this)
|
||||
} else {
|
||||
logWarning("Connection to " + remoteConnectionManagerId +
|
||||
logWarning("Connection to " + getRemoteConnectionManagerId() +
|
||||
" closed and OnExceptionCallback not registered")
|
||||
}
|
||||
|
||||
|
@ -81,7 +119,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
|||
|
||||
def changeConnectionKeyInterest(ops: Int) {
|
||||
if (onKeyInterestChangeCallback != null) {
|
||||
onKeyInterestChangeCallback(this, ops)
|
||||
onKeyInterestChangeCallback(this, ops)
|
||||
} else {
|
||||
throw new Exception("OnKeyInterestChangeCallback not registered")
|
||||
}
|
||||
|
@ -105,24 +143,25 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
|
|||
print(" (" + position + ", " + length + ")")
|
||||
buffer.position(curPosition)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
private[spark] class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
|
||||
remoteId_ : ConnectionManagerId)
|
||||
extends Connection(SocketChannel.open, selector_, remoteId_) {
|
||||
private[spark]
|
||||
class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
|
||||
remoteId_ : ConnectionManagerId)
|
||||
extends Connection(SocketChannel.open, selector_, remoteId_) {
|
||||
|
||||
class Outbox(fair: Int = 0) {
|
||||
val messages = new Queue[Message]()
|
||||
val defaultChunkSize = 65536 //32768 //16384
|
||||
val defaultChunkSize = 65536 //32768 //16384
|
||||
var nextMessageToBeUsed = 0
|
||||
|
||||
def addMessage(message: Message) {
|
||||
messages.synchronized{
|
||||
messages.synchronized{
|
||||
/*messages += message*/
|
||||
messages.enqueue(message)
|
||||
logDebug("Added [" + message + "] to outbox for sending to [" + remoteConnectionManagerId + "]")
|
||||
logDebug("Added [" + message + "] to outbox for sending to " +
|
||||
"[" + getRemoteConnectionManagerId() + "]")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -147,18 +186,18 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
message.started = true
|
||||
message.startTime = System.currentTimeMillis
|
||||
}
|
||||
return chunk
|
||||
return chunk
|
||||
} else {
|
||||
/*logInfo("Finished sending [" + message + "] to [" + remoteConnectionManagerId + "]")*/
|
||||
/*logInfo("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() + "]")*/
|
||||
message.finishTime = System.currentTimeMillis
|
||||
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
|
||||
logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
|
||||
"] in " + message.timeTaken )
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
|
||||
private def getChunkRR(): Option[MessageChunk] = {
|
||||
messages.synchronized {
|
||||
while (!messages.isEmpty) {
|
||||
|
@ -170,15 +209,17 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
messages.enqueue(message)
|
||||
nextMessageToBeUsed = nextMessageToBeUsed + 1
|
||||
if (!message.started) {
|
||||
logDebug("Starting to send [" + message + "] to [" + remoteConnectionManagerId + "]")
|
||||
logDebug(
|
||||
"Starting to send [" + message + "] to [" + getRemoteConnectionManagerId() + "]")
|
||||
message.started = true
|
||||
message.startTime = System.currentTimeMillis
|
||||
}
|
||||
logTrace("Sending chunk from [" + message+ "] to [" + remoteConnectionManagerId + "]")
|
||||
return chunk
|
||||
logTrace(
|
||||
"Sending chunk from [" + message+ "] to [" + getRemoteConnectionManagerId() + "]")
|
||||
return chunk
|
||||
} else {
|
||||
message.finishTime = System.currentTimeMillis
|
||||
logDebug("Finished sending [" + message + "] to [" + remoteConnectionManagerId +
|
||||
logDebug("Finished sending [" + message + "] to [" + getRemoteConnectionManagerId() +
|
||||
"] in " + message.timeTaken )
|
||||
}
|
||||
}
|
||||
|
@ -186,27 +227,40 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
None
|
||||
}
|
||||
}
|
||||
|
||||
val outbox = new Outbox(1)
|
||||
|
||||
private val outbox = new Outbox(1)
|
||||
val currentBuffers = new ArrayBuffer[ByteBuffer]()
|
||||
|
||||
/*channel.socket.setSendBufferSize(256 * 1024)*/
|
||||
|
||||
override def getRemoteAddress() = address
|
||||
override def getRemoteAddress() = address
|
||||
|
||||
val DEFAULT_INTEREST = SelectionKey.OP_READ
|
||||
|
||||
override def registerInterest() {
|
||||
// Registering read too - does not really help in most cases, but for some
|
||||
// it does - so let us keep it for now.
|
||||
changeConnectionKeyInterest(SelectionKey.OP_WRITE | DEFAULT_INTEREST)
|
||||
}
|
||||
|
||||
override def unregisterInterest() {
|
||||
changeConnectionKeyInterest(DEFAULT_INTEREST)
|
||||
}
|
||||
|
||||
def send(message: Message) {
|
||||
outbox.synchronized {
|
||||
outbox.addMessage(message)
|
||||
if (channel.isConnected) {
|
||||
changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
|
||||
registerInterest()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MUST be called within the selector loop
|
||||
def connect() {
|
||||
try{
|
||||
channel.connect(address)
|
||||
channel.register(selector, SelectionKey.OP_CONNECT)
|
||||
channel.connect(address)
|
||||
logInfo("Initiating connection to [" + address + "]")
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
|
@ -216,36 +270,52 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
}
|
||||
}
|
||||
|
||||
def finishConnect() {
|
||||
def finishConnect(force: Boolean): Boolean = {
|
||||
try {
|
||||
channel.finishConnect
|
||||
changeConnectionKeyInterest(SelectionKey.OP_WRITE | SelectionKey.OP_READ)
|
||||
// Typically, this should finish immediately since it was triggered by a connect
|
||||
// selection - though need not necessarily always complete successfully.
|
||||
val connected = channel.finishConnect
|
||||
if (!force && !connected) {
|
||||
logInfo(
|
||||
"finish connect failed [" + address + "], " + outbox.messages.size + " messages pending")
|
||||
return false
|
||||
}
|
||||
|
||||
// Fallback to previous behavior - assume finishConnect completed
|
||||
// This will happen only when finishConnect failed for some repeated number of times
|
||||
// (10 or so)
|
||||
// Is highly unlikely unless there was an unclean close of socket, etc
|
||||
registerInterest()
|
||||
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
|
||||
return true
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logWarning("Error finishing connection to " + address, e)
|
||||
callOnExceptionCallback(e)
|
||||
// ignore
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def write() {
|
||||
try{
|
||||
while(true) {
|
||||
override def write(): Boolean = {
|
||||
try {
|
||||
while (true) {
|
||||
if (currentBuffers.size == 0) {
|
||||
outbox.synchronized {
|
||||
outbox.getChunk() match {
|
||||
case Some(chunk) => {
|
||||
currentBuffers ++= chunk.buffers
|
||||
currentBuffers ++= chunk.buffers
|
||||
}
|
||||
case None => {
|
||||
changeConnectionKeyInterest(SelectionKey.OP_READ)
|
||||
return
|
||||
// changeConnectionKeyInterest(0)
|
||||
/*key.interestOps(0)*/
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (currentBuffers.size > 0) {
|
||||
val buffer = currentBuffers(0)
|
||||
val remainingBytes = buffer.remaining
|
||||
|
@ -254,69 +324,109 @@ extends Connection(SocketChannel.open, selector_, remoteId_) {
|
|||
currentBuffers -= buffer
|
||||
}
|
||||
if (writtenBytes < remainingBytes) {
|
||||
return
|
||||
// re-register for write.
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logWarning("Error writing in connection to " + remoteConnectionManagerId, e)
|
||||
case e: Exception => {
|
||||
logWarning("Error writing in connection to " + getRemoteConnectionManagerId(), e)
|
||||
callOnExceptionCallback(e)
|
||||
close()
|
||||
return false
|
||||
}
|
||||
}
|
||||
// should not happen - to keep scala compiler happy
|
||||
return true
|
||||
}
|
||||
|
||||
override def read() {
|
||||
// This is a hack to determine if remote socket was closed or not.
|
||||
// SendingConnection DOES NOT expect to receive any data - if it does, it is an error
|
||||
// For a bunch of cases, read will return -1 in case remote socket is closed : hence we
|
||||
// register for reads to determine that.
|
||||
override def read(): Boolean = {
|
||||
// We don't expect the other side to send anything; so, we just read to detect an error or EOF.
|
||||
try {
|
||||
val length = channel.read(ByteBuffer.allocate(1))
|
||||
if (length == -1) { // EOF
|
||||
close()
|
||||
} else if (length > 0) {
|
||||
logWarning("Unexpected data read from SendingConnection to " + remoteConnectionManagerId)
|
||||
logWarning(
|
||||
"Unexpected data read from SendingConnection to " + getRemoteConnectionManagerId())
|
||||
}
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("Exception while reading SendingConnection to " + remoteConnectionManagerId, e)
|
||||
logError("Exception while reading SendingConnection to " + getRemoteConnectionManagerId(), e)
|
||||
callOnExceptionCallback(e)
|
||||
close()
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
override def changeInterestForRead(): Boolean = false
|
||||
|
||||
override def changeInterestForWrite(): Boolean = true
|
||||
}
|
||||
|
||||
|
||||
private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
|
||||
extends Connection(channel_, selector_) {
|
||||
|
||||
// Must be created within selector loop - else deadlock
|
||||
private[spark] class ReceivingConnection(channel_ : SocketChannel, selector_ : Selector)
|
||||
extends Connection(channel_, selector_) {
|
||||
|
||||
class Inbox() {
|
||||
val messages = new HashMap[Int, BufferMessage]()
|
||||
|
||||
|
||||
def getChunk(header: MessageChunkHeader): Option[MessageChunk] = {
|
||||
|
||||
|
||||
def createNewMessage: BufferMessage = {
|
||||
val newMessage = Message.create(header).asInstanceOf[BufferMessage]
|
||||
newMessage.started = true
|
||||
newMessage.startTime = System.currentTimeMillis
|
||||
logDebug("Starting to receive [" + newMessage + "] from [" + remoteConnectionManagerId + "]")
|
||||
logDebug(
|
||||
"Starting to receive [" + newMessage + "] from [" + getRemoteConnectionManagerId() + "]")
|
||||
messages += ((newMessage.id, newMessage))
|
||||
newMessage
|
||||
}
|
||||
|
||||
|
||||
val message = messages.getOrElseUpdate(header.id, createNewMessage)
|
||||
logTrace("Receiving chunk of [" + message + "] from [" + remoteConnectionManagerId + "]")
|
||||
logTrace(
|
||||
"Receiving chunk of [" + message + "] from [" + getRemoteConnectionManagerId() + "]")
|
||||
message.getChunkForReceiving(header.chunkSize)
|
||||
}
|
||||
|
||||
|
||||
def getMessageForChunk(chunk: MessageChunk): Option[BufferMessage] = {
|
||||
messages.get(chunk.header.id)
|
||||
messages.get(chunk.header.id)
|
||||
}
|
||||
|
||||
def removeMessage(message: Message) {
|
||||
messages -= message.id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@volatile private var inferredRemoteManagerId: ConnectionManagerId = null
|
||||
|
||||
override def getRemoteConnectionManagerId(): ConnectionManagerId = {
|
||||
val currId = inferredRemoteManagerId
|
||||
if (currId != null) currId else super.getRemoteConnectionManagerId()
|
||||
}
|
||||
|
||||
// The reciever's remote address is the local socket on remote side : which is NOT
|
||||
// the connection manager id of the receiver.
|
||||
// We infer that from the messages we receive on the receiver socket.
|
||||
private def processConnectionManagerId(header: MessageChunkHeader) {
|
||||
val currId = inferredRemoteManagerId
|
||||
if (header.address == null || currId != null) return
|
||||
|
||||
val managerId = ConnectionManagerId.fromSocketAddress(header.address)
|
||||
|
||||
if (managerId != null) {
|
||||
inferredRemoteManagerId = managerId
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
val inbox = new Inbox()
|
||||
val headerBuffer: ByteBuffer = ByteBuffer.allocate(MessageChunkHeader.HEADER_SIZE)
|
||||
var onReceiveCallback: (Connection , Message) => Unit = null
|
||||
|
@ -324,24 +434,29 @@ extends Connection(channel_, selector_) {
|
|||
|
||||
channel.register(selector, SelectionKey.OP_READ)
|
||||
|
||||
override def read() {
|
||||
override def read(): Boolean = {
|
||||
try {
|
||||
while (true) {
|
||||
if (currentChunk == null) {
|
||||
val headerBytesRead = channel.read(headerBuffer)
|
||||
if (headerBytesRead == -1) {
|
||||
close()
|
||||
return
|
||||
return false
|
||||
}
|
||||
if (headerBuffer.remaining > 0) {
|
||||
return
|
||||
// re-register for read event ...
|
||||
return true
|
||||
}
|
||||
headerBuffer.flip
|
||||
if (headerBuffer.remaining != MessageChunkHeader.HEADER_SIZE) {
|
||||
throw new Exception("Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
|
||||
throw new Exception(
|
||||
"Unexpected number of bytes (" + headerBuffer.remaining + ") in the header")
|
||||
}
|
||||
val header = MessageChunkHeader.create(headerBuffer)
|
||||
headerBuffer.clear()
|
||||
|
||||
processConnectionManagerId(header)
|
||||
|
||||
header.typ match {
|
||||
case Message.BUFFER_MESSAGE => {
|
||||
if (header.totalSize == 0) {
|
||||
|
@ -349,7 +464,8 @@ extends Connection(channel_, selector_) {
|
|||
onReceiveCallback(this, Message.create(header))
|
||||
}
|
||||
currentChunk = null
|
||||
return
|
||||
// re-register for read event ...
|
||||
return true
|
||||
} else {
|
||||
currentChunk = inbox.getChunk(header).orNull
|
||||
}
|
||||
|
@ -357,26 +473,28 @@ extends Connection(channel_, selector_) {
|
|||
case _ => throw new Exception("Message of unknown type received")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (currentChunk == null) throw new Exception("No message chunk to receive data")
|
||||
|
||||
|
||||
val bytesRead = channel.read(currentChunk.buffer)
|
||||
if (bytesRead == 0) {
|
||||
return
|
||||
// re-register for read event ...
|
||||
return true
|
||||
} else if (bytesRead == -1) {
|
||||
close()
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
/*logDebug("Read " + bytesRead + " bytes for the buffer")*/
|
||||
|
||||
|
||||
if (currentChunk.buffer.remaining == 0) {
|
||||
/*println("Filled buffer at " + System.currentTimeMillis)*/
|
||||
val bufferMessage = inbox.getMessageForChunk(currentChunk).get
|
||||
if (bufferMessage.isCompletelyReceived) {
|
||||
bufferMessage.flip
|
||||
bufferMessage.finishTime = System.currentTimeMillis
|
||||
logDebug("Finished receiving [" + bufferMessage + "] from [" + remoteConnectionManagerId + "] in " + bufferMessage.timeTaken)
|
||||
logDebug("Finished receiving [" + bufferMessage + "] from " +
|
||||
"[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
|
||||
if (onReceiveCallback != null) {
|
||||
onReceiveCallback(this, bufferMessage)
|
||||
}
|
||||
|
@ -386,13 +504,32 @@ extends Connection(channel_, selector_) {
|
|||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logWarning("Error reading from connection to " + remoteConnectionManagerId, e)
|
||||
case e: Exception => {
|
||||
logWarning("Error reading from connection to " + getRemoteConnectionManagerId(), e)
|
||||
callOnExceptionCallback(e)
|
||||
close()
|
||||
return false
|
||||
}
|
||||
}
|
||||
// should not happen - to keep scala compiler happy
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
def onReceive(callback: (Connection, Message) => Unit) {onReceiveCallback = callback}
|
||||
|
||||
override def changeInterestForRead(): Boolean = true
|
||||
|
||||
override def changeInterestForWrite(): Boolean = {
|
||||
throw new IllegalStateException("Unexpected invocation right now")
|
||||
}
|
||||
|
||||
override def registerInterest() {
|
||||
// Registering read too - does not really help in most cases, but for some
|
||||
// it does - so let us keep it for now.
|
||||
changeConnectionKeyInterest(SelectionKey.OP_READ)
|
||||
}
|
||||
|
||||
override def unregisterInterest() {
|
||||
changeConnectionKeyInterest(0)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,28 +6,19 @@ import java.nio._
|
|||
import java.nio.channels._
|
||||
import java.nio.channels.spi._
|
||||
import java.net._
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.{LinkedBlockingDeque, TimeUnit, ThreadPoolExecutor}
|
||||
|
||||
import scala.collection.mutable.HashSet
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.mutable.SynchronizedMap
|
||||
import scala.collection.mutable.SynchronizedQueue
|
||||
import scala.collection.mutable.Queue
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import akka.dispatch.{Await, Promise, ExecutionContext, Future}
|
||||
import akka.util.Duration
|
||||
import akka.util.duration._
|
||||
|
||||
private[spark] case class ConnectionManagerId(host: String, port: Int) {
|
||||
def toSocketAddress() = new InetSocketAddress(host, port)
|
||||
}
|
||||
|
||||
private[spark] object ConnectionManagerId {
|
||||
def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
|
||||
new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class ConnectionManager(port: Int) extends Logging {
|
||||
|
||||
class MessageStatus(
|
||||
|
@ -41,73 +32,263 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
|
||||
def markDone() { completionHandler(this) }
|
||||
}
|
||||
|
||||
val selector = SelectorProvider.provider.openSelector()
|
||||
val handleMessageExecutor = Executors.newFixedThreadPool(System.getProperty("spark.core.connection.handler.threads","20").toInt)
|
||||
val serverChannel = ServerSocketChannel.open()
|
||||
val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
|
||||
val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
|
||||
val messageStatuses = new HashMap[Int, MessageStatus]
|
||||
val connectionRequests = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
|
||||
val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
|
||||
val sendMessageRequests = new Queue[(Message, SendingConnection)]
|
||||
|
||||
private val selector = SelectorProvider.provider.openSelector()
|
||||
|
||||
private val handleMessageExecutor = new ThreadPoolExecutor(
|
||||
System.getProperty("spark.core.connection.handler.threads.min","20").toInt,
|
||||
System.getProperty("spark.core.connection.handler.threads.max","60").toInt,
|
||||
System.getProperty("spark.core.connection.handler.threads.keepalive","60").toInt, TimeUnit.SECONDS,
|
||||
new LinkedBlockingDeque[Runnable]())
|
||||
|
||||
private val handleReadWriteExecutor = new ThreadPoolExecutor(
|
||||
System.getProperty("spark.core.connection.io.threads.min","4").toInt,
|
||||
System.getProperty("spark.core.connection.io.threads.max","32").toInt,
|
||||
System.getProperty("spark.core.connection.io.threads.keepalive","60").toInt, TimeUnit.SECONDS,
|
||||
new LinkedBlockingDeque[Runnable]())
|
||||
|
||||
// Use a different, yet smaller, thread pool - infrequently used with very short lived tasks : which should be executed asap
|
||||
private val handleConnectExecutor = new ThreadPoolExecutor(
|
||||
System.getProperty("spark.core.connection.connect.threads.min","1").toInt,
|
||||
System.getProperty("spark.core.connection.connect.threads.max","8").toInt,
|
||||
System.getProperty("spark.core.connection.connect.threads.keepalive","60").toInt, TimeUnit.SECONDS,
|
||||
new LinkedBlockingDeque[Runnable]())
|
||||
|
||||
private val serverChannel = ServerSocketChannel.open()
|
||||
private val connectionsByKey = new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
|
||||
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection] with SynchronizedMap[ConnectionManagerId, SendingConnection]
|
||||
private val messageStatuses = new HashMap[Int, MessageStatus]
|
||||
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
|
||||
private val registerRequests = new SynchronizedQueue[SendingConnection]
|
||||
|
||||
implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
|
||||
|
||||
var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
|
||||
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
|
||||
|
||||
serverChannel.configureBlocking(false)
|
||||
serverChannel.socket.setReuseAddress(true)
|
||||
serverChannel.socket.setReceiveBufferSize(256 * 1024)
|
||||
serverChannel.socket.setReceiveBufferSize(256 * 1024)
|
||||
|
||||
serverChannel.socket.bind(new InetSocketAddress(port))
|
||||
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
|
||||
|
||||
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
|
||||
logInfo("Bound socket to port " + serverChannel.socket.getLocalPort() + " with id = " + id)
|
||||
|
||||
val selectorThread = new Thread("connection-manager-thread") {
|
||||
|
||||
private val selectorThread = new Thread("connection-manager-thread") {
|
||||
override def run() = ConnectionManager.this.run()
|
||||
}
|
||||
selectorThread.setDaemon(true)
|
||||
selectorThread.start()
|
||||
|
||||
private def run() {
|
||||
private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
|
||||
|
||||
private def triggerWrite(key: SelectionKey) {
|
||||
val conn = connectionsByKey.getOrElse(key, null)
|
||||
if (conn == null) return
|
||||
|
||||
writeRunnableStarted.synchronized {
|
||||
// So that we do not trigger more write events while processing this one.
|
||||
// The write method will re-register when done.
|
||||
if (conn.changeInterestForWrite()) conn.unregisterInterest()
|
||||
if (writeRunnableStarted.contains(key)) {
|
||||
// key.interestOps(key.interestOps() & ~ SelectionKey.OP_WRITE)
|
||||
return
|
||||
}
|
||||
|
||||
writeRunnableStarted += key
|
||||
}
|
||||
handleReadWriteExecutor.execute(new Runnable {
|
||||
override def run() {
|
||||
var register: Boolean = false
|
||||
try {
|
||||
register = conn.write()
|
||||
} finally {
|
||||
writeRunnableStarted.synchronized {
|
||||
writeRunnableStarted -= key
|
||||
if (register && conn.changeInterestForWrite()) {
|
||||
conn.registerInterest()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} )
|
||||
}
|
||||
|
||||
private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
|
||||
|
||||
private def triggerRead(key: SelectionKey) {
|
||||
val conn = connectionsByKey.getOrElse(key, null)
|
||||
if (conn == null) return
|
||||
|
||||
readRunnableStarted.synchronized {
|
||||
// So that we do not trigger more read events while processing this one.
|
||||
// The read method will re-register when done.
|
||||
if (conn.changeInterestForRead())conn.unregisterInterest()
|
||||
if (readRunnableStarted.contains(key)) {
|
||||
return
|
||||
}
|
||||
|
||||
readRunnableStarted += key
|
||||
}
|
||||
handleReadWriteExecutor.execute(new Runnable {
|
||||
override def run() {
|
||||
var register: Boolean = false
|
||||
try {
|
||||
register = conn.read()
|
||||
} finally {
|
||||
readRunnableStarted.synchronized {
|
||||
readRunnableStarted -= key
|
||||
if (register && conn.changeInterestForRead()) {
|
||||
conn.registerInterest()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} )
|
||||
}
|
||||
|
||||
private def triggerConnect(key: SelectionKey) {
|
||||
val conn = connectionsByKey.getOrElse(key, null).asInstanceOf[SendingConnection]
|
||||
if (conn == null) return
|
||||
|
||||
// prevent other events from being triggered
|
||||
// Since we are still trying to connect, we do not need to do the additional steps in triggerWrite
|
||||
conn.changeConnectionKeyInterest(0)
|
||||
|
||||
handleConnectExecutor.execute(new Runnable {
|
||||
override def run() {
|
||||
|
||||
var tries: Int = 10
|
||||
while (tries >= 0) {
|
||||
if (conn.finishConnect(false)) return
|
||||
// Sleep ?
|
||||
Thread.sleep(1)
|
||||
tries -= 1
|
||||
}
|
||||
|
||||
// fallback to previous behavior : we should not really come here since this method was
|
||||
// triggered since channel became connectable : but at times, the first finishConnect need not
|
||||
// succeed : hence the loop to retry a few 'times'.
|
||||
conn.finishConnect(true)
|
||||
}
|
||||
} )
|
||||
}
|
||||
|
||||
// MUST be called within selector loop - else deadlock.
|
||||
private def triggerForceCloseByException(key: SelectionKey, e: Exception) {
|
||||
try {
|
||||
key.interestOps(0)
|
||||
} catch {
|
||||
// ignore exceptions
|
||||
case e: Exception => logDebug("Ignoring exception", e)
|
||||
}
|
||||
|
||||
val conn = connectionsByKey.getOrElse(key, null)
|
||||
if (conn == null) return
|
||||
|
||||
// Pushing to connect threadpool
|
||||
handleConnectExecutor.execute(new Runnable {
|
||||
override def run() {
|
||||
try {
|
||||
conn.callOnExceptionCallback(e)
|
||||
} catch {
|
||||
// ignore exceptions
|
||||
case e: Exception => logDebug("Ignoring exception", e)
|
||||
}
|
||||
try {
|
||||
conn.close()
|
||||
} catch {
|
||||
// ignore exceptions
|
||||
case e: Exception => logDebug("Ignoring exception", e)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
def run() {
|
||||
try {
|
||||
while(!selectorThread.isInterrupted) {
|
||||
for ((connectionManagerId, sendingConnection) <- connectionRequests) {
|
||||
sendingConnection.connect()
|
||||
addConnection(sendingConnection)
|
||||
connectionRequests -= connectionManagerId
|
||||
}
|
||||
sendMessageRequests.synchronized {
|
||||
while (!sendMessageRequests.isEmpty) {
|
||||
val (message, connection) = sendMessageRequests.dequeue
|
||||
connection.send(message)
|
||||
}
|
||||
while (! registerRequests.isEmpty) {
|
||||
val conn: SendingConnection = registerRequests.dequeue
|
||||
addListeners(conn)
|
||||
conn.connect()
|
||||
addConnection(conn)
|
||||
}
|
||||
|
||||
while (!keyInterestChangeRequests.isEmpty) {
|
||||
while(!keyInterestChangeRequests.isEmpty) {
|
||||
val (key, ops) = keyInterestChangeRequests.dequeue
|
||||
val connection = connectionsByKey(key)
|
||||
val lastOps = key.interestOps()
|
||||
key.interestOps(ops)
|
||||
|
||||
def intToOpStr(op: Int): String = {
|
||||
val opStrs = ArrayBuffer[String]()
|
||||
if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
|
||||
if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
|
||||
if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
|
||||
if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
|
||||
if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
|
||||
|
||||
try {
|
||||
if (key.isValid) {
|
||||
val connection = connectionsByKey.getOrElse(key, null)
|
||||
if (connection != null) {
|
||||
val lastOps = key.interestOps()
|
||||
key.interestOps(ops)
|
||||
|
||||
// hot loop - prevent materialization of string if trace not enabled.
|
||||
if (isTraceEnabled()) {
|
||||
def intToOpStr(op: Int): String = {
|
||||
val opStrs = ArrayBuffer[String]()
|
||||
if ((op & SelectionKey.OP_READ) != 0) opStrs += "READ"
|
||||
if ((op & SelectionKey.OP_WRITE) != 0) opStrs += "WRITE"
|
||||
if ((op & SelectionKey.OP_CONNECT) != 0) opStrs += "CONNECT"
|
||||
if ((op & SelectionKey.OP_ACCEPT) != 0) opStrs += "ACCEPT"
|
||||
if (opStrs.size > 0) opStrs.reduceLeft(_ + " | " + _) else " "
|
||||
}
|
||||
|
||||
logTrace("Changed key for connection to [" + connection.getRemoteConnectionManagerId() +
|
||||
"] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logInfo("Key not valid ? " + key)
|
||||
throw new CancelledKeyException()
|
||||
}
|
||||
} catch {
|
||||
case e: CancelledKeyException => {
|
||||
logInfo("key already cancelled ? " + key, e)
|
||||
triggerForceCloseByException(key, e)
|
||||
}
|
||||
case e: Exception => {
|
||||
logError("Exception processing key " + key, e)
|
||||
triggerForceCloseByException(key, e)
|
||||
}
|
||||
}
|
||||
|
||||
logTrace("Changed key for connection to [" + connection.remoteConnectionManagerId +
|
||||
"] changed from [" + intToOpStr(lastOps) + "] to [" + intToOpStr(ops) + "]")
|
||||
|
||||
}
|
||||
|
||||
val selectedKeysCount = selector.select()
|
||||
val selectedKeysCount =
|
||||
try {
|
||||
selector.select()
|
||||
} catch {
|
||||
// Explicitly only dealing with CancelledKeyException here since other exceptions should be dealt with differently.
|
||||
case e: CancelledKeyException => {
|
||||
// Some keys within the selectors list are invalid/closed. clear them.
|
||||
val allKeys = selector.keys().iterator()
|
||||
|
||||
while (allKeys.hasNext()) {
|
||||
val key = allKeys.next()
|
||||
try {
|
||||
if (! key.isValid) {
|
||||
logInfo("Key not valid ? " + key)
|
||||
throw new CancelledKeyException()
|
||||
}
|
||||
} catch {
|
||||
case e: CancelledKeyException => {
|
||||
logInfo("key already cancelled ? " + key, e)
|
||||
triggerForceCloseByException(key, e)
|
||||
}
|
||||
case e: Exception => {
|
||||
logError("Exception processing key " + key, e)
|
||||
triggerForceCloseByException(key, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
0
|
||||
}
|
||||
|
||||
if (selectedKeysCount == 0) {
|
||||
logDebug("Selector selected " + selectedKeysCount + " of " + selector.keys.size + " keys")
|
||||
}
|
||||
|
@ -115,20 +296,40 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
logInfo("Selector thread was interrupted!")
|
||||
return
|
||||
}
|
||||
|
||||
val selectedKeys = selector.selectedKeys().iterator()
|
||||
while (selectedKeys.hasNext()) {
|
||||
val key = selectedKeys.next
|
||||
selectedKeys.remove()
|
||||
if (key.isValid) {
|
||||
if (key.isAcceptable) {
|
||||
acceptConnection(key)
|
||||
} else if (key.isConnectable) {
|
||||
connectionsByKey(key).asInstanceOf[SendingConnection].finishConnect()
|
||||
} else if (key.isReadable) {
|
||||
connectionsByKey(key).read()
|
||||
} else if (key.isWritable) {
|
||||
connectionsByKey(key).write()
|
||||
|
||||
if (0 != selectedKeysCount) {
|
||||
val selectedKeys = selector.selectedKeys().iterator()
|
||||
while (selectedKeys.hasNext()) {
|
||||
val key = selectedKeys.next
|
||||
selectedKeys.remove()
|
||||
try {
|
||||
if (key.isValid) {
|
||||
if (key.isAcceptable) {
|
||||
acceptConnection(key)
|
||||
} else
|
||||
if (key.isConnectable) {
|
||||
triggerConnect(key)
|
||||
} else
|
||||
if (key.isReadable) {
|
||||
triggerRead(key)
|
||||
} else
|
||||
if (key.isWritable) {
|
||||
triggerWrite(key)
|
||||
}
|
||||
} else {
|
||||
logInfo("Key not valid ? " + key)
|
||||
throw new CancelledKeyException()
|
||||
}
|
||||
} catch {
|
||||
// weird, but we saw this happening - even though key.isValid was true, key.isAcceptable would throw CancelledKeyException.
|
||||
case e: CancelledKeyException => {
|
||||
logInfo("key already cancelled ? " + key, e)
|
||||
triggerForceCloseByException(key, e)
|
||||
}
|
||||
case e: Exception => {
|
||||
logError("Exception processing key " + key, e)
|
||||
triggerForceCloseByException(key, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -137,97 +338,119 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
case e: Exception => logError("Error in select loop", e)
|
||||
}
|
||||
}
|
||||
|
||||
private def acceptConnection(key: SelectionKey) {
|
||||
|
||||
def acceptConnection(key: SelectionKey) {
|
||||
val serverChannel = key.channel.asInstanceOf[ServerSocketChannel]
|
||||
val newChannel = serverChannel.accept()
|
||||
val newConnection = new ReceivingConnection(newChannel, selector)
|
||||
newConnection.onReceive(receiveMessage)
|
||||
newConnection.onClose(removeConnection)
|
||||
addConnection(newConnection)
|
||||
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
|
||||
|
||||
var newChannel = serverChannel.accept()
|
||||
|
||||
// accept them all in a tight loop. non blocking accept with no processing, should be fine
|
||||
while (newChannel != null) {
|
||||
try {
|
||||
val newConnection = new ReceivingConnection(newChannel, selector)
|
||||
newConnection.onReceive(receiveMessage)
|
||||
addListeners(newConnection)
|
||||
addConnection(newConnection)
|
||||
logInfo("Accepted connection from [" + newConnection.remoteAddress.getAddress + "]")
|
||||
} catch {
|
||||
// might happen in case of issues with registering with selector
|
||||
case e: Exception => logError("Error in accept loop", e)
|
||||
}
|
||||
|
||||
newChannel = serverChannel.accept()
|
||||
}
|
||||
}
|
||||
|
||||
private def addConnection(connection: Connection) {
|
||||
connectionsByKey += ((connection.key, connection))
|
||||
if (connection.isInstanceOf[SendingConnection]) {
|
||||
val sendingConnection = connection.asInstanceOf[SendingConnection]
|
||||
connectionsById += ((sendingConnection.remoteConnectionManagerId, sendingConnection))
|
||||
}
|
||||
private def addListeners(connection: Connection) {
|
||||
connection.onKeyInterestChange(changeConnectionKeyInterest)
|
||||
connection.onException(handleConnectionError)
|
||||
connection.onClose(removeConnection)
|
||||
}
|
||||
|
||||
private def removeConnection(connection: Connection) {
|
||||
def addConnection(connection: Connection) {
|
||||
connectionsByKey += ((connection.key, connection))
|
||||
}
|
||||
|
||||
def removeConnection(connection: Connection) {
|
||||
connectionsByKey -= connection.key
|
||||
if (connection.isInstanceOf[SendingConnection]) {
|
||||
val sendingConnection = connection.asInstanceOf[SendingConnection]
|
||||
val sendingConnectionManagerId = sendingConnection.remoteConnectionManagerId
|
||||
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
|
||||
|
||||
connectionsById -= sendingConnectionManagerId
|
||||
|
||||
messageStatuses.synchronized {
|
||||
messageStatuses
|
||||
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
|
||||
logInfo("Notifying " + status)
|
||||
status.synchronized {
|
||||
status.attempted = true
|
||||
status.acked = false
|
||||
status.markDone()
|
||||
}
|
||||
try {
|
||||
if (connection.isInstanceOf[SendingConnection]) {
|
||||
val sendingConnection = connection.asInstanceOf[SendingConnection]
|
||||
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
|
||||
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)
|
||||
|
||||
connectionsById -= sendingConnectionManagerId
|
||||
|
||||
messageStatuses.synchronized {
|
||||
messageStatuses
|
||||
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
|
||||
logInfo("Notifying " + status)
|
||||
status.synchronized {
|
||||
status.attempted = true
|
||||
status.acked = false
|
||||
status.markDone()
|
||||
}
|
||||
})
|
||||
|
||||
messageStatuses.retain((i, status) => {
|
||||
status.connectionManagerId != sendingConnectionManagerId
|
||||
})
|
||||
}
|
||||
} else if (connection.isInstanceOf[ReceivingConnection]) {
|
||||
val receivingConnection = connection.asInstanceOf[ReceivingConnection]
|
||||
val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
|
||||
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
|
||||
|
||||
messageStatuses.retain((i, status) => {
|
||||
status.connectionManagerId != sendingConnectionManagerId
|
||||
})
|
||||
}
|
||||
} else if (connection.isInstanceOf[ReceivingConnection]) {
|
||||
val receivingConnection = connection.asInstanceOf[ReceivingConnection]
|
||||
val remoteConnectionManagerId = receivingConnection.remoteConnectionManagerId
|
||||
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)
|
||||
|
||||
val sendingConnectionManagerId = connectionsById.keys.find(_.host == remoteConnectionManagerId.host).orNull
|
||||
if (sendingConnectionManagerId == null) {
|
||||
logError("Corresponding SendingConnectionManagerId not found")
|
||||
return
|
||||
}
|
||||
logInfo("Corresponding SendingConnectionManagerId is " + sendingConnectionManagerId)
|
||||
|
||||
val sendingConnection = connectionsById(sendingConnectionManagerId)
|
||||
sendingConnection.close()
|
||||
connectionsById -= sendingConnectionManagerId
|
||||
|
||||
messageStatuses.synchronized {
|
||||
for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
|
||||
logInfo("Notifying " + s)
|
||||
s.synchronized {
|
||||
s.attempted = true
|
||||
s.acked = false
|
||||
s.markDone()
|
||||
}
|
||||
val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
|
||||
if (! sendingConnectionOpt.isDefined) {
|
||||
logError("Corresponding SendingConnectionManagerId not found")
|
||||
return
|
||||
}
|
||||
|
||||
messageStatuses.retain((i, status) => {
|
||||
status.connectionManagerId != sendingConnectionManagerId
|
||||
})
|
||||
val sendingConnection = sendingConnectionOpt.get
|
||||
connectionsById -= remoteConnectionManagerId
|
||||
sendingConnection.close()
|
||||
|
||||
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
|
||||
|
||||
assert (sendingConnectionManagerId == remoteConnectionManagerId)
|
||||
|
||||
messageStatuses.synchronized {
|
||||
for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
|
||||
logInfo("Notifying " + s)
|
||||
s.synchronized {
|
||||
s.attempted = true
|
||||
s.acked = false
|
||||
s.markDone()
|
||||
}
|
||||
}
|
||||
|
||||
messageStatuses.retain((i, status) => {
|
||||
status.connectionManagerId != sendingConnectionManagerId
|
||||
})
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
// So that the selection keys can be removed.
|
||||
wakeupSelector()
|
||||
}
|
||||
}
|
||||
|
||||
private def handleConnectionError(connection: Connection, e: Exception) {
|
||||
logInfo("Handling connection error on connection to " + connection.remoteConnectionManagerId)
|
||||
def handleConnectionError(connection: Connection, e: Exception) {
|
||||
logInfo("Handling connection error on connection to " + connection.getRemoteConnectionManagerId())
|
||||
removeConnection(connection)
|
||||
}
|
||||
|
||||
private def changeConnectionKeyInterest(connection: Connection, ops: Int) {
|
||||
keyInterestChangeRequests += ((connection.key, ops))
|
||||
def changeConnectionKeyInterest(connection: Connection, ops: Int) {
|
||||
keyInterestChangeRequests += ((connection.key, ops))
|
||||
// so that registerations happen !
|
||||
wakeupSelector()
|
||||
}
|
||||
|
||||
private def receiveMessage(connection: Connection, message: Message) {
|
||||
def receiveMessage(connection: Connection, message: Message) {
|
||||
val connectionManagerId = ConnectionManagerId.fromSocketAddress(message.senderAddress)
|
||||
logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
|
||||
logDebug("Received [" + message + "] from [" + connectionManagerId + "]")
|
||||
val runnable = new Runnable() {
|
||||
val creationTime = System.currentTimeMillis
|
||||
def run() {
|
||||
|
@ -247,11 +470,11 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
if (bufferMessage.hasAckId) {
|
||||
val sentMessageStatus = messageStatuses.synchronized {
|
||||
messageStatuses.get(bufferMessage.ackId) match {
|
||||
case Some(status) => {
|
||||
messageStatuses -= bufferMessage.ackId
|
||||
case Some(status) => {
|
||||
messageStatuses -= bufferMessage.ackId
|
||||
status
|
||||
}
|
||||
case None => {
|
||||
case None => {
|
||||
throw new Exception("Could not find reference for received ack message " + message.id)
|
||||
null
|
||||
}
|
||||
|
@ -271,7 +494,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
logDebug("Not calling back as callback is null")
|
||||
None
|
||||
}
|
||||
|
||||
|
||||
if (ackMessage.isDefined) {
|
||||
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
|
||||
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type " + ackMessage.get.getClass())
|
||||
|
@ -281,7 +504,7 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
sendMessage(connectionManagerId, ackMessage.getOrElse {
|
||||
sendMessage(connectionManagerId, ackMessage.getOrElse {
|
||||
Message.createBufferMessage(bufferMessage.id)
|
||||
})
|
||||
}
|
||||
|
@ -293,18 +516,22 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
private def sendMessage(connectionManagerId: ConnectionManagerId, message: Message) {
|
||||
def startNewConnection(): SendingConnection = {
|
||||
val inetSocketAddress = new InetSocketAddress(connectionManagerId.host, connectionManagerId.port)
|
||||
val newConnection = connectionRequests.getOrElseUpdate(connectionManagerId,
|
||||
new SendingConnection(inetSocketAddress, selector, connectionManagerId))
|
||||
newConnection
|
||||
val newConnection = new SendingConnection(inetSocketAddress, selector, connectionManagerId)
|
||||
registerRequests.enqueue(newConnection)
|
||||
|
||||
newConnection
|
||||
}
|
||||
val lookupKey = ConnectionManagerId.fromSocketAddress(connectionManagerId.toSocketAddress)
|
||||
val connection = connectionsById.getOrElse(lookupKey, startNewConnection())
|
||||
// I removed the lookupKey stuff as part of merge ... should I re-add it ? We did not find it useful in our test-env ...
|
||||
// If we do re-add it, we should consistently use it everywhere I guess ?
|
||||
val connection = connectionsById.getOrElseUpdate(connectionManagerId, startNewConnection())
|
||||
message.senderAddress = id.toSocketAddress()
|
||||
logDebug("Sending [" + message + "] to [" + connectionManagerId + "]")
|
||||
/*connection.send(message)*/
|
||||
sendMessageRequests.synchronized {
|
||||
sendMessageRequests += ((message, connection))
|
||||
}
|
||||
connection.send(message)
|
||||
|
||||
wakeupSelector()
|
||||
}
|
||||
|
||||
private def wakeupSelector() {
|
||||
selector.wakeup()
|
||||
}
|
||||
|
||||
|
@ -337,6 +564,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
logWarning("All connections not cleaned up")
|
||||
}
|
||||
handleMessageExecutor.shutdown()
|
||||
handleReadWriteExecutor.shutdown()
|
||||
handleConnectExecutor.shutdown()
|
||||
logInfo("ConnectionManager stopped")
|
||||
}
|
||||
}
|
||||
|
@ -346,17 +575,17 @@ private[spark] object ConnectionManager {
|
|||
|
||||
def main(args: Array[String]) {
|
||||
val manager = new ConnectionManager(9999)
|
||||
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
|
||||
manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => {
|
||||
println("Received [" + msg + "] from [" + id + "]")
|
||||
None
|
||||
})
|
||||
|
||||
|
||||
/*testSequentialSending(manager)*/
|
||||
/*System.gc()*/
|
||||
|
||||
/*testParallelSending(manager)*/
|
||||
/*System.gc()*/
|
||||
|
||||
|
||||
/*testParallelDecreasingSending(manager)*/
|
||||
/*System.gc()*/
|
||||
|
||||
|
@ -368,9 +597,9 @@ private[spark] object ConnectionManager {
|
|||
println("--------------------------")
|
||||
println("Sequential Sending")
|
||||
println("--------------------------")
|
||||
val size = 10 * 1024 * 1024
|
||||
val size = 10 * 1024 * 1024
|
||||
val count = 10
|
||||
|
||||
|
||||
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
|
||||
buffer.flip
|
||||
|
||||
|
@ -386,7 +615,7 @@ private[spark] object ConnectionManager {
|
|||
println("--------------------------")
|
||||
println("Parallel Sending")
|
||||
println("--------------------------")
|
||||
val size = 10 * 1024 * 1024
|
||||
val size = 10 * 1024 * 1024
|
||||
val count = 10
|
||||
|
||||
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
|
||||
|
@ -401,12 +630,12 @@ private[spark] object ConnectionManager {
|
|||
if (!g.isDefined) println("Failed")
|
||||
})
|
||||
val finishTime = System.currentTimeMillis
|
||||
|
||||
|
||||
val mb = size * count / 1024.0 / 1024.0
|
||||
val ms = finishTime - startTime
|
||||
val tput = mb * 1000.0 / ms
|
||||
println("--------------------------")
|
||||
println("Started at " + startTime + ", finished at " + finishTime)
|
||||
println("Started at " + startTime + ", finished at " + finishTime)
|
||||
println("Sent " + count + " messages of size " + size + " in " + ms + " ms (" + tput + " MB/s)")
|
||||
println("--------------------------")
|
||||
println()
|
||||
|
@ -416,7 +645,7 @@ private[spark] object ConnectionManager {
|
|||
println("--------------------------")
|
||||
println("Parallel Decreasing Sending")
|
||||
println("--------------------------")
|
||||
val size = 10 * 1024 * 1024
|
||||
val size = 10 * 1024 * 1024
|
||||
val count = 10
|
||||
val buffers = Array.tabulate(count)(i => ByteBuffer.allocate(size * (i + 1)).put(Array.tabulate[Byte](size * (i + 1))(x => x.toByte)))
|
||||
buffers.foreach(_.flip)
|
||||
|
@ -431,7 +660,7 @@ private[spark] object ConnectionManager {
|
|||
if (!g.isDefined) println("Failed")
|
||||
})
|
||||
val finishTime = System.currentTimeMillis
|
||||
|
||||
|
||||
val ms = finishTime - startTime
|
||||
val tput = mb * 1000.0 / ms
|
||||
println("--------------------------")
|
||||
|
@ -445,7 +674,7 @@ private[spark] object ConnectionManager {
|
|||
println("--------------------------")
|
||||
println("Continuous Sending")
|
||||
println("--------------------------")
|
||||
val size = 10 * 1024 * 1024
|
||||
val size = 10 * 1024 * 1024
|
||||
val count = 10
|
||||
|
||||
val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte))
|
||||
|
|
21
core/src/main/scala/spark/network/ConnectionManagerId.scala
Normal file
21
core/src/main/scala/spark/network/ConnectionManagerId.scala
Normal file
|
@ -0,0 +1,21 @@
|
|||
package spark.network
|
||||
|
||||
import java.net.InetSocketAddress
|
||||
|
||||
import spark.Utils
|
||||
|
||||
|
||||
private[spark] case class ConnectionManagerId(host: String, port: Int) {
|
||||
// DEBUG code
|
||||
Utils.checkHost(host)
|
||||
assert (port > 0)
|
||||
|
||||
def toSocketAddress() = new InetSocketAddress(host, port)
|
||||
}
|
||||
|
||||
|
||||
private[spark] object ConnectionManagerId {
|
||||
def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
|
||||
new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
|
||||
}
|
||||
}
|
|
@ -1,55 +1,10 @@
|
|||
package spark.network
|
||||
|
||||
import spark._
|
||||
import java.nio.ByteBuffer
|
||||
import java.net.InetSocketAddress
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.net.InetAddress
|
||||
import java.net.InetSocketAddress
|
||||
import storage.BlockManager
|
||||
|
||||
private[spark] class MessageChunkHeader(
|
||||
val typ: Long,
|
||||
val id: Int,
|
||||
val totalSize: Int,
|
||||
val chunkSize: Int,
|
||||
val other: Int,
|
||||
val address: InetSocketAddress) {
|
||||
lazy val buffer = {
|
||||
val ip = address.getAddress.getAddress()
|
||||
val port = address.getPort()
|
||||
ByteBuffer.
|
||||
allocate(MessageChunkHeader.HEADER_SIZE).
|
||||
putLong(typ).
|
||||
putInt(id).
|
||||
putInt(totalSize).
|
||||
putInt(chunkSize).
|
||||
putInt(other).
|
||||
putInt(ip.size).
|
||||
put(ip).
|
||||
putInt(port).
|
||||
position(MessageChunkHeader.HEADER_SIZE).
|
||||
flip.asInstanceOf[ByteBuffer]
|
||||
}
|
||||
|
||||
override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
|
||||
" and sizes " + totalSize + " / " + chunkSize + " bytes"
|
||||
}
|
||||
|
||||
private[spark] class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
|
||||
val size = if (buffer == null) 0 else buffer.remaining
|
||||
lazy val buffers = {
|
||||
val ab = new ArrayBuffer[ByteBuffer]()
|
||||
ab += header.buffer
|
||||
if (buffer != null) {
|
||||
ab += buffer
|
||||
}
|
||||
ab
|
||||
}
|
||||
|
||||
override def toString = "" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
|
||||
}
|
||||
|
||||
private[spark] abstract class Message(val typ: Long, val id: Int) {
|
||||
var senderAddress: InetSocketAddress = null
|
||||
|
@ -58,120 +13,16 @@ private[spark] abstract class Message(val typ: Long, val id: Int) {
|
|||
var finishTime = -1L
|
||||
|
||||
def size: Int
|
||||
|
||||
|
||||
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk]
|
||||
|
||||
|
||||
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk]
|
||||
|
||||
|
||||
def timeTaken(): String = (finishTime - startTime).toString + " ms"
|
||||
|
||||
override def toString = this.getClass.getSimpleName + "(id = " + id + ", size = " + size + ")"
|
||||
}
|
||||
|
||||
private[spark] class BufferMessage(id_ : Int, val buffers: ArrayBuffer[ByteBuffer], var ackId: Int)
|
||||
extends Message(Message.BUFFER_MESSAGE, id_) {
|
||||
|
||||
val initialSize = currentSize()
|
||||
var gotChunkForSendingOnce = false
|
||||
|
||||
def size = initialSize
|
||||
|
||||
def currentSize() = {
|
||||
if (buffers == null || buffers.isEmpty) {
|
||||
0
|
||||
} else {
|
||||
buffers.map(_.remaining).reduceLeft(_ + _)
|
||||
}
|
||||
}
|
||||
|
||||
def getChunkForSending(maxChunkSize: Int): Option[MessageChunk] = {
|
||||
if (maxChunkSize <= 0) {
|
||||
throw new Exception("Max chunk size is " + maxChunkSize)
|
||||
}
|
||||
|
||||
if (size == 0 && gotChunkForSendingOnce == false) {
|
||||
val newChunk = new MessageChunk(new MessageChunkHeader(typ, id, 0, 0, ackId, senderAddress), null)
|
||||
gotChunkForSendingOnce = true
|
||||
return Some(newChunk)
|
||||
}
|
||||
|
||||
while(!buffers.isEmpty) {
|
||||
val buffer = buffers(0)
|
||||
if (buffer.remaining == 0) {
|
||||
BlockManager.dispose(buffer)
|
||||
buffers -= buffer
|
||||
} else {
|
||||
val newBuffer = if (buffer.remaining <= maxChunkSize) {
|
||||
buffer.duplicate()
|
||||
} else {
|
||||
buffer.slice().limit(maxChunkSize).asInstanceOf[ByteBuffer]
|
||||
}
|
||||
buffer.position(buffer.position + newBuffer.remaining)
|
||||
val newChunk = new MessageChunk(new MessageChunkHeader(
|
||||
typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
|
||||
gotChunkForSendingOnce = true
|
||||
return Some(newChunk)
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
def getChunkForReceiving(chunkSize: Int): Option[MessageChunk] = {
|
||||
// STRONG ASSUMPTION: BufferMessage created when receiving data has ONLY ONE data buffer
|
||||
if (buffers.size > 1) {
|
||||
throw new Exception("Attempting to get chunk from message with multiple data buffers")
|
||||
}
|
||||
val buffer = buffers(0)
|
||||
if (buffer.remaining > 0) {
|
||||
if (buffer.remaining < chunkSize) {
|
||||
throw new Exception("Not enough space in data buffer for receiving chunk")
|
||||
}
|
||||
val newBuffer = buffer.slice().limit(chunkSize).asInstanceOf[ByteBuffer]
|
||||
buffer.position(buffer.position + newBuffer.remaining)
|
||||
val newChunk = new MessageChunk(new MessageChunkHeader(
|
||||
typ, id, size, newBuffer.remaining, ackId, senderAddress), newBuffer)
|
||||
return Some(newChunk)
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
def flip() {
|
||||
buffers.foreach(_.flip)
|
||||
}
|
||||
|
||||
def hasAckId() = (ackId != 0)
|
||||
|
||||
def isCompletelyReceived() = !buffers(0).hasRemaining
|
||||
|
||||
override def toString = {
|
||||
if (hasAckId) {
|
||||
"BufferAckMessage(aid = " + ackId + ", id = " + id + ", size = " + size + ")"
|
||||
} else {
|
||||
"BufferMessage(id = " + id + ", size = " + size + ")"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] object MessageChunkHeader {
|
||||
val HEADER_SIZE = 40
|
||||
|
||||
def create(buffer: ByteBuffer): MessageChunkHeader = {
|
||||
if (buffer.remaining != HEADER_SIZE) {
|
||||
throw new IllegalArgumentException("Cannot convert buffer data to Message")
|
||||
}
|
||||
val typ = buffer.getLong()
|
||||
val id = buffer.getInt()
|
||||
val totalSize = buffer.getInt()
|
||||
val chunkSize = buffer.getInt()
|
||||
val other = buffer.getInt()
|
||||
val ipSize = buffer.getInt()
|
||||
val ipBytes = new Array[Byte](ipSize)
|
||||
buffer.get(ipBytes)
|
||||
val ip = InetAddress.getByAddress(ipBytes)
|
||||
val port = buffer.getInt()
|
||||
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] object Message {
|
||||
val BUFFER_MESSAGE = 1111111111L
|
||||
|
@ -180,14 +31,16 @@ private[spark] object Message {
|
|||
|
||||
def getNewId() = synchronized {
|
||||
lastId += 1
|
||||
if (lastId == 0) lastId += 1
|
||||
if (lastId == 0) {
|
||||
lastId += 1
|
||||
}
|
||||
lastId
|
||||
}
|
||||
|
||||
def createBufferMessage(dataBuffers: Seq[ByteBuffer], ackId: Int): BufferMessage = {
|
||||
if (dataBuffers == null) {
|
||||
return new BufferMessage(getNewId(), new ArrayBuffer[ByteBuffer], ackId)
|
||||
}
|
||||
}
|
||||
if (dataBuffers.exists(_ == null)) {
|
||||
throw new Exception("Attempting to create buffer message with null buffer")
|
||||
}
|
||||
|
@ -196,7 +49,7 @@ private[spark] object Message {
|
|||
|
||||
def createBufferMessage(dataBuffers: Seq[ByteBuffer]): BufferMessage =
|
||||
createBufferMessage(dataBuffers, 0)
|
||||
|
||||
|
||||
def createBufferMessage(dataBuffer: ByteBuffer, ackId: Int): BufferMessage = {
|
||||
if (dataBuffer == null) {
|
||||
return createBufferMessage(Array(ByteBuffer.allocate(0)), ackId)
|
||||
|
@ -204,15 +57,18 @@ private[spark] object Message {
|
|||
return createBufferMessage(Array(dataBuffer), ackId)
|
||||
}
|
||||
}
|
||||
|
||||
def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
|
||||
|
||||
def createBufferMessage(dataBuffer: ByteBuffer): BufferMessage =
|
||||
createBufferMessage(dataBuffer, 0)
|
||||
|
||||
def createBufferMessage(ackId: Int): BufferMessage = createBufferMessage(new Array[ByteBuffer](0), ackId)
|
||||
|
||||
def createBufferMessage(ackId: Int): BufferMessage = {
|
||||
createBufferMessage(new Array[ByteBuffer](0), ackId)
|
||||
}
|
||||
|
||||
def create(header: MessageChunkHeader): Message = {
|
||||
val newMessage: Message = header.typ match {
|
||||
case BUFFER_MESSAGE => new BufferMessage(header.id, ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
|
||||
case BUFFER_MESSAGE => new BufferMessage(header.id,
|
||||
ArrayBuffer(ByteBuffer.allocate(header.totalSize)), header.other)
|
||||
}
|
||||
newMessage.senderAddress = header.address
|
||||
newMessage
|
||||
|
|
25
core/src/main/scala/spark/network/MessageChunk.scala
Normal file
25
core/src/main/scala/spark/network/MessageChunk.scala
Normal file
|
@ -0,0 +1,25 @@
|
|||
package spark.network
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
|
||||
private[network]
|
||||
class MessageChunk(val header: MessageChunkHeader, val buffer: ByteBuffer) {
|
||||
|
||||
val size = if (buffer == null) 0 else buffer.remaining
|
||||
|
||||
lazy val buffers = {
|
||||
val ab = new ArrayBuffer[ByteBuffer]()
|
||||
ab += header.buffer
|
||||
if (buffer != null) {
|
||||
ab += buffer
|
||||
}
|
||||
ab
|
||||
}
|
||||
|
||||
override def toString = {
|
||||
"" + this.getClass.getSimpleName + " (id = " + header.id + ", size = " + size + ")"
|
||||
}
|
||||
}
|
58
core/src/main/scala/spark/network/MessageChunkHeader.scala
Normal file
58
core/src/main/scala/spark/network/MessageChunkHeader.scala
Normal file
|
@ -0,0 +1,58 @@
|
|||
package spark.network
|
||||
|
||||
import java.net.InetAddress
|
||||
import java.net.InetSocketAddress
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
|
||||
private[spark] class MessageChunkHeader(
|
||||
val typ: Long,
|
||||
val id: Int,
|
||||
val totalSize: Int,
|
||||
val chunkSize: Int,
|
||||
val other: Int,
|
||||
val address: InetSocketAddress) {
|
||||
lazy val buffer = {
|
||||
// No need to change this, at 'use' time, we do a reverse lookup of the hostname.
|
||||
// Refer to network.Connection
|
||||
val ip = address.getAddress.getAddress()
|
||||
val port = address.getPort()
|
||||
ByteBuffer.
|
||||
allocate(MessageChunkHeader.HEADER_SIZE).
|
||||
putLong(typ).
|
||||
putInt(id).
|
||||
putInt(totalSize).
|
||||
putInt(chunkSize).
|
||||
putInt(other).
|
||||
putInt(ip.size).
|
||||
put(ip).
|
||||
putInt(port).
|
||||
position(MessageChunkHeader.HEADER_SIZE).
|
||||
flip.asInstanceOf[ByteBuffer]
|
||||
}
|
||||
|
||||
override def toString = "" + this.getClass.getSimpleName + ":" + id + " of type " + typ +
|
||||
" and sizes " + totalSize + " / " + chunkSize + " bytes"
|
||||
}
|
||||
|
||||
|
||||
private[spark] object MessageChunkHeader {
|
||||
val HEADER_SIZE = 40
|
||||
|
||||
def create(buffer: ByteBuffer): MessageChunkHeader = {
|
||||
if (buffer.remaining != HEADER_SIZE) {
|
||||
throw new IllegalArgumentException("Cannot convert buffer data to Message")
|
||||
}
|
||||
val typ = buffer.getLong()
|
||||
val id = buffer.getInt()
|
||||
val totalSize = buffer.getInt()
|
||||
val chunkSize = buffer.getInt()
|
||||
val other = buffer.getInt()
|
||||
val ipSize = buffer.getInt()
|
||||
val ipBytes = new Array[Byte](ipSize)
|
||||
buffer.get(ipBytes)
|
||||
val ip = InetAddress.getByAddress(ipBytes)
|
||||
val port = buffer.getInt()
|
||||
new MessageChunkHeader(typ, id, totalSize, chunkSize, other, new InetSocketAddress(ip, port))
|
||||
}
|
||||
}
|
57
core/src/main/scala/spark/network/netty/FileHeader.scala
Normal file
57
core/src/main/scala/spark/network/netty/FileHeader.scala
Normal file
|
@ -0,0 +1,57 @@
|
|||
package spark.network.netty
|
||||
|
||||
import io.netty.buffer._
|
||||
|
||||
import spark.Logging
|
||||
|
||||
private[spark] class FileHeader (
|
||||
val fileLen: Int,
|
||||
val blockId: String) extends Logging {
|
||||
|
||||
lazy val buffer = {
|
||||
val buf = Unpooled.buffer()
|
||||
buf.capacity(FileHeader.HEADER_SIZE)
|
||||
buf.writeInt(fileLen)
|
||||
buf.writeInt(blockId.length)
|
||||
blockId.foreach((x: Char) => buf.writeByte(x))
|
||||
//padding the rest of header
|
||||
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
|
||||
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
|
||||
} else {
|
||||
throw new Exception("too long header " + buf.readableBytes)
|
||||
logInfo("too long header")
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private[spark] object FileHeader {
|
||||
|
||||
val HEADER_SIZE = 40
|
||||
|
||||
def getFileLenOffset = 0
|
||||
def getFileLenSize = Integer.SIZE/8
|
||||
|
||||
def create(buf: ByteBuf): FileHeader = {
|
||||
val length = buf.readInt
|
||||
val idLength = buf.readInt
|
||||
val idBuilder = new StringBuilder(idLength)
|
||||
for (i <- 1 to idLength) {
|
||||
idBuilder += buf.readByte().asInstanceOf[Char]
|
||||
}
|
||||
val blockId = idBuilder.toString()
|
||||
new FileHeader(length, blockId)
|
||||
}
|
||||
|
||||
|
||||
def main (args:Array[String]){
|
||||
|
||||
val header = new FileHeader(25,"block_0");
|
||||
val buf = header.buffer;
|
||||
val newheader = FileHeader.create(buf);
|
||||
System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
|
||||
|
||||
}
|
||||
}
|
||||
|
101
core/src/main/scala/spark/network/netty/ShuffleCopier.scala
Normal file
101
core/src/main/scala/spark/network/netty/ShuffleCopier.scala
Normal file
|
@ -0,0 +1,101 @@
|
|||
package spark.network.netty
|
||||
|
||||
import java.util.concurrent.Executors
|
||||
|
||||
import io.netty.buffer.ByteBuf
|
||||
import io.netty.channel.ChannelHandlerContext
|
||||
import io.netty.util.CharsetUtil
|
||||
|
||||
import spark.Logging
|
||||
import spark.network.ConnectionManagerId
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
|
||||
private[spark] class ShuffleCopier extends Logging {
|
||||
|
||||
def getBlock(host: String, port: Int, blockId: String,
|
||||
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
|
||||
|
||||
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
|
||||
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
|
||||
val fc = new FileClient(handler, connectTimeout)
|
||||
|
||||
try {
|
||||
fc.init()
|
||||
fc.connect(host, port)
|
||||
fc.sendRequest(blockId)
|
||||
fc.waitForClose()
|
||||
fc.close()
|
||||
} catch {
|
||||
// Handle any socket-related exceptions in FileClient
|
||||
case e: Exception => {
|
||||
logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
|
||||
handler.handleError(blockId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def getBlock(cmId: ConnectionManagerId, blockId: String,
|
||||
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
|
||||
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
|
||||
}
|
||||
|
||||
def getBlocks(cmId: ConnectionManagerId,
|
||||
blocks: Seq[(String, Long)],
|
||||
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
|
||||
|
||||
for ((blockId, size) <- blocks) {
|
||||
getBlock(cmId, blockId, resultCollectCallback)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private[spark] object ShuffleCopier extends Logging {
|
||||
|
||||
private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
|
||||
extends FileClientHandler with Logging {
|
||||
|
||||
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
|
||||
logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
|
||||
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
|
||||
}
|
||||
|
||||
override def handleError(blockId: String) {
|
||||
if (!isComplete) {
|
||||
resultCollectCallBack(blockId, -1, null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
|
||||
if (size != -1) {
|
||||
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
|
||||
}
|
||||
}
|
||||
|
||||
def main(args: Array[String]) {
|
||||
if (args.length < 3) {
|
||||
System.err.println("Usage: ShuffleCopier <host> <port> <shuffle_block_id> <threads>")
|
||||
System.exit(1)
|
||||
}
|
||||
val host = args(0)
|
||||
val port = args(1).toInt
|
||||
val file = args(2)
|
||||
val threads = if (args.length > 3) args(3).toInt else 10
|
||||
|
||||
val copiers = Executors.newFixedThreadPool(80)
|
||||
val tasks = (for (i <- Range(0, threads)) yield {
|
||||
Executors.callable(new Runnable() {
|
||||
def run() {
|
||||
val copier = new ShuffleCopier()
|
||||
copier.getBlock(host, port, file, echoResultCollectCallBack)
|
||||
}
|
||||
})
|
||||
}).asJava
|
||||
copiers.invokeAll(tasks)
|
||||
copiers.shutdown
|
||||
System.exit(0)
|
||||
}
|
||||
}
|
53
core/src/main/scala/spark/network/netty/ShuffleSender.scala
Normal file
53
core/src/main/scala/spark/network/netty/ShuffleSender.scala
Normal file
|
@ -0,0 +1,53 @@
|
|||
package spark.network.netty
|
||||
|
||||
import java.io.File
|
||||
|
||||
import spark.Logging
|
||||
|
||||
|
||||
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
|
||||
|
||||
val server = new FileServer(pResolver, portIn)
|
||||
server.start()
|
||||
|
||||
def stop() {
|
||||
server.stop()
|
||||
}
|
||||
|
||||
def port: Int = server.getPort()
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* An application for testing the shuffle sender as a standalone program.
|
||||
*/
|
||||
private[spark] object ShuffleSender {
|
||||
|
||||
def main(args: Array[String]) {
|
||||
if (args.length < 3) {
|
||||
System.err.println(
|
||||
"Usage: ShuffleSender <port> <subDirsPerLocalDir> <list of shuffle_block_directories>")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
val port = args(0).toInt
|
||||
val subDirsPerLocalDir = args(1).toInt
|
||||
val localDirs = args.drop(2).map(new File(_))
|
||||
|
||||
val pResovler = new PathResolver {
|
||||
override def getAbsolutePath(blockId: String): String = {
|
||||
if (!blockId.startsWith("shuffle_")) {
|
||||
throw new Exception("Block " + blockId + " is not a shuffle block")
|
||||
}
|
||||
// Figure out which local directory it hashes to, and which subdirectory in that
|
||||
val hash = math.abs(blockId.hashCode)
|
||||
val dirId = hash % localDirs.length
|
||||
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
|
||||
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
|
||||
val file = new File(subDir, blockId)
|
||||
return file.getAbsolutePath
|
||||
}
|
||||
}
|
||||
val sender = new ShuffleSender(port, pResovler)
|
||||
}
|
||||
}
|
|
@ -1,7 +1,7 @@
|
|||
package spark.rdd
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
import spark.{RDD, SparkContext, SparkEnv, Partition, TaskContext}
|
||||
import spark.storage.BlockManager
|
||||
|
||||
private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
|
||||
val index = idx
|
||||
|
@ -11,12 +11,7 @@ private[spark]
|
|||
class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[String])
|
||||
extends RDD[T](sc, Nil) {
|
||||
|
||||
@transient lazy val locations_ = {
|
||||
val blockManager = SparkEnv.get.blockManager
|
||||
/*val locations = blockIds.map(id => blockManager.getLocations(id))*/
|
||||
val locations = blockManager.getLocations(blockIds)
|
||||
HashMap(blockIds.zip(locations):_*)
|
||||
}
|
||||
@transient lazy val locations_ = BlockManager.blockIdsToExecutorLocations(blockIds, SparkEnv.get)
|
||||
|
||||
override def getPartitions: Array[Partition] = (0 until blockIds.size).map(i => {
|
||||
new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition]
|
||||
|
|
|
@ -8,6 +8,7 @@ import org.apache.hadoop.util.ReflectionUtils
|
|||
import org.apache.hadoop.fs.Path
|
||||
import java.io.{File, IOException, EOFException}
|
||||
import java.text.NumberFormat
|
||||
import spark.deploy.SparkHadoopUtil
|
||||
|
||||
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
|
||||
|
||||
|
@ -21,13 +22,20 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri
|
|||
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
|
||||
|
||||
override def getPartitions: Array[Partition] = {
|
||||
val dirContents = fs.listStatus(new Path(checkpointPath))
|
||||
val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
|
||||
val numPartitions = partitionFiles.size
|
||||
if (numPartitions > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
|
||||
! partitionFiles(numPartitions-1).endsWith(CheckpointRDD.splitIdToFile(numPartitions-1)))) {
|
||||
throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
|
||||
}
|
||||
val cpath = new Path(checkpointPath)
|
||||
val numPartitions =
|
||||
// listStatus can throw exception if path does not exist.
|
||||
if (fs.exists(cpath)) {
|
||||
val dirContents = fs.listStatus(cpath)
|
||||
val partitionFiles = dirContents.map(_.getPath.toString).filter(_.contains("part-")).sorted
|
||||
val numPart = partitionFiles.size
|
||||
if (numPart > 0 && (! partitionFiles(0).endsWith(CheckpointRDD.splitIdToFile(0)) ||
|
||||
! partitionFiles(numPart-1).endsWith(CheckpointRDD.splitIdToFile(numPart-1)))) {
|
||||
throw new SparkException("Invalid checkpoint directory: " + checkpointPath)
|
||||
}
|
||||
numPart
|
||||
} else 0
|
||||
|
||||
Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
|
||||
}
|
||||
|
||||
|
@ -35,7 +43,7 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: Stri
|
|||
checkpointData.get.cpFile = Some(checkpointPath)
|
||||
|
||||
override def getPreferredLocations(split: Partition): Seq[String] = {
|
||||
val status = fs.getFileStatus(new Path(checkpointPath))
|
||||
val status = fs.getFileStatus(new Path(checkpointPath, CheckpointRDD.splitIdToFile(split.index)))
|
||||
val locations = fs.getFileBlockLocations(status, 0, status.getLen)
|
||||
locations.headOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
|
||||
}
|
||||
|
@ -58,7 +66,7 @@ private[spark] object CheckpointRDD extends Logging {
|
|||
|
||||
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
|
||||
val outputDir = new Path(path)
|
||||
val fs = outputDir.getFileSystem(new Configuration())
|
||||
val fs = outputDir.getFileSystem(SparkHadoopUtil.newConfiguration())
|
||||
|
||||
val finalOutputName = splitIdToFile(ctx.splitId)
|
||||
val finalOutputPath = new Path(outputDir, finalOutputName)
|
||||
|
@ -83,6 +91,7 @@ private[spark] object CheckpointRDD extends Logging {
|
|||
|
||||
if (!fs.rename(tempOutputPath, finalOutputPath)) {
|
||||
if (!fs.exists(finalOutputPath)) {
|
||||
logInfo("Deleting tempOutputPath " + tempOutputPath)
|
||||
fs.delete(tempOutputPath, false)
|
||||
throw new IOException("Checkpoint failed: failed to save output of task: "
|
||||
+ ctx.attemptId + " and final output path does not exist")
|
||||
|
@ -95,7 +104,7 @@ private[spark] object CheckpointRDD extends Logging {
|
|||
}
|
||||
|
||||
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
|
||||
val fs = path.getFileSystem(new Configuration())
|
||||
val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
|
||||
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
|
||||
val fileInputStream = fs.open(path, bufferSize)
|
||||
val serializer = SparkEnv.get.serializer.newInstance()
|
||||
|
@ -117,11 +126,11 @@ private[spark] object CheckpointRDD extends Logging {
|
|||
val sc = new SparkContext(cluster, "CheckpointRDD Test")
|
||||
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
|
||||
val path = new Path(hdfsPath, "temp")
|
||||
val fs = path.getFileSystem(new Configuration())
|
||||
val fs = path.getFileSystem(SparkHadoopUtil.newConfiguration())
|
||||
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
|
||||
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
|
||||
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
|
||||
assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
|
||||
fs.delete(path)
|
||||
fs.delete(path, true)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,10 +2,11 @@ package spark.rdd
|
|||
|
||||
import java.io.{ObjectOutputStream, IOException}
|
||||
import java.util.{HashMap => JHashMap}
|
||||
|
||||
import scala.collection.JavaConversions
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import spark.{Aggregator, Logging, Partitioner, RDD, SparkEnv, Partition, TaskContext}
|
||||
import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext}
|
||||
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
|
||||
|
||||
|
||||
|
@ -28,7 +29,8 @@ private[spark] case class NarrowCoGroupSplitDep(
|
|||
private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
|
||||
|
||||
private[spark]
|
||||
class CoGroupPartition(idx: Int, val deps: Seq[CoGroupSplitDep]) extends Partition with Serializable {
|
||||
class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
|
||||
extends Partition with Serializable {
|
||||
override val index: Int = idx
|
||||
override def hashCode(): Int = idx
|
||||
}
|
||||
|
@ -40,7 +42,24 @@ private[spark] class CoGroupAggregator
|
|||
{ (b1, b2) => b1 ++ b2 })
|
||||
with Serializable
|
||||
|
||||
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
||||
|
||||
/**
|
||||
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
|
||||
* tuple with the list of values for that key.
|
||||
*
|
||||
* @param rdds parent RDDs.
|
||||
* @param part partitioner used to partition the shuffle output.
|
||||
* @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag
|
||||
* is on, Spark does an extra pass over the data on the map side to merge
|
||||
* all values belonging to the same key together. This can reduce the amount
|
||||
* of data shuffled if and only if the number of distinct keys is very small,
|
||||
* and the ratio of key size to value size is also very small.
|
||||
*/
|
||||
class CoGroupedRDD[K](
|
||||
@transient var rdds: Seq[RDD[(K, _)]],
|
||||
part: Partitioner,
|
||||
val mapSideCombine: Boolean = false,
|
||||
val serializerClass: String = null)
|
||||
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
|
||||
|
||||
private val aggr = new CoGroupAggregator
|
||||
|
@ -52,8 +71,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
|||
new OneToOneDependency(rdd)
|
||||
} else {
|
||||
logInfo("Adding shuffle dependency with " + rdd)
|
||||
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
|
||||
new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part)
|
||||
if (mapSideCombine) {
|
||||
val mapSideCombinedRDD = rdd.mapPartitions(aggr.combineValuesByKey(_), true)
|
||||
new ShuffleDependency[Any, ArrayBuffer[Any]](mapSideCombinedRDD, part, serializerClass)
|
||||
} else {
|
||||
new ShuffleDependency[Any, Any](rdd.asInstanceOf[RDD[(Any, Any)]], part, serializerClass)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -70,7 +93,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
|||
case _ =>
|
||||
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
|
||||
}
|
||||
}.toList)
|
||||
}.toArray)
|
||||
}
|
||||
array
|
||||
}
|
||||
|
@ -82,6 +105,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
|||
val numRdds = split.deps.size
|
||||
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
|
||||
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
|
||||
|
||||
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
|
||||
val seq = map.get(k)
|
||||
if (seq != null) {
|
||||
|
@ -92,6 +116,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
|||
seq
|
||||
}
|
||||
}
|
||||
|
||||
val ser = SparkEnv.get.serializerManager.get(serializerClass)
|
||||
for ((dep, depNum) <- split.deps.zipWithIndex) dep match {
|
||||
case NarrowCoGroupSplitDep(rdd, _, itsSplit) => {
|
||||
// Read them from the parent
|
||||
|
@ -102,8 +128,16 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
|
|||
case ShuffleCoGroupSplitDep(shuffleId) => {
|
||||
// Read map outputs of shuffle
|
||||
val fetcher = SparkEnv.get.shuffleFetcher
|
||||
for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) {
|
||||
getSeq(k)(depNum) ++= vs
|
||||
if (mapSideCombine) {
|
||||
// With map side combine on, for each key, the shuffle fetcher returns a list of values.
|
||||
fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
|
||||
case (key, values) => getSeq(key)(depNum) ++= values
|
||||
}
|
||||
} else {
|
||||
// With map side combine off, for each key the shuffle fetcher returns a single value.
|
||||
fetcher.fetch[K, Any](shuffleId, split.index, context.taskMetrics, ser).foreach {
|
||||
case (key, value) => getSeq(key)(depNum) += value
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,8 +37,8 @@ class CoalescedRDD[T: ClassManifest](
|
|||
prevSplits.map(_.index).map{idx => new CoalescedRDDPartition(idx, prev, Array(idx)) }
|
||||
} else {
|
||||
(0 until maxPartitions).map { i =>
|
||||
val rangeStart = (i * prevSplits.length) / maxPartitions
|
||||
val rangeEnd = ((i + 1) * prevSplits.length) / maxPartitions
|
||||
val rangeStart = ((i.toLong * prevSplits.length) / maxPartitions).toInt
|
||||
val rangeEnd = (((i.toLong + 1) * prevSplits.length) / maxPartitions).toInt
|
||||
new CoalescedRDDPartition(i, prev, (rangeStart until rangeEnd).toArray)
|
||||
}.toArray
|
||||
}
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue