Merge remote-tracking branch 'mesos/master' into yarnUILink

Conflicts:
	core/src/main/scala/org/apache/spark/ui/UIUtils.scala
	core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
	core/src/main/scala/org/apache/spark/ui/jobs/StageTable.scala
	docs/running-on-yarn.md
This commit is contained in:
Y.CORP.YAHOO.COM\tgraves 2013-09-03 08:36:59 -05:00
commit 547fc4a412
1058 changed files with 71622 additions and 69643 deletions

196
LICENSE
View file

@ -200,3 +200,199 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
=======================================================================
Apache Spark Subcomponents:
The Apache Spark project contains subcomponents with separate copyright
notices and license terms. Your use of the source code for the these
subcomponents is subject to the terms and conditions of the following
licenses.
=======================================================================
For the Boto EC2 library (ec2/third_party/boto*.zip):
=======================================================================
Copyright (c) 2006-2008 Mitch Garnaat http://garnaat.org/
Permission is hereby granted, free of charge, to any person obtaining a
copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish, dis-
tribute, sublicense, and/or sell copies of the Software, and to permit
persons to whom the Software is furnished to do so, subject to the fol-
lowing conditions:
The above copyright notice and this permission notice shall be included
in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABIL-
ITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHOR BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
IN THE SOFTWARE.
========================================================================
For CloudPickle (pyspark/cloudpickle.py):
========================================================================
Copyright (c) 2012, Regents of the University of California.
Copyright (c) 2009 `PiCloud, Inc. <http://www.picloud.com>`_.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of the University of California, Berkeley nor the
names of its contributors may be used to endorse or promote
products derived from this software without specific prior written
permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
========================================================================
For Py4J (python/lib/py4j0.7.egg and files in assembly/lib/net/sf/py4j):
========================================================================
Copyright (c) 2009-2011, Barthelemy Dagenais All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
- Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
- The name of the author may not be used to endorse or promote products
derived from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
========================================================================
For DPark join code (python/pyspark/join.py):
========================================================================
Copyright (c) 2011, Douban Inc. <http://www.douban.com/>
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of the Douban Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
========================================================================
For sorttable (core/src/main/resources/org/apache/spark/ui/static/sorttable.js):
========================================================================
Copyright (c) 1997-2007 Stuart Langridge
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
========================================================================
For Scala Interpreter classes (all .scala files in repl/src/main/scala
except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala):
========================================================================
Copyright (c) 2002-2013 EPFL
Copyright (c) 2011-2013 Typesafe, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
- Redistributions of source code must retain the above copyright notice,
this list of conditions and the following disclaimer.
- Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
- Neither the name of the EPFL nor the names of its contributors may be
used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.

View file

@ -1,12 +1,12 @@
# Spark
# Apache Spark
Lightning-Fast Cluster Computing - <http://www.spark-project.org/>
Lightning-Fast Cluster Computing - <http://spark.incubator.apache.org/>
## Online Documentation
You can find the latest Spark documentation, including a programming
guide, on the project webpage at <http://spark-project.org/documentation.html>.
guide, on the project webpage at <http://spark.incubator.apache.org/documentation.html>.
This README file only contains basic setup instructions.
@ -18,24 +18,22 @@ Spark and its example programs, run:
sbt/sbt assembly
Spark also supports building using Maven. If you would like to build using Maven,
see the [instructions for building Spark with Maven](http://spark-project.org/docs/latest/building-with-maven.html)
in the Spark documentation..
Once you've built Spark, the easiest way to start using it is the shell:
To run Spark, you will need to have Scala's bin directory in your `PATH`, or
you will need to set the `SCALA_HOME` environment variable to point to where
you've installed Scala. Scala must be accessible through one of these
methods on your cluster's worker nodes as well as its master.
./spark-shell
To run one of the examples, use `./run-example <class> <params>`. For example:
Or, for the Python API, the Python shell (`./pyspark`).
./run-example spark.examples.SparkLR local[2]
Spark also comes with several sample programs in the `examples` directory.
To run one of them, use `./run-example <class> <params>`. For example:
./run-example org.apache.spark.examples.SparkLR local[2]
will run the Logistic Regression example locally on 2 CPUs.
Each of the example programs prints usage help if no params are given.
All of the Spark samples take a `<host>` parameter that is the cluster URL
All of the Spark samples take a `<master>` parameter that is the cluster URL
to connect to. This can be a mesos:// or spark:// URL, or "local" to run
locally with one thread, or "local[N]" to run locally with N threads.
@ -58,13 +56,13 @@ versions without YARN, use:
$ SPARK_HADOOP_VERSION=2.0.0-mr1-cdh4.2.0 sbt/sbt assembly
For Apache Hadoop 2.x, 0.23.x, Cloudera CDH MRv2, and other Hadoop versions
with YARN, also set `SPARK_WITH_YARN=true`:
with YARN, also set `SPARK_YARN=true`:
# Apache Hadoop 2.0.5-alpha
$ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_WITH_YARN=true sbt/sbt assembly
$ SPARK_HADOOP_VERSION=2.0.5-alpha SPARK_YARN=true sbt/sbt assembly
# Cloudera CDH 4.2.0 with MapReduce v2
$ SPARK_HADOOP_VERSION=2.0.0-cdh4.2.0 SPARK_WITH_YARN=true sbt/sbt assembly
$ SPARK_HADOOP_VERSION=2.0.0-cdh4.2.0 SPARK_YARN=true sbt/sbt assembly
For convenience, these variables may also be set through the `conf/spark-env.sh` file
described below.
@ -81,22 +79,25 @@ If your project is built with Maven, add this to your POM file's `<dependencies>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client</artifactId>
<!-- the brackets are needed to tell Maven that this is a hard dependency on version "1.2.1" exactly -->
<version>[1.2.1]</version>
<version>1.2.1</version>
</dependency>
## Configuration
Please refer to the "Configuration" guide in the online documentation for a
full overview on how to configure Spark. At the minimum, you will need to
create a `conf/spark-env.sh` script (copy `conf/spark-env.sh.template`) and
set the following two variables:
Please refer to the [Configuration guide](http://spark.incubator.apache.org/docs/latest/configuration.html)
in the online documentation for an overview on how to configure Spark.
- `SCALA_HOME`: Location where Scala is installed.
- `MESOS_NATIVE_LIBRARY`: Your Mesos library (only needed if you want to run
on Mesos). For example, this might be `/usr/local/lib/libmesos.so` on Linux.
## Apache Incubator Notice
Apache Spark is an effort undergoing incubation at The Apache Software
Foundation (ASF), sponsored by the Apache Incubator. Incubation is required of
all newly accepted projects until a further review indicates that the
infrastructure, communications, and decision making process have stabilized in
a manner consistent with other successful ASF projects. While incubation status
is not necessarily a reflection of the completeness or stability of the code,
it does indicate that the project has yet to be fully endorsed by the ASF.
## Contributing to Spark

View file

@ -19,16 +19,16 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
<version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-assembly</artifactId>
<name>Spark Project Assembly</name>
<url>http://spark-project.org/</url>
<url>http://spark.incubator.apache.org/</url>
<repositories>
<!-- A repository in the local filesystem for the Py4J JAR, which is not in Maven central -->
@ -40,27 +40,27 @@
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-bagel</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-mllib</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-repl</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-streaming</artifactId>
<version>${project.version}</version>
</dependency>
@ -121,7 +121,7 @@
<id>hadoop2-yarn</id>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-yarn</artifactId>
<version>${project.version}</version>
</dependency>

View file

@ -30,9 +30,9 @@
</fileSet>
<fileSet>
<directory>
${project.parent.basedir}/core/src/main/resources/spark/ui/static/
${project.parent.basedir}/core/src/main/resources/org/apache/spark/ui/static/
</directory>
<outputDirectory>/ui-resources/spark/ui/static</outputDirectory>
<outputDirectory>/ui-resources/org/apache/spark/ui/static</outputDirectory>
<includes>
<include>**/*</include>
</includes>
@ -63,10 +63,10 @@
<dependencySets>
<dependencySet>
<includes>
<include>org.spark-project:*:jar</include>
<include>org.apache.spark:*:jar</include>
</includes>
<excludes>
<exclude>org.spark-project:spark-assembly:jar</exclude>
<exclude>org.apache.spark:spark-assembly:jar</exclude>
</excludes>
</dependencySet>
<dependencySet>
@ -77,7 +77,7 @@
<useProjectArtifact>false</useProjectArtifact>
<excludes>
<exclude>org.apache.hadoop:*:jar</exclude>
<exclude>org.spark-project:*:jar</exclude>
<exclude>org.apache.spark:*:jar</exclude>
</excludes>
</dependencySet>
</dependencySets>

View file

@ -19,21 +19,21 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
<version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-bagel</artifactId>
<packaging>jar</packaging>
<name>Spark Project Bagel</name>
<url>http://spark-project.org/</url>
<url>http://spark.incubator.apache.org/</url>
<dependencies>
<dependency>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core</artifactId>
<version>${project.version}</version>
</dependency>

View file

@ -0,0 +1,293 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.bagel
import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
object Bagel extends Logging {
val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK
/**
* Runs a Bagel program.
* @param sc [[org.apache.spark.SparkContext]] to use for the program.
* @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the Key will be
* the vertex id.
* @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often this will be an
* empty array, i.e. sc.parallelize(Array[K, Message]()).
* @param combiner [[org.apache.spark.bagel.Combiner]] combines multiple individual messages to a given vertex into one
* message before sending (which often involves network I/O).
* @param aggregator [[org.apache.spark.bagel.Aggregator]] performs a reduce across all vertices after each superstep,
* and provides the result to each vertex in the next superstep.
* @param partitioner [[org.apache.spark.Partitioner]] partitions values by key
* @param numPartitions number of partitions across which to split the graph.
* Default is the default parallelism of the SparkContext
* @param storageLevel [[org.apache.spark.storage.StorageLevel]] to use for caching of intermediate RDDs in each superstep.
* Defaults to caching in memory.
* @param compute function that takes a Vertex, optional set of (possibly combined) messages to the Vertex,
* optional Aggregator and the current superstep,
* and returns a set of (Vertex, outgoing Messages) pairs
* @tparam K key
* @tparam V vertex type
* @tparam M message type
* @tparam C combiner
* @tparam A aggregator
* @return an RDD of (K, V) pairs representing the graph after completion of the program
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
C: Manifest, A: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
aggregator: Option[Aggregator[V, A]],
partitioner: Partitioner,
numPartitions: Int,
storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL
)(
compute: (V, Option[C], Option[A], Int) => (V, Array[M])
): RDD[(K, V)] = {
val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism
var superstep = 0
var verts = vertices
var msgs = messages
var noActivity = false
do {
logInfo("Starting superstep "+superstep+".")
val startTime = System.currentTimeMillis
val aggregated = agg(verts, aggregator)
val combinedMsgs = msgs.combineByKey(
combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner)
val grouped = combinedMsgs.groupWith(verts)
val superstep_ = superstep // Create a read-only copy of superstep for capture in closure
val (processed, numMsgs, numActiveVerts) =
comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
verts = processed.mapValues { case (vert, msgs) => vert }
msgs = processed.flatMap {
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
}
superstep += 1
noActivity = numMsgs == 0 && numActiveVerts == 0
} while (!noActivity)
verts
}
/** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the default storage level */
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
partitioner: Partitioner,
numPartitions: Int
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
/** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] */
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
partitioner: Partitioner,
numPartitions: Int,
storageLevel: StorageLevel
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = {
run[K, V, M, C, Nothing](
sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)(
addAggregatorArg[K, V, M, C](compute))
}
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default [[org.apache.spark.HashPartitioner]]
* and default storage level
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
numPartitions: Int
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
/** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the default [[org.apache.spark.HashPartitioner]]*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
numPartitions: Int,
storageLevel: StorageLevel
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = {
val part = new HashPartitioner(numPartitions)
run[K, V, M, C, Nothing](
sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)(
addAggregatorArg[K, V, M, C](compute))
}
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default [[org.apache.spark.HashPartitioner]],
* [[org.apache.spark.bagel.DefaultCombiner]] and the default storage level
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
numPartitions: Int
)(
compute: (V, Option[Array[M]], Int) => (V, Array[M])
): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
/**
* Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], the default [[org.apache.spark.HashPartitioner]]
* and [[org.apache.spark.bagel.DefaultCombiner]]
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
numPartitions: Int,
storageLevel: StorageLevel
)(
compute: (V, Option[Array[M]], Int) => (V, Array[M])
): RDD[(K, V)] = {
val part = new HashPartitioner(numPartitions)
run[K, V, M, Array[M], Nothing](
sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)(
addAggregatorArg[K, V, M, Array[M]](compute))
}
/**
* Aggregates the given vertices using the given aggregator, if it
* is specified.
*/
private def agg[K, V <: Vertex, A: Manifest](
verts: RDD[(K, V)],
aggregator: Option[Aggregator[V, A]]
): Option[A] = aggregator match {
case Some(a) =>
Some(verts.map {
case (id, vert) => a.createAggregator(vert)
}.reduce(a.mergeAggregators(_, _)))
case None => None
}
/**
* Processes the given vertex-message RDD using the compute
* function. Returns the processed RDD, the number of messages
* created, and the number of active vertices.
*/
private def comp[K: Manifest, V <: Vertex, M <: Message[K], C](
sc: SparkContext,
grouped: RDD[(K, (Seq[C], Seq[V]))],
compute: (V, Option[C]) => (V, Array[M]),
storageLevel: StorageLevel
): (RDD[(K, (V, Array[M]))], Int, Int) = {
var numMsgs = sc.accumulator(0)
var numActiveVerts = sc.accumulator(0)
val processed = grouped.flatMapValues {
case (_, vs) if vs.size == 0 => None
case (c, vs) =>
val (newVert, newMsgs) =
compute(vs(0), c match {
case Seq(comb) => Some(comb)
case Seq() => None
})
numMsgs += newMsgs.size
if (newVert.active)
numActiveVerts += 1
Some((newVert, newMsgs))
}.persist(storageLevel)
// Force evaluation of processed RDD for accurate performance measurements
processed.foreach(x => {})
(processed, numMsgs.value, numActiveVerts.value)
}
/**
* Converts a compute function that doesn't take an aggregator to
* one that does, so it can be passed to Bagel.run.
*/
private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C](
compute: (V, Option[C], Int) => (V, Array[M])
): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = {
(vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) =>
compute(vert, msgs, superstep)
}
}
trait Combiner[M, C] {
def createCombiner(msg: M): C
def mergeMsg(combiner: C, msg: M): C
def mergeCombiners(a: C, b: C): C
}
trait Aggregator[V, A] {
def createAggregator(vert: V): A
def mergeAggregators(a: A, b: A): A
}
/** Default combiner that simply appends messages together (i.e. performs no aggregation) */
class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable {
def createCombiner(msg: M): Array[M] =
Array(msg)
def mergeMsg(combiner: Array[M], msg: M): Array[M] =
combiner :+ msg
def mergeCombiners(a: Array[M], b: Array[M]): Array[M] =
a ++ b
}
/**
* Represents a Bagel vertex.
*
* Subclasses may store state along with each vertex and must
* inherit from java.io.Serializable or scala.Serializable.
*/
trait Vertex {
def active: Boolean
}
/**
* Represents a Bagel message to a target vertex.
*
* Subclasses may contain a payload to deliver to the target vertex
* and must inherit from java.io.Serializable or scala.Serializable.
*/
trait Message[K] {
def targetId: K
}

View file

@ -1,294 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.bagel
import spark._
import spark.SparkContext._
import scala.collection.mutable.ArrayBuffer
import storage.StorageLevel
object Bagel extends Logging {
val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK
/**
* Runs a Bagel program.
* @param sc [[spark.SparkContext]] to use for the program.
* @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the Key will be
* the vertex id.
* @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often this will be an
* empty array, i.e. sc.parallelize(Array[K, Message]()).
* @param combiner [[spark.bagel.Combiner]] combines multiple individual messages to a given vertex into one
* message before sending (which often involves network I/O).
* @param aggregator [[spark.bagel.Aggregator]] performs a reduce across all vertices after each superstep,
* and provides the result to each vertex in the next superstep.
* @param partitioner [[spark.Partitioner]] partitions values by key
* @param numPartitions number of partitions across which to split the graph.
* Default is the default parallelism of the SparkContext
* @param storageLevel [[spark.storage.StorageLevel]] to use for caching of intermediate RDDs in each superstep.
* Defaults to caching in memory.
* @param compute function that takes a Vertex, optional set of (possibly combined) messages to the Vertex,
* optional Aggregator and the current superstep,
* and returns a set of (Vertex, outgoing Messages) pairs
* @tparam K key
* @tparam V vertex type
* @tparam M message type
* @tparam C combiner
* @tparam A aggregator
* @return an RDD of (K, V) pairs representing the graph after completion of the program
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest,
C: Manifest, A: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
aggregator: Option[Aggregator[V, A]],
partitioner: Partitioner,
numPartitions: Int,
storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL
)(
compute: (V, Option[C], Option[A], Int) => (V, Array[M])
): RDD[(K, V)] = {
val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism
var superstep = 0
var verts = vertices
var msgs = messages
var noActivity = false
do {
logInfo("Starting superstep "+superstep+".")
val startTime = System.currentTimeMillis
val aggregated = agg(verts, aggregator)
val combinedMsgs = msgs.combineByKey(
combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner)
val grouped = combinedMsgs.groupWith(verts)
val superstep_ = superstep // Create a read-only copy of superstep for capture in closure
val (processed, numMsgs, numActiveVerts) =
comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel)
val timeTaken = System.currentTimeMillis - startTime
logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000))
verts = processed.mapValues { case (vert, msgs) => vert }
msgs = processed.flatMap {
case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m))
}
superstep += 1
noActivity = numMsgs == 0 && numActiveVerts == 0
} while (!noActivity)
verts
}
/** Runs a Bagel program with no [[spark.bagel.Aggregator]] and the default storage level */
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
partitioner: Partitioner,
numPartitions: Int
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
/** Runs a Bagel program with no [[spark.bagel.Aggregator]] */
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
partitioner: Partitioner,
numPartitions: Int,
storageLevel: StorageLevel
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = {
run[K, V, M, C, Nothing](
sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)(
addAggregatorArg[K, V, M, C](compute))
}
/**
* Runs a Bagel program with no [[spark.bagel.Aggregator]], default [[spark.HashPartitioner]]
* and default storage level
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
numPartitions: Int
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
/** Runs a Bagel program with no [[spark.bagel.Aggregator]] and the default [[spark.HashPartitioner]]*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
combiner: Combiner[M, C],
numPartitions: Int,
storageLevel: StorageLevel
)(
compute: (V, Option[C], Int) => (V, Array[M])
): RDD[(K, V)] = {
val part = new HashPartitioner(numPartitions)
run[K, V, M, C, Nothing](
sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)(
addAggregatorArg[K, V, M, C](compute))
}
/**
* Runs a Bagel program with no [[spark.bagel.Aggregator]], default [[spark.HashPartitioner]],
* [[spark.bagel.DefaultCombiner]] and the default storage level
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
numPartitions: Int
)(
compute: (V, Option[Array[M]], Int) => (V, Array[M])
): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute)
/**
* Runs a Bagel program with no [[spark.bagel.Aggregator]], the default [[spark.HashPartitioner]]
* and [[spark.bagel.DefaultCombiner]]
*/
def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest](
sc: SparkContext,
vertices: RDD[(K, V)],
messages: RDD[(K, M)],
numPartitions: Int,
storageLevel: StorageLevel
)(
compute: (V, Option[Array[M]], Int) => (V, Array[M])
): RDD[(K, V)] = {
val part = new HashPartitioner(numPartitions)
run[K, V, M, Array[M], Nothing](
sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)(
addAggregatorArg[K, V, M, Array[M]](compute))
}
/**
* Aggregates the given vertices using the given aggregator, if it
* is specified.
*/
private def agg[K, V <: Vertex, A: Manifest](
verts: RDD[(K, V)],
aggregator: Option[Aggregator[V, A]]
): Option[A] = aggregator match {
case Some(a) =>
Some(verts.map {
case (id, vert) => a.createAggregator(vert)
}.reduce(a.mergeAggregators(_, _)))
case None => None
}
/**
* Processes the given vertex-message RDD using the compute
* function. Returns the processed RDD, the number of messages
* created, and the number of active vertices.
*/
private def comp[K: Manifest, V <: Vertex, M <: Message[K], C](
sc: SparkContext,
grouped: RDD[(K, (Seq[C], Seq[V]))],
compute: (V, Option[C]) => (V, Array[M]),
storageLevel: StorageLevel
): (RDD[(K, (V, Array[M]))], Int, Int) = {
var numMsgs = sc.accumulator(0)
var numActiveVerts = sc.accumulator(0)
val processed = grouped.flatMapValues {
case (_, vs) if vs.size == 0 => None
case (c, vs) =>
val (newVert, newMsgs) =
compute(vs(0), c match {
case Seq(comb) => Some(comb)
case Seq() => None
})
numMsgs += newMsgs.size
if (newVert.active)
numActiveVerts += 1
Some((newVert, newMsgs))
}.persist(storageLevel)
// Force evaluation of processed RDD for accurate performance measurements
processed.foreach(x => {})
(processed, numMsgs.value, numActiveVerts.value)
}
/**
* Converts a compute function that doesn't take an aggregator to
* one that does, so it can be passed to Bagel.run.
*/
private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C](
compute: (V, Option[C], Int) => (V, Array[M])
): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = {
(vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) =>
compute(vert, msgs, superstep)
}
}
trait Combiner[M, C] {
def createCombiner(msg: M): C
def mergeMsg(combiner: C, msg: M): C
def mergeCombiners(a: C, b: C): C
}
trait Aggregator[V, A] {
def createAggregator(vert: V): A
def mergeAggregators(a: A, b: A): A
}
/** Default combiner that simply appends messages together (i.e. performs no aggregation) */
class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable {
def createCombiner(msg: M): Array[M] =
Array(msg)
def mergeMsg(combiner: Array[M], msg: M): Array[M] =
combiner :+ msg
def mergeCombiners(a: Array[M], b: Array[M]): Array[M] =
a ++ b
}
/**
* Represents a Bagel vertex.
*
* Subclasses may store state along with each vertex and must
* inherit from java.io.Serializable or scala.Serializable.
*/
trait Vertex {
def active: Boolean
}
/**
* Represents a Bagel message to a target vertex.
*
* Subclasses may contain a payload to deliver to the target vertex
* and must inherit from java.io.Serializable or scala.Serializable.
*/
trait Message[K] {
def targetId: K
}

View file

@ -1,118 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.bagel
import org.scalatest.{FunSuite, Assertions, BeforeAndAfter}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
import scala.collection.mutable.ArrayBuffer
import spark._
import storage.StorageLevel
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable
class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts {
var sc: SparkContext = _
after {
if (sc != null) {
sc.stop()
sc = null
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
}
test("halting by voting") {
sc = new SparkContext("local", "test")
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 5
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
}
}
test("halting by message silence") {
sc = new SparkContext("local", "test")
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0))))
val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
val numSupersteps = 5
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
val msgsOut =
msgs match {
case Some(ms) if (superstep < numSupersteps - 1) =>
ms
case _ =>
Array[TestMessage]()
}
(new TestVertex(self.active, self.age + 1), msgsOut)
}
for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
}
}
test("large number of iterations") {
// This tests whether jobs with a large number of iterations finish in a reasonable time,
// because non-memoized recursion in RDD or DAGScheduler used to cause them to hang
failAfter(10 seconds) {
sc = new SparkContext("local", "test")
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 50
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
}
}
}
test("using non-default persistence level") {
failAfter(10 seconds) {
sc = new SparkContext("local", "test")
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 50
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
}
}
}
}

View file

@ -0,0 +1,116 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.bagel
import org.scalatest.{BeforeAndAfter, FunSuite, Assertions}
import org.scalatest.concurrent.Timeouts
import org.scalatest.time.SpanSugar._
import org.apache.spark._
import org.apache.spark.storage.StorageLevel
class TestVertex(val active: Boolean, val age: Int) extends Vertex with Serializable
class TestMessage(val targetId: String) extends Message[String] with Serializable
class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeouts {
var sc: SparkContext = _
after {
if (sc != null) {
sc.stop()
sc = null
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
}
test("halting by voting") {
sc = new SparkContext("local", "test")
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 5
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
}
}
test("halting by message silence") {
sc = new SparkContext("local", "test")
val verts = sc.parallelize(Array("a", "b", "c", "d").map(id => (id, new TestVertex(false, 0))))
val msgs = sc.parallelize(Array("a" -> new TestMessage("a")))
val numSupersteps = 5
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
val msgsOut =
msgs match {
case Some(ms) if (superstep < numSupersteps - 1) =>
ms
case _ =>
Array[TestMessage]()
}
(new TestVertex(self.active, self.age + 1), msgsOut)
}
for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
}
}
test("large number of iterations") {
// This tests whether jobs with a large number of iterations finish in a reasonable time,
// because non-memoized recursion in RDD or DAGScheduler used to cause them to hang
failAfter(10 seconds) {
sc = new SparkContext("local", "test")
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 50
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
}
}
}
test("using non-default persistence level") {
failAfter(10 seconds) {
sc = new SparkContext("local", "test")
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
val numSupersteps = 50
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
(new TestVertex(superstep < numSupersteps - 1, self.age + 1), Array[TestMessage]())
}
for ((id, vert) <- result.collect) {
assert(vert.age === numSupersteps)
}
}
}
}

View file

@ -28,30 +28,27 @@ set FWDIR=%~dp0..\
rem Load environment variables from conf\spark-env.cmd, if it exists
if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
set CORE_DIR=%FWDIR%core
set REPL_DIR=%FWDIR%repl
set EXAMPLES_DIR=%FWDIR%examples
set BAGEL_DIR=%FWDIR%bagel
set MLLIB_DIR=%FWDIR%mllib
set TOOLS_DIR=%FWDIR%tools
set YARN_DIR=%FWDIR%yarn
set STREAMING_DIR=%FWDIR%streaming
set PYSPARK_DIR=%FWDIR%python
rem Build up classpath
set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources
set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes
set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\*
set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\*
set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\*
set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\*
set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\*
set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%MLLIB_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%TOOLS_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%YARN_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%SPARK_CLASSPATH%;%FWDIR%conf
if exist "%FWDIR%RELEASE" (
for %%d in ("%FWDIR%jars\spark-assembly*.jar") do (
set ASSEMBLY_JAR=%%d
)
) else (
for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
set ASSEMBLY_JAR=%%d
)
)
set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
if "x%SPARK_TESTING%"=="x1" (
rem Add test clases to path
set CLASSPATH=%CLASSPATH%;%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
set CLASSPATH=%CLASSPATH%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
set CLASSPATH=%CLASSPATH%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
set CLASSPATH=%CLASSPATH%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
set CLASSPATH=%CLASSPATH%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
)
rem Add hadoop conf dir - else FileSystem.*, etc fail
rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
@ -64,9 +61,6 @@ if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
:no_yarn_conf_dir
rem Add Scala standard library
set CLASSPATH=%CLASSPATH%;%SCALA_HOME%\lib\scala-library.jar;%SCALA_HOME%\lib\scala-compiler.jar;%SCALA_HOME%\lib\jline.jar
rem A bit of a hack to allow calling this script within run2.cmd without seeing output
if "%DONT_PRINT_CLASSPATH%"=="1" goto exit

View file

@ -42,7 +42,7 @@ bin=`cd "$bin"; pwd`
. "$bin/spark-config.sh"
# If the slaves file is specified in the command line,
# then it takes precedence over the definition in
# then it takes precedence over the definition in
# spark-env.sh. Save it here.
HOSTLIST=$SPARK_SLAVES
@ -58,8 +58,6 @@ if [ "$HOSTLIST" = "" ]; then
fi
fi
echo $"${@// /\\ }"
# By default disable strict host key checking
if [ "$SPARK_SSH_OPTS" = "" ]; then
SPARK_SSH_OPTS="-o StrictHostKeyChecking=no"

View file

@ -75,6 +75,7 @@ if [ "$SPARK_IDENT_STRING" = "" ]; then
export SPARK_IDENT_STRING="$USER"
fi
export SPARK_PRINT_LAUNCH_COMMAND="1"
# get log directory
@ -124,13 +125,19 @@ case $startStop in
rsync -a -e ssh --delete --exclude=.svn --exclude='logs/*' --exclude='contrib/hod/logs/*' $SPARK_MASTER/ "$SPARK_HOME"
fi
spark_rotate_log $log
spark_rotate_log "$log"
echo starting $command, logging to $log
echo "Spark Daemon: $command" > $log
cd "$SPARK_PREFIX"
nohup nice -n $SPARK_NICENESS "$SPARK_PREFIX"/spark-class $command "$@" >> "$log" 2>&1 < /dev/null &
echo $! > $pid
sleep 1; head "$log"
newpid=$!
echo $newpid > $pid
sleep 2
# Check if the process has died; in that case we'll tail the log so the user can see
if ! kill -0 $newpid >/dev/null 2>&1; then
echo "failed to launch $command:"
tail -2 "$log" | sed 's/^/ /'
echo "full log in $log"
fi
;;
(stop)

View file

@ -49,4 +49,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then
fi
fi
"$bin"/spark-daemon.sh start spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT
"$bin"/spark-daemon.sh start org.apache.spark.deploy.master.Master 1 --ip $SPARK_MASTER_IP --port $SPARK_MASTER_PORT --webui-port $SPARK_MASTER_WEBUI_PORT

View file

@ -32,4 +32,4 @@ if [ "$SPARK_PUBLIC_DNS" = "" ]; then
fi
fi
"$bin"/spark-daemon.sh start spark.deploy.worker.Worker "$@"
"$bin"/spark-daemon.sh start org.apache.spark.deploy.worker.Worker "$@"

View file

@ -35,8 +35,6 @@ if [ "$SPARK_MASTER_IP" = "" ]; then
SPARK_MASTER_IP=`hostname`
fi
echo "Master IP: $SPARK_MASTER_IP"
# Launch the slaves
if [ "$SPARK_WORKER_INSTANCES" = "" ]; then
exec "$bin/slaves.sh" cd "$SPARK_HOME" \; "$bin/start-slave.sh" 1 spark://$SPARK_MASTER_IP:$SPARK_MASTER_PORT

View file

@ -20,6 +20,7 @@
# Start all spark daemons.
# Run this on the master nde
bin=`dirname "$0"`
bin=`cd "$bin"; pwd`

View file

@ -24,4 +24,4 @@ bin=`cd "$bin"; pwd`
. "$bin/spark-config.sh"
"$bin"/spark-daemon.sh stop spark.deploy.master.Master 1
"$bin"/spark-daemon.sh stop org.apache.spark.deploy.master.Master 1

View file

@ -29,9 +29,9 @@ if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then
fi
if [ "$SPARK_WORKER_INSTANCES" = "" ]; then
"$bin"/spark-daemons.sh stop spark.deploy.worker.Worker 1
"$bin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker 1
else
for ((i=0; i<$SPARK_WORKER_INSTANCES; i++)); do
"$bin"/spark-daemons.sh stop spark.deploy.worker.Worker $(( $i + 1 ))
"$bin"/spark-daemons.sh stop org.apache.spark.deploy.worker.Worker $(( $i + 1 ))
done
fi

View file

@ -1,24 +1,21 @@
#!/usr/bin/env bash
# This file contains environment variables required to run Spark. Copy it as
# spark-env.sh and edit that to configure Spark for your site. At a minimum,
# the following two variables should be set:
# - SCALA_HOME, to point to your Scala installation, or SCALA_LIBRARY_PATH to
# point to the directory for Scala library JARs (if you install Scala as a
# Debian or RPM package, these are in a separate path, often /usr/share/java)
# - MESOS_NATIVE_LIBRARY, to point to your libmesos.so if you use Mesos
# spark-env.sh and edit that to configure Spark for your site.
#
# If using the standalone deploy mode, you can also set variables for it:
# - SPARK_MASTER_IP, to bind the master to a different IP address
# The following variables can be set in this file:
# - SPARK_LOCAL_IP, to set the IP address Spark binds to on this node
# - MESOS_NATIVE_LIBRARY, to point to your libmesos.so if you use Mesos
# - SPARK_JAVA_OPTS, to set node-specific JVM options for Spark. Note that
# we recommend setting app-wide options in the application's driver program.
# Examples of node-specific options : -Dspark.local.dir, GC options
# Examples of app-wide options : -Dspark.serializer
#
# If using the standalone deploy mode, you can also set variables for it here:
# - SPARK_MASTER_IP, to bind the master to a different IP address or hostname
# - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports
# - SPARK_WORKER_CORES, to set the number of cores to use on this machine
# - SPARK_WORKER_MEMORY, to set how much memory to use (e.g. 1000m, 2g)
# - SPARK_WORKER_PORT / SPARK_WORKER_WEBUI_PORT
# - SPARK_WORKER_INSTANCES, to set the number of worker instances/processes
# to be spawned on every slave machine
# - SPARK_JAVA_OPTS, to set the jvm options for executor backend. Note: This is
# only for node-specific options, whereas app-specific options should be set
# in the application.
# Examples of node-speicic options : -Dspark.local.dir, GC related options.
# Examples of app-specific options : -Dspark.serializer
# - SPARK_WORKER_INSTANCES, to set the number of worker processes per node

View file

@ -19,17 +19,17 @@
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-parent</artifactId>
<version>0.8.0-SNAPSHOT</version>
<relativePath>../pom.xml</relativePath>
</parent>
<groupId>org.spark-project</groupId>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core</artifactId>
<packaging>jar</packaging>
<name>Spark Project Core</name>
<url>http://spark-project.org/</url>
<url>http://spark.incubator.apache.org/</url>
<dependencies>
<dependency>

View file

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class FileClient {
private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
private FileClientHandler handler = null;
private Channel channel = null;
private Bootstrap bootstrap = null;
private int connectTimeout = 60*1000; // 1 min
public FileClient(FileClientHandler handler, int connectTimeout) {
this.handler = handler;
this.connectTimeout = connectTimeout;
}
public void init() {
bootstrap = new Bootstrap();
bootstrap.group(new OioEventLoopGroup())
.channel(OioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
.handler(new FileClientChannelInitializer(handler));
}
public void connect(String host, int port) {
try {
// Start the connection attempt.
channel = bootstrap.connect(host, port).sync().channel();
// ChannelFuture cf = channel.closeFuture();
//cf.addListener(new ChannelCloseListener(this));
} catch (InterruptedException e) {
close();
}
}
public void waitForClose() {
try {
channel.closeFuture().sync();
} catch (InterruptedException e) {
LOG.warn("FileClient interrupted", e);
}
}
public void sendRequest(String file) {
//assert(file == null);
//assert(channel == null);
channel.write(file + "\r\n");
}
public void close() {
if(channel != null) {
channel.close();
channel = null;
}
if ( bootstrap!=null) {
bootstrap.shutdown();
bootstrap = null;
}
}
}

View file

@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty;
import io.netty.buffer.BufType;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.string.StringEncoder;
class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> {
private FileClientHandler fhandler;
public FileClientChannelInitializer(FileClientHandler handler) {
fhandler = handler;
}
@Override
public void initChannel(SocketChannel channel) {
// file no more than 2G
channel.pipeline()
.addLast("encoder", new StringEncoder(BufType.BYTE))
.addLast("handler", fhandler);
}
}

View file

@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter;
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
private FileHeader currentHeader = null;
private volatile boolean handlerCalled = false;
public boolean isComplete() {
return handlerCalled;
}
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
public abstract void handleError(String blockId);
@Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
// Use direct buffer if possible.
return ctx.alloc().ioBuffer();
}
@Override
public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) {
// get header
if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) {
currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE()));
}
// get file
if(in.readableBytes() >= currentHeader.fileLen()) {
handle(ctx, in, currentHeader);
handlerCalled = true;
currentHeader = null;
ctx.close();
}
}
}

View file

@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty;
import java.net.InetSocketAddress;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioServerSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Server that accept the path of a file an echo back its content.
*/
class FileServer {
private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
private ServerBootstrap bootstrap = null;
private ChannelFuture channelFuture = null;
private int port = 0;
private Thread blockingThread = null;
public FileServer(PathResolver pResolver, int port) {
InetSocketAddress addr = new InetSocketAddress(port);
// Configure the server.
bootstrap = new ServerBootstrap();
bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup())
.channel(OioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 100)
.option(ChannelOption.SO_RCVBUF, 1500)
.childHandler(new FileServerChannelInitializer(pResolver));
// Start the server.
channelFuture = bootstrap.bind(addr);
try {
// Get the address we bound to.
InetSocketAddress boundAddress =
((InetSocketAddress) channelFuture.sync().channel().localAddress());
this.port = boundAddress.getPort();
} catch (InterruptedException ie) {
this.port = 0;
}
}
/**
* Start the file server asynchronously in a new thread.
*/
public void start() {
blockingThread = new Thread() {
public void run() {
try {
channelFuture.channel().closeFuture().sync();
LOG.info("FileServer exiting");
} catch (InterruptedException e) {
LOG.error("File server start got interrupted", e);
}
// NOTE: bootstrap is shutdown in stop()
}
};
blockingThread.setDaemon(true);
blockingThread.start();
}
public int getPort() {
return port;
}
public void stop() {
// Close the bound channel.
if (channelFuture != null) {
channelFuture.channel().close();
channelFuture = null;
}
// Shutdown bootstrap.
if (bootstrap != null) {
bootstrap.shutdown();
bootstrap = null;
}
// TODO: Shutdown all accepted channels as well ?
}
}

View file

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.DelimiterBasedFrameDecoder;
import io.netty.handler.codec.Delimiters;
import io.netty.handler.codec.string.StringDecoder;
class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> {
PathResolver pResolver;
public FileServerChannelInitializer(PathResolver pResolver) {
this.pResolver = pResolver;
}
@Override
public void initChannel(SocketChannel channel) {
channel.pipeline()
.addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter()))
.addLast("strDecoder", new StringDecoder())
.addLast("handler", new FileServerHandler(pResolver));
}
}

View file

@ -0,0 +1,82 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty;
import java.io.File;
import java.io.FileInputStream;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
PathResolver pResolver;
public FileServerHandler(PathResolver pResolver){
this.pResolver = pResolver;
}
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockId) {
String path = pResolver.getAbsolutePath(blockId);
// if getFilePath returns null, close the channel
if (path == null) {
//ctx.close();
return;
}
File file = new File(path);
if (file.exists()) {
if (!file.isFile()) {
//logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
long length = file.length();
if (length > Integer.MAX_VALUE || length <= 0) {
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
//logger.info("Sending block "+blockId+" filelen = "+len);
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
.getChannel(), 0, file.length()));
} catch (Exception e) {
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
//logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
cause.printStackTrace();
ctx.close();
}
}

View file

@ -0,0 +1,29 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.network.netty;
public interface PathResolver {
/**
* Get the absolute path of the file
*
* @param fileId
* @return the absolute path of file
*/
public String getAbsolutePath(String fileId);
}

View file

@ -1,89 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.network.netty;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class FileClient {
private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
private FileClientHandler handler = null;
private Channel channel = null;
private Bootstrap bootstrap = null;
private int connectTimeout = 60*1000; // 1 min
public FileClient(FileClientHandler handler, int connectTimeout) {
this.handler = handler;
this.connectTimeout = connectTimeout;
}
public void init() {
bootstrap = new Bootstrap();
bootstrap.group(new OioEventLoopGroup())
.channel(OioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
.handler(new FileClientChannelInitializer(handler));
}
public void connect(String host, int port) {
try {
// Start the connection attempt.
channel = bootstrap.connect(host, port).sync().channel();
// ChannelFuture cf = channel.closeFuture();
//cf.addListener(new ChannelCloseListener(this));
} catch (InterruptedException e) {
close();
}
}
public void waitForClose() {
try {
channel.closeFuture().sync();
} catch (InterruptedException e) {
LOG.warn("FileClient interrupted", e);
}
}
public void sendRequest(String file) {
//assert(file == null);
//assert(channel == null);
channel.write(file + "\r\n");
}
public void close() {
if(channel != null) {
channel.close();
channel = null;
}
if ( bootstrap!=null) {
bootstrap.shutdown();
bootstrap = null;
}
}
}

View file

@ -1,41 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.network.netty;
import io.netty.buffer.BufType;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.string.StringEncoder;
class FileClientChannelInitializer extends ChannelInitializer<SocketChannel> {
private FileClientHandler fhandler;
public FileClientChannelInitializer(FileClientHandler handler) {
fhandler = handler;
}
@Override
public void initChannel(SocketChannel channel) {
// file no more than 2G
channel.pipeline()
.addLast("encoder", new StringEncoder(BufType.BYTE))
.addLast("handler", fhandler);
}
}

View file

@ -1,60 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.network.netty;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter;
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
private FileHeader currentHeader = null;
private volatile boolean handlerCalled = false;
public boolean isComplete() {
return handlerCalled;
}
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
public abstract void handleError(String blockId);
@Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
// Use direct buffer if possible.
return ctx.alloc().ioBuffer();
}
@Override
public void inboundBufferUpdated(ChannelHandlerContext ctx, ByteBuf in) {
// get header
if (currentHeader == null && in.readableBytes() >= FileHeader.HEADER_SIZE()) {
currentHeader = FileHeader.create(in.readBytes(FileHeader.HEADER_SIZE()));
}
// get file
if(in.readableBytes() >= currentHeader.fileLen()) {
handle(ctx, in, currentHeader);
handlerCalled = true;
currentHeader = null;
ctx.close();
}
}
}

View file

@ -1,103 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.network.netty;
import java.net.InetSocketAddress;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioServerSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Server that accept the path of a file an echo back its content.
*/
class FileServer {
private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
private ServerBootstrap bootstrap = null;
private ChannelFuture channelFuture = null;
private int port = 0;
private Thread blockingThread = null;
public FileServer(PathResolver pResolver, int port) {
InetSocketAddress addr = new InetSocketAddress(port);
// Configure the server.
bootstrap = new ServerBootstrap();
bootstrap.group(new OioEventLoopGroup(), new OioEventLoopGroup())
.channel(OioServerSocketChannel.class)
.option(ChannelOption.SO_BACKLOG, 100)
.option(ChannelOption.SO_RCVBUF, 1500)
.childHandler(new FileServerChannelInitializer(pResolver));
// Start the server.
channelFuture = bootstrap.bind(addr);
try {
// Get the address we bound to.
InetSocketAddress boundAddress =
((InetSocketAddress) channelFuture.sync().channel().localAddress());
this.port = boundAddress.getPort();
} catch (InterruptedException ie) {
this.port = 0;
}
}
/**
* Start the file server asynchronously in a new thread.
*/
public void start() {
blockingThread = new Thread() {
public void run() {
try {
channelFuture.channel().closeFuture().sync();
LOG.info("FileServer exiting");
} catch (InterruptedException e) {
LOG.error("File server start got interrupted", e);
}
// NOTE: bootstrap is shutdown in stop()
}
};
blockingThread.setDaemon(true);
blockingThread.start();
}
public int getPort() {
return port;
}
public void stop() {
// Close the bound channel.
if (channelFuture != null) {
channelFuture.channel().close();
channelFuture = null;
}
// Shutdown bootstrap.
if (bootstrap != null) {
bootstrap.shutdown();
bootstrap = null;
}
// TODO: Shutdown all accepted channels as well ?
}
}

View file

@ -1,42 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.network.netty;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.DelimiterBasedFrameDecoder;
import io.netty.handler.codec.Delimiters;
import io.netty.handler.codec.string.StringDecoder;
class FileServerChannelInitializer extends ChannelInitializer<SocketChannel> {
PathResolver pResolver;
public FileServerChannelInitializer(PathResolver pResolver) {
this.pResolver = pResolver;
}
@Override
public void initChannel(SocketChannel channel) {
channel.pipeline()
.addLast("framer", new DelimiterBasedFrameDecoder(8192, Delimiters.lineDelimiter()))
.addLast("strDecoder", new StringDecoder())
.addLast("handler", new FileServerHandler(pResolver));
}
}

View file

@ -1,82 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.network.netty;
import java.io.File;
import java.io.FileInputStream;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
PathResolver pResolver;
public FileServerHandler(PathResolver pResolver){
this.pResolver = pResolver;
}
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockId) {
String path = pResolver.getAbsolutePath(blockId);
// if getFilePath returns null, close the channel
if (path == null) {
//ctx.close();
return;
}
File file = new File(path);
if (file.exists()) {
if (!file.isFile()) {
//logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
long length = file.length();
if (length > Integer.MAX_VALUE || length <= 0) {
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
//logger.info("Sending block "+blockId+" filelen = "+len);
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
.getChannel(), 0, file.length()));
} catch (Exception e) {
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
//logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
cause.printStackTrace();
ctx.close();
}
}

View file

@ -1,29 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package spark.network.netty;
public interface PathResolver {
/**
* Get the absolute path of the file
*
* @param fileId
* @return the absolute path of file
*/
public String getAbsolutePath(String fileId);
}

File diff suppressed because one or more lines are too long

View file

Before

Width:  |  Height:  |  Size: 14 KiB

After

Width:  |  Height:  |  Size: 14 KiB

View file

@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
.navbar {
height: 50px;
font-size: 15px;
margin-bottom: 15px;
}
.navbar .navbar-inner {
height: 50px;
}
.navbar .brand {
margin-right: 20px;
margin-bottom: 0;
margin-top: 0;
margin-left: 10px;
padding: 0;
}
.navbar .nav > li {
height: 50px;
}
.navbar .nav > li a {
height: 30px;
line-height: 30px;
}
.navbar-text {
height: 50px;
line-height: 50px;
}
table.sortable thead {
cursor: pointer;
}
.progress {
margin-bottom: 0px; position: relative
}
.progress-completed .bar,
.progress .bar-completed {
background-color: #3EC0FF;
background-image: -moz-linear-gradient(top, #44CBFF, #34B0EE);
background-image: -webkit-gradient(linear, 0 0, 0 100%, from(#44CBFF), to(#34B0EE));
background-image: -webkit-linear-gradient(top, #44CBFF, #34B0EE);
background-image: -o-linear-gradient(top, #44CBFF, #34B0EE);
background-image: linear-gradient(to bottom, #64CBFF, #54B0EE);
background-repeat: repeat-x;
filter: progid:dximagetransform.microsoft.gradient(startColorstr='#FF44CBFF', endColorstr='#FF34B0EE', GradientType=0);
}
.progress-running .bar,
.progress .bar-running {
background-color: #A0DFFF;
background-image: -moz-linear-gradient(top, #A4EDFF, #94DDFF);
background-image: -webkit-gradient(linear, 0 0, 0 100%, from(#A4EDFF), to(#94DDFF));
background-image: -webkit-linear-gradient(top, #A4EDFF, #94DDFF);
background-image: -o-linear-gradient(top, #A4EDFF, #94DDFF);
background-image: linear-gradient(to bottom, #A4EDFF, #94DDFF);
background-repeat: repeat-x;
filter: progid:dximagetransform.microsoft.gradient(startColorstr='#FFA4EDFF', endColorstr='#FF94DDFF', GradientType=0);
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View file

@ -1,82 +0,0 @@
.navbar .brand {
height: 50px;
width: 110px;
margin-left: 1px;
padding: 0;
}
.version {
line-height: 30px;
vertical-align: bottom;
font-size: 12px;
padding: 0;
margin: 0;
font-weight: bold;
color: #777;
}
.navbar-inner {
padding-top: 2px;
height: 50px;
}
.navbar-inner .nav {
margin-top: 5px;
font-size: 15px;
}
#infolist {
margin-left: 400px;
margin-top: 14px;
}
#infolist li {
display: inline;
list-style-type: none;
list-style-position: outside;
padding-right: 20px;
padding-top: 10px;
padding-bottom: 10px;
}
.progress-cell {
width: 134px;
border-right: 0;
padding: 0;
padding-top: 7px;
padding-left: 4px;
line-height: 15px !important;
}
.table-fixed {
table-layout:fixed;
}
.table td {
vertical-align: middle !important;
}
.progress-completed .bar,
.progress .bar-completed {
background-color: #b3def9;
background-image: -moz-linear-gradient(top, #addfff, #badcf2);
background-image: -webkit-gradient(linear, 0 0, 0 100%, from(#addfff), to(#badcf2));
background-image: -webkit-linear-gradient(top, #addfff, #badcf2);
background-image: -o-linear-gradient(top, #addfff, #badcf2);
background-image: linear-gradient(to bottom, #addfff, #badcf2);
background-repeat: repeat-x;
filter: progid:dximagetransform.microsoft.gradient(startColorstr='#ffaddfff', endColorstr='#ffbadcf2', GradientType=0);
}
.progress-running .bar,
.progress .bar-running {
background-color: #c2ebfa;
background-image: -moz-linear-gradient(top, #bdedff, #c7e8f5);
background-image: -webkit-gradient(linear, 0 0, 0 100%, from(#bdedff), to(#c7e8f5));
background-image: -webkit-linear-gradient(top, #bdedff, #c7e8f5);
background-image: -o-linear-gradient(top, #bdedff, #c7e8f5);
background-image: linear-gradient(to bottom, #bdedff, #c7e8f5);
background-repeat: repeat-x;
filter: progid:dximagetransform.microsoft.gradient(startColorstr='#ffbdedff', endColorstr='#ffc7e8f5', GradientType=0);
}

View file

@ -0,0 +1,257 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.io._
import scala.collection.mutable.Map
import scala.collection.generic.Growable
import org.apache.spark.serializer.JavaSerializer
/**
* A datatype that can be accumulated, i.e. has an commutative and associative "add" operation,
* but where the result type, `R`, may be different from the element type being added, `T`.
*
* You must define how to add data, and how to merge two of these together. For some datatypes,
* such as a counter, these might be the same operation. In that case, you can use the simpler
* [[org.apache.spark.Accumulator]]. They won't always be the same, though -- e.g., imagine you are
* accumulating a set. You will add items to the set, and you will union two sets together.
*
* @param initialValue initial value of accumulator
* @param param helper object defining how to add elements of type `R` and `T`
* @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
class Accumulable[R, T] (
@transient initialValue: R,
param: AccumulableParam[R, T])
extends Serializable {
val id = Accumulators.newId
@transient private var value_ = initialValue // Current value on master
val zero = param.zero(initialValue) // Zero value to be passed to workers
var deserialized = false
Accumulators.register(this, true)
/**
* Add more data to this accumulator / accumulable
* @param term the data to add
*/
def += (term: T) { value_ = param.addAccumulator(value_, term) }
/**
* Add more data to this accumulator / accumulable
* @param term the data to add
*/
def add(term: T) { value_ = param.addAccumulator(value_, term) }
/**
* Merge two accumulable objects together
*
* Normally, a user will not want to use this version, but will instead call `+=`.
* @param term the other `R` that will get merged with this
*/
def ++= (term: R) { value_ = param.addInPlace(value_, term)}
/**
* Merge two accumulable objects together
*
* Normally, a user will not want to use this version, but will instead call `add`.
* @param term the other `R` that will get merged with this
*/
def merge(term: R) { value_ = param.addInPlace(value_, term)}
/**
* Access the accumulator's current value; only allowed on master.
*/
def value: R = {
if (!deserialized) {
value_
} else {
throw new UnsupportedOperationException("Can't read accumulator value in task")
}
}
/**
* Get the current value of this accumulator from within a task.
*
* This is NOT the global value of the accumulator. To get the global value after a
* completed operation on the dataset, call `value`.
*
* The typical use of this method is to directly mutate the local value, eg., to add
* an element to a Set.
*/
def localValue = value_
/**
* Set the accumulator's value; only allowed on master.
*/
def value_= (newValue: R) {
if (!deserialized) value_ = newValue
else throw new UnsupportedOperationException("Can't assign accumulator value in task")
}
/**
* Set the accumulator's value; only allowed on master
*/
def setValue(newValue: R) {
this.value = newValue
}
// Called by Java when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
value_ = zero
deserialized = true
Accumulators.register(this, false)
}
override def toString = value_.toString
}
/**
* Helper object defining how to accumulate values of a particular type. An implicit
* AccumulableParam needs to be available when you create Accumulables of a specific type.
*
* @tparam R the full accumulated data (result type)
* @tparam T partial data that can be added in
*/
trait AccumulableParam[R, T] extends Serializable {
/**
* Add additional data to the accumulator value. Is allowed to modify and return `r`
* for efficiency (to avoid allocating objects).
*
* @param r the current value of the accumulator
* @param t the data to be added to the accumulator
* @return the new value of the accumulator
*/
def addAccumulator(r: R, t: T): R
/**
* Merge two accumulated values together. Is allowed to modify and return the first value
* for efficiency (to avoid allocating objects).
*
* @param r1 one set of accumulated data
* @param r2 another set of accumulated data
* @return both data sets merged together
*/
def addInPlace(r1: R, r2: R): R
/**
* Return the "zero" (identity) value for an accumulator type, given its initial value. For
* example, if R was a vector of N dimensions, this would return a vector of N zeroes.
*/
def zero(initialValue: R): R
}
private[spark]
class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
extends AccumulableParam[R,T] {
def addAccumulator(growable: R, elem: T): R = {
growable += elem
growable
}
def addInPlace(t1: R, t2: R): R = {
t1 ++= t2
t1
}
def zero(initialValue: R): R = {
// We need to clone initialValue, but it's hard to specify that R should also be Cloneable.
// Instead we'll serialize it to a buffer and load it back.
val ser = new JavaSerializer().newInstance()
val copy = ser.deserialize[R](ser.serialize(initialValue))
copy.clear() // In case it contained stuff
copy
}
}
/**
* A simpler value of [[org.apache.spark.Accumulable]] where the result type being accumulated is the same
* as the types of elements being merged.
*
* @param initialValue initial value of accumulator
* @param param helper object defining how to add elements of type `T`
* @tparam T result type
*/
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T])
extends Accumulable[T,T](initialValue, param)
/**
* A simpler version of [[org.apache.spark.AccumulableParam]] where the only datatype you can add in is the same type
* as the accumulated value. An implicit AccumulatorParam object needs to be available when you create
* Accumulators of a specific type.
*
* @tparam T type of value to accumulate
*/
trait AccumulatorParam[T] extends AccumulableParam[T, T] {
def addAccumulator(t1: T, t2: T): T = {
addInPlace(t1, t2)
}
}
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private object Accumulators {
// TODO: Use soft references? => need to make readObject work properly then
val originals = Map[Long, Accumulable[_, _]]()
val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
var lastId: Long = 0
def newId: Long = synchronized {
lastId += 1
return lastId
}
def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
if (original) {
originals(a.id) = a
} else {
val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
accums(a.id) = a
}
}
// Clear the local (non-original) accumulators for the current thread
def clear() {
synchronized {
localAccums.remove(Thread.currentThread)
}
}
// Get the values of the local accumulators for the current thread (by ID)
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
ret(id) = accum.localValue
}
return ret
}
// Add values to the original accumulators with some given IDs
def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
if (originals.contains(id)) {
originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value
}
}
}
}

View file

@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions._
/** A set of functions used to aggregate data.
*
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
*/
case class Aggregator[K, V, C] (
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
for (kv <- iter) {
val oldC = combiners.get(kv._1)
if (oldC == null) {
combiners.put(kv._1, createCombiner(kv._2))
} else {
combiners.put(kv._1, mergeValue(oldC, kv._2))
}
}
combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
iter.foreach { case(k, c) =>
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, mergeCombiners(oldC, c))
}
}
combiners.iterator
}
}

View file

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.CompletionIterator
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
: Iterator[T] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))
val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
for (((address, size), index) <- statuses.zipWithIndex) {
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
}
def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
case Some(block) => {
block.asInstanceOf[Iterator[T]]
}
case None => {
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
case regex(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block")
}
}
}
}
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)
CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
metrics.shuffleReadMetrics = Some(shuffleMetrics)
})
}
}

View file

@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
import org.apache.spark.storage.{BlockManager, StorageLevel}
import org.apache.spark.rdd.RDD
/** Spark class responsible for passing RDDs split contents to the BlockManager and making
sure a node doesn't load two copies of an RDD at once.
*/
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
private val loading = new HashSet[String]
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
logInfo("Cache key is " + key)
blockManager.get(key) match {
case Some(cachedValues) =>
// Partition is in cache, so just return its values
logInfo("Found partition in cache!")
return cachedValues.asInstanceOf[Iterator[T]]
case None =>
// Mark the split as loading (unless someone else marks it first)
loading.synchronized {
if (loading.contains(key)) {
logInfo("Loading contains " + key + ", waiting...")
while (loading.contains(key)) {
try {loading.wait()} catch {case _ : Throwable =>}
}
logInfo("Loading no longer contains " + key + ", so returning cached result")
// See whether someone else has successfully loaded it. The main way this would fail
// is for the RDD-level cache eviction policy if someone else has loaded the same RDD
// partition but we didn't want to make space for it. However, that case is unlikely
// because it's unlikely that two threads would work on the same RDD partition. One
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
case None =>
logInfo("Whoever was loading " + key + " failed; we'll try it ourselves")
loading.add(key)
}
} else {
loading.add(key)
}
}
try {
// If we got here, we have to load the split
val elements = new ArrayBuffer[Any]
logInfo("Computing partition " + split)
elements ++= rdd.computeOrReadCheckpoint(split, context)
// Try to put this block in the blockManager
blockManager.put(key, elements, storageLevel, true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {
loading.remove(key)
loading.notifyAll()
}
}
}
}
}

View file

@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import org.apache.spark.rdd.RDD
/**
* Base class for dependencies.
*/
abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
/**
* Base class for dependencies where each partition of the parent RDD is used by at most one
* partition of the child RDD. Narrow dependencies allow for pipelined execution.
*/
abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
/**
* Get the parent partitions for a child partition.
* @param partitionId a partition of the child RDD
* @return the partitions of the parent RDD that the child partition depends upon
*/
def getParents(partitionId: Int): Seq[Int]
}
/**
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
* @param serializerClass class name of the serializer to use
*/
class ShuffleDependency[K, V](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
val serializerClass: String = null)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
}
/**
* Represents a one-to-one dependency between partitions of the parent and child RDDs.
*/
class OneToOneDependency[T](rdd: RDD[T]) extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = List(partitionId)
}
/**
* Represents a one-to-one dependency between ranges of partitions in the parent and child RDDs.
* @param rdd the parent RDD
* @param inStart the start of the range in the parent RDD
* @param outStart the start of the range in the child RDD
* @param length the length of the range
*/
class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
extends NarrowDependency[T](rdd) {
override def getParents(partitionId: Int) = {
if (partitionId >= outStart && partitionId < outStart + length) {
List(partitionId - outStart + inStart)
} else {
Nil
}
}
}

View file

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import org.apache.spark.storage.BlockManagerId
private[spark] class FetchFailedException(
taskEndReason: TaskEndReason,
message: String,
cause: Throwable)
extends Exception {
def this (bmAddress: BlockManagerId, shuffleId: Int, mapId: Int, reduceId: Int, cause: Throwable) =
this(FetchFailed(bmAddress, shuffleId, mapId, reduceId),
"Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId),
cause)
def this (shuffleId: Int, reduceId: Int, cause: Throwable) =
this(FetchFailed(null, shuffleId, -1, reduceId),
"Unable to fetch locations from master: %d %d".format(shuffleId, reduceId), cause)
override def getMessage(): String = message
override def getCause(): Throwable = cause
def toTaskEndReason: TaskEndReason = taskEndReason
}

View file

@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.io.{File}
import com.google.common.io.Files
import org.apache.spark.util.Utils
private[spark] class HttpFileServer extends Logging {
var baseDir : File = null
var fileDir : File = null
var jarDir : File = null
var httpServer : HttpServer = null
var serverUri : String = null
def initialize() {
baseDir = Utils.createTempDir()
fileDir = new File(baseDir, "files")
jarDir = new File(baseDir, "jars")
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
httpServer = new HttpServer(baseDir)
httpServer.start()
serverUri = httpServer.uri
}
def stop() {
httpServer.stop()
}
def addFile(file: File) : String = {
addFileToDir(file, fileDir)
return serverUri + "/files/" + file.getName
}
def addJar(file: File) : String = {
addFileToDir(file, jarDir)
return serverUri + "/jars/" + file.getName
}
def addFileToDir(file: File, dir: File) : String = {
Files.copy(file, new File(dir, file.getName))
return dir + "/" + file.getName
}
}

View file

@ -0,0 +1,89 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.io.File
import java.net.InetAddress
import org.eclipse.jetty.server.Server
import org.eclipse.jetty.server.bio.SocketConnector
import org.eclipse.jetty.server.handler.DefaultHandler
import org.eclipse.jetty.server.handler.HandlerList
import org.eclipse.jetty.server.handler.ResourceHandler
import org.eclipse.jetty.util.thread.QueuedThreadPool
import org.apache.spark.util.Utils
/**
* Exception type thrown by HttpServer when it is in the wrong state for an operation.
*/
private[spark] class ServerStateException(message: String) extends Exception(message)
/**
* An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext
* as well as classes created by the interpreter when the user types in code. This is just a wrapper
* around a Jetty server.
*/
private[spark] class HttpServer(resourceBase: File) extends Logging {
private var server: Server = null
private var port: Int = -1
def start() {
if (server != null) {
throw new ServerStateException("Server is already started")
} else {
server = new Server()
val connector = new SocketConnector
connector.setMaxIdleTime(60*1000)
connector.setSoLingerTime(-1)
connector.setPort(0)
server.addConnector(connector)
val threadPool = new QueuedThreadPool
threadPool.setDaemon(true)
server.setThreadPool(threadPool)
val resHandler = new ResourceHandler
resHandler.setResourceBase(resourceBase.getAbsolutePath)
val handlerList = new HandlerList
handlerList.setHandlers(Array(resHandler, new DefaultHandler))
server.setHandler(handlerList)
server.start()
port = server.getConnectors()(0).getLocalPort()
}
}
def stop() {
if (server == null) {
throw new ServerStateException("Server is already stopped")
} else {
server.stop()
port = -1
server = null
}
}
/**
* Get the URI of this HTTP server (http://host:port)
*/
def uri: String = {
if (server == null) {
throw new ServerStateException("Server is not started")
} else {
return "http://" + Utils.localIpAddress + ":" + port
}
}
}

View file

@ -0,0 +1,95 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import org.slf4j.Logger
import org.slf4j.LoggerFactory
/**
* Utility trait for classes that want to log data. Creates a SLF4J logger for the class and allows
* logging messages at different levels using methods that only evaluate parameters lazily if the
* log level is enabled.
*/
trait Logging {
// Make the log field transient so that objects with Logging can
// be serialized and used on another machine
@transient private var log_ : Logger = null
// Method to get or create the logger for this object
protected def log: Logger = {
if (log_ == null) {
var className = this.getClass.getName
// Ignore trailing $'s in the class names for Scala objects
if (className.endsWith("$")) {
className = className.substring(0, className.length - 1)
}
log_ = LoggerFactory.getLogger(className)
}
return log_
}
// Log methods that take only a String
protected def logInfo(msg: => String) {
if (log.isInfoEnabled) log.info(msg)
}
protected def logDebug(msg: => String) {
if (log.isDebugEnabled) log.debug(msg)
}
protected def logTrace(msg: => String) {
if (log.isTraceEnabled) log.trace(msg)
}
protected def logWarning(msg: => String) {
if (log.isWarnEnabled) log.warn(msg)
}
protected def logError(msg: => String) {
if (log.isErrorEnabled) log.error(msg)
}
// Log methods that take Throwables (Exceptions/Errors) too
protected def logInfo(msg: => String, throwable: Throwable) {
if (log.isInfoEnabled) log.info(msg, throwable)
}
protected def logDebug(msg: => String, throwable: Throwable) {
if (log.isDebugEnabled) log.debug(msg, throwable)
}
protected def logTrace(msg: => String, throwable: Throwable) {
if (log.isTraceEnabled) log.trace(msg, throwable)
}
protected def logWarning(msg: => String, throwable: Throwable) {
if (log.isWarnEnabled) log.warn(msg, throwable)
}
protected def logError(msg: => String, throwable: Throwable) {
if (log.isErrorEnabled) log.error(msg, throwable)
}
protected def isTraceEnabled(): Boolean = {
log.isTraceEnabled
}
// Method for ensuring that logging is initialized, to avoid having multiple
// threads do it concurrently (as SLF4J initialization is not thread safe).
protected def initLogging() { log }
}

View file

@ -0,0 +1,338 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait MapOutputTrackerMessage
private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String)
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
def receive = {
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
sender ! tracker.getSerializedLocations(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
sender ! true
context.stop(self)
}
}
private[spark] class MapOutputTracker extends Logging {
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var epoch: Long = 0
private val epochLock = new java.lang.Object
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error communicating with MapOutputTracker", e)
}
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
}
def registerMapOutputs(
shuffleId: Int,
statuses: Array[MapStatus],
changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
incrementEpoch()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
var array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
}
incrementEpoch()
} else {
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
// Called on possibly remote nodes to get the server URIs and output sizes for a given shuffle
def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = {
val statuses = mapStatuses.get(shuffleId).orNull
if (statuses == null) {
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
if (fetching.contains(shuffleId)) {
// Someone else is fetching it; wait for them to be done
while (fetching.contains(shuffleId)) {
try {
fetching.wait()
} catch {
case e: InterruptedException =>
}
}
}
// Either while we waited the fetch happened successfully, or
// someone fetched it in between the get and the fetching.synchronized.
fetchedStatuses = mapStatuses.get(shuffleId).orNull
if (fetchedStatuses == null) {
// We have to do the fetch, get others to wait for us.
fetching += shuffleId
}
}
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
val hostPort = Utils.localHostPort()
// This try-finally prevents hangs due to timeouts:
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
fetchedStatuses = deserializeStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
fetching.synchronized {
fetching -= shuffleId
fetching.notifyAll()
}
}
}
if (fetchedStatuses != null) {
fetchedStatuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
}
else{
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
}
} else {
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
}
}
}
private def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
communicate(StopMapOutputTracker)
mapStatuses.clear()
metadataCleaner.cancel()
trackerActor = null
}
// Called on master to increment the epoch number
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
logDebug("Increasing epoch to " + epoch)
}
}
// Called on master or workers to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
}
}
// Called on workers to update the epoch number, potentially clearing old outputs
// because of a fetch failure. (Each worker task calls this with the latest epoch
// number on the master at the time it was created.)
def updateEpoch(newEpoch: Long) {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
mapStatuses.clear()
epoch = newEpoch
}
}
}
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
if (epoch > cacheEpoch) {
cachedSerializedStatuses.clear()
cacheEpoch = epoch
}
cachedSerializedStatuses.get(shuffleId) match {
case Some(bytes) =>
return bytes
case None =>
statuses = mapStatuses(shuffleId)
epochGotten = epoch
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
if (epoch == epochGotten) {
cachedSerializedStatuses(shuffleId) = bytes
}
}
return bytes
}
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
statuses.synchronized {
objOut.writeObject(statuses)
}
objOut.close()
out.toByteArray
}
// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
objIn.readObject().
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
// comment this out - nulls could be due to missing location ?
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
}
}
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),
// throw a FetchFailedException.
private def convertMapStatuses(
shuffleId: Int,
reduceId: Int,
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
assert (statuses != null)
statuses.map {
status =>
if (status == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
} else {
(status.location, decompressSize(status.compressedSizes(reduceId)))
}
}
}
/**
* Compress a size in bytes to 8 bits for efficient reporting of map output sizes.
* We do this by encoding the log base 1.1 of the size as an integer, which can support
* sizes up to 35 GB with at most 10% error.
*/
def compressSize(size: Long): Byte = {
if (size == 0) {
0
} else if (size <= 1L) {
1
} else {
math.min(255, math.ceil(math.log(size) / math.log(LOG_BASE)).toInt).toByte
}
}
/**
* Decompress an 8-bit encoded block size, using the reverse operation of compressSize.
*/
def decompressSize(compressedSize: Byte): Long = {
if (compressedSize == 0) {
0
} else {
math.pow(LOG_BASE, (compressedSize & 0xFF)).toLong
}
}
}

View file

@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
/**
* A partition of an RDD.
*/
trait Partition extends Serializable {
/**
* Get the split's index within its parent RDD
*/
def index: Int
// A better default implementation of HashCode
override def hashCode(): Int = index
}

View file

@ -0,0 +1,138 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import org.apache.spark.util.Utils
import org.apache.spark.rdd.RDD
/**
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
* Maps each key to a partition ID, from 0 to `numPartitions - 1`.
*/
abstract class Partitioner extends Serializable {
def numPartitions: Int
def getPartition(key: Any): Int
}
object Partitioner {
/**
* Choose a partitioner to use for a cogroup-like operation between a number of RDDs.
*
* If any of the RDDs already has a partitioner, choose that one.
*
* Otherwise, we use a default HashPartitioner. For the number of partitions, if
* spark.default.parallelism is set, then we'll use the value from SparkContext
* defaultParallelism, otherwise we'll use the max number of upstream partitions.
*
* Unless spark.default.parallelism is set, He number of partitions will be the
* same as the number of partitions in the largest upstream RDD, as this should
* be least likely to cause out-of-memory errors.
*
* We use two method parameters (rdd, others) to enforce callers passing at least 1 RDD.
*/
def defaultPartitioner(rdd: RDD[_], others: RDD[_]*): Partitioner = {
val bySize = (Seq(rdd) ++ others).sortBy(_.partitions.size).reverse
for (r <- bySize if r.partitioner != None) {
return r.partitioner.get
}
if (System.getProperty("spark.default.parallelism") != null) {
return new HashPartitioner(rdd.context.defaultParallelism)
} else {
return new HashPartitioner(bySize.head.partitions.size)
}
}
}
/**
* A [[org.apache.spark.Partitioner]] that implements hash-based partitioning using Java's `Object.hashCode`.
*
* Java arrays have hashCodes that are based on the arrays' identities rather than their contents,
* so attempting to partition an RDD[Array[_]] or RDD[(Array[_], _)] using a HashPartitioner will
* produce an unexpected or incorrect result.
*/
class HashPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
def getPartition(key: Any): Int = key match {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}
override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
case _ =>
false
}
}
/**
* A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges.
* Determines the ranges by sampling the RDD passed in.
*/
class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
partitions: Int,
@transient rdd: RDD[_ <: Product2[K,V]],
private val ascending: Boolean = true)
extends Partitioner {
// An array of upper bounds for the first (partitions - 1) partitions
private val rangeBounds: Array[K] = {
if (partitions == 1) {
Array()
} else {
val rddSize = rdd.count()
val maxSampleSize = partitions * 20.0
val frac = math.min(maxSampleSize / math.max(rddSize, 1), 1.0)
val rddSample = rdd.sample(false, frac, 1).map(_._1).collect().sortWith(_ < _)
if (rddSample.length == 0) {
Array()
} else {
val bounds = new Array[K](partitions - 1)
for (i <- 0 until partitions - 1) {
val index = (rddSample.length - 1) * (i + 1) / partitions
bounds(i) = rddSample(index)
}
bounds
}
}
}
def numPartitions = partitions
def getPartition(key: Any): Int = {
// TODO: Use a binary search here if number of partitions is large
val k = key.asInstanceOf[K]
var partition = 0
while (partition < rangeBounds.length && k > rangeBounds(partition)) {
partition += 1
}
if (ascending) {
partition
} else {
rangeBounds.length - partition
}
}
override def equals(other: Any): Boolean = other match {
case r: RangePartitioner[_,_] =>
r.rangeBounds.sameElements(rangeBounds) && r.ascending == ascending
case _ =>
false
}
}

View file

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.io._
import org.apache.hadoop.io.ObjectWritable
import org.apache.hadoop.io.Writable
import org.apache.hadoop.conf.Configuration
class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable {
def value = t
override def toString = t.toString
private def writeObject(out: ObjectOutputStream) {
out.defaultWriteObject()
new ObjectWritable(t).write(out)
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
val ow = new ObjectWritable()
ow.setConf(new Configuration())
ow.readFields(in)
t = ow.get().asInstanceOf[T]
}
}

View file

@ -0,0 +1,35 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.serializer.Serializer
private[spark] abstract class ShuffleFetcher {
/**
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
/** Stop the fetcher */
def stop() {}
}

View file

@ -0,0 +1,999 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import java.io._
import java.net.URI
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.util.DynamicVariable
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.ArrayWritable
import org.apache.hadoop.io.BooleanWritable
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.DoubleWritable
import org.apache.hadoop.io.FloatWritable
import org.apache.hadoop.io.IntWritable
import org.apache.hadoop.io.LongWritable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.FileInputFormat
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.SequenceFileInputFormat
import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
import org.apache.mesos.MesosNativeLibrary
import org.apache.spark.deploy.LocalSparkCluster
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend,
ClusterScheduler, Schedulable, SchedulingMode}
import org.apache.spark.scheduler.local.LocalScheduler
import org.apache.spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.storage.{StorageUtils, BlockManagerSource}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ClosureCleaner, Utils, MetadataCleaner, TimeStampedHashMap}
import org.apache.spark.scheduler.StageInfo
import org.apache.spark.storage.RDDInfo
import org.apache.spark.storage.StorageStatus
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
*
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param appName A name for your application, to display on the cluster web UI.
* @param sparkHome Location where Spark is installed on cluster nodes.
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
* @param environment Environment variables to set on worker nodes.
*/
class SparkContext(
val master: String,
val appName: String,
val sparkHome: String = null,
val jars: Seq[String] = Nil,
val environment: Map[String, String] = Map(),
// This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) too.
// This is typically generated from InputFormatInfo.computePreferredLocations .. host, set of data-local splits on host
val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map())
extends Logging {
// Ensure logging is initialized before we spawn any threads
initLogging()
// Set Spark driver host and port system properties
if (System.getProperty("spark.driver.host") == null) {
System.setProperty("spark.driver.host", Utils.localHostName())
}
if (System.getProperty("spark.driver.port") == null) {
System.setProperty("spark.driver.port", "0")
}
val isLocal = (master == "local" || master.startsWith("local["))
// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.createFromSystemProperties(
"<driver>",
System.getProperty("spark.driver.host"),
System.getProperty("spark.driver.port").toInt,
true,
isLocal)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
private[spark] val addedFiles = HashMap[String, Long]()
private[spark] val addedJars = HashMap[String, Long]()
// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
// Initalize the Spark UI
private[spark] val ui = new SparkUI(this)
ui.bind()
val startTime = System.currentTimeMillis()
// Add each JAR given through the constructor
if (jars != null) {
jars.foreach { addJar(_) }
}
// Environment variables to pass to our executors
private[spark] val executorEnvs = HashMap[String, String]()
// Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner
for (key <- Seq("SPARK_CLASSPATH", "SPARK_LIBRARY_PATH", "SPARK_JAVA_OPTS", "SPARK_TESTING")) {
val value = System.getenv(key)
if (value != null) {
executorEnvs(key) = value
}
}
// Since memory can be set with a system property too, use that
executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m"
if (environment != null) {
executorEnvs ++= environment
}
// Create and start the scheduler
private var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
//Regular expression for connection to Mesos cluster
val MESOS_REGEX = """(mesos://.*)""".r
master match {
case "local" =>
new LocalScheduler(1, 0, this)
case LOCAL_N_REGEX(threads) =>
new LocalScheduler(threads.toInt, 0, this)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
new LocalScheduler(threads.toInt, maxFailures.toInt, this)
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
scheduler.initialize(backend)
scheduler
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
// Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) {
throw new SparkException(
"Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
memoryPerSlaveInt, SparkContext.executorMemoryRequested))
}
val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
}
scheduler
case "yarn-standalone" =>
val scheduler = try {
val clazz = Class.forName("spark.scheduler.cluster.YarnClusterScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(this).asInstanceOf[ClusterScheduler]
} catch {
// TODO: Enumerate the exact reasons why it can fail
// But irrespective of it, it means we cannot proceed !
case th: Throwable => {
throw new SparkException("YARN mode not available ?", th)
}
}
val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
scheduler.initialize(backend)
scheduler
case _ =>
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
}
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
} else {
new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
}
scheduler.initialize(backend)
scheduler
}
}
taskScheduler.start()
@volatile private var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
ui.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
val env = SparkEnv.get
val conf = env.hadoop.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null && System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
conf.set("fs.s3.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
conf.set("fs.s3n.awsAccessKeyId", System.getenv("AWS_ACCESS_KEY_ID"))
conf.set("fs.s3.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
}
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
conf
}
private[spark] var checkpointDir: Option[String] = None
// Thread Local variable that can be used by users to pass information down the stack
private val localProperties = new DynamicVariable[Properties](null)
def initLocalProperties() {
localProperties.value = new Properties()
}
def setLocalProperty(key: String, value: String) {
if (localProperties.value == null) {
localProperties.value = new Properties()
}
if (value == null) {
localProperties.value.remove(key)
} else {
localProperties.value.setProperty(key, value)
}
}
/** Set a human readable description of the current job. */
def setJobDescription(value: String) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
}
// Post init
taskScheduler.postStartHook()
val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler)
val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager)
def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
SparkEnv.get.metricsSystem.registerSource(blockManagerSource)
}
initDriverMetrics()
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
/** Distribute a local Scala collection to form an RDD. */
def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
parallelize(seq, numSlices)
}
/** Distribute a local Scala collection to form an RDD, with one or more
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = {
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
/**
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minSplits: Int = defaultMinSplits): RDD[String] = {
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], minSplits)
.map(pair => pair._2.toString)
}
/**
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
* etc).
*/
def hadoopRDD[K, V](
conf: JobConf,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int = defaultMinSplits
): RDD[(K, V)] = {
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat */
def hadoopFile[K, V](
path: String,
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int = defaultMinSplits
) : RDD[(K, V)] = {
val conf = new JobConf(hadoopConfiguration)
FileInputFormat.setInputPaths(conf, path)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
/**
* Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
* val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path, minSplits)
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
: RDD[(K, V)] = {
hadoopFile(path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]],
minSplits)
}
/**
* Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
* val file = sparkContext.hadoopFile[LongWritable, Text, TextInputFormat](path)
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
hadoopFile[K, V, F](path, defaultMinSplits)
/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
(implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = {
newAPIHadoopFile(
path,
fm.erasure.asInstanceOf[Class[F]],
km.erasure.asInstanceOf[Class[K]],
vm.erasure.asInstanceOf[Class[V]])
}
/**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](
path: String,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf)
}
/**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
conf: Configuration = hadoopConfiguration,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int
): RDD[(K, V)] = {
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] =
sequenceFile(path, keyClass, valueClass, defaultMinSplits)
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
* WritableConverter. For example, to access a SequenceFile where the keys are Text and the
* values are IntWritable, you could simply write
* {{{
* sparkContext.sequenceFile[String, Int](path, ...)
* }}}
*
* WritableConverters are provided in a somewhat strange way (by an implicit function) to support
* both subclasses of Writable and types for which we define a converter (e.g. Int to
* IntWritable). The most natural thing would've been to have implicit objects for the
* converters, but then we couldn't have an object for every subclass of Writable (you can't
* have a parameterized singleton object). We use functions instead to create a new converter
* for the appropriate type. In addition, we pass the converter a ClassManifest of its type to
* allow it to figure out the Writable class to use in the subclass case.
*/
def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits)
(implicit km: ClassManifest[K], vm: ClassManifest[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
val writables = hadoopFile(path, format,
kc.writableClass(km).asInstanceOf[Class[Writable]],
vc.writableClass(vm).asInstanceOf[Class[Writable]], minSplits)
writables.map{case (k,v) => (kc.convert(k), vc.convert(v))}
}
/**
* Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and
* BytesWritable values that contain a serialized partition. This is still an experimental storage
* format and may not be supported exactly as is in future Spark releases. It will also be pretty
* slow if you use the default serializer (Java serialization), though the nice thing about it is
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T: ClassManifest](
path: String,
minSplits: Int = defaultMinSplits
): RDD[T] = {
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minSplits)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes))
}
protected[spark] def checkpointFile[T: ClassManifest](
path: String
): RDD[T] = {
new CheckpointRDD[T](this, path)
}
/** Build the union of a list of RDDs. */
def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
/** Build the union of a list of RDDs passed as variable-length arguments. */
def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] =
new UnionRDD(this, Seq(first) ++ rest)
// Methods for creating shared variables
/**
* Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" values
* to using the `+=` method. Only the driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
new Accumulator(initialValue, param)
/**
* Create an [[org.apache.spark.Accumulable]] shared variable, to which tasks can add values with `+=`.
* Only the driver can access the accumuable's `value`.
* @tparam T accumulator type
* @tparam R type that can be added to the accumulator
*/
def accumulable[T, R](initialValue: T)(implicit param: AccumulableParam[T, R]) =
new Accumulable(initialValue, param)
/**
* Create an accumulator from a "mutable collection" type.
*
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T](initialValue: R) = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
}
/**
* Broadcast a read-only variable to the cluster, returning a [[org.apache.spark.broadcast.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T) = env.broadcastManager.newBroadcast[T](value, isLocal)
/**
* Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
case _ => path
}
addedFiles(key) = System.currentTimeMillis
// Fetch the file locally in case a job is executed locally.
// Jobs that run through LocalScheduler will already fetch the required dependencies,
// but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here.
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory))
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
}
def addSparkListener(listener: SparkListener) {
dagScheduler.addSparkListener(listener)
}
/**
* Return a map from the slave to the max memory available for caching and the remaining
* memory available for caching.
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
}
/**
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
* they take, etc.
*/
def getRDDStorageInfo: Array[RDDInfo] = {
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
/**
* Returns an immutable map of RDDs that have marked themselves as persistent via cache() call.
* Note that this does not necessarily mean the caching or computation was successful.
*/
def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap
def getStageInfo: Map[Stage,StageInfo] = {
dagScheduler.stageToInfos
}
/**
* Return information about blocks stored in all of the slaves
*/
def getExecutorStorageStatus: Array[StorageStatus] = {
env.blockManager.master.getStorageStatus
}
/**
* Return pools for fair scheduler
* TODO(xiajunluan): We should take nested pools into account
*/
def getAllPools: ArrayBuffer[Schedulable] = {
taskScheduler.rootPool.schedulableQueue
}
/**
* Return the pool associated with the given name, if one exists
*/
def getPoolForName(pool: String): Option[Schedulable] = {
taskScheduler.rootPool.schedulableNameToSchedulable.get(pool)
}
/**
* Return current scheduling mode
*/
def getSchedulingMode: SchedulingMode.SchedulingMode = {
taskScheduler.schedulingMode
}
/**
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
*/
def clearFiles() {
addedFiles.clear()
}
/**
* Gets the locality information associated with the partition in a particular rdd
* @param rdd of interest
* @param partition to be looked up for locality
* @return list of preferred locations for the partition
*/
private [spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
dagScheduler.getPreferredLocs(rdd, partition)
}
/**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
*/
def addJar(path: String) {
if (path == null) {
logWarning("null specified as parameter to addJar",
new SparkException("null specified as parameter to addJar"))
} else {
var key = ""
if (path.contains("\\")) {
// For local paths with backslashes on Windows, URI throws an exception
key = env.httpFileServer.addJar(new File(path))
} else {
val uri = new URI(path)
key = uri.getScheme match {
case null | "file" =>
if (env.hadoop.isYarnMode()) {
logWarning("local jar specified as parameter to addJar under Yarn mode")
return
}
env.httpFileServer.addJar(new File(uri.getPath))
case _ =>
path
}
}
addedJars(key) = System.currentTimeMillis
logInfo("Added JAR " + path + " at " + key + " with timestamp " + addedJars(key))
}
}
/**
* Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
* any new nodes.
*/
def clearJars() {
addedJars.clear()
}
/** Shut down the SparkContext. */
def stop() {
ui.stop()
// Do this only if not stopped already - best case effort.
// prevent NPE if stopped more than once.
val dagSchedulerCopy = dagScheduler
dagScheduler = null
if (dagSchedulerCopy != null) {
metadataCleaner.cancel()
dagSchedulerCopy.stop()
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
// Clean up locally linked files
clearFiles()
clearJars()
SparkEnv.set(null)
ShuffleMapTask.clearCache()
ResultTask.clearCache()
logInfo("Successfully stopped SparkContext")
} else {
logInfo("SparkContext already stopped")
}
}
/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable
* (in that order of preference). If neither of these is set, return None.
*/
private[spark] def getSparkHome(): Option[String] = {
if (sparkHome != null) {
Some(sparkHome)
} else if (System.getProperty("spark.home") != null) {
Some(System.getProperty("spark.home"))
} else if (System.getenv("SPARK_HOME") != null) {
Some(System.getenv("SPARK_HOME"))
} else {
None
}
}
/**
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark. The allowLocal
* flag specifies whether the scheduler can run the computation on the driver rather than
* shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
}
/**
* Run a function on a given set of partitions in an RDD and return the results as an array. The
* allowLocal flag specifies whether the scheduler can run the computation on the driver rather
* than shipping it out to the cluster, for short actions like first().
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
val results = new Array[U](partitions.size)
runJob[T, U](rdd, func, partitions, allowLocal, (index, res) => results(index) = res)
results
}
/**
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
func: Iterator[T] => U,
partitions: Seq[Int],
allowLocal: Boolean
): Array[U] = {
runJob(rdd, (context: TaskContext, iter: Iterator[T]) => func(iter), partitions, allowLocal)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
processPartition: (TaskContext, Iterator[T]) => U,
resultHandler: (Int, U) => Unit)
{
runJob[T, U](rdd, processPartition, 0 until rdd.partitions.size, false, resultHandler)
}
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
def runJob[T, U: ClassManifest](
rdd: RDD[T],
processPartition: Iterator[T] => U,
resultHandler: (Int, U) => Unit)
{
val processFunc = (context: TaskContext, iter: Iterator[T]) => processPartition(iter)
runJob[T, U](rdd, processFunc, 0 until rdd.partitions.size, false, resultHandler)
}
/**
* Run a job that can return approximate results.
*/
def runApproximateJob[T, U, R](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
result
}
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
*/
private[spark] def clean[F <: AnyRef](f: F): F = {
ClosureCleaner.clean(f)
return f
}
/**
* Set the directory under which RDDs are going to be checkpointed. The directory must
* be a HDFS path if running on a cluster. If the directory does not exist, it will
* be created. If the directory exists and useExisting is set to true, then the
* exisiting directory will be used. Otherwise an exception will be thrown to
* prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
val env = SparkEnv.get
val path = new Path(dir)
val fs = path.getFileSystem(env.hadoop.newConfiguration())
if (!useExisting) {
if (fs.exists(path)) {
throw new Exception("Checkpoint directory '" + path + "' already exists.")
} else {
fs.mkdirs(path)
}
}
checkpointDir = Some(dir)
}
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
def defaultParallelism: Int = taskScheduler.defaultParallelism
/** Default min number of partitions for Hadoop RDDs when not given by user */
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
private val nextShuffleId = new AtomicInteger(0)
private[spark] def newShuffleId(): Int = nextShuffleId.getAndIncrement()
private val nextRddId = new AtomicInteger(0)
/** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
/** Called by MetadataCleaner to clean up the persistentRdds map periodically */
private[spark] def cleanup(cleanupTime: Long) {
persistentRdds.clearOldValues(cleanupTime)
}
}
/**
* The SparkContext object contains a number of implicit conversions and parameters for use with
* various Spark features.
*/
object SparkContext {
val SPARK_JOB_DESCRIPTION = "spark.job.description"
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
}
implicit object IntAccumulatorParam extends AccumulatorParam[Int] {
def addInPlace(t1: Int, t2: Int): Int = t1 + t2
def zero(initialValue: Int) = 0
}
implicit object LongAccumulatorParam extends AccumulatorParam[Long] {
def addInPlace(t1: Long, t2: Long) = t1 + t2
def zero(initialValue: Long) = 0l
}
implicit object FloatAccumulatorParam extends AccumulatorParam[Float] {
def addInPlace(t1: Float, t2: Float) = t1 + t2
def zero(initialValue: Float) = 0f
}
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
rdd: RDD[(K, V)]) =
new OrderedRDDFunctions[K, V, (K, V)](rdd)
implicit def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]) = new DoubleRDDFunctions(rdd)
implicit def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
new DoubleRDDFunctions(rdd.map(x => num.toDouble(x)))
// Implicit conversions to common Writable types, for saveAsSequenceFile
implicit def intToIntWritable(i: Int) = new IntWritable(i)
implicit def longToLongWritable(l: Long) = new LongWritable(l)
implicit def floatToFloatWritable(f: Float) = new FloatWritable(f)
implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d)
implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b)
implicit def bytesToBytesWritable (aob: Array[Byte]) = new BytesWritable(aob)
implicit def stringToText(s: String) = new Text(s)
private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = {
def anyToWritable[U <% Writable](u: U): Writable = u
new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]],
arr.map(x => anyToWritable(x)).toArray)
}
// Helper objects for converting common types to Writable
private def simpleWritableConverter[T, W <: Writable: ClassManifest](convert: W => T) = {
val wClass = classManifest[W].erasure.asInstanceOf[Class[W]]
new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
}
implicit def intWritableConverter() = simpleWritableConverter[Int, IntWritable](_.get)
implicit def longWritableConverter() = simpleWritableConverter[Long, LongWritable](_.get)
implicit def doubleWritableConverter() = simpleWritableConverter[Double, DoubleWritable](_.get)
implicit def floatWritableConverter() = simpleWritableConverter[Float, FloatWritable](_.get)
implicit def booleanWritableConverter() = simpleWritableConverter[Boolean, BooleanWritable](_.get)
implicit def bytesWritableConverter() = simpleWritableConverter[Array[Byte], BytesWritable](_.getBytes)
implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString)
implicit def writableWritableConverter[T <: Writable]() =
new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T])
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
* their JARs to SparkContext
*/
def jarOfClass(cls: Class[_]): Seq[String] = {
val uri = cls.getResource("/" + cls.getName.replace('.', '/') + ".class")
if (uri != null) {
val uriStr = uri.toString
if (uriStr.startsWith("jar:file:")) {
// URI will be of the form "jar:file:/path/foo.jar!/package/cls.class", so pull out the /path/foo.jar
List(uriStr.substring("jar:file:".length, uriStr.indexOf('!')))
} else {
Nil
}
} else {
Nil
}
}
/** Find the JAR that contains the class of a particular object */
def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass)
/** Get the amount of memory per executor requested through system properties or SPARK_MEM */
private[spark] val executorMemoryRequested = {
// TODO: Might need to add some extra memory for the non-heap parts of the JVM
Option(System.getProperty("spark.executor.memory"))
.orElse(Option(System.getenv("SPARK_MEM")))
.map(Utils.memoryStringToMb)
.getOrElse(512)
}
}
/**
* A class encapsulating how to convert some type T to Writable. It stores both the Writable class
* corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion.
* The getter for the writable class takes a ClassManifest[T] in case this is a generic object
* that doesn't know the type of T when it is created. This sounds strange but is necessary to
* support converting subclasses of Writable to themselves (writableWritableConverter).
*/
private[spark] class WritableConverter[T](
val writableClass: ClassManifest[T] => Class[_ <: Writable],
val convert: Writable => T)
extends Serializable

View file

@ -0,0 +1,240 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import collection.mutable
import serializer.Serializer
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster}
import org.apache.spark.network.ConnectionManager
import org.apache.spark.serializer.{Serializer, SerializerManager}
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.api.python.PythonWorkerFactory
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
* including the serializer, Akka actor system, block manager, map output tracker, etc. Currently
* Spark code finds the SparkEnv through a thread-local variable, so each thread that accesses these
* objects needs to have the right SparkEnv set. You can get the current environment with
* SparkEnv.get (e.g. after creating a SparkContext) and set it with SparkEnv.set.
*/
class SparkEnv (
val executorId: String,
val actorSystem: ActorSystem,
val serializerManager: SerializerManager,
val serializer: Serializer,
val closureSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
val shuffleFetcher: ShuffleFetcher,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
val httpFileServer: HttpFileServer,
val sparkFilesDir: String,
val metricsSystem: MetricsSystem) {
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
val hadoop = {
val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if(yarnMode) {
try {
Class.forName("spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
} catch {
case th: Throwable => throw new SparkException("Unable to load YARN support", th)
}
} else {
new SparkHadoopUtil
}
}
def stop() {
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
metricsSystem.stop()
actorSystem.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
actorSystem.awaitTermination()
}
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create()
}
}
}
object SparkEnv extends Logging {
private val env = new ThreadLocal[SparkEnv]
@volatile private var lastSetSparkEnv : SparkEnv = _
def set(e: SparkEnv) {
lastSetSparkEnv = e
env.set(e)
}
/**
* Returns the ThreadLocal SparkEnv, if non-null. Else returns the SparkEnv
* previously set in any thread.
*/
def get: SparkEnv = {
Option(env.get()).getOrElse(lastSetSparkEnv)
}
/**
* Returns the ThreadLocal SparkEnv.
*/
def getThreadLocal : SparkEnv = {
env.get()
}
def createFromSystemProperties(
executorId: String,
hostname: String,
port: Int,
isDriver: Boolean,
isLocal: Boolean): SparkEnv = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, port)
// Bit of a hack: If this is the driver and our port was 0 (meaning bind to any free port),
// figure out which port number Akka actually bound to and set spark.driver.port to it.
if (isDriver && port == 0) {
System.setProperty("spark.driver.port", boundPort.toString)
}
// set only if unset until now.
if (System.getProperty("spark.hostPort", null) == null) {
if (!isDriver){
// unexpected
Utils.logErrorWithStack("Unexpected NOT to have spark.hostPort set")
}
Utils.checkHost(hostname)
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
}
val classLoader = Thread.currentThread.getContextClassLoader
// Create an instance of the class named by the given Java system property, or by
// defaultClassName if the property is not set, and return it as a T
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = System.getProperty(propertyName, defaultClassName)
Class.forName(name, true, classLoader).newInstance().asInstanceOf[T]
}
val serializerManager = new SerializerManager
val serializer = serializerManager.setDefault(
System.getProperty("spark.serializer", "org.apache.spark.serializer.JavaSerializer"))
val closureSerializer = serializerManager.get(
System.getProperty("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"))
def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
if (isDriver) {
logInfo("Registering " + name)
actorSystem.actorOf(Props(newActor), name = name)
} else {
val driverHost: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
Utils.checkHost(driverHost, "Expected hostname")
val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
logInfo("Connecting to " + name + ": " + url)
actorSystem.actorFor(url)
}
}
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal)))
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
val connectionManager = blockManager.connectionManager
val broadcastManager = new BroadcastManager(isDriver)
val cacheManager = new CacheManager(blockManager)
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
val mapOutputTracker = new MapOutputTracker()
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerActor(mapOutputTracker))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
val httpFileServer = new HttpFileServer()
httpFileServer.initialize()
System.setProperty("spark.fileserver.uri", httpFileServer.serverUri)
val metricsSystem = if (isDriver) {
MetricsSystem.createMetricsSystem("driver")
} else {
MetricsSystem.createMetricsSystem("executor")
}
metricsSystem.start()
// Set the sparkFiles directory, used when downloading dependencies. In local mode,
// this is a temporary directory; in distributed mode, this is the executor's current working
// directory.
val sparkFilesDir: String = if (isDriver) {
Utils.createTempDir().getAbsolutePath
} else {
"."
}
// Warn about deprecated spark.cache.class property
if (System.getProperty("spark.cache.class") != null) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
"levels using the RDD.persist() method instead.")
}
new SparkEnv(
executorId,
actorSystem,
serializerManager,
serializer,
closureSerializer,
cacheManager,
mapOutputTracker,
shuffleFetcher,
broadcastManager,
blockManager,
connectionManager,
httpFileServer,
sparkFilesDir,
metricsSystem)
}
}

View file

@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
class SparkException(message: String, cause: Throwable)
extends Exception(message, cause) {
def this(message: String) = this(message, null)
}

View file

@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark;
import java.io.File;
/**
* Resolves paths to files added through `SparkContext.addFile()`.
*/
public class SparkFiles {
private SparkFiles() {}
/**
* Get the absolute path of a file added through `SparkContext.addFile()`.
*/
public static String get(String filename) {
return new File(getRootDirectory(), filename).getAbsolutePath();
}
/**
* Get the root directory that contains files added through `SparkContext.addFile()`.
*/
public static String getRootDirectory() {
return SparkEnv.get().sparkFilesDir();
}
}

View file

@ -0,0 +1,201 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.mapred
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import java.text.SimpleDateFormat
import java.text.NumberFormat
import java.io.IOException
import java.util.Date
import org.apache.spark.Logging
import org.apache.spark.SerializableWritable
/**
* Internal helper class that saves an RDD using a Hadoop OutputFormat. This is only public
* because we need to access this class from the `spark` package to use some package-private Hadoop
* functions, but this class should not be used directly by users.
*
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable {
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
private var jobID = 0
private var splitID = 0
private var attemptID = 0
private var jID: SerializableWritable[JobID] = null
private var taID: SerializableWritable[TaskAttemptID] = null
@transient private var writer: RecordWriter[AnyRef,AnyRef] = null
@transient private var format: OutputFormat[AnyRef,AnyRef] = null
@transient private var committer: OutputCommitter = null
@transient private var jobContext: JobContext = null
@transient private var taskContext: TaskAttemptContext = null
def preSetup() {
setIDs(0, 0, 0)
setConfParams()
val jCtxt = getJobContext()
getOutputCommitter().setupJob(jCtxt)
}
def setup(jobid: Int, splitid: Int, attemptid: Int) {
setIDs(jobid, splitid, attemptid)
setConfParams()
}
def open() {
val numfmt = NumberFormat.getInstance()
numfmt.setMinimumIntegerDigits(5)
numfmt.setGroupingUsed(false)
val outputName = "part-" + numfmt.format(splitID)
val path = FileOutputFormat.getOutputPath(conf.value)
val fs: FileSystem = {
if (path != null) {
path.getFileSystem(conf.value)
} else {
FileSystem.get(conf.value)
}
}
getOutputCommitter().setupTask(getTaskContext())
writer = getOutputFormat().getRecordWriter(
fs, conf.value, outputName, Reporter.NULL)
}
def write(key: AnyRef, value: AnyRef) {
if (writer!=null) {
//println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
writer.write(key, value)
} else {
throw new IOException("Writer is null, open() has not been called")
}
}
def close() {
writer.close(Reporter.NULL)
}
def commit() {
val taCtxt = getTaskContext()
val cmtr = getOutputCommitter()
if (cmtr.needsTaskCommit(taCtxt)) {
try {
cmtr.commitTask(taCtxt)
logInfo (taID + ": Committed")
} catch {
case e: IOException => {
logError("Error committing the output of task: " + taID.value, e)
cmtr.abortTask(taCtxt)
throw e
}
}
} else {
logWarning ("No need to commit output of task: " + taID.value)
}
}
def commitJob() {
// always ? Or if cmtr.needsTaskCommit ?
val cmtr = getOutputCommitter()
cmtr.commitJob(getJobContext())
}
def cleanup() {
getOutputCommitter().cleanupJob(getJobContext())
}
// ********* Private Functions *********
private def getOutputFormat(): OutputFormat[AnyRef,AnyRef] = {
if (format == null) {
format = conf.value.getOutputFormat()
.asInstanceOf[OutputFormat[AnyRef,AnyRef]]
}
return format
}
private def getOutputCommitter(): OutputCommitter = {
if (committer == null) {
committer = conf.value.getOutputCommitter
}
return committer
}
private def getJobContext(): JobContext = {
if (jobContext == null) {
jobContext = newJobContext(conf.value, jID.value)
}
return jobContext
}
private def getTaskContext(): TaskAttemptContext = {
if (taskContext == null) {
taskContext = newTaskAttemptContext(conf.value, taID.value)
}
return taskContext
}
private def setIDs(jobid: Int, splitid: Int, attemptid: Int) {
jobID = jobid
splitID = splitid
attemptID = attemptid
jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid))
taID = new SerializableWritable[TaskAttemptID](
new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID))
}
private def setConfParams() {
conf.value.set("mapred.job.id", jID.value.toString)
conf.value.set("mapred.tip.id", taID.value.getTaskID.toString)
conf.value.set("mapred.task.id", taID.value.toString)
conf.value.setBoolean("mapred.task.is.map", true)
conf.value.setInt("mapred.task.partition", splitID)
}
}
object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
val jobtrackerID = formatter.format(new Date())
return new JobID(jobtrackerID, id)
}
def createPathFromString(path: String, conf: JobConf): Path = {
if (path == null) {
throw new IllegalArgumentException("Output path is null")
}
var outputPath = new Path(path)
val fs = outputPath.getFileSystem(conf)
if (outputPath == null || fs == null) {
throw new IllegalArgumentException("Incorrectly formatted output path")
}
outputPath = outputPath.makeQualified(fs)
return outputPath
}
}

View file

@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import executor.TaskMetrics
import scala.collection.mutable.ArrayBuffer
class TaskContext(
val stageId: Int,
val splitId: Int,
val attemptId: Long,
val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable {
@transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
// Add a callback function to be executed on task completion. An example use
// is for HadoopRDD to register a callback to close the input stream.
def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += f
}
def executeOnCompleteCallbacks() {
onCompleteCallbacks.foreach{_()}
}
}

View file

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
/**
* Various possible reasons why a task ended. The low-level TaskScheduler is supposed to retry
* tasks several times for "ephemeral" failures, and only report back failures that require some
* old stages to be resubmitted, such as shuffle map fetch failures.
*/
private[spark] sealed trait TaskEndReason
private[spark] case object Success extends TaskEndReason
private[spark]
case object Resubmitted extends TaskEndReason // Task was finished earlier but we've now lost it
private[spark] case class FetchFailed(
bmAddress: BlockManagerId,
shuffleId: Int,
mapId: Int,
reduceId: Int)
extends TaskEndReason
private[spark] case class ExceptionFailure(
className: String,
description: String,
stackTrace: Array[StackTraceElement],
metrics: Option[TaskMetrics])
extends TaskEndReason
private[spark] case class OtherFailure(message: String) extends TaskEndReason
private[spark] case class TaskResultTooBigFailure() extends TaskEndReason

View file

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import org.apache.mesos.Protos.{TaskState => MesosTaskState}
private[spark] object TaskState
extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value
val FINISHED_STATES = Set(FINISHED, FAILED, KILLED, LOST)
type TaskState = Value
def isFinished(state: TaskState) = FINISHED_STATES.contains(state)
def toMesos(state: TaskState): MesosTaskState = state match {
case LAUNCHING => MesosTaskState.TASK_STARTING
case RUNNING => MesosTaskState.TASK_RUNNING
case FINISHED => MesosTaskState.TASK_FINISHED
case FAILED => MesosTaskState.TASK_FAILED
case KILLED => MesosTaskState.TASK_KILLED
case LOST => MesosTaskState.TASK_LOST
}
def fromMesos(mesosState: MesosTaskState): TaskState = mesosState match {
case MesosTaskState.TASK_STAGING => LAUNCHING
case MesosTaskState.TASK_STARTING => LAUNCHING
case MesosTaskState.TASK_RUNNING => RUNNING
case MesosTaskState.TASK_FINISHED => FINISHED
case MesosTaskState.TASK_FAILED => FAILED
case MesosTaskState.TASK_KILLED => KILLED
case MesosTaskState.TASK_LOST => LOST
}
}

View file

@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.util.StatCounter
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.storage.StorageLevel
import java.lang.Double
import org.apache.spark.Partitioner
class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]]
override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x))
override def wrapRDD(rdd: RDD[Double]): JavaDoubleRDD =
new JavaDoubleRDD(rdd.map(_.doubleValue))
// Common RDD functions
import JavaDoubleRDD.fromRDD
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaDoubleRDD = fromRDD(srdd.cache())
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
// first() has to be overriden here in order for its return type to be Double instead of Object.
override def first(): Double = srdd.first()
// Transformations (return a new RDD)
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(): JavaDoubleRDD = fromRDD(srdd.distinct())
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.distinct(numPartitions))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[Double, java.lang.Boolean]): JavaDoubleRDD =
fromRDD(srdd.filter(x => f(x).booleanValue()))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.coalesce(numPartitions))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD =
fromRDD(srdd.coalesce(numPartitions, shuffle))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
def subtract(other: JavaDoubleRDD): JavaDoubleRDD =
fromRDD(srdd.subtract(other))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
def subtract(other: JavaDoubleRDD, numPartitions: Int): JavaDoubleRDD =
fromRDD(srdd.subtract(other, numPartitions))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
def subtract(other: JavaDoubleRDD, p: Partitioner): JavaDoubleRDD =
fromRDD(srdd.subtract(other, p))
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaDoubleRDD =
fromRDD(srdd.sample(withReplacement, fraction, seed))
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd))
// Double RDD functions
/** Add up the elements in this RDD. */
def sum(): Double = srdd.sum()
/**
* Return a [[org.apache.spark.util.StatCounter]] object that captures the mean, variance and count
* of the RDD's elements in one operation.
*/
def stats(): StatCounter = srdd.stats()
/** Compute the mean of this RDD's elements. */
def mean(): Double = srdd.mean()
/** Compute the variance of this RDD's elements. */
def variance(): Double = srdd.variance()
/** Compute the standard deviation of this RDD's elements. */
def stdev(): Double = srdd.stdev()
/**
* Compute the sample standard deviation of this RDD's elements (which corrects for bias in
* estimating the standard deviation by dividing by N-1 instead of N).
*/
def sampleStdev(): Double = srdd.sampleStdev()
/**
* Compute the sample variance of this RDD's elements (which corrects for bias in
* estimating the standard variance by dividing by N-1 instead of N).
*/
def sampleVariance(): Double = srdd.sampleVariance()
/** Return the approximate mean of the elements in this RDD. */
def meanApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
srdd.meanApprox(timeout, confidence)
/** (Experimental) Approximate operation to return the mean within a timeout. */
def meanApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.meanApprox(timeout)
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
srdd.sumApprox(timeout, confidence)
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
}
object JavaDoubleRDD {
def fromRDD(rdd: RDD[scala.Double]): JavaDoubleRDD = new JavaDoubleRDD(rdd)
implicit def toRDD(rdd: JavaDoubleRDD): RDD[scala.Double] = rdd.srdd
}

View file

@ -0,0 +1,601 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java
import java.util.{List => JList}
import java.util.Comparator
import scala.Tuple2
import scala.collection.JavaConversions._
import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.hadoop.conf.Configuration
import org.apache.spark.HashPartitioner
import org.apache.spark.Partitioner
import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext.rddToPairRDDFunctions
import org.apache.spark.api.java.function.{Function2 => JFunction2}
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.partial.BoundedDouble
import org.apache.spark.partial.PartialResult
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.OrderedRDDFunctions
import org.apache.spark.storage.StorageLevel
class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K],
implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
override val classManifest: ClassManifest[(K, V)] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
import JavaPairRDD._
// Common RDD functions
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache())
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
// Transformations (return a new RDD)
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct())
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numPartitions: Int): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct(numPartitions))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[(K, V), java.lang.Boolean]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.filter(x => f(x).booleanValue()))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.coalesce(numPartitions))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] =
fromRDD(rdd.coalesce(numPartitions, shuffle))
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.sample(withReplacement, fraction, seed))
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.union(other.rdd))
// first() has to be overridden here so that the generated method has the signature
// 'public scala.Tuple2 first()'; if the trait's definition is used,
// then the method has the signature 'public java.lang.Object first()',
// causing NoSuchMethodErrors at runtime.
override def first(): (K, V) = rdd.first()
// Pair RDD functions
/**
* Generic function to combine the elements for each key using a custom set of aggregation
* functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
* "combined type" C * Note that V and C can be different -- for example, one might group an
* RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
* functions:
*
* - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
*
* In addition, users can control the partitioning of the output RDD, and whether to perform
* map-side aggregation (if a mapper can produce multiple items with the same key).
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
partitioner: Partitioner): JavaPairRDD[K, C] = {
implicit val cm: ClassManifest[C] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
fromRDD(rdd.combineByKey(
createCombiner,
mergeValue,
mergeCombiners,
partitioner
))
}
/**
* Simplified version of combineByKey that hash-partitions the output RDD.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
numPartitions: Int): JavaPairRDD[K, C] =
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce.
*/
def reduceByKey(partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
fromRDD(rdd.reduceByKey(partitioner, func))
/**
* Merge the values for each key using an associative reduce function, but return the results
* immediately to the master as a Map. This will also perform the merging locally on each mapper
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
mapAsJavaMap(rdd.reduceByKeyLocally(func))
/** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
/**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
/**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
/**
* Merge the values for each key using an associative function and a neutral "zero value" which may
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, partitioner: Partitioner, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
fromRDD(rdd.foldByKey(zeroValue, partitioner)(func))
/**
* Merge the values for each key using an associative function and a neutral "zero value" which may
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, numPartitions: Int, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
fromRDD(rdd.foldByKey(zeroValue, numPartitions)(func))
/**
* Merge the values for each key using an associative function and a neutral "zero value" which may
* be added to the result an arbitrary number of times, and must not change the result (e.g., Nil for
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, func: JFunction2[V, V, V]): JavaPairRDD[K, V] =
fromRDD(rdd.foldByKey(zeroValue)(func))
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with numPartitions partitions.
*/
def reduceByKey(func: JFunction2[V, V, V], numPartitions: Int): JavaPairRDD[K, V] =
fromRDD(rdd.reduceByKey(func, numPartitions))
/**
* Group the values for each key in the RDD into a single sequence. Allows controlling the
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
*/
def groupByKey(partitioner: Partitioner): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey(partitioner)))
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with into `numPartitions` partitions.
*/
def groupByKey(numPartitions: Int): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey(numPartitions)))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
fromRDD(rdd.subtract(other))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
def subtract(other: JavaPairRDD[K, V], numPartitions: Int): JavaPairRDD[K, V] =
fromRDD(rdd.subtract(other, numPartitions))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
def subtract(other: JavaPairRDD[K, V], p: Partitioner): JavaPairRDD[K, V] =
fromRDD(rdd.subtract(other, p))
/**
* Return a copy of the RDD partitioned using the specified partitioner.
*/
def partitionBy(partitioner: Partitioner): JavaPairRDD[K, V] =
fromRDD(rdd.partitionBy(partitioner))
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce.
*/
def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] =
fromRDD(rdd.join(other, partitioner))
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Uses the given Partitioner to
* partition the output RDD.
*/
def leftOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
: JavaPairRDD[K, (V, Optional[W])] = {
val joinResult = rdd.leftOuterJoin(other, partitioner)
fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))})
}
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Uses the given Partitioner to
* partition the output RDD.
*/
def rightOuterJoin[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
: JavaPairRDD[K, (Optional[V], W)] = {
val joinResult = rdd.rightOuterJoin(other, partitioner)
fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
}
/**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
* partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
implicit val cm: ClassManifest[C] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd)))
}
/**
* Merge the values for each key using an associative reduce function. This will also perform
* the merging locally on each mapper before sending results to a reducer, similarly to a
* "combiner" in MapReduce. Output will be hash-partitioned with the existing partitioner/
* parallelism level.
*/
def reduceByKey(func: JFunction2[V, V, V]): JavaPairRDD[K, V] = {
fromRDD(reduceByKey(defaultPartitioner(rdd), func))
}
/**
* Group the values for each key in the RDD into a single sequence. Hash-partitions the
* resulting RDD with the existing partitioner/parallelism level.
*/
def groupByKey(): JavaPairRDD[K, JList[V]] =
fromRDD(groupByResultToJava(rdd.groupByKey()))
/**
* Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
* pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
* (k, v2) is in `other`. Performs a hash join across the cluster.
*/
def join[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, W)] =
fromRDD(rdd.join(other))
/**
* Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each
* pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and
* (k, v2) is in `other`. Performs a hash join across the cluster.
*/
def join[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, W)] =
fromRDD(rdd.join(other, numPartitions))
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
* using the existing partitioner/parallelism level.
*/
def leftOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (V, Optional[W])] = {
val joinResult = rdd.leftOuterJoin(other)
fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))})
}
/**
* Perform a left outer join of `this` and `other`. For each element (k, v) in `this`, the
* resulting RDD will either contain all pairs (k, (v, Some(w))) for w in `other`, or the
* pair (k, (v, None)) if no elements in `other` have key k. Hash-partitions the output
* into `numPartitions` partitions.
*/
def leftOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (V, Optional[W])] = {
val joinResult = rdd.leftOuterJoin(other, numPartitions)
fromRDD(joinResult.mapValues{case (v, w) => (v, JavaUtils.optionToOptional(w))})
}
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD using the existing partitioner/parallelism level.
*/
def rightOuterJoin[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (Optional[V], W)] = {
val joinResult = rdd.rightOuterJoin(other)
fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
}
/**
* Perform a right outer join of `this` and `other`. For each element (k, w) in `other`, the
* resulting RDD will either contain all pairs (k, (Some(v), w)) for v in `this`, or the
* pair (k, (None, w)) if no elements in `this` have key k. Hash-partitions the resulting
* RDD into the given number of partitions.
*/
def rightOuterJoin[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (Optional[V], W)] = {
val joinResult = rdd.rightOuterJoin(other, numPartitions)
fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
}
/**
* Return the key-value pairs in this RDD to the master as a Map.
*/
def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap())
/**
* Pass each value in the key-value pair RDD through a map function without changing the keys;
* this also retains the original RDD's partitioning.
*/
def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = {
implicit val cm: ClassManifest[U] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
fromRDD(rdd.mapValues(f))
}
/**
* Pass each value in the key-value pair RDD through a flatMap function without changing the
* keys; this also retains the original RDD's partitioning.
*/
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
import scala.collection.JavaConverters._
def fn = (x: V) => f.apply(x).asScala
implicit val cm: ClassManifest[U] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
fromRDD(rdd.flatMapValues(fn))
}
/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: JavaPairRDD[K, W], partitioner: Partitioner)
: JavaPairRDD[K, (JList[V], JList[W])] =
fromRDD(cogroupResultToJava(rdd.cogroup(other, partitioner)))
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], partitioner: Partitioner)
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, partitioner)))
/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
fromRDD(cogroupResultToJava(rdd.cogroup(other)))
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2)))
/**
* For each key k in `this` or `other`, return a resulting RDD that contains a tuple with the
* list of values for that key in `this` as well as `other`.
*/
def cogroup[W](other: JavaPairRDD[K, W], numPartitions: Int): JavaPairRDD[K, (JList[V], JList[W])]
= fromRDD(cogroupResultToJava(rdd.cogroup(other, numPartitions)))
/**
* For each key k in `this` or `other1` or `other2`, return a resulting RDD that contains a
* tuple with the list of values for that key in `this`, `other1` and `other2`.
*/
def cogroup[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2], numPartitions: Int)
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.cogroup(other1, other2, numPartitions)))
/** Alias for cogroup. */
def groupWith[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, (JList[V], JList[W])] =
fromRDD(cogroupResultToJava(rdd.groupWith(other)))
/** Alias for cogroup. */
def groupWith[W1, W2](other1: JavaPairRDD[K, W1], other2: JavaPairRDD[K, W2])
: JavaPairRDD[K, (JList[V], JList[W1], JList[W2])] =
fromRDD(cogroupResult2ToJava(rdd.groupWith(other1, other2)))
/**
* Return the list of values in the RDD for key `key`. This operation is done efficiently if the
* RDD has a known partitioner by only searching the partition that the key maps to.
*/
def lookup(key: K): JList[V] = seqAsJavaList(rdd.lookup(key))
/** Output the RDD to any Hadoop-supported file system. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F],
conf: JobConf) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F]) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
/** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F],
codec: Class[_ <: CompressionCodec]) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F],
conf: Configuration) {
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass, conf)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F]) {
rdd.saveAsNewAPIHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
/**
* Output the RDD to any Hadoop-supported storage system, using a Hadoop JobConf object for
* that storage system. The JobConf should set an OutputFormat and any output paths required
* (e.g. a table name to write to) in the same way as it would be configured for a Hadoop
* MapReduce job.
*/
def saveAsHadoopDataset(conf: JobConf) {
rdd.saveAsHadoopDataset(conf)
}
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements in
* ascending order. Calling `collect` or `save` on the resulting RDD will return or output an
* ordered list of records (in the `save` case, they will be written to multiple `part-X` files
* in the filesystem, in order of the keys).
*/
def sortByKey(): JavaPairRDD[K, V] = sortByKey(true)
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
* `collect` or `save` on the resulting RDD will return or output an ordered list of records
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
def sortByKey(ascending: Boolean): JavaPairRDD[K, V] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
sortByKey(comp, ascending)
}
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
* `collect` or `save` on the resulting RDD will return or output an ordered list of records
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
def sortByKey(comp: Comparator[K]): JavaPairRDD[K, V] = sortByKey(comp, true)
/**
* Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
* `collect` or `save` on the resulting RDD will return or output an ordered list of records
* (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
* order of the keys).
*/
def sortByKey(comp: Comparator[K], ascending: Boolean): JavaPairRDD[K, V] = {
class KeyOrdering(val a: K) extends Ordered[K] {
override def compare(b: K) = comp.compare(a, b)
}
implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending))
}
/**
* Return an RDD with the keys of each tuple.
*/
def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1))
/**
* Return an RDD with the values of each tuple.
*/
def values(): JavaRDD[V] = JavaRDD.fromRDD[V](rdd.map(_._2))
}
object JavaPairRDD {
def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K],
vcm: ClassManifest[T]): RDD[(K, JList[T])] =
rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _)
def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K],
vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V],
Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1],
Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1],
JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues(
(x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1),
seqAsJavaList(x._2),
seqAsJavaList(x._3)))
def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
}

View file

@ -0,0 +1,115 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.storage.StorageLevel
class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends
JavaRDDLike[T, JavaRDD[T]] {
override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
// Common RDD functions
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. This can only be used to assign a new storage level if the RDD does not
* have a storage level set yet..
*/
def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*/
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
// Transformations (return a new RDD)
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct())
/**
* Return a new RDD containing the distinct elements in this RDD.
*/
def distinct(numPartitions: Int): JavaRDD[T] = wrapRDD(rdd.distinct(numPartitions))
/**
* Return a new RDD containing only the elements that satisfy a predicate.
*/
def filter(f: JFunction[T, java.lang.Boolean]): JavaRDD[T] =
wrapRDD(rdd.filter((x => f(x).booleanValue())))
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int): JavaRDD[T] = rdd.coalesce(numPartitions)
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*/
def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] =
rdd.coalesce(numPartitions, shuffle)
/**
* Return a sampled subset of this RDD.
*/
def sample(withReplacement: Boolean, fraction: Double, seed: Int): JavaRDD[T] =
wrapRDD(rdd.sample(withReplacement, fraction, seed))
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
def subtract(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.subtract(other))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
def subtract(other: JavaRDD[T], numPartitions: Int): JavaRDD[T] =
wrapRDD(rdd.subtract(other, numPartitions))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*/
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
wrapRDD(rdd.subtract(other, p))
}
object JavaRDD {
implicit def fromRDD[T: ClassManifest](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd
}

View file

@ -0,0 +1,428 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java
import java.util.{List => JList, Comparator}
import scala.Tuple2
import scala.collection.JavaConversions._
import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark.{SparkContext, Partition, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import org.apache.spark.partial.{PartialResult, BoundedDouble}
import org.apache.spark.storage.StorageLevel
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
implicit val classManifest: ClassManifest[T]
def rdd: RDD[T]
/** Set of partitions in this RDD. */
def splits: JList[Partition] = new java.util.ArrayList(rdd.partitions.toSeq)
/** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
def context: SparkContext = rdd.context
/** A unique ID for this RDD (within its SparkContext). */
def id: Int = rdd.id
/** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */
def getStorageLevel: StorageLevel = rdd.getStorageLevel
/**
* Internal method to this RDD; will read from cache if applicable, or otherwise compute it.
* This should ''not'' be called by users directly, but is available for implementors of custom
* subclasses of RDD.
*/
def iterator(split: Partition, taskContext: TaskContext): java.util.Iterator[T] =
asJavaIterator(rdd.iterator(split, taskContext))
// Transformations (return a new RDD)
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[R](f: JFunction[T, R]): JavaRDD[R] =
new JavaRDD(rdd.map(f)(f.returnType()))(f.returnType())
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[R](f: DoubleFunction[T]): JavaDoubleRDD =
new JavaDoubleRDD(rdd.map(x => f(x).doubleValue()))
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType())
}
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
JavaRDD.fromRDD(rdd.flatMap(fn)(f.elementType()))(f.elementType())
}
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
def flatMap(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
new JavaDoubleRDD(rdd.flatMap(fn).map((x: java.lang.Double) => x.doubleValue()))
}
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue()))
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]):
JavaPairRDD[K2, V2] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
}
/**
* Return an RDD created by coalescing all elements within each partition into an array.
*/
def glom(): JavaRDD[JList[T]] =
new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
/**
* Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
* elements (a, b) where a is in `this` and b is in `other`.
*/
def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] =
JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest,
other.classManifest)
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
implicit val kcm: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
implicit val vcm: ClassManifest[JList[T]] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm)
}
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = {
implicit val kcm: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
implicit val vcm: ClassManifest[JList[T]] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm)
}
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: String): JavaRDD[String] = rdd.pipe(command)
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: JList[String]): JavaRDD[String] =
rdd.pipe(asScalaBuffer(command))
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: JList[String], env: java.util.Map[String, String]): JavaRDD[String] =
rdd.pipe(asScalaBuffer(command), mapAsScalaMap(env))
/**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of
* partitions* and the *same number of elements in each partition* (e.g. one was made through
* a map on the other).
*/
def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = {
JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest)
}
/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
* applying a function to the zipped partitions. Assumes that all the RDDs have the
* *same number of partitions*, but does *not* require them to have the same number
* of elements in each partition.
*/
def zipPartitions[U, V](
other: JavaRDDLike[U, _],
f: FlatMapFunction2[java.util.Iterator[T], java.util.Iterator[U], V]): JavaRDD[V] = {
def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
JavaRDD.fromRDD(
rdd.zipPartitions(other.rdd)(fn)(other.classManifest, f.elementType()))(f.elementType())
}
// Actions (launch a job to return a value to the user program)
/**
* Applies a function f to all elements of this RDD.
*/
def foreach(f: VoidFunction[T]) {
val cleanF = rdd.context.clean(f)
rdd.foreach(cleanF)
}
/**
* Return an array that contains all of the elements in this RDD.
*/
def collect(): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.collect().toSeq
new java.util.ArrayList(arr)
}
/**
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
*/
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
* modify t1 and return it as its result value to avoid object allocation; however, it should not
* modify t2.
*/
def fold(zeroValue: T)(f: JFunction2[T, T, T]): T =
rdd.fold(zeroValue)(f)
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using
* given combine functions and a neutral "zero value". This function can return a different result
* type, U, than the type of this RDD, T. Thus, we need one operation for merging a T into an U
* and one operation for merging two U's, as in scala.TraversableOnce. Both of these functions are
* allowed to modify and return their first argument instead of creating a new U to avoid memory
* allocation.
*/
def aggregate[U](zeroValue: U)(seqOp: JFunction2[U, T, U],
combOp: JFunction2[U, U, U]): U =
rdd.aggregate(zeroValue)(seqOp, combOp)(seqOp.returnType)
/**
* Return the number of elements in the RDD.
*/
def count(): Long = rdd.count()
/**
* (Experimental) Approximate version of count() that returns a potentially incomplete result
* within a timeout, even if not all tasks have finished.
*/
def countApprox(timeout: Long, confidence: Double): PartialResult[BoundedDouble] =
rdd.countApprox(timeout, confidence)
/**
* (Experimental) Approximate version of count() that returns a potentially incomplete result
* within a timeout, even if not all tasks have finished.
*/
def countApprox(timeout: Long): PartialResult[BoundedDouble] =
rdd.countApprox(timeout)
/**
* Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): java.util.Map[T, java.lang.Long] =
mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
/**
* (Experimental) Approximate version of countByValue().
*/
def countByValueApprox(
timeout: Long,
confidence: Double
): PartialResult[java.util.Map[T, BoundedDouble]] =
rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
/**
* (Experimental) Approximate version of countByValue().
*/
def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
rdd.countByValueApprox(timeout).map(mapAsJavaMap)
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
* it will be slow if a lot of partitions are required. In that case, use collect() to get the
* whole RDD instead.
*/
def take(num: Int): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.take(num).toSeq
new java.util.ArrayList(arr)
}
def takeSample(withReplacement: Boolean, num: Int, seed: Int): JList[T] = {
import scala.collection.JavaConversions._
val arr: java.util.Collection[T] = rdd.takeSample(withReplacement, num, seed).toSeq
new java.util.ArrayList(arr)
}
/**
* Return the first element in this RDD.
*/
def first(): T = rdd.first()
/**
* Save this RDD as a text file, using string representations of elements.
*/
def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
/**
* Save this RDD as a compressed text file, using string representations of elements.
*/
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
rdd.saveAsTextFile(path, codec)
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
def saveAsObjectFile(path: String) = rdd.saveAsObjectFile(path)
/**
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
/**
* Mark this RDD for checkpointing. It will be saved to a file inside the checkpoint
* directory set with SparkContext.setCheckpointDir() and all references to its parent
* RDDs will be removed. This function must be called before any job has been
* executed on this RDD. It is strongly recommended that this RDD is persisted in
* memory, otherwise saving it on a file will require recomputation.
*/
def checkpoint() = rdd.checkpoint()
/**
* Return whether this RDD has been checkpointed or not
*/
def isCheckpointed: Boolean = rdd.isCheckpointed
/**
* Gets the name of the file to which this RDD was checkpointed
*/
def getCheckpointFile(): Optional[String] = {
JavaUtils.optionToOptional(rdd.getCheckpointFile)
}
/** A description of this RDD and its recursive dependencies for debugging. */
def toDebugString(): String = {
rdd.toDebugString
}
/**
* Returns the top K elements from this RDD as defined by
* the specified Comparator[T].
* @param num the number of top elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
def top(num: Int, comp: Comparator[T]): JList[T] = {
import scala.collection.JavaConversions._
val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp))
val arr: java.util.Collection[T] = topElems.toSeq
new java.util.ArrayList(arr)
}
/**
* Returns the top K elements from this RDD using the
* natural ordering for T.
* @param num the number of top elements to return
* @return an array of top elements
*/
def top(num: Int): JList[T] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
top(num, comp)
}
/**
* Returns the first K elements from this RDD as defined by
* the specified Comparator[T] and maintains the order.
* @param num the number of top elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
def takeOrdered(num: Int, comp: Comparator[T]): JList[T] = {
import scala.collection.JavaConversions._
val topElems = rdd.takeOrdered(num)(Ordering.comparatorToOrdering(comp))
val arr: java.util.Collection[T] = topElems.toSeq
new java.util.ArrayList(arr)
}
/**
* Returns the first K elements from this RDD using the
* natural ordering for T while maintain the order.
* @param num the number of top elements to return
* @return an array of top elements
*/
def takeOrdered(num: Int): JList[T] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
takeOrdered(num, comp)
}
}

View file

@ -0,0 +1,418 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java
import java.util.{Map => JMap}
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import com.google.common.base.Optional
import org.apache.spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, SparkContext}
import org.apache.spark.SparkContext.IntAccumulatorParam
import org.apache.spark.SparkContext.DoubleAccumulatorParam
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
/**
* A Java-friendly version of [[org.apache.spark.SparkContext]] that returns [[org.apache.spark.api.java.JavaRDD]]s and
* works with Java collections instead of Scala ones.
*/
class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround {
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param appName A name for your application, to display on the cluster web UI
*/
def this(master: String, appName: String) = this(new SparkContext(master, appName))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jarFile JAR file to send to the cluster. This can be a path on the local file system
* or an HDFS, HTTP, HTTPS, or FTP URL.
*/
def this(master: String, appName: String, sparkHome: String, jarFile: String) =
this(new SparkContext(master, appName, sparkHome, Seq(jarFile)))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
*/
def this(master: String, appName: String, sparkHome: String, jars: Array[String]) =
this(new SparkContext(master, appName, sparkHome, jars.toSeq))
/**
* @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]).
* @param appName A name for your application, to display on the cluster web UI
* @param sparkHome The SPARK_HOME directory on the slave nodes
* @param jars Collection of JARs to send to the cluster. These can be paths on the local file
* system or HDFS, HTTP, HTTPS, or FTP URLs.
* @param environment Environment variables to set on worker nodes
*/
def this(master: String, appName: String, sparkHome: String, jars: Array[String],
environment: JMap[String, String]) =
this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment))
private[spark] val env = sc.env
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)
}
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T]): JavaRDD[T] =
parallelize(list, sc.defaultParallelism)
/** Distribute a local Scala collection to form an RDD. */
def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int)
: JavaPairRDD[K, V] = {
implicit val kcm: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
implicit val vcm: ClassManifest[V] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices))
}
/** Distribute a local Scala collection to form an RDD. */
def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]]): JavaPairRDD[K, V] =
parallelizePairs(list, sc.defaultParallelism)
/** Distribute a local Scala collection to form an RDD. */
def parallelizeDoubles(list: java.util.List[java.lang.Double], numSlices: Int): JavaDoubleRDD =
JavaDoubleRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list).map(_.doubleValue()),
numSlices))
/** Distribute a local Scala collection to form an RDD. */
def parallelizeDoubles(list: java.util.List[java.lang.Double]): JavaDoubleRDD =
parallelizeDoubles(list, sc.defaultParallelism)
/**
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String): JavaRDD[String] = sc.textFile(path)
/**
* Read a text file from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minSplits: Int): JavaRDD[String] = sc.textFile(path, minSplits)
/**Get an RDD for a Hadoop SequenceFile with given key and value types. */
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
implicit val kcm = ClassManifest.fromClass(keyClass)
implicit val vcm = ClassManifest.fromClass(valueClass)
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits))
}
/**Get an RDD for a Hadoop SequenceFile. */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]):
JavaPairRDD[K, V] = {
implicit val kcm = ClassManifest.fromClass(keyClass)
implicit val vcm = ClassManifest.fromClass(valueClass)
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass))
}
/**
* Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and
* BytesWritable values that contain a serialized partition. This is still an experimental storage
* format and may not be supported exactly as is in future Spark releases. It will also be pretty
* slow if you use the default serializer (Java serialization), though the nice thing about it is
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = {
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
sc.objectFile(path, minSplits)(cm)
}
/**
* Load an RDD saved as a SequenceFile containing serialized objects, with NullWritable keys and
* BytesWritable values that contain a serialized partition. This is still an experimental storage
* format and may not be supported exactly as is in future Spark releases. It will also be pretty
* slow if you use the default serializer (Java serialization), though the nice thing about it is
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T](path: String): JavaRDD[T] = {
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
sc.objectFile(path)(cm)
}
/**
* Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
* etc).
*/
def hadoopRDD[K, V, F <: InputFormat[K, V]](
conf: JobConf,
inputFormatClass: Class[F],
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
implicit val kcm = ClassManifest.fromClass(keyClass)
implicit val vcm = ClassManifest.fromClass(valueClass)
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits))
}
/**
* Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
* etc).
*/
def hadoopRDD[K, V, F <: InputFormat[K, V]](
conf: JobConf,
inputFormatClass: Class[F],
keyClass: Class[K],
valueClass: Class[V]
): JavaPairRDD[K, V] = {
implicit val kcm = ClassManifest.fromClass(keyClass)
implicit val vcm = ClassManifest.fromClass(valueClass)
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass))
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat */
def hadoopFile[K, V, F <: InputFormat[K, V]](
path: String,
inputFormatClass: Class[F],
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
implicit val kcm = ClassManifest.fromClass(keyClass)
implicit val vcm = ClassManifest.fromClass(valueClass)
new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits))
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat */
def hadoopFile[K, V, F <: InputFormat[K, V]](
path: String,
inputFormatClass: Class[F],
keyClass: Class[K],
valueClass: Class[V]
): JavaPairRDD[K, V] = {
implicit val kcm = ClassManifest.fromClass(keyClass)
implicit val vcm = ClassManifest.fromClass(valueClass)
new JavaPairRDD(sc.hadoopFile(path,
inputFormatClass, keyClass, valueClass))
}
/**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](
path: String,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V],
conf: Configuration): JavaPairRDD[K, V] = {
implicit val kcm = ClassManifest.fromClass(kClass)
implicit val vcm = ClassManifest.fromClass(vClass)
new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf))
}
/**
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*/
def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
conf: Configuration,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): JavaPairRDD[K, V] = {
implicit val kcm = ClassManifest.fromClass(kClass)
implicit val vcm = ClassManifest.fromClass(vClass)
new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass))
}
/** Build the union of two or more RDDs. */
override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = {
val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
implicit val cm: ClassManifest[T] = first.classManifest
sc.union(rdds)(cm)
}
/** Build the union of two or more RDDs. */
override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]])
: JavaPairRDD[K, V] = {
val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
implicit val cm: ClassManifest[(K, V)] = first.classManifest
implicit val kcm: ClassManifest[K] = first.kManifest
implicit val vcm: ClassManifest[V] = first.vManifest
new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm)
}
/** Build the union of two or more RDDs. */
override def union(first: JavaDoubleRDD, rest: java.util.List[JavaDoubleRDD]): JavaDoubleRDD = {
val rdds: Seq[RDD[Double]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.srdd)
new JavaDoubleRDD(sc.union(rdds))
}
/**
* Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def intAccumulator(initialValue: Int): Accumulator[java.lang.Integer] =
sc.accumulator(initialValue)(IntAccumulatorParam).asInstanceOf[Accumulator[java.lang.Integer]]
/**
* Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def doubleAccumulator(initialValue: Double): Accumulator[java.lang.Double] =
sc.accumulator(initialValue)(DoubleAccumulatorParam).asInstanceOf[Accumulator[java.lang.Double]]
/**
* Create an [[org.apache.spark.Accumulator]] integer variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator(initialValue: Int): Accumulator[java.lang.Integer] = intAccumulator(initialValue)
/**
* Create an [[org.apache.spark.Accumulator]] double variable, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator(initialValue: Double): Accumulator[java.lang.Double] =
doubleAccumulator(initialValue)
/**
* Create an [[org.apache.spark.Accumulator]] variable of a given type, which tasks can "add" values
* to using the `add` method. Only the master can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T, accumulatorParam: AccumulatorParam[T]): Accumulator[T] =
sc.accumulator(initialValue)(accumulatorParam)
/**
* Create an [[org.apache.spark.Accumulable]] shared variable of the given type, to which tasks can
* "add" values with `add`. Only the master can access the accumuable's `value`.
*/
def accumulable[T, R](initialValue: T, param: AccumulableParam[T, R]): Accumulable[T, R] =
sc.accumulable(initialValue)(param)
/**
* Broadcast a read-only variable to the cluster, returning a [[org.apache.spark.Broadcast]] object for
* reading it in distributed functions. The variable will be sent to each cluster only once.
*/
def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)
/** Shut down the SparkContext. */
def stop() {
sc.stop()
}
/**
* Get Spark's home location from either a value set through the constructor,
* or the spark.home Java property, or the SPARK_HOME environment variable
* (in that order of preference). If neither of these is set, return None.
*/
def getSparkHome(): Optional[String] = JavaUtils.optionToOptional(sc.getSparkHome())
/**
* Add a file to be downloaded with this Spark job on every node.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(path)` to find its download location.
*/
def addFile(path: String) {
sc.addFile(path)
}
/**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
*/
def addJar(path: String) {
sc.addJar(path)
}
/**
* Clear the job's list of JARs added by `addJar` so that they do not get downloaded to
* any new nodes.
*/
def clearJars() {
sc.clearJars()
}
/**
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
*/
def clearFiles() {
sc.clearFiles()
}
/**
* Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
*/
def hadoopConfiguration(): Configuration = {
sc.hadoopConfiguration
}
/**
* Set the directory under which RDDs are going to be checkpointed. The directory must
* be a HDFS path if running on a cluster. If the directory does not exist, it will
* be created. If the directory exists and useExisting is set to true, then the
* exisiting directory will be used. Otherwise an exception will be thrown to
* prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean) {
sc.setCheckpointDir(dir, useExisting)
}
/**
* Set the directory under which RDDs are going to be checkpointed. The directory must
* be a HDFS path if running on a cluster. If the directory does not exist, it will
* be created. If the directory exists, an exception will be thrown to prevent accidental
* overriding of checkpoint files.
*/
def setCheckpointDir(dir: String) {
sc.setCheckpointDir(dir)
}
protected def checkpointFile[T](path: String): JavaRDD[T] = {
implicit val cm: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
new JavaRDD(sc.checkpointFile(path))
}
}
object JavaSparkContext {
implicit def fromSparkContext(sc: SparkContext): JavaSparkContext = new JavaSparkContext(sc)
implicit def toSparkContext(jsc: JavaSparkContext): SparkContext = jsc.sc
}

View file

@ -0,0 +1,64 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java;
import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
// See
// http://scala-programming-language.1934581.n4.nabble.com/Workaround-for-implementing-java-varargs-in-2-7-2-final-tp1944767p1944772.html
abstract class JavaSparkContextVarargsWorkaround {
public <T> JavaRDD<T> union(JavaRDD<T>... rdds) {
if (rdds.length == 0) {
throw new IllegalArgumentException("Union called on empty list");
}
ArrayList<JavaRDD<T>> rest = new ArrayList<JavaRDD<T>>(rdds.length - 1);
for (int i = 1; i < rdds.length; i++) {
rest.add(rdds[i]);
}
return union(rdds[0], rest);
}
public JavaDoubleRDD union(JavaDoubleRDD... rdds) {
if (rdds.length == 0) {
throw new IllegalArgumentException("Union called on empty list");
}
ArrayList<JavaDoubleRDD> rest = new ArrayList<JavaDoubleRDD>(rdds.length - 1);
for (int i = 1; i < rdds.length; i++) {
rest.add(rdds[i]);
}
return union(rdds[0], rest);
}
public <K, V> JavaPairRDD<K, V> union(JavaPairRDD<K, V>... rdds) {
if (rdds.length == 0) {
throw new IllegalArgumentException("Union called on empty list");
}
ArrayList<JavaPairRDD<K, V>> rest = new ArrayList<JavaPairRDD<K, V>>(rdds.length - 1);
for (int i = 1; i < rdds.length; i++) {
rest.add(rdds[i]);
}
return union(rdds[0], rest);
}
// These methods take separate "first" and "rest" elements to avoid having the same type erasure
abstract public <T> JavaRDD<T> union(JavaRDD<T> first, List<JavaRDD<T>> rest);
abstract public JavaDoubleRDD union(JavaDoubleRDD first, List<JavaDoubleRDD> rest);
abstract public <K, V> JavaPairRDD<K, V> union(JavaPairRDD<K, V> first, List<JavaPairRDD<K, V>> rest);
}

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java
import com.google.common.base.Optional
object JavaUtils {
def optionToOptional[T](option: Option[T]): Optional[T] =
option match {
case Some(value) => Optional.of(value)
case None => Optional.absent()
}
}

View file

@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java;
import org.apache.spark.storage.StorageLevel;
/**
* Expose some commonly useful storage level constants.
*/
public class StorageLevels {
public static final StorageLevel NONE = new StorageLevel(false, false, false, 1);
public static final StorageLevel DISK_ONLY = new StorageLevel(true, false, false, 1);
public static final StorageLevel DISK_ONLY_2 = new StorageLevel(true, false, false, 2);
public static final StorageLevel MEMORY_ONLY = new StorageLevel(false, true, true, 1);
public static final StorageLevel MEMORY_ONLY_2 = new StorageLevel(false, true, true, 2);
public static final StorageLevel MEMORY_ONLY_SER = new StorageLevel(false, true, false, 1);
public static final StorageLevel MEMORY_ONLY_SER_2 = new StorageLevel(false, true, false, 2);
public static final StorageLevel MEMORY_AND_DISK = new StorageLevel(true, true, true, 1);
public static final StorageLevel MEMORY_AND_DISK_2 = new StorageLevel(true, true, true, 2);
public static final StorageLevel MEMORY_AND_DISK_SER = new StorageLevel(true, true, false, 1);
public static final StorageLevel MEMORY_AND_DISK_SER_2 = new StorageLevel(true, true, false, 2);
/**
* Create a new StorageLevel object.
* @param useDisk saved to disk, if true
* @param useMemory saved to memory, if true
* @param deserialized saved as deserialized objects, if true
* @param replication replication factor
*/
public static StorageLevel create(boolean useDisk, boolean useMemory, boolean deserialized, int replication) {
return StorageLevel.apply(useDisk, useMemory, deserialized, replication);
}
}

View file

@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
* A function that returns zero or more records of type Double from each input record.
*/
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
implements Serializable {
public abstract Iterable<Double> call(T t);
@Override
public final Iterable<Double> apply(T t) { return call(t); }
}

View file

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
* A function that returns Doubles, and can be used to construct DoubleRDDs.
*/
// DoubleFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
implements Serializable {
public abstract Double call(T t) throws Exception;
}

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function
/**
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
@throws(classOf[Exception])
def call(x: T) : java.lang.Iterable[R]
def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]]
}

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function
/**
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
@throws(classOf[Exception])
def call(a: A, b:B) : java.lang.Iterable[C]
def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
}

View file

@ -0,0 +1,39 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
* Base class for functions whose return types do not create special RDDs. PairFunction and
* DoubleFunction are handled separately, to allow PairRDDs and DoubleRDDs to be constructed
* when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
public abstract R call(T t) throws Exception;
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}
}

View file

@ -0,0 +1,38 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction2;
import java.io.Serializable;
/**
* A two-argument function that takes arguments of type T1 and T2 and returns an R.
*/
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {
public abstract R call(T1 t1, T2 t2) throws Exception;
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}
}

View file

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import scala.Tuple2;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
* A function that returns zero or more key-value pair records from each input record. The
* key-value pairs are represented as scala.Tuple2 objects.
*/
// PairFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and PairFlatMapFunction.
public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
public ClassManifest<K> keyType() {
return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
}
public ClassManifest<V> valueType() {
return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
}
}

View file

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import scala.Tuple2;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
* A function that returns key-value pairs (Tuple2<K, V>), and can be used to construct PairRDDs.
*/
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
public abstract class PairFunction<T, K, V>
extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
public abstract Tuple2<K, V> call(T t) throws Exception;
public ClassManifest<K> keyType() {
return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
}
public ClassManifest<V> valueType() {
return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
}
}

View file

@ -0,0 +1,33 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function
/**
* A function with no return value.
*/
// This allows Java users to write void methods without having to return Unit.
abstract class VoidFunction[T] extends Serializable {
@throws(classOf[Exception])
def call(t: T) : Unit
}
// VoidFunction cannot extend AbstractFunction1 (because that would force users to explicitly
// return Unit), so it is implicitly converted to a Function1[T, Unit]:
object VoidFunction {
implicit def toFunction[T](f: VoidFunction[T]) : Function1[T, Unit] = ((x : T) => f.call(x))
}

View file

@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function
import scala.runtime.AbstractFunction1
/**
* Subclass of Function1 for ease of calling from Java. The main thing it does is re-expose the
* apply() method as call() and declare that it can throw Exception (since AbstractFunction1.apply
* isn't marked to allow that).
*/
private[spark] abstract class WrappedFunction1[T, R] extends AbstractFunction1[T, R] {
@throws(classOf[Exception])
def call(t: T): R
final def apply(t: T): R = call(t)
}

View file

@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function
import scala.runtime.AbstractFunction2
/**
* Subclass of Function2 for ease of calling from Java. The main thing it does is re-expose the
* apply() method as call() and declare that it can throw Exception (since AbstractFunction2.apply
* isn't marked to allow that).
*/
private[spark] abstract class WrappedFunction2[T1, T2, R] extends AbstractFunction2[T1, T2, R] {
@throws(classOf[Exception])
def call(t1: T1, t2: T2): R
final def apply(t1: T1, t2: T2): R = call(t1, t2)
}

View file

@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.python
import org.apache.spark.Partitioner
import java.util.Arrays
import org.apache.spark.util.Utils
/**
* A [[org.apache.spark.Partitioner]] that performs handling of byte arrays, for use by the Python API.
*
* Stores the unique id() of the Python-side partitioning function so that it is incorporated into
* equality comparisons. Correctness requires that the id is a unique identifier for the
* lifetime of the program (i.e. that it is not re-used as the id of a different partitioning
* function). This can be ensured by using the Python id() function and maintaining a reference
* to the Python partitioning function so that its id() is not reused.
*/
private[spark] class PythonPartitioner(
override val numPartitions: Int,
val pyPartitionFunctionId: Long)
extends Partitioner {
override def getPartition(key: Any): Int = key match {
case null => 0
case key: Array[Byte] => Utils.nonNegativeMod(Arrays.hashCode(key), numPartitions)
case _ => Utils.nonNegativeMod(key.hashCode(), numPartitions)
}
override def equals(other: Any): Boolean = other match {
case h: PythonPartitioner =>
h.numPartitions == numPartitions && h.pyPartitionFunctionId == pyPartitionFunctionId
case _ =>
false
}
}

View file

@ -0,0 +1,346 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.python
import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PipedRDD
import org.apache.spark.util.Utils
private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T],
command: Seq[String],
envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
broadcastVars, accumulator)
override def getPartitions = parent.partitions
override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
val worker = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) {
override def run() {
try {
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
val printOut = new PrintWriter(stream)
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
for (f <- pythonIncludes) {
PythonRDD.writeAsPickle(f, dataOut)
}
dataOut.flush()
// Serialized user code
for (elem <- command) {
printOut.println(elem)
}
printOut.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
PythonRDD.writeAsPickle(elem, dataOut)
}
dataOut.flush()
printOut.flush()
worker.shutdownOutput()
} catch {
case e: IOException =>
// This can happen for legitimate reasons if the Python code stops returning data before we are done
// passing elements through, e.g., for take(). Just log a message to say it happened.
logInfo("stdin writer to Python finished early")
logDebug("stdin writer to Python finished early", e)
}
}
}.start()
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
if (hasNext) {
// FIXME: can deadlock if worker is waiting for us to
// respond to current message (currently irrelevant because
// output is shutdown before we read any input)
_nextObj = read()
}
obj
}
private def read(): Array[Byte] = {
try {
stream.readInt() match {
case length if length > 0 =>
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case -3 =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
read
case -2 =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj))
case -1 =>
// We've finished the data section of the output, but we can still
// read some accumulator updates; let's do that, breaking when we
// get a negative length record.
var len2 = stream.readInt()
while (len2 >= 0) {
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
len2 = stream.readInt()
}
new Array[Byte](0)
}
} catch {
case eof: EOFException => {
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
case e => throw e
}
}
var _nextObj = read()
def hasNext = _nextObj.length != 0
}
}
val asJavaRDD : JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this)
}
/** Thrown for exceptions in user Python code. */
private class PythonException(msg: String) extends Exception(msg)
/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python.
* This is used by PySpark's shuffle operations.
*/
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Array[Byte], Array[Byte])](prev) {
override def getPartitions = prev.partitions
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
private[spark] object PythonRDD {
/** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
def stripPickle(arr: Array[Byte]) : Array[Byte] = {
arr.slice(2, arr.length - 1)
}
/**
* Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
* The data format is a 32-bit integer representing the pickled object's length (in bytes),
* followed by the pickled data.
*
* Pickle module:
*
* http://docs.python.org/2/library/pickle.html
*
* The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
*
* http://hg.python.org/cpython/file/2.6/Lib/pickle.py
* http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
*
* @param elem the object to write
* @param dOut a data output stream
*/
def writeAsPickle(elem: Any, dOut: DataOutputStream) {
if (elem.isInstanceOf[Array[Byte]]) {
val arr = elem.asInstanceOf[Array[Byte]]
dOut.writeInt(arr.length)
dOut.write(arr)
} else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(PythonRDD.stripPickle(t._1))
dOut.write(PythonRDD.stripPickle(t._2))
dOut.writeByte(Pickle.TUPLE2)
dOut.writeByte(Pickle.STOP)
} else if (elem.isInstanceOf[String]) {
// For uniformity, strings are wrapped into Pickles.
val s = elem.asInstanceOf[String].getBytes("UTF-8")
val length = 2 + 1 + 4 + s.length + 1
dOut.writeInt(length)
dOut.writeByte(Pickle.PROTO)
dOut.writeByte(Pickle.TWO)
dOut.write(Pickle.BINUNICODE)
dOut.writeInt(Integer.reverseBytes(s.length))
dOut.write(s)
dOut.writeByte(Pickle.STOP)
} else {
throw new SparkException("Unexpected RDD type")
}
}
def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
try {
while (true) {
val length = file.readInt()
val obj = new Array[Byte](length)
file.readFully(obj)
objs.append(obj)
}
} catch {
case eof: EOFException => {}
case e => throw e
}
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
writeIteratorToPickleFile(items.asScala, filename)
}
def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
writeAsPickle(item, file)
}
file.close()
}
def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
implicit val cm : ClassManifest[T] = rdd.elementClassManifest
rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
}
}
private object Pickle {
val PROTO: Byte = 0x80.toByte
val TWO: Byte = 0x02.toByte
val BINUNICODE: Byte = 'X'
val STOP: Byte = '.'
val TUPLE2: Byte = 0x86.toByte
val EMPTY_LIST: Byte = ']'
val MARK: Byte = '('
val APPENDS: Byte = 'e'
}
private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
/**
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
: JList[Array[Byte]] = {
if (serverHost == null) {
// This happens on the worker node, where we just want to remember all the updates
val1.addAll(val2)
val1
} else {
// This happens on the master, where we pass the updates to Python through a socket
val socket = new Socket(serverHost, serverPort)
val in = socket.getInputStream
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
out.writeInt(val2.size)
for (array <- val2) {
out.writeInt(array.length)
out.write(array)
}
out.flush()
// Wait for a byte from the Python side as an acknowledgement
val byteRead = in.read()
if (byteRead == -1) {
throw new SparkException("EOF reached before Python server acknowledged")
}
socket.close()
null
}
}
}

View file

@ -0,0 +1,223 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.python
import java.io.{OutputStreamWriter, File, DataInputStream, IOException}
import java.net.{ServerSocket, Socket, SocketException, InetAddress}
import scala.collection.JavaConversions._
import org.apache.spark._
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
extends Logging {
// Because forking processes from Java is expensive, we prefer to launch a single Python daemon
// (pyspark/daemon.py) and tell it to fork new workers for our tasks. This daemon currently
// only works on UNIX-based systems now because it uses signals for child management, so we can
// also fall back to launching workers (pyspark/worker.py) directly.
val useDaemon = !System.getProperty("os.name").startsWith("Windows")
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
def create(): Socket = {
if (useDaemon) {
createThroughDaemon()
} else {
createSimpleWorker()
}
}
/**
* Connect to a worker launched through pyspark/daemon.py, which forks python processes itself
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
*/
private def createThroughDaemon(): Socket = {
synchronized {
// Start the daemon if it hasn't been started
startDaemon()
// Attempt to connect, restart and retry once if it fails
try {
new Socket(daemonHost, daemonPort)
} catch {
case exc: SocketException => {
logWarning("Python daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
new Socket(daemonHost, daemonPort)
}
case e => throw e
}
}
}
/**
* Launch a worker by executing worker.py directly and telling it to connect to us.
*/
private def createSimpleWorker(): Socket = {
var serverSocket: ServerSocket = null
try {
serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1)))
// Create and start the worker
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/worker.py"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
workerEnv.put("PYTHONPATH", pythonPath)
val worker = pb.start()
// Redirect the worker's stderr to ours
new Thread("stderr reader for " + pythonExec) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val in = worker.getErrorStream
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
// Redirect worker's stdout to our stderr
new Thread("stdout reader for " + pythonExec) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val in = worker.getInputStream
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
// Tell the worker our port
val out = new OutputStreamWriter(worker.getOutputStream)
out.write(serverSocket.getLocalPort + "\n")
out.flush()
// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
return serverSocket.accept()
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
}
} finally {
if (serverSocket != null) {
serverSocket.close()
}
}
null
}
def stop() {
stopDaemon()
}
private def startDaemon() {
synchronized {
// Is it already running?
if (daemon != null) {
return
}
try {
// Create and start the daemon
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
val pythonPath = sparkHome + "/python/" + File.pathSeparator + workerEnv.get("PYTHONPATH")
workerEnv.put("PYTHONPATH", pythonPath)
daemon = pb.start()
// Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val in = daemon.getErrorStream
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
val in = new DataInputStream(daemon.getInputStream)
daemonPort = in.readInt()
// Redirect further stdout output to our stderr
new Thread("stdout reader for " + pythonExec) {
setDaemon(true)
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME: We copy the stream on the level of bytes to avoid encoding problems.
val buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
} catch {
case e => {
stopDaemon()
throw e
}
}
// Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
// detect our disappearance.
}
}
private def stopDaemon() {
synchronized {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}
daemon = null
daemonPort = 0
}
}
}

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,70 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io._
import java.util.concurrent.atomic.AtomicLong
import org.apache.spark._
abstract class Broadcast[T](private[spark] val id: Long) extends Serializable {
def value: T
// We cannot have an abstract readObject here due to some weird issues with
// readObject having to be 'private' in sub-classes.
override def toString = "Broadcast(" + id + ")"
}
private[spark]
class BroadcastManager(val _isDriver: Boolean) extends Logging with Serializable {
private var initialized = false
private var broadcastFactory: BroadcastFactory = null
initialize()
// Called by SparkContext or Executor before using Broadcast
private def initialize() {
synchronized {
if (!initialized) {
val broadcastFactoryClass = System.getProperty(
"spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
broadcastFactory =
Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory]
// Initialize appropriate BroadcastFactory and BroadcastObject
broadcastFactory.initialize(isDriver)
initialized = true
}
}
}
def stop() {
broadcastFactory.stop()
}
private val nextBroadcastId = new AtomicLong(0)
def newBroadcast[T](value_ : T, isLocal: Boolean) =
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
def isDriver = _isDriver
}

View file

@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
/**
* An interface for all the broadcast implementations in Spark (to allow
* multiple broadcast implementations). SparkContext uses a user-specified
* BroadcastFactory implementation to instantiate a particular broadcast for the
* entire Spark job.
*/
private[spark] trait BroadcastFactory {
def initialize(isDriver: Boolean): Unit
def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def stop(): Unit
}

View file

@ -0,0 +1,171 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
import java.net.URL
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.{HttpServer, Logging, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashSet}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId: String = "broadcast_" + id
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
if (!isLocal) {
HttpBroadcast.write(id, value_)
}
// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) => value_ = x.asInstanceOf[T]
case None => {
logInfo("Started reading broadcast variable " + id)
val start = System.nanoTime
value_ = HttpBroadcast.read[T](id)
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
}
}
private[spark] class HttpBroadcastFactory extends BroadcastFactory {
def initialize(isDriver: Boolean) { HttpBroadcast.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def stop() { HttpBroadcast.stop() }
}
private object HttpBroadcast extends Logging {
private var initialized = false
private var broadcastDir: File = null
private var compress: Boolean = false
private var bufferSize: Int = 65536
private var serverUri: String = null
private var server: HttpServer = null
private val files = new TimeStampedHashSet[String]
private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
private lazy val compressionCodec = CompressionCodec.createCodec()
def initialize(isDriver: Boolean) {
synchronized {
if (!initialized) {
bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
compress = System.getProperty("spark.broadcast.compress", "true").toBoolean
if (isDriver) {
createServer()
}
serverUri = System.getProperty("spark.httpBroadcast.uri")
initialized = true
}
}
}
def stop() {
synchronized {
if (server != null) {
server.stop()
server = null
}
initialized = false
cleaner.cancel()
}
}
private def createServer() {
broadcastDir = Utils.createTempDir(Utils.getLocalDir)
server = new HttpServer(broadcastDir)
server.start()
serverUri = server.uri
System.setProperty("spark.httpBroadcast.uri", serverUri)
logInfo("Broadcast server started at " + serverUri)
}
def write(id: Long, value: Any) {
val file = new File(broadcastDir, "broadcast-" + id)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
} else {
new FastBufferedOutputStream(new FileOutputStream(file), bufferSize)
}
}
val ser = SparkEnv.get.serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject(value)
serOut.close()
files += file.getAbsolutePath
}
def read[T](id: Long): T = {
val url = serverUri + "/broadcast-" + id
val in = {
if (compress) {
compressionCodec.compressedInputStream(new URL(url).openStream())
} else {
new FastBufferedInputStream(new URL(url).openStream(), bufferSize)
}
}
val ser = SparkEnv.get.serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
serIn.close()
obj
}
def cleanup(cleanupTime: Long) {
val iterator = files.internalMap.entrySet().iterator()
while(iterator.hasNext) {
val entry = iterator.next()
val (file, time) = (entry.getKey, entry.getValue)
if (time < cleanupTime) {
try {
iterator.remove()
new File(file.toString).delete()
logInfo("Deleted broadcast file '" + file + "'")
} catch {
case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e)
}
}
}
}
}

View file

@ -0,0 +1,410 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io._
import java.net._
import java.util.Random
import scala.collection.mutable.Map
import org.apache.spark._
import org.apache.spark.util.Utils
private object MultiTracker
extends Logging {
// Tracker Messages
val REGISTER_BROADCAST_TRACKER = 0
val UNREGISTER_BROADCAST_TRACKER = 1
val FIND_BROADCAST_TRACKER = 2
// Map to keep track of guides of ongoing broadcasts
var valueToGuideMap = Map[Long, SourceInfo]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var _isDriver = false
private var stopBroadcast = false
private var trackMV: TrackMultipleValues = null
def initialize(__isDriver: Boolean) {
synchronized {
if (!initialized) {
_isDriver = __isDriver
if (isDriver) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
// Set DriverHostAddress to the driver's IP address for the slaves to read
System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
}
initialized = true
}
}
}
def stop() {
stopBroadcast = true
}
// Load common parameters
private var DriverHostAddress_ = System.getProperty(
"spark.MultiTracker.DriverHostAddress", "")
private var DriverTrackerPort_ = System.getProperty(
"spark.broadcast.driverTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty(
"spark.broadcast.maxRetryCount", "2").toInt
private var TrackerSocketTimeout_ = System.getProperty(
"spark.broadcast.trackerSocketTimeout", "50000").toInt
private var ServerSocketTimeout_ = System.getProperty(
"spark.broadcast.serverSocketTimeout", "10000").toInt
private var MinKnockInterval_ = System.getProperty(
"spark.broadcast.minKnockInterval", "500").toInt
private var MaxKnockInterval_ = System.getProperty(
"spark.broadcast.maxKnockInterval", "999").toInt
// Load TreeBroadcast config params
private var MaxDegree_ = System.getProperty(
"spark.broadcast.maxDegree", "2").toInt
// Load BitTorrentBroadcast config params
private var MaxPeersInGuideResponse_ = System.getProperty(
"spark.broadcast.maxPeersInGuideResponse", "4").toInt
private var MaxChatSlots_ = System.getProperty(
"spark.broadcast.maxChatSlots", "4").toInt
private var MaxChatTime_ = System.getProperty(
"spark.broadcast.maxChatTime", "500").toInt
private var MaxChatBlocks_ = System.getProperty(
"spark.broadcast.maxChatBlocks", "1024").toInt
private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble
def isDriver = _isDriver
// Common config params
def DriverHostAddress = DriverHostAddress_
def DriverTrackerPort = DriverTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
// TreeBroadcast configs
def MaxDegree = MaxDegree_
// BitTorrentBroadcast configs
def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
def MaxChatSlots = MaxChatSlots_
def MaxChatTime = MaxChatTime_
def MaxChatBlocks = MaxChatBlocks_
def EndGameFraction = EndGameFraction_
class TrackMultipleValues
extends Thread with Logging {
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(DriverTrackerPort)
logInfo("TrackMultipleValues started at " + serverSocket)
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(TrackerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
if (stopBroadcast) {
logInfo("Stopping TrackMultipleValues...")
}
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run() {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
// First, read message type
val messageType = ois.readObject.asInstanceOf[Int]
if (messageType == REGISTER_BROADCAST_TRACKER) {
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Receive hostAddress and listenPort
val gInfo = ois.readObject.asInstanceOf[SourceInfo]
// Add to the map
valueToGuideMap.synchronized {
valueToGuideMap += (id -> gInfo)
}
logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Remove from the map
valueToGuideMap.synchronized {
valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
}
logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == FIND_BROADCAST_TRACKER) {
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
var gInfo =
if (valueToGuideMap.contains(id)) valueToGuideMap(id)
else SourceInfo("", SourceInfo.TxNotStartedRetry)
logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
// Send reply back
oos.writeObject(gInfo)
oos.flush()
} else {
throw new SparkException("Undefined messageType at TrackMultipleValues")
}
} catch {
case e: Exception => {
logError("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close socket here; else, client thread will close
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown()
}
}
def getGuideInfo(variableLong: Long): SourceInfo = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
var retriesLeft = MultiTracker.MaxRetryCount
do {
try {
// Connect to the tracker to find out GuideInfo
clientSocketToTracker =
new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send messageType/intention
oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
oosTracker.flush()
// Send Long and receive GuideInfo
oosTracker.writeObject(variableLong)
oosTracker.flush()
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
} catch {
case e: Exception => logError("getGuideInfo had a " + e)
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
Thread.sleep(MultiTracker.ranGen.nextInt(
MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
MultiTracker.MinKnockInterval)
retriesLeft -= 1
} while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
return gInfo
}
def registerBroadcast(id: Long, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
// Send messageType/intention
oosST.writeObject(REGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send Long of this broadcast
oosST.writeObject(id)
oosST.flush()
// Send this tracker's information
oosST.writeObject(gInfo)
oosST.flush()
// Receive ACK and throw it away
oisST.readObject.asInstanceOf[Int]
// Shut stuff down
oisST.close()
oosST.close()
socket.close()
}
def unregisterBroadcast(id: Long) {
val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
// Send messageType/intention
oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send Long of this broadcast
oosST.writeObject(id)
oosST.flush()
// Receive ACK and throw it away
oisST.readObject.asInstanceOf[Int]
// Shut stuff down
oisST.close()
oosST.close()
socket.close()
}
// Helper method to convert an object to Array[BroadcastBlock]
def blockifyObject[IN](obj: IN): VariableInfo = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream(baos)
oos.writeObject(obj)
oos.close()
baos.close()
val byteArray = baos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / BlockSize)
if (byteArray.length % BlockSize != 0)
blockNum += 1
var retVal = new Array[BroadcastBlock](blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, BlockSize)) {
val thisBlockSize = math.min(BlockSize, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()
var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
variableInfo.hasBlocks = blockNum
return variableInfo
}
// Helper method to convert Array[BroadcastBlock] to object
def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
totalBytes: Int,
totalBlocks: Int): OUT = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BlockSize, arrayOfBlocks(i).byteArray.length)
}
byteArrayToObject(retByteArray)
}
private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
}
val retVal = in.readObject.asInstanceOf[OUT]
in.close()
return retVal
}
}
private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
extends Serializable
private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient var hasBlocks = 0
}

View file

@ -0,0 +1,54 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.util.BitSet
import org.apache.spark._
/**
* Used to keep and pass around information of peers involved in a broadcast
*/
private[spark] case class SourceInfo (hostAddress: String,
listenPort: Int,
totalBlocks: Int = SourceInfo.UnusedParam,
totalBytes: Int = SourceInfo.UnusedParam)
extends Comparable[SourceInfo] with Logging {
var currentLeechers = 0
var receptionFailed = false
var hasBlocks = 0
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
// Ascending sort based on leecher count
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
}
/**
* Helper Object of SourceInfo for its constants
*/
private[spark] object SourceInfo {
// Broadcast has not started yet! Should never happen.
val TxNotStartedRetry = -1
// Broadcast has already finished. Try default mechanism.
val TxOverGoToDefault = -3
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
}

View file

@ -0,0 +1,603 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io._
import java.net._
import java.util.{Comparator, Random, UUID}
import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math
import org.apache.spark._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId = "broadcast_" + id
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var listOfSources = ListBuffer[SourceInfo]()
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast()
}
def sendBroadcast() {
logInfo("Local host address: " + hostAddress)
// Create a variableInfo object and store it in valueInfos
var variableInfo = MultiTracker.blockifyObject(value_)
// Prepare the value being broadcasted
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
guideMR = new GuideMultipleRequests
guideMR.setDaemon(true)
guideMR.start()
logInfo("GuideMultipleRequests started...")
// Must always come AFTER guideMR is created
while (guidePort == -1) {
guidePortLock.synchronized { guidePortLock.wait() }
}
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
// Must always come AFTER serveMR is created
while (listenPort == -1) {
listenPortLock.synchronized { listenPortLock.wait() }
}
// Must always come AFTER listenPort is created
val masterSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
listOfSources += masterSource
// Register with the Tracker
MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables()
logInfo("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
private def initializeWorkerVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = Utils.localIpAddress
listenPort = -1
stopBroadcast = false
}
def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized { listenPortLock.wait() }
}
var clientSocketToDriver: Socket = null
var oosDriver: ObjectOutputStream = null
var oisDriver: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = MultiTracker.MaxRetryCount
do {
// Connect to Driver and send this worker's Information
clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
oosDriver.flush()
oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
logDebug("Connected to Driver's guiding object")
// Send local source information
oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
oosDriver.flush()
// Receive source information from Driver
var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = sourceInfo.totalBytes
logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time = (System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Driver will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Driver
oosDriver.writeObject(sourceInfo)
if (oisDriver != null) {
oisDriver.close()
}
if (oosDriver != null) {
oosDriver.close()
}
if (clientSocketToDriver != null) {
clientSocketToDriver.close()
}
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return (hasBlocks == totalBlocks)
}
/**
* Tries to receive broadcast from the source and returns Boolean status.
* This might be called multiple times to retry a defined number of times.
*/
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
oosSource.flush()
oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
logDebug("Inside receiveSingleTransmission")
logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
oosSource.flush()
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
}
} catch {
case e: Exception => logError("receiveSingleTransmission had a " + e)
} finally {
if (oisSource != null) {
oisSource.close()
}
if (oosSource != null) {
oosSource.close()
}
if (clientSocketToSource != null) {
clientSocketToSource.close()
}
}
return receptionSucceeded
}
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo]()
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized { guidePortLock.notifyAll() }
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
listOfSources.synchronized {
setOfCompletedSources.synchronized {
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
}
}
}
}
}
if (clientSocket != null) {
logDebug("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
// In failure, close() the socket here; else, the thread will close() it
case ioe: IOException => clientSocket.close()
}
}
}
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
MultiTracker.unregisterBroadcast(id)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
private def sendStopBroadcastNotifications() {
listOfSources.synchronized {
var listIter = listOfSources.iterator
while (listIter.hasNext) {
var sourceInfo = listIter.next
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
// Send stopBroadcast signal
gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
gosSource.flush()
} catch {
case e: Exception => {
logError("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
gisSource.close()
}
if (gosSource != null) {
gosSource.close()
}
if (guideSocketToSource != null) {
guideSocketToSource.close()
}
}
}
}
}
class GuideSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run() {
try {
logInfo("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource(sourceInfo)
logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject(selectedSourceInfo)
oos.flush()
// Add this new (if it can finish) source to the list of sources
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes)
logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
listOfSources += thisWorkerInfo
}
// Wait till the whole transfer is done. Then receive and update source
// statistics in listOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// This should work since SourceInfo is a case class
assert(listOfSources.contains(selectedSourceInfo))
// Remove first
// (Currently removing a source based on just one failure notification!)
listOfSources = listOfSources - selectedSourceInfo
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources.synchronized {
setOfCompletedSources += thisWorkerInfo
}
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
listOfSources += selectedSourceInfo
}
}
} catch {
case e: Exception => {
// Remove failed worker from listOfSources and update leecherCount of
// corresponding source worker
listOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
listOfSources = listOfSources - selectedSourceInfo
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
listOfSources += selectedSourceInfo
}
// Remove thisWorkerInfo
if (listOfSources != null) {
listOfSources = listOfSources - thisWorkerInfo
}
}
}
} finally {
logInfo("GuideSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
// Assuming the caller to have a synchronized block on listOfSources
// Select one with the most leechers. This will level-wise fill the tree
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
var maxLeechers = -1
var selectedSource: SourceInfo = null
listOfSources.foreach { source =>
if ((source.hostAddress != skipSourceInfo.hostAddress ||
source.listenPort != skipSourceInfo.listenPort) &&
source.currentLeechers < MultiTracker.MaxDegree &&
source.currentLeechers > maxLeechers) {
selectedSource = source
maxLeechers = source.currentLeechers
}
}
// Update leecher count
selectedSource.currentLeechers += 1
return selectedSource
}
}
}
class ServeMultipleRequests
extends Thread with Logging {
var threadPool = Utils.newDaemonCachedThreadPool()
override def run() {
var serverSocket = new ServerSocket(0)
listenPort = serverSocket.getLocalPort
logInfo("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized { listenPortLock.notifyAll() }
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => { }
}
if (clientSocket != null) {
logDebug("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
if (serverSocket != null) {
logInfo("ServeMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
class ServeSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run() {
try {
logInfo("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
// If not a valid range, stop broadcast
if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
sendObject
}
} catch {
case e: Exception => logError("ServeSingleRequest had a " + e)
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
private def sendObject() {
// Wait till receiving the SourceInfo from Driver
while (totalBlocks == -1) {
totalBlocksLock.synchronized { totalBlocksLock.wait() }
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized { hasBlocksLock.wait() }
}
try {
oos.writeObject(arrayOfBlocks(i))
oos.flush()
} catch {
case e: Exception => logError("sendObject had a " + e)
}
logDebug("Sent block: " + i + " to " + clientSocket)
}
}
}
}
}
private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
}

View file

@ -0,0 +1,32 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy
private[spark] class ApplicationDescription(
val name: String,
val maxCores: Int, /* Integer.MAX_VALUE denotes an unlimited number of cores */
val memoryPerSlave: Int,
val command: Command,
val sparkHome: String,
val appUiUrl: String)
extends Serializable {
val user = System.getProperty("user.name", "<unknown>")
override def toString: String = "ApplicationDescription(" + name + ")"
}

View file

@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy
import scala.collection.Map
private[spark] case class Command(
mainClass: String,
arguments: Seq[String],
environment: Map[String, String]) {
}

View file

@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy
import scala.collection.immutable.List
import org.apache.spark.deploy.ExecutorState.ExecutorState
import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo}
import org.apache.spark.deploy.worker.ExecutorRunner
import org.apache.spark.util.Utils
private[deploy] sealed trait DeployMessage extends Serializable
private[deploy] object DeployMessages {
// Worker to Master
case class RegisterWorker(
id: String,
host: String,
port: Int,
cores: Int,
memory: Int,
webUiPort: Int,
publicAddress: String)
extends DeployMessage {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
}
case class ExecutorStateChanged(
appId: String,
execId: Int,
state: ExecutorState,
message: Option[String],
exitStatus: Option[Int])
extends DeployMessage
case class Heartbeat(workerId: String) extends DeployMessage
// Master to Worker
case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
case class RegisterWorkerFailed(message: String) extends DeployMessage
case class KillExecutor(appId: String, execId: Int) extends DeployMessage
case class LaunchExecutor(
appId: String,
execId: Int,
appDesc: ApplicationDescription,
cores: Int,
memory: Int,
sparkHome: String)
extends DeployMessage
// Client to Master
case class RegisterApplication(appDescription: ApplicationDescription)
extends DeployMessage
// Master to Client
case class RegisteredApplication(appId: String) extends DeployMessage
// TODO(matei): replace hostPort with host
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
Utils.checkHostPort(hostPort, "Required hostport")
}
case class ExecutorUpdated(id: Int, state: ExecutorState, message: Option[String],
exitStatus: Option[Int])
case class ApplicationRemoved(message: String)
// Internal message in Client
case object StopClient
// MasterWebUI To Master
case object RequestMasterState
// Master to MasterWebUI
case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo],
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
def uri = "spark://" + host + ":" + port
}
// WorkerWebUI to Worker
case object RequestWorkerState
// Worker to WorkerWebUI
case class WorkerStateResponse(host: String, port: Int, workerId: String,
executors: List[ExecutorRunner], finishedExecutors: List[ExecutorRunner], masterUrl: String,
cores: Int, memory: Int, coresUsed: Int, memoryUsed: Int, masterWebUiUrl: String) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
}
// Actor System to Master
case object CheckForWorkerTimeOut
}

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy
private[spark] object ExecutorState
extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") {
val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value
type ExecutorState = Value
def isFinished(state: ExecutorState): Boolean = Seq(KILLED, FAILED, LOST).contains(state)
}

View file

@ -0,0 +1,86 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy
import net.liftweb.json.JsonDSL._
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse}
import org.apache.spark.deploy.master.{ApplicationInfo, WorkerInfo}
import org.apache.spark.deploy.worker.ExecutorRunner
private[spark] object JsonProtocol {
def writeWorkerInfo(obj: WorkerInfo) = {
("id" -> obj.id) ~
("host" -> obj.host) ~
("port" -> obj.port) ~
("webuiaddress" -> obj.webUiAddress) ~
("cores" -> obj.cores) ~
("coresused" -> obj.coresUsed) ~
("memory" -> obj.memory) ~
("memoryused" -> obj.memoryUsed) ~
("state" -> obj.state.toString)
}
def writeApplicationInfo(obj: ApplicationInfo) = {
("starttime" -> obj.startTime) ~
("id" -> obj.id) ~
("name" -> obj.desc.name) ~
("cores" -> obj.desc.maxCores) ~
("user" -> obj.desc.user) ~
("memoryperslave" -> obj.desc.memoryPerSlave) ~
("submitdate" -> obj.submitDate.toString)
}
def writeApplicationDescription(obj: ApplicationDescription) = {
("name" -> obj.name) ~
("cores" -> obj.maxCores) ~
("memoryperslave" -> obj.memoryPerSlave) ~
("user" -> obj.user)
}
def writeExecutorRunner(obj: ExecutorRunner) = {
("id" -> obj.execId) ~
("memory" -> obj.memory) ~
("appid" -> obj.appId) ~
("appdesc" -> writeApplicationDescription(obj.appDesc))
}
def writeMasterState(obj: MasterStateResponse) = {
("url" -> ("spark://" + obj.uri)) ~
("workers" -> obj.workers.toList.map(writeWorkerInfo)) ~
("cores" -> obj.workers.map(_.cores).sum) ~
("coresused" -> obj.workers.map(_.coresUsed).sum) ~
("memory" -> obj.workers.map(_.memory).sum) ~
("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~
("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~
("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo))
}
def writeWorkerState(obj: WorkerStateResponse) = {
("id" -> obj.workerId) ~
("masterurl" -> obj.masterUrl) ~
("masterwebuiurl" -> obj.masterWebUiUrl) ~
("cores" -> obj.cores) ~
("coresused" -> obj.coresUsed) ~
("memory" -> obj.memory) ~
("memoryused" -> obj.memoryUsed) ~
("executors" -> obj.executors.toList.map(writeExecutorRunner)) ~
("finishedexecutors" -> obj.finishedExecutors.toList.map(writeExecutorRunner))
}
}

View file

@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy
import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated}
import org.apache.spark.deploy.worker.Worker
import org.apache.spark.deploy.master.Master
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.{Logging}
import scala.collection.mutable.ArrayBuffer
/**
* Testing class that creates a Spark standalone process in-cluster (that is, running the
* spark.deploy.master.Master and spark.deploy.worker.Workers in the same JVMs). Executors launched
* by the Workers still run in separate JVMs. This can be used to test distributed operation and
* fault recovery without spinning up a lot of processes.
*/
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
private val localHostname = Utils.localHostName()
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
def start(): String = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
val (masterSystem, masterPort) = Master.startSystemAndActor(localHostname, 0, 0)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localHostname + ":" + masterPort
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
memoryPerWorker, masterUrl, null, Some(workerNum))
workerActorSystems += workerSystem
}
return masterUrl
}
def stop() {
logInfo("Shutting down local Spark cluster.")
// Stop the workers before the master so they don't get upset that it disconnected
workerActorSystems.foreach(_.shutdown())
workerActorSystems.foreach(_.awaitTermination())
masterActorSystems.foreach(_.shutdown())
masterActorSystems.foreach(_.awaitTermination())
}
}

View file

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
/**
* Contains util methods to interact with Hadoop from spark.
*/
class SparkHadoopUtil {
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop subsystems
def newConfiguration(): Configuration = new Configuration()
// add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
}

Some files were not shown because too many files have changed in this diff Show more