Merge branch 'scala210-master' of github.com:colorant/incubator-spark into scala-2.10

Conflicts:
	core/src/main/scala/org/apache/spark/deploy/client/Client.scala
	core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
	core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
	core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
This commit is contained in:
Prashant Sharma 2013-11-21 11:55:48 +05:30
commit 199e9cf02d
268 changed files with 11616 additions and 5447 deletions

View file

@ -68,7 +68,7 @@ described below.
When developing a Spark application, specify the Hadoop version by adding the
"hadoop-client" artifact to your project's dependencies. For example, if you're
using Hadoop 1.0.1 and build your application using SBT, add this entry to
using Hadoop 1.2.1 and build your application using SBT, add this entry to
`libraryDependencies`:
"org.apache.hadoop" % "hadoop-client" % "1.2.1"

View file

@ -32,12 +32,26 @@ fi
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf"
if [ -f "$FWDIR/RELEASE" ]; then
ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
# First check if we have a dependencies jar. If so, include binary classes with the deps jar
if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar ]; then
CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar`
CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR"
else
ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
# Else use spark-assembly jar from either RELEASE or assembly directory
if [ -f "$FWDIR/RELEASE" ]; then
ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
else
ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
fi
CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
fi
CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
if [[ $SPARK_TESTING == 1 ]]; then

View file

@ -28,7 +28,7 @@
# SPARK_SSH_OPTS Options passed to ssh when running remote commands.
##
usage="Usage: slaves.sh [--config confdir] command..."
usage="Usage: slaves.sh [--config <conf-dir>] command..."
# if no args specified, show usage
if [ $# -le 0 ]; then
@ -46,6 +46,23 @@ bin=`cd "$bin"; pwd`
# spark-env.sh. Save it here.
HOSTLIST=$SPARK_SLAVES
# Check if --config is passed as an argument. It is an optional parameter.
# Exit if the argument is not a directory.
if [ "$1" == "--config" ]
then
shift
conf_dir=$1
if [ ! -d "$conf_dir" ]
then
echo "ERROR : $conf_dir is not a directory"
echo $usage
exit 1
else
export SPARK_CONF_DIR=$conf_dir
fi
shift
fi
if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then
. "${SPARK_CONF_DIR}/spark-env.sh"
fi

View file

@ -29,7 +29,7 @@
# SPARK_NICENESS The scheduling priority for daemons. Defaults to 0.
##
usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <spark-instance-number> <args...>"
usage="Usage: spark-daemon.sh [--config <conf-dir>] (start|stop) <spark-command> <spark-instance-number> <args...>"
# if no args specified, show usage
if [ $# -le 1 ]; then
@ -43,6 +43,25 @@ bin=`cd "$bin"; pwd`
. "$bin/spark-config.sh"
# get arguments
# Check if --config is passed as an argument. It is an optional parameter.
# Exit if the argument is not a directory.
if [ "$1" == "--config" ]
then
shift
conf_dir=$1
if [ ! -d "$conf_dir" ]
then
echo "ERROR : $conf_dir is not a directory"
echo $usage
exit 1
else
export SPARK_CONF_DIR=$conf_dir
fi
shift
fi
startStop=$1
shift
command=$1

View file

@ -19,7 +19,7 @@
# Run a Spark command on all slave hosts.
usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command instance-number args..."
usage="Usage: spark-daemons.sh [--config <conf-dir>] [start|stop] command instance-number args..."
# if no args specified, show usage
if [ $# -le 1 ]; then

View file

@ -17,8 +17,6 @@
# limitations under the License.
#
# Starts the master on the machine this script is executed on.
bin=`dirname "$0"`
bin=`cd "$bin"; pwd`

View file

@ -48,6 +48,10 @@
<groupId>org.apache.avro</groupId>
<artifactId>avro-ipc</artifactId>
</dependency>
<dependency>
<groupId>org.apache.zookeeper</groupId>
<artifactId>zookeeper</artifactId>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
<artifactId>jetty-server</artifactId>

View file

@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundByteHandlerAdapter;
import org.apache.spark.storage.BlockId;
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
@ -33,7 +34,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
}
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
public abstract void handleError(String blockId);
public abstract void handleError(BlockId blockId);
@Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {

View file

@ -24,6 +24,8 @@ import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.FileSegment;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@ -34,41 +36,36 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
}
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockId) {
String path = pResolver.getAbsolutePath(blockId);
// if getFilePath returns null, close the channel
if (path == null) {
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
BlockId blockId = BlockId.apply(blockIdString);
FileSegment fileSegment = pResolver.getBlockLocation(blockId);
// if getBlockLocation returns null, close the channel
if (fileSegment == null) {
//ctx.close();
return;
}
File file = new File(path);
File file = fileSegment.file();
if (file.exists()) {
if (!file.isFile()) {
//logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
long length = file.length();
long length = fileSegment.length();
if (length > Integer.MAX_VALUE || length <= 0) {
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
//logger.info("Sending block "+blockId+" filelen = "+len);
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
.getChannel(), 0, file.length()));
.getChannel(), fileSegment.offset(), fileSegment.length()));
} catch (Exception e) {
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
//logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();

View file

@ -17,13 +17,10 @@
package org.apache.spark.network.netty;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.FileSegment;
public interface PathResolver {
/**
* Get the absolute path of the file
*
* @param fileId
* @return the absolute path of file
*/
public String getAbsolutePath(String fileId);
/** Get the file segment in which the given block resides. */
public FileSegment getBlockLocation(BlockId blockId);
}

View file

@ -17,20 +17,29 @@
package org.apache.hadoop.mapred
private[apache]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext");
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID])
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
"org.apache.hadoop.mapred.JobContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf],
classOf[org.apache.hadoop.mapreduce.JobID])
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = {
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext")
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
"org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
def newTaskAttemptID(
jtIdentifier: String,
jobId: Int,
isMap: Boolean,
taskId: Int,
attemptId: Int) = {
new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId)
}

View file

@ -17,9 +17,10 @@
package org.apache.hadoop.mapreduce
import org.apache.hadoop.conf.Configuration
import java.lang.{Integer => JInteger, Boolean => JBoolean}
import org.apache.hadoop.conf.Configuration
private[apache]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil {
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID");
def newTaskAttemptID(
jtIdentifier: String,
jobId: Int,
isMap: Boolean,
taskId: Int,
attemptId: Int) = {
val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
// first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN)
// First, attempt to use the old-style constructor that takes a boolean isMap
// (not available in YARN)
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
classOf[Int], classOf[Int])
ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new
JInteger(attemptId)).asInstanceOf[TaskAttemptID]
classOf[Int], classOf[Int])
ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
} catch {
case exc: NoSuchMethodException => {
// failed, look for the new ctor that takes a TaskType (not available in 1.x)
val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]]
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE")
// If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
.asInstanceOf[Class[Enum[_]]]
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
taskTypeClass, if(isMap) "MAP" else "REDUCE")
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
classOf[Int], classOf[Int])
ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new
JInteger(attemptId)).asInstanceOf[TaskAttemptID]
ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
}
}
}

View file

@ -17,43 +17,42 @@
package org.apache.spark
import java.util.{HashMap => JHashMap}
import org.apache.spark.util.AppendOnlyMap
import scala.collection.JavaConversions._
/** A set of functions used to aggregate data.
*
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
*/
/**
* A set of functions used to aggregate data.
*
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
*/
case class Aggregator[K, V, C] (
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
for (kv <- iter) {
val oldC = combiners.get(kv._1)
if (oldC == null) {
combiners.put(kv._1, createCombiner(kv._2))
} else {
combiners.put(kv._1, mergeValue(oldC, kv._2))
}
val combiners = new AppendOnlyMap[K, C]
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
}
while (iter.hasNext) {
kv = iter.next()
combiners.changeValue(kv._1, update)
}
combiners.iterator
}
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
iter.foreach { case(k, c) =>
val oldC = combiners.get(k)
if (oldC == null) {
combiners.put(k, c)
} else {
combiners.put(k, mergeCombiners(oldC, c))
}
val combiners = new AppendOnlyMap[K, C]
var kc: (K, C) = null
val update = (hadValue: Boolean, oldValue: C) => {
if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
}
while (iter.hasNext) {
kc = iter.next()
combiners.changeValue(kc._1, update)
}
combiners.iterator
}

View file

@ -22,13 +22,17 @@ import scala.collection.mutable.HashMap
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
override def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
{
@ -45,12 +49,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
}
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
case (address, splits) =>
(address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = {
def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
@ -58,9 +62,8 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
block.asInstanceOf[Iterator[T]]
}
case None => {
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
blockId match {
case regex(shufId, mapId, _) =>
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
case _ =>
@ -74,7 +77,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)
CompletionIterator[T, Iterator[T]](itr, {
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
@ -83,7 +86,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
metrics.shuffleReadMetrics = Some(shuffleMetrics)
context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
})
new InterruptibleIterator[T](context, completionIter)
}
}

View file

@ -18,7 +18,7 @@
package org.apache.spark
import scala.collection.mutable.{ArrayBuffer, HashSet}
import org.apache.spark.storage.{BlockManager, StorageLevel}
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId}
import org.apache.spark.rdd.RDD
@ -28,17 +28,17 @@ import org.apache.spark.rdd.RDD
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
/** Keys of RDD splits that are being computed/loaded. */
private val loading = new HashSet[String]()
private val loading = new HashSet[RDDBlockId]()
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
: Iterator[T] = {
val key = "rdd_%d_%d".format(rdd.id, split.index)
val key = RDDBlockId(rdd.id, split.index)
logDebug("Looking for partition " + key)
blockManager.get(key) match {
case Some(values) =>
// Partition is already materialized, so just return its values
return values.asInstanceOf[Iterator[T]]
return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
// Mark the split as loading (unless someone else marks it first)
@ -56,7 +56,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
// downside of the current code is that threads wait serially if this does happen.
blockManager.get(key) match {
case Some(values) =>
return values.asInstanceOf[Iterator[T]]
return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
case None =>
logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key))
loading.add(key)
@ -73,7 +73,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
if (context.runningLocally) { return computedValues }
val elements = new ArrayBuffer[Any]
elements ++= computedValues
blockManager.put(key, elements, storageLevel, true)
blockManager.put(key, elements, storageLevel, tellMaster = true)
return elements.iterator.asInstanceOf[Iterator[T]]
} finally {
loading.synchronized {

View file

@ -0,0 +1,250 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
import scala.concurrent._
import scala.concurrent.duration.Duration
import scala.util.Try
import org.apache.spark.scheduler.{JobSucceeded, JobWaiter}
import org.apache.spark.scheduler.JobFailed
import org.apache.spark.rdd.RDD
/**
* A future for the result of an action. This is an extension of the Scala Future interface to
* support cancellation.
*/
trait FutureAction[T] extends Future[T] {
// Note that we redefine methods of the Future trait here explicitly so we can specify a different
// documentation (with reference to the word "action").
/**
* Cancels the execution of this action.
*/
def cancel()
/**
* Blocks until this action completes.
* @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
* for unbounded waiting, or a finite positive duration
* @return this FutureAction
*/
override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type
/**
* Awaits and returns the result (of type T) of this action.
* @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
* for unbounded waiting, or a finite positive duration
* @throws Exception exception during action execution
* @return the result value if the action is completed within the specific maximum wait time
*/
@throws(classOf[Exception])
override def result(atMost: Duration)(implicit permit: CanAwait): T
/**
* When this action is completed, either through an exception, or a value, applies the provided
* function.
*/
def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext)
/**
* Returns whether the action has already been completed with a value or an exception.
*/
override def isCompleted: Boolean
/**
* The value of this Future.
*
* If the future is not completed the returned value will be None. If the future is completed
* the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if
* it contains an exception.
*/
override def value: Option[Try[T]]
/**
* Blocks and returns the result of this job.
*/
@throws(classOf[Exception])
def get(): T = Await.result(this, Duration.Inf)
}
/**
* The future holding the result of an action that triggers a single job. Examples include
* count, collect, reduce.
*/
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
extends FutureAction[T] {
override def cancel() {
jobWaiter.cancel()
}
override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
if (!atMost.isFinite()) {
awaitResult()
} else {
val finishTime = System.currentTimeMillis() + atMost.toMillis
while (!isCompleted) {
val time = System.currentTimeMillis()
if (time >= finishTime) {
throw new TimeoutException
} else {
jobWaiter.wait(finishTime - time)
}
}
}
this
}
@throws(classOf[Exception])
override def result(atMost: Duration)(implicit permit: CanAwait): T = {
ready(atMost)(permit)
awaitResult() match {
case scala.util.Success(res) => res
case scala.util.Failure(e) => throw e
}
}
override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
executor.execute(new Runnable {
override def run() {
func(awaitResult())
}
})
}
override def isCompleted: Boolean = jobWaiter.jobFinished
override def value: Option[Try[T]] = {
if (jobWaiter.jobFinished) {
Some(awaitResult())
} else {
None
}
}
private def awaitResult(): Try[T] = {
jobWaiter.awaitResult() match {
case JobSucceeded => scala.util.Success(resultFunc)
case JobFailed(e: Exception, _) => scala.util.Failure(e)
}
}
}
/**
* A FutureAction for actions that could trigger multiple Spark jobs. Examples include take,
* takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
* action thread if it is being blocked by a job.
*/
class ComplexFutureAction[T] extends FutureAction[T] {
// Pointer to the thread that is executing the action. It is set when the action is run.
@volatile private var thread: Thread = _
// A flag indicating whether the future has been cancelled. This is used in case the future
// is cancelled before the action was even run (and thus we have no thread to interrupt).
@volatile private var _cancelled: Boolean = false
// A promise used to signal the future.
private val p = promise[T]()
override def cancel(): Unit = this.synchronized {
_cancelled = true
if (thread != null) {
thread.interrupt()
}
}
/**
* Executes some action enclosed in the closure. To properly enable cancellation, the closure
* should use runJob implementation in this promise. See takeAsync for example.
*/
def run(func: => T)(implicit executor: ExecutionContext): this.type = {
scala.concurrent.future {
thread = Thread.currentThread
try {
p.success(func)
} catch {
case e: Exception => p.failure(e)
} finally {
thread = null
}
}
this
}
/**
* Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
* to enable cancellation.
*/
def runJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitions: Seq[Int],
resultHandler: (Int, U) => Unit,
resultFunc: => R) {
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
// command need to be in an atomic block.
val job = this.synchronized {
if (!cancelled) {
rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
} else {
throw new SparkException("Action has been cancelled")
}
}
// Wait for the job to complete. If the action is cancelled (with an interrupt),
// cancel the job and stop the execution. This is not in a synchronized block because
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
try {
Await.ready(job, Duration.Inf)
} catch {
case e: InterruptedException =>
job.cancel()
throw new SparkException("Action has been cancelled")
}
}
/**
* Returns whether the promise has been cancelled.
*/
def cancelled: Boolean = _cancelled
@throws(classOf[InterruptedException])
@throws(classOf[scala.concurrent.TimeoutException])
override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = {
p.future.ready(atMost)(permit)
this
}
@throws(classOf[Exception])
override def result(atMost: Duration)(implicit permit: CanAwait): T = {
p.future.result(atMost)(permit)
}
override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = {
p.future.onComplete(func)(executor)
}
override def isCompleted: Boolean = p.isCompleted
override def value: Option[Try[T]] = p.future.value
}

View file

@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark
/**
* An iterator that wraps around an existing iterator to provide task killing functionality.
* It works by checking the interrupted flag in TaskContext.
*/
class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
extends Iterator[T] {
def hasNext: Boolean = !context.interrupted && delegate.hasNext
def next(): T = delegate.next()
}

View file

@ -20,7 +20,6 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
@ -34,7 +33,7 @@ import scala.concurrent.duration._
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashMap}
import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashMap}
private[spark] sealed trait MapOutputTrackerMessage
@ -42,11 +41,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
extends Actor with Logging {
def receive = {
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
sender ! tracker.getSerializedLocations(shuffleId)
sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@ -62,22 +62,19 @@ private[spark] class MapOutputTracker extends Logging {
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: Either[ActorRef, ActorSelection] = _
private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var epoch: Long = 0
private val epochLock = new java.lang.Object
protected var epoch: Long = 0
protected val epochLock = new java.lang.Object
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
private def askTracker(message: Any): Any = {
try {
val future = if (trackerActor.isLeft ) {
trackerActor.left.get.ask(message)(timeout)
@ -92,50 +89,12 @@ private[spark] class MapOutputTracker extends Logging {
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
def communicate(message: Any) {
private def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
}
def registerMapOutputs(
shuffleId: Int,
statuses: Array[MapStatus],
changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
incrementEpoch()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
var array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
}
incrementEpoch()
} else {
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
@ -174,7 +133,7 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
fetchedStatuses = deserializeStatuses(fetchedBytes)
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
@ -200,9 +159,8 @@ private[spark] class MapOutputTracker extends Logging {
}
}
private def cleanup(cleanupTime: Long) {
protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
@ -212,15 +170,7 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}
// Called on master to increment the epoch number
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
logDebug("Increasing epoch to " + epoch)
}
}
// Called on master or workers to get current epoch number
// Called to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
@ -234,14 +184,62 @@ private[spark] class MapOutputTracker extends Logging {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
mapStatuses.clear()
epoch = newEpoch
mapStatuses.clear()
}
}
}
}
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
// Cache a serialized version of the output statuses for each shuffle to send them out faster
private var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
val array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
}
def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
incrementEpoch()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
val arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
val array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
}
incrementEpoch()
} else {
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
logDebug("Increasing epoch to " + epoch)
}
}
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
@ -259,7 +257,7 @@ private[spark] class MapOutputTracker extends Logging {
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeStatuses(statuses)
val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
@ -267,13 +265,31 @@ private[spark] class MapOutputTracker extends Logging {
cachedSerializedStatuses(shuffleId) = bytes
}
}
return bytes
bytes
}
protected override def cleanup(cleanupTime: Long) {
super.cleanup(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
override def stop() {
super.stop()
cachedSerializedStatuses.clear()
}
override def updateEpoch(newEpoch: Long) {
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
}
}
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
@ -284,18 +300,11 @@ private[spark] class MapOutputTracker extends Logging {
out.toByteArray
}
// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
// Opposite of serializeMapStatuses.
def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
objIn.readObject().
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
// comment this out - nulls could be due to missing location ?
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
objIn.readObject().asInstanceOf[Array[MapStatus]]
}
}
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),

View file

@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher {
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
/** Stop the fetcher */

View file

@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.reflect.{ ClassTag, classTag}
@ -53,21 +53,19 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
import org.apache.mesos.MesosNativeLibrary
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.LocalSparkCluster
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend,
ClusterScheduler}
import org.apache.spark.scheduler.local.LocalScheduler
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.storage.{StorageUtils, BlockManagerSource}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ClosureCleaner, Utils, MetadataCleaner, TimeStampedHashMap}
import org.apache.spark.scheduler.local.LocalScheduler
import org.apache.spark.scheduler.StageInfo
import org.apache.spark.storage.RDDInfo
import org.apache.spark.storage.StorageStatus
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
TimeStampedHashMap, Utils}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@ -121,9 +119,9 @@ class SparkContext(
// Keeps track of all persisted RDDs
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup)
// Initalize the Spark UI
// Initialize the Spark UI
private[spark] val ui = new SparkUI(this)
ui.bind()
@ -149,6 +147,14 @@ class SparkContext(
executorEnvs ++= environment
}
// Set SPARK_USER for user who is running SparkContext.
val sparkUser = Option {
Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER"))
}.getOrElse {
SparkContext.SPARK_UNKNOWN_USER
}
executorEnvs("SPARK_USER") = sparkUser
// Create and start the scheduler
private[spark] var taskScheduler: TaskScheduler = {
// Regular expression used for local[N] master format
@ -158,9 +164,11 @@ class SparkContext(
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """(spark://.*)""".r
//Regular expression for connection to Mesos cluster
val MESOS_REGEX = """(mesos://.*)""".r
val SPARK_REGEX = """spark://(.*)""".r
// Regular expression for connection to Mesos cluster
val MESOS_REGEX = """mesos://(.*)""".r
// Regular expression for connection to Simr cluster
val SIMR_REGEX = """simr://(.*)""".r
master match {
case "local" =>
@ -174,7 +182,14 @@ class SparkContext(
case SPARK_REGEX(sparkUrl) =>
val scheduler = new ClusterScheduler(this)
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
val masterUrls = sparkUrl.split(",").map("spark://" + _)
val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
scheduler.initialize(backend)
scheduler
case SIMR_REGEX(simrUrl) =>
val scheduler = new ClusterScheduler(this)
val backend = new SimrSchedulerBackend(scheduler, this, simrUrl)
scheduler.initialize(backend)
scheduler
@ -190,8 +205,8 @@ class SparkContext(
val scheduler = new ClusterScheduler(this)
val localCluster = new LocalSparkCluster(
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
val sparkUrl = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
val masterUrls = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
scheduler.initialize(backend)
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
@ -210,38 +225,36 @@ class SparkContext(
throw new SparkException("YARN mode not available ?", th)
}
}
val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem)
scheduler.initialize(backend)
scheduler
case MESOS_REGEX(mesosUrl) =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
val backend = if (coarseGrained) {
new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
} else {
new MesosSchedulerBackend(scheduler, this, mesosUrl, appName)
}
scheduler.initialize(backend)
scheduler
case _ =>
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
}
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
} else {
new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
}
scheduler.initialize(backend)
scheduler
throw new SparkException("Could not parse Master URL: '" + master + "'")
}
}
taskScheduler.start()
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
dagScheduler.start()
ui.start()
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = {
val env = SparkEnv.get
val conf = env.hadoop.newConfiguration()
val conf = SparkHadoopUtil.get.newConfiguration()
// Explicitly check for S3 environment variables
if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
@ -251,8 +264,10 @@ class SparkContext(
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
}
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
Utils.getSystemProperties.foreach { case (key, value) =>
if (key.startsWith("spark.hadoop.")) {
conf.set(key.substring("spark.hadoop.".length), value)
}
}
val bufferSize = System.getProperty("spark.buffer.size", "65536")
conf.set("io.file.buffer.size", bufferSize)
@ -266,6 +281,12 @@ class SparkContext(
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
private[spark] def getLocalProperties(): Properties = localProperties.get()
private[spark] def setLocalProperties(props: Properties) {
localProperties.set(props)
}
def initLocalProperties() {
localProperties.set(new Properties())
}
@ -285,15 +306,46 @@ class SparkContext(
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
@deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
}
/**
* Assigns a group id to all the jobs started by this thread until the group id is set to a
* different value or cleared.
*
* Often, a unit of execution in an application consists of multiple Spark actions or jobs.
* Application programmers can use this method to group all those jobs together and give a
* group description. Once set, the Spark web UI will associate such jobs with this group.
*
* The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all
* running jobs in this group. For example,
* {{{
* // In the main thread:
* sc.setJobGroup("some_job_to_cancel", "some job description")
* sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
*
* // In a separate thread:
* sc.cancelJobGroup("some_job_to_cancel")
* }}}
*/
def setJobGroup(groupId: String, description: String) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
}
/** Clear the job group id and its description. */
def clearJobGroup() {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
}
// Post init
taskScheduler.postStartHook()
val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
@ -332,7 +384,7 @@ class SparkContext(
}
/**
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
* etc).
*/
@ -344,7 +396,7 @@ class SparkContext(
minSplits: Int = defaultMinSplits
): RDD[(K, V)] = {
// Add necessary security credentials to the JobConf before broadcasting it.
SparkEnv.get.hadoop.addCredentials(conf)
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
}
@ -358,24 +410,15 @@ class SparkContext(
): RDD[(K, V)] = {
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
hadoopFile(path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits)
}
/**
* Get an RDD for a Hadoop file with an arbitray InputFormat. Accept a Hadoop Configuration
* that has already been broadcast, assuming that it's safe to use it to construct a
* HadoopFileRDD (i.e., except for file 'path', all other configuration properties can be resued).
*/
def hadoopFile[K, V](
path: String,
confBroadcast: Broadcast[SerializableWritable[Configuration]],
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int
): RDD[(K, V)] = {
new HadoopFileRDD(
this, path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits)
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
new HadoopRDD(
this,
confBroadcast,
Some(setInputPathsFunc),
inputFormatClass,
keyClass,
valueClass,
minSplits)
}
/**
@ -563,7 +606,8 @@ class SparkContext(
val uri = new URI(path)
val key = uri.getScheme match {
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
case _ => path
case "local" => "file:" + uri.getPath
case _ => path
}
addedFiles(key) = System.currentTimeMillis
@ -657,12 +701,11 @@ class SparkContext(
/**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
* filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node.
*/
def addJar(path: String) {
if (path == null) {
logWarning("null specified as parameter to addJar",
new SparkException("null specified as parameter to addJar"))
logWarning("null specified as parameter to addJar")
} else {
var key = ""
if (path.contains("\\")) {
@ -671,12 +714,27 @@ class SparkContext(
} else {
val uri = new URI(path)
key = uri.getScheme match {
// A JAR file which exists only on the driver node
case null | "file" =>
if (env.hadoop.isYarnMode()) {
logWarning("local jar specified as parameter to addJar under Yarn mode")
return
if (SparkHadoopUtil.get.isYarnMode()) {
// In order for this to work on yarn the user must specify the --addjars option to
// the client to upload the file into the distributed cache to make it show up in the
// current working directory.
val fileName = new Path(uri.getPath).getName()
try {
env.httpFileServer.addJar(new File(fileName))
} catch {
case e: Exception => {
logError("Error adding jar (" + e + "), was the --addJars option used?")
throw e
}
}
} else {
env.httpFileServer.addJar(new File(uri.getPath))
}
env.httpFileServer.addJar(new File(uri.getPath))
// A JAR file which exists locally on every worker node
case "local" =>
"file:" + uri.getPath
case _ =>
path
}
@ -750,13 +808,13 @@ class SparkContext(
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.formatSparkCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,
localProperties.get)
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
resultHandler, localProperties.get)
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
rdd.doCheckpoint()
result
}
/**
@ -842,6 +900,42 @@ class SparkContext(
result
}
/**
* Submit a job for execution and return a FutureJob holding the result.
*/
def submitJob[T, U, R](
rdd: RDD[T],
processPartition: Iterator[T] => U,
partitions: Seq[Int],
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
val cleanF = clean(processPartition)
val callSite = Utils.formatSparkCallSite
val waiter = dagScheduler.submitJob(
rdd,
(context: TaskContext, iter: Iterator[T]) => cleanF(iter),
partitions,
callSite,
allowLocal = false,
resultHandler,
localProperties.get)
new SimpleFutureAction(waiter, resultFunc)
}
/**
* Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]]
* for more information.
*/
def cancelJobGroup(groupId: String) {
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
dagScheduler.cancelAllJobs()
}
/**
* Clean a closure to make it ready to serialized and send to tasks
* (removes unreferenced variables in $outer's, updates REPL variables)
@ -859,9 +953,8 @@ class SparkContext(
* prevent accidental overriding of checkpoint files in the existing directory.
*/
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
val env = SparkEnv.get
val path = new Path(dir)
val fs = path.getFileSystem(env.hadoop.newConfiguration())
val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
if (!useExisting) {
if (fs.exists(path)) {
throw new Exception("Checkpoint directory '" + path + "' already exists.")
@ -898,7 +991,12 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
val SPARK_JOB_DESCRIPTION = "spark.job.description"
private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
private[spark] val SPARK_UNKNOWN_USER = "<unknown>"
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
@ -925,6 +1023,8 @@ object SparkContext {
implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)

View file

@ -25,13 +25,13 @@ import akka.remote.RemoteActorRefProvider
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster}
import org.apache.spark.network.ConnectionManager
import org.apache.spark.serializer.{Serializer, SerializerManager}
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.api.python.PythonWorkerFactory
import com.google.common.collect.MapMaker
/**
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
@ -58,18 +58,9 @@ class SparkEnv (
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
val hadoop = {
val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if(yarnMode) {
try {
Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
} catch {
case th: Throwable => throw new SparkException("Unable to load YARN support", th)
}
} else {
new SparkHadoopUtil
}
}
// A general, soft-reference map for metadata needed during HadoopRDD split computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
def stop() {
pythonWorkers.foreach { case(key, worker) => worker.stop() }
@ -188,10 +179,14 @@ object SparkEnv extends Logging {
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
val mapOutputTracker = new MapOutputTracker()
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster()
} else {
new MapOutputTracker()
}
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerActor(mapOutputTracker))
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")

View file

@ -17,14 +17,14 @@
package org.apache.hadoop.mapred
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import java.io.IOException
import java.text.SimpleDateFormat
import java.text.NumberFormat
import java.io.IOException
import java.util.Date
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
import org.apache.spark.SerializableWritable
@ -36,7 +36,11 @@ import org.apache.spark.SerializableWritable
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable {
private[apache]
class SparkHadoopWriter(@transient jobConf: JobConf)
extends Logging
with SparkHadoopMapRedUtil
with Serializable {
private val now = new Date()
private val conf = new SerializableWritable(jobConf)
@ -83,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
}
getOutputCommitter().setupTask(getTaskContext())
writer = getOutputFormat().getRecordWriter(
fs, conf.value, outputName, Reporter.NULL)
writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
}
def write(key: AnyRef, value: AnyRef) {
if (writer!=null) {
//println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
if (writer != null) {
writer.write(key, value)
} else {
throw new IOException("Writer is null, open() has not been called")
@ -179,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
}
}
private[apache]
object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")

View file

@ -17,21 +17,30 @@
package org.apache.spark
import executor.TaskMetrics
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.executor.TaskMetrics
class TaskContext(
val stageId: Int,
val splitId: Int,
val partitionId: Int,
val attemptId: Long,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty()
@volatile var interrupted: Boolean = false,
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
) extends Serializable {
@transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
@deprecated("use partitionId", "0.8.1")
def splitId = partitionId
// Add a callback function to be executed on task completion. An example use
// is for HadoopRDD to register a callback to close the input stream.
// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
/**
* Add a callback function to be executed on task completion. An example use
* is for HadoopRDD to register a callback to close the input stream.
* @param f Callback function.
*/
def addOnCompleteCallback(f: () => Unit) {
onCompleteCallbacks += f
}

View file

@ -52,4 +52,6 @@ private[spark] case class ExceptionFailure(
*/
private[spark] case object TaskResultLost extends TaskEndReason
private[spark] case object TaskKilled extends TaskEndReason
private[spark] case class OtherFailure(message: String) extends TaskEndReason

View file

@ -51,6 +51,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
*/
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
* This method blocks until all blocks are deleted.
*/
def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist())
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*
* @param blocking Whether to block until all blocks are deleted.
*/
def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking))
// first() has to be overriden here in order for its return type to be Double instead of Object.
override def first(): Double = srdd.first()
@ -83,6 +96,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD =
fromRDD(srdd.coalesce(numPartitions, shuffle))
/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*

View file

@ -66,6 +66,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
* This method blocks until all blocks are deleted.
*/
def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist())
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*
* @param blocking Whether to block until all blocks are deleted.
*/
def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking))
// Transformations (return a new RDD)
/**
@ -95,6 +108,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] =
fromRDD(rdd.coalesce(numPartitions, shuffle))
/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions))
/**
* Return a sampled subset of this RDD.
*/
@ -599,4 +623,15 @@ object JavaPairRDD {
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
/** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
implicit val cmk: ClassTag[K] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
implicit val cmv: ClassTag[V] =
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
new JavaPairRDD[K, V](rdd.rdd)
}
}

View file

@ -43,9 +43,17 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
* This method blocks until all blocks are deleted.
*/
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*
* @param blocking Whether to block until all blocks are deleted.
*/
def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking))
// Transformations (return a new RDD)
/**
@ -75,6 +83,17 @@ JavaRDDLike[T, JavaRDD[T]] {
def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] =
rdd.coalesce(numPartitions, shuffle)
/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions)
/**
* Return a sampled subset of this RDD.
*/

View file

@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
@ -27,11 +25,7 @@ import java.io.Serializable;
*/
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>>
implements Serializable {
public abstract Iterable<Double> call(T t);
@Override
public final Iterable<Double> apply(T t) { return call(t); }
// Intentionally left blank
}

View file

@ -27,6 +27,5 @@ import java.io.Serializable;
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
implements Serializable {
public abstract Double call(T t) throws Exception;
// Intentionally left blank
}

View file

@ -23,8 +23,5 @@ import scala.reflect.ClassTag
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
@throws(classOf[Exception])
def call(x: T) : java.lang.Iterable[R]
def elementType() : ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]]
}

View file

@ -23,8 +23,5 @@ import scala.reflect.ClassTag
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
@throws(classOf[Exception])
def call(a: A, b:B) : java.lang.Iterable[C]
def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]]
}

View file

@ -29,10 +29,8 @@ import java.io.Serializable;
* when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
public abstract R call(T t) throws Exception;
public ClassTag<R> returnType() {
return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
return ClassTag$.MODULE$.apply(Object.class);
}
}

View file

@ -28,8 +28,6 @@ import java.io.Serializable;
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {
public abstract R call(T1 t1, T2 t2) throws Exception;
public ClassTag<R> returnType() {
return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
}

View file

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import scala.reflect.ClassTag;
import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction2;
import java.io.Serializable;
/**
* A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
*/
public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R>
implements Serializable {
public ClassTag<R> returnType() {
return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
}
}

View file

@ -33,8 +33,6 @@ public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
public ClassTag<K> keyType() {
return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}

View file

@ -28,12 +28,9 @@ import java.io.Serializable;
*/
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
public abstract class PairFunction<T, K, V>
extends WrappedFunction1<T, Tuple2<K, V>>
public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
public abstract Tuple2<K, V> call(T t) throws Exception;
public ClassTag<K> keyType() {
return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}

View file

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function
import scala.runtime.AbstractFunction3
/**
* Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the
* apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply
* isn't marked to allow that).
*/
private[spark] abstract class WrappedFunction3[T1, T2, T3, R]
extends AbstractFunction3[T1, T2, T3, R] {
@throws(classOf[Exception])
def call(t1: T1, t2: T2, t3: T3): R
final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3)
}

View file

@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")

View file

@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.{HttpServer, Logging, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashSet}
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId: String = "broadcast_" + id
def blockId = BroadcastBlockId(id)
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
@ -82,7 +81,7 @@ private object HttpBroadcast extends Logging {
private var server: HttpServer = null
private val files = new TimeStampedHashSet[String]
private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup)
private lazy val compressionCodec = CompressionCodec.createCodec()
@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging {
}
def write(id: Long, value: Any) {
val file = new File(broadcastDir, "broadcast-" + id)
val file = new File(broadcastDir, BroadcastBlockId(id).name)
val out: OutputStream = {
if (compress) {
compressionCodec.compressedOutputStream(new FileOutputStream(file))
@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging {
}
def read[T](id: Long): T = {
val url = serverUri + "/broadcast-" + id
val url = serverUri + "/" + BroadcastBlockId(id).name
val in = {
if (compress) {
compressionCodec.compressedInputStream(new URL(url).openStream())

View file

@ -1,410 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io._
import java.net._
import java.util.Random
import scala.collection.mutable.Map
import org.apache.spark._
import org.apache.spark.util.Utils
private object MultiTracker
extends Logging {
// Tracker Messages
val REGISTER_BROADCAST_TRACKER = 0
val UNREGISTER_BROADCAST_TRACKER = 1
val FIND_BROADCAST_TRACKER = 2
// Map to keep track of guides of ongoing broadcasts
var valueToGuideMap = Map[Long, SourceInfo]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var _isDriver = false
private var stopBroadcast = false
private var trackMV: TrackMultipleValues = null
def initialize(__isDriver: Boolean) {
synchronized {
if (!initialized) {
_isDriver = __isDriver
if (isDriver) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
// Set DriverHostAddress to the driver's IP address for the slaves to read
System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
}
initialized = true
}
}
}
def stop() {
stopBroadcast = true
}
// Load common parameters
private var DriverHostAddress_ = System.getProperty(
"spark.MultiTracker.DriverHostAddress", "")
private var DriverTrackerPort_ = System.getProperty(
"spark.broadcast.driverTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty(
"spark.broadcast.maxRetryCount", "2").toInt
private var TrackerSocketTimeout_ = System.getProperty(
"spark.broadcast.trackerSocketTimeout", "50000").toInt
private var ServerSocketTimeout_ = System.getProperty(
"spark.broadcast.serverSocketTimeout", "10000").toInt
private var MinKnockInterval_ = System.getProperty(
"spark.broadcast.minKnockInterval", "500").toInt
private var MaxKnockInterval_ = System.getProperty(
"spark.broadcast.maxKnockInterval", "999").toInt
// Load TreeBroadcast config params
private var MaxDegree_ = System.getProperty(
"spark.broadcast.maxDegree", "2").toInt
// Load BitTorrentBroadcast config params
private var MaxPeersInGuideResponse_ = System.getProperty(
"spark.broadcast.maxPeersInGuideResponse", "4").toInt
private var MaxChatSlots_ = System.getProperty(
"spark.broadcast.maxChatSlots", "4").toInt
private var MaxChatTime_ = System.getProperty(
"spark.broadcast.maxChatTime", "500").toInt
private var MaxChatBlocks_ = System.getProperty(
"spark.broadcast.maxChatBlocks", "1024").toInt
private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble
def isDriver = _isDriver
// Common config params
def DriverHostAddress = DriverHostAddress_
def DriverTrackerPort = DriverTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
// TreeBroadcast configs
def MaxDegree = MaxDegree_
// BitTorrentBroadcast configs
def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
def MaxChatSlots = MaxChatSlots_
def MaxChatTime = MaxChatTime_
def MaxChatBlocks = MaxChatBlocks_
def EndGameFraction = EndGameFraction_
class TrackMultipleValues
extends Thread with Logging {
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(DriverTrackerPort)
logInfo("TrackMultipleValues started at " + serverSocket)
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(TrackerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
if (stopBroadcast) {
logInfo("Stopping TrackMultipleValues...")
}
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run() {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
// First, read message type
val messageType = ois.readObject.asInstanceOf[Int]
if (messageType == REGISTER_BROADCAST_TRACKER) {
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Receive hostAddress and listenPort
val gInfo = ois.readObject.asInstanceOf[SourceInfo]
// Add to the map
valueToGuideMap.synchronized {
valueToGuideMap += (id -> gInfo)
}
logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Remove from the map
valueToGuideMap.synchronized {
valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
}
logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == FIND_BROADCAST_TRACKER) {
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
var gInfo =
if (valueToGuideMap.contains(id)) valueToGuideMap(id)
else SourceInfo("", SourceInfo.TxNotStartedRetry)
logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
// Send reply back
oos.writeObject(gInfo)
oos.flush()
} else {
throw new SparkException("Undefined messageType at TrackMultipleValues")
}
} catch {
case e: Exception => {
logError("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close socket here; else, client thread will close
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown()
}
}
def getGuideInfo(variableLong: Long): SourceInfo = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
var retriesLeft = MultiTracker.MaxRetryCount
do {
try {
// Connect to the tracker to find out GuideInfo
clientSocketToTracker =
new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send messageType/intention
oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
oosTracker.flush()
// Send Long and receive GuideInfo
oosTracker.writeObject(variableLong)
oosTracker.flush()
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
} catch {
case e: Exception => logError("getGuideInfo had a " + e)
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
Thread.sleep(MultiTracker.ranGen.nextInt(
MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
MultiTracker.MinKnockInterval)
retriesLeft -= 1
} while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
return gInfo
}
def registerBroadcast(id: Long, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
// Send messageType/intention
oosST.writeObject(REGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send Long of this broadcast
oosST.writeObject(id)
oosST.flush()
// Send this tracker's information
oosST.writeObject(gInfo)
oosST.flush()
// Receive ACK and throw it away
oisST.readObject.asInstanceOf[Int]
// Shut stuff down
oisST.close()
oosST.close()
socket.close()
}
def unregisterBroadcast(id: Long) {
val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
// Send messageType/intention
oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send Long of this broadcast
oosST.writeObject(id)
oosST.flush()
// Receive ACK and throw it away
oisST.readObject.asInstanceOf[Int]
// Shut stuff down
oisST.close()
oosST.close()
socket.close()
}
// Helper method to convert an object to Array[BroadcastBlock]
def blockifyObject[IN](obj: IN): VariableInfo = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream(baos)
oos.writeObject(obj)
oos.close()
baos.close()
val byteArray = baos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / BlockSize)
if (byteArray.length % BlockSize != 0)
blockNum += 1
var retVal = new Array[BroadcastBlock](blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, BlockSize)) {
val thisBlockSize = math.min(BlockSize, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()
var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
variableInfo.hasBlocks = blockNum
return variableInfo
}
// Helper method to convert Array[BroadcastBlock] to object
def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
totalBytes: Int,
totalBlocks: Int): OUT = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BlockSize, arrayOfBlocks(i).byteArray.length)
}
byteArrayToObject(retByteArray)
}
private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
}
val retVal = in.readObject.asInstanceOf[OUT]
in.close()
return retVal
}
}
private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
extends Serializable
private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient var hasBlocks = 0
}

View file

@ -1,54 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.util.BitSet
import org.apache.spark._
/**
* Used to keep and pass around information of peers involved in a broadcast
*/
private[spark] case class SourceInfo (hostAddress: String,
listenPort: Int,
totalBlocks: Int = SourceInfo.UnusedParam,
totalBytes: Int = SourceInfo.UnusedParam)
extends Comparable[SourceInfo] with Logging {
var currentLeechers = 0
var receptionFailed = false
var hasBlocks = 0
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
// Ascending sort based on leecher count
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
}
/**
* Helper Object of SourceInfo for its constants
*/
private[spark] object SourceInfo {
// Broadcast has not started yet! Should never happen.
val TxNotStartedRetry = -1
// Broadcast has already finished. Try default mechanism.
val TxOverGoToDefault = -3
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
}

View file

@ -0,0 +1,247 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io._
import scala.math
import scala.util.Random
import org.apache.spark._
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
import org.apache.spark.util.Utils
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def broadcastId = BroadcastBlockId(id)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[TorrentBlock] = null
@transient var totalBlocks = -1
@transient var totalBytes = -1
@transient var hasBlocks = 0
if (!isLocal) {
sendBroadcast()
}
def sendBroadcast() {
var tInfo = TorrentBroadcast.blockifyObject(value_)
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
hasBlocks = tInfo.totalBlocks
// Store meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
}
// Store individual pieces
for (i <- 0 until totalBlocks) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
}
}
}
// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(broadcastId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
val start = System.nanoTime
logInfo("Started reading broadcast variable " + id)
// Initialize @transient variables that will receive garbage values from the master.
resetWorkerVariables()
if (receiveBroadcast(id)) {
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
// Store the merged copy in cache so that the next worker doesn't need to rebuild it.
// This creates a tradeoff between memory usage and latency.
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
} else {
logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
private def resetWorkerVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
}
def receiveBroadcast(variableID: Long): Boolean = {
// Receive meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(metaId) match {
case Some(x) =>
val tInfo = x.asInstanceOf[TorrentInfo]
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
hasBlocks = 0
case None =>
Thread.sleep(500)
}
}
attemptId -= 1
}
if (totalBlocks == -1) {
return false
}
// Receive actual blocks
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
for (pid <- recvOrder) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
hasBlocks += 1
SparkEnv.get.blockManager.putSingle(
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
}
}
}
(hasBlocks == totalBlocks)
}
}
private object TorrentBroadcast
extends Logging {
private var initialized = false
def initialize(_isDriver: Boolean) {
synchronized {
if (!initialized) {
initialized = true
}
}
}
def stop() {
initialized = false
}
val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
def blockifyObject[T](obj: T): TorrentInfo = {
val byteArray = Utils.serialize[T](obj)
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / BLOCK_SIZE)
if (byteArray.length % BLOCK_SIZE != 0)
blockNum += 1
var retVal = new Array[TorrentBlock](blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()
var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
tInfo.hasBlocks = blockNum
return tInfo
}
def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): T = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
}
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
}
}
private[spark] case class TorrentBlock(
blockID: Int,
byteArray: Array[Byte])
extends Serializable
private[spark] case class TorrentInfo(
@transient arrayOfBlocks : Array[TorrentBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient var hasBlocks = 0
}
private[spark] class TorrentBroadcastFactory
extends BroadcastFactory {
def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
def stop() { TorrentBroadcast.stop() }
}

View file

@ -1,603 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io._
import java.net._
import java.util.{Comparator, Random, UUID}
import scala.collection.mutable.{ListBuffer, Map, Set}
import scala.math
import org.apache.spark._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId = "broadcast_" + id
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var listOfSources = ListBuffer[SourceInfo]()
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast()
}
def sendBroadcast() {
logInfo("Local host address: " + hostAddress)
// Create a variableInfo object and store it in valueInfos
var variableInfo = MultiTracker.blockifyObject(value_)
// Prepare the value being broadcasted
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
guideMR = new GuideMultipleRequests
guideMR.setDaemon(true)
guideMR.start()
logInfo("GuideMultipleRequests started...")
// Must always come AFTER guideMR is created
while (guidePort == -1) {
guidePortLock.synchronized { guidePortLock.wait() }
}
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
// Must always come AFTER serveMR is created
while (listenPort == -1) {
listenPortLock.synchronized { listenPortLock.wait() }
}
// Must always come AFTER listenPort is created
val masterSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
listOfSources += masterSource
// Register with the Tracker
MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables()
logInfo("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
private def initializeWorkerVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = Utils.localIpAddress
listenPort = -1
stopBroadcast = false
}
def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized { listenPortLock.wait() }
}
var clientSocketToDriver: Socket = null
var oosDriver: ObjectOutputStream = null
var oisDriver: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = MultiTracker.MaxRetryCount
do {
// Connect to Driver and send this worker's Information
clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
oosDriver.flush()
oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
logDebug("Connected to Driver's guiding object")
// Send local source information
oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
oosDriver.flush()
// Receive source information from Driver
var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = sourceInfo.totalBytes
logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time = (System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Driver will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Driver
oosDriver.writeObject(sourceInfo)
if (oisDriver != null) {
oisDriver.close()
}
if (oosDriver != null) {
oosDriver.close()
}
if (clientSocketToDriver != null) {
clientSocketToDriver.close()
}
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return (hasBlocks == totalBlocks)
}
/**
* Tries to receive broadcast from the source and returns Boolean status.
* This might be called multiple times to retry a defined number of times.
*/
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
oosSource.flush()
oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
logDebug("Inside receiveSingleTransmission")
logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
oosSource.flush()
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
}
} catch {
case e: Exception => logError("receiveSingleTransmission had a " + e)
} finally {
if (oisSource != null) {
oisSource.close()
}
if (oosSource != null) {
oosSource.close()
}
if (clientSocketToSource != null) {
clientSocketToSource.close()
}
}
return receptionSucceeded
}
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo]()
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool()
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized { guidePortLock.notifyAll() }
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
listOfSources.synchronized {
setOfCompletedSources.synchronized {
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
}
}
}
}
}
if (clientSocket != null) {
logDebug("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
// In failure, close() the socket here; else, the thread will close() it
case ioe: IOException => clientSocket.close()
}
}
}
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
MultiTracker.unregisterBroadcast(id)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
private def sendStopBroadcastNotifications() {
listOfSources.synchronized {
var listIter = listOfSources.iterator
while (listIter.hasNext) {
var sourceInfo = listIter.next
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
// Send stopBroadcast signal
gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
gosSource.flush()
} catch {
case e: Exception => {
logError("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
gisSource.close()
}
if (gosSource != null) {
gosSource.close()
}
if (guideSocketToSource != null) {
guideSocketToSource.close()
}
}
}
}
}
class GuideSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run() {
try {
logInfo("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource(sourceInfo)
logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject(selectedSourceInfo)
oos.flush()
// Add this new (if it can finish) source to the list of sources
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes)
logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
listOfSources += thisWorkerInfo
}
// Wait till the whole transfer is done. Then receive and update source
// statistics in listOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// This should work since SourceInfo is a case class
assert(listOfSources.contains(selectedSourceInfo))
// Remove first
// (Currently removing a source based on just one failure notification!)
listOfSources = listOfSources - selectedSourceInfo
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources.synchronized {
setOfCompletedSources += thisWorkerInfo
}
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
listOfSources += selectedSourceInfo
}
}
} catch {
case e: Exception => {
// Remove failed worker from listOfSources and update leecherCount of
// corresponding source worker
listOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
listOfSources = listOfSources - selectedSourceInfo
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
listOfSources += selectedSourceInfo
}
// Remove thisWorkerInfo
if (listOfSources != null) {
listOfSources = listOfSources - thisWorkerInfo
}
}
}
} finally {
logInfo("GuideSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
// Assuming the caller to have a synchronized block on listOfSources
// Select one with the most leechers. This will level-wise fill the tree
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
var maxLeechers = -1
var selectedSource: SourceInfo = null
listOfSources.foreach { source =>
if ((source.hostAddress != skipSourceInfo.hostAddress ||
source.listenPort != skipSourceInfo.listenPort) &&
source.currentLeechers < MultiTracker.MaxDegree &&
source.currentLeechers > maxLeechers) {
selectedSource = source
maxLeechers = source.currentLeechers
}
}
// Update leecher count
selectedSource.currentLeechers += 1
return selectedSource
}
}
}
class ServeMultipleRequests
extends Thread with Logging {
var threadPool = Utils.newDaemonCachedThreadPool()
override def run() {
var serverSocket = new ServerSocket(0)
listenPort = serverSocket.getLocalPort
logInfo("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized { listenPortLock.notifyAll() }
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => { }
}
if (clientSocket != null) {
logDebug("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
if (serverSocket != null) {
logInfo("ServeMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
class ServeSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run() {
try {
logInfo("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
// If not a valid range, stop broadcast
if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
sendObject
}
} catch {
case e: Exception => logError("ServeSingleRequest had a " + e)
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
private def sendObject() {
// Wait till receiving the SourceInfo from Driver
while (totalBlocks == -1) {
totalBlocksLock.synchronized { totalBlocksLock.wait() }
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized { hasBlocksLock.wait() }
}
try {
oos.writeObject(arrayOfBlocks(i))
oos.flush()
} catch {
case e: Exception => logError("sendObject had a " + e)
}
logDebug("Sent block: " + i + " to " + clientSocket)
}
}
}
}
}
private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
}

View file

@ -21,12 +21,14 @@ import scala.collection.immutable.List
import org.apache.spark.deploy.ExecutorState.ExecutorState
import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo}
import org.apache.spark.deploy.master.RecoveryState.MasterState
import org.apache.spark.deploy.worker.ExecutorRunner
import org.apache.spark.util.Utils
private[deploy] sealed trait DeployMessage extends Serializable
/** Contains messages sent between Scheduler actor nodes. */
private[deploy] object DeployMessages {
// Worker to Master
@ -52,17 +54,20 @@ private[deploy] object DeployMessages {
exitStatus: Option[Int])
extends DeployMessage
case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription])
case class Heartbeat(workerId: String) extends DeployMessage
// Master to Worker
case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage
case class RegisterWorkerFailed(message: String) extends DeployMessage
case class KillExecutor(appId: String, execId: Int) extends DeployMessage
case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage
case class LaunchExecutor(
masterUrl: String,
appId: String,
execId: Int,
appDesc: ApplicationDescription,
@ -76,9 +81,11 @@ private[deploy] object DeployMessages {
case class RegisterApplication(appDescription: ApplicationDescription)
extends DeployMessage
case class MasterChangeAcknowledged(appId: String)
// Master to Client
case class RegisteredApplication(appId: String) extends DeployMessage
case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage
// TODO(matei): replace hostPort with host
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
@ -94,6 +101,10 @@ private[deploy] object DeployMessages {
case object StopClient
// Master to Worker & Client
case class MasterChanged(masterUrl: String, masterWebUiUrl: String)
// MasterWebUI To Master
case object RequestMasterState
@ -101,7 +112,8 @@ private[deploy] object DeployMessages {
// Master to MasterWebUI
case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo],
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo],
status: MasterState) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
@ -123,12 +135,7 @@ private[deploy] object DeployMessages {
assert (port > 0)
}
// Actor System to Master
case object CheckForWorkerTimeOut
case object RequestWebUIPort
case class WebUIPortResponse(webUIBoundPort: Int)
// Actor System to Worker
case object SendHeartbeat
}

View file

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy
/**
* Used to send state on-the-wire about Executors from Worker to Master.
* This state is sufficient for the Master to reconstruct its internal data structures during
* failover.
*/
private[spark] class ExecutorDescription(
val appId: String,
val execId: Int,
val cores: Int,
val state: ExecutorState.Value)
extends Serializable {
override def toString: String =
"ExecutorState(appId=%s, execId=%d, cores=%d, state=%s)".format(appId, execId, cores, state)
}

View file

@ -0,0 +1,420 @@
/*
*
* * Licensed to the Apache Software Foundation (ASF) under one or more
* * contributor license agreements. See the NOTICE file distributed with
* * this work for additional information regarding copyright ownership.
* * The ASF licenses this file to You under the Apache License, Version 2.0
* * (the "License"); you may not use this file except in compliance with
* * the License. You may obtain a copy of the License at
* *
* * http://www.apache.org/licenses/LICENSE-2.0
* *
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS,
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* * See the License for the specific language governing permissions and
* * limitations under the License.
*
*/
package org.apache.spark.deploy
import java.io._
import java.net.URL
import java.util.concurrent.TimeoutException
import scala.concurrent.{Await, future, promise}
import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
import scala.collection.mutable.ListBuffer
import scala.sys.process._
import net.liftweb.json.JsonParser
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.deploy.master.RecoveryState
/**
* This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master.
* In order to mimic a real distributed cluster more closely, Docker is used.
* Execute using
* ./spark-class org.apache.spark.deploy.FaultToleranceTest
*
* Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS:
* - spark.deploy.recoveryMode=ZOOKEEPER
* - spark.deploy.zookeeper.url=172.17.42.1:2181
* Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port.
*
* Unfortunately, due to the Docker dependency this suite cannot be run automatically without a
* working installation of Docker. In addition to having Docker, the following are assumed:
* - Docker can run without sudo (see http://docs.docker.io/en/latest/use/basics/)
* - The docker images tagged spark-test-master and spark-test-worker are built from the
* docker/ directory. Run 'docker/spark-test/build' to generate these.
*/
private[spark] object FaultToleranceTest extends App with Logging {
val masters = ListBuffer[TestMasterInfo]()
val workers = ListBuffer[TestWorkerInfo]()
var sc: SparkContext = _
var numPassed = 0
var numFailed = 0
val sparkHome = System.getenv("SPARK_HOME")
assertTrue(sparkHome != null, "Run with a valid SPARK_HOME")
val containerSparkHome = "/opt/spark"
val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome)
System.setProperty("spark.driver.host", "172.17.42.1") // default docker host ip
def afterEach() {
if (sc != null) {
sc.stop()
sc = null
}
terminateCluster()
}
test("sanity-basic") {
addMasters(1)
addWorkers(1)
createClient()
assertValidClusterState()
}
test("sanity-many-masters") {
addMasters(3)
addWorkers(3)
createClient()
assertValidClusterState()
}
test("single-master-halt") {
addMasters(3)
addWorkers(2)
createClient()
assertValidClusterState()
killLeader()
delay(30 seconds)
assertValidClusterState()
createClient()
assertValidClusterState()
}
test("single-master-restart") {
addMasters(1)
addWorkers(2)
createClient()
assertValidClusterState()
killLeader()
addMasters(1)
delay(30 seconds)
assertValidClusterState()
killLeader()
addMasters(1)
delay(30 seconds)
assertValidClusterState()
}
test("cluster-failure") {
addMasters(2)
addWorkers(2)
createClient()
assertValidClusterState()
terminateCluster()
addMasters(2)
addWorkers(2)
assertValidClusterState()
}
test("all-but-standby-failure") {
addMasters(2)
addWorkers(2)
createClient()
assertValidClusterState()
killLeader()
workers.foreach(_.kill())
workers.clear()
delay(30 seconds)
addWorkers(2)
assertValidClusterState()
}
test("rolling-outage") {
addMasters(1)
delay()
addMasters(1)
delay()
addMasters(1)
addWorkers(2)
createClient()
assertValidClusterState()
assertTrue(getLeader == masters.head)
(1 to 3).foreach { _ =>
killLeader()
delay(30 seconds)
assertValidClusterState()
assertTrue(getLeader == masters.head)
addMasters(1)
}
}
def test(name: String)(fn: => Unit) {
try {
fn
numPassed += 1
logInfo("Passed: " + name)
} catch {
case e: Exception =>
numFailed += 1
logError("FAILED: " + name, e)
}
afterEach()
}
def addMasters(num: Int) {
(1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) }
}
def addWorkers(num: Int) {
val masterUrls = getMasterUrls(masters)
(1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) }
}
/** Creates a SparkContext, which constructs a Client to interact with our cluster. */
def createClient() = {
if (sc != null) { sc.stop() }
// Counter-hack: Because of a hack in SparkEnv#createFromSystemProperties() that changes this
// property, we need to reset it.
System.setProperty("spark.driver.port", "0")
sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome)
}
def getMasterUrls(masters: Seq[TestMasterInfo]): String = {
"spark://" + masters.map(master => master.ip + ":7077").mkString(",")
}
def getLeader: TestMasterInfo = {
val leaders = masters.filter(_.state == RecoveryState.ALIVE)
assertTrue(leaders.size == 1)
leaders(0)
}
def killLeader(): Unit = {
masters.foreach(_.readState())
val leader = getLeader
masters -= leader
leader.kill()
}
def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis)
def terminateCluster() {
masters.foreach(_.kill())
workers.foreach(_.kill())
masters.clear()
workers.clear()
}
/** This includes Client retry logic, so it may take a while if the cluster is recovering. */
def assertUsable() = {
val f = future {
try {
val res = sc.parallelize(0 until 10).collect()
assertTrue(res.toList == (0 until 10))
true
} catch {
case e: Exception =>
logError("assertUsable() had exception", e)
e.printStackTrace()
false
}
}
// Avoid waiting indefinitely (e.g., we could register but get no executors).
assertTrue(Await.result(f, 120 seconds))
}
/**
* Asserts that the cluster is usable and that the expected masters and workers
* are all alive in a proper configuration (e.g., only one leader).
*/
def assertValidClusterState() = {
assertUsable()
var numAlive = 0
var numStandby = 0
var numLiveApps = 0
var liveWorkerIPs: Seq[String] = List()
def stateValid(): Boolean = {
(workers.map(_.ip) -- liveWorkerIPs).isEmpty &&
numAlive == 1 && numStandby == masters.size - 1 && numLiveApps >= 1
}
val f = future {
try {
while (!stateValid()) {
Thread.sleep(1000)
numAlive = 0
numStandby = 0
numLiveApps = 0
masters.foreach(_.readState())
for (master <- masters) {
master.state match {
case RecoveryState.ALIVE =>
numAlive += 1
liveWorkerIPs = master.liveWorkerIPs
case RecoveryState.STANDBY =>
numStandby += 1
case _ => // ignore
}
numLiveApps += master.numLiveApps
}
}
true
} catch {
case e: Exception =>
logError("assertValidClusterState() had exception", e)
false
}
}
try {
assertTrue(Await.result(f, 120 seconds))
} catch {
case e: TimeoutException =>
logError("Master states: " + masters.map(_.state))
logError("Num apps: " + numLiveApps)
logError("IPs expected: " + workers.map(_.ip) + " / found: " + liveWorkerIPs)
throw new RuntimeException("Failed to get into acceptable cluster state after 2 min.", e)
}
}
def assertTrue(bool: Boolean, message: String = "") {
if (!bool) {
throw new IllegalStateException("Assertion failed: " + message)
}
}
logInfo("Ran %s tests, %s passed and %s failed".format(numPassed+numFailed, numPassed, numFailed))
}
private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File)
extends Logging {
implicit val formats = net.liftweb.json.DefaultFormats
var state: RecoveryState.Value = _
var liveWorkerIPs: List[String] = _
var numLiveApps = 0
logDebug("Created master: " + this)
def readState() {
try {
val masterStream = new InputStreamReader(new URL("http://%s:8080/json".format(ip)).openStream)
val json = JsonParser.parse(masterStream, closeAutomatically = true)
val workers = json \ "workers"
val liveWorkers = workers.children.filter(w => (w \ "state").extract[String] == "ALIVE")
liveWorkerIPs = liveWorkers.map(w => (w \ "host").extract[String])
numLiveApps = (json \ "activeapps").children.size
val status = json \\ "status"
val stateString = status.extract[String]
state = RecoveryState.values.filter(state => state.toString == stateString).head
} catch {
case e: Exception =>
// ignore, no state update
logWarning("Exception", e)
}
}
def kill() { Docker.kill(dockerId) }
override def toString: String =
"[ip=%s, id=%s, logFile=%s, state=%s]".
format(ip, dockerId.id, logFile.getAbsolutePath, state)
}
private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File)
extends Logging {
implicit val formats = net.liftweb.json.DefaultFormats
logDebug("Created worker: " + this)
def kill() { Docker.kill(dockerId) }
override def toString: String =
"[ip=%s, id=%s, logFile=%s]".format(ip, dockerId, logFile.getAbsolutePath)
}
private[spark] object SparkDocker {
def startMaster(mountDir: String): TestMasterInfo = {
val cmd = Docker.makeRunCmd("spark-test-master", mountDir = mountDir)
val (ip, id, outFile) = startNode(cmd)
new TestMasterInfo(ip, id, outFile)
}
def startWorker(mountDir: String, masters: String): TestWorkerInfo = {
val cmd = Docker.makeRunCmd("spark-test-worker", args = masters, mountDir = mountDir)
val (ip, id, outFile) = startNode(cmd)
new TestWorkerInfo(ip, id, outFile)
}
private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = {
val ipPromise = promise[String]()
val outFile = File.createTempFile("fault-tolerance-test", "")
outFile.deleteOnExit()
val outStream: FileWriter = new FileWriter(outFile)
def findIpAndLog(line: String): Unit = {
if (line.startsWith("CONTAINER_IP=")) {
val ip = line.split("=")(1)
ipPromise.success(ip)
}
outStream.write(line + "\n")
outStream.flush()
}
dockerCmd.run(ProcessLogger(findIpAndLog _))
val ip = Await.result(ipPromise.future, 30 seconds)
val dockerId = Docker.getLastProcessId
(ip, dockerId, outFile)
}
}
private[spark] class DockerId(val id: String) {
override def toString = id
}
private[spark] object Docker extends Logging {
def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = {
val mountCmd = if (mountDir != "") { " -v " + mountDir } else ""
val cmd = "docker run %s %s %s".format(mountCmd, imageTag, args)
logDebug("Run command: " + cmd)
cmd
}
def kill(dockerId: DockerId) : Unit = {
"docker kill %s".format(dockerId.id).!
}
def getLastProcessId: DockerId = {
var id: String = null
"docker ps -l -q".!(ProcessLogger(line => id = line))
new DockerId(id)
}
}

View file

@ -72,7 +72,8 @@ private[spark] object JsonProtocol {
("memory" -> obj.workers.map(_.memory).sum) ~
("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~
("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~
("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo))
("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~
("status" -> obj.status.toString)
}
def writeWorkerState(obj: WorkerStateResponse) = {

View file

@ -39,22 +39,23 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
def start(): String = {
def start(): Array[String] = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
/* Start the Master */
val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localHostname + ":" + masterPort
val masters = Array(masterUrl)
/* Start the Workers */
for (workerNum <- 1 to numWorkers) {
val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
memoryPerWorker, masterUrl, null, Some(workerNum))
memoryPerWorker, masters, null, Some(workerNum))
workerActorSystems += workerSystem
}
return masterUrl
return masters
}
def stop() {

View file

@ -17,28 +17,70 @@
package org.apache.spark.deploy
import com.google.common.collect.MapMaker
import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{SparkContext, SparkException}
/**
* Contains util methods to interact with Hadoop from spark.
* Contains util methods to interact with Hadoop from Spark.
*/
private[spark]
class SparkHadoopUtil {
// A general, soft-reference map for metadata needed during HadoopRDD split computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
val conf = newConfiguration()
UserGroupInformation.setConfiguration(conf)
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop
// subsystems
def runAsUser(user: String)(func: () => Unit) {
// if we are already running as the user intended there is no reason to do the doAs. It
// will actually break secure HDFS access as it doesn't fill in the credentials. Also if
// the user is UNKNOWN then we shouldn't be creating a remote unknown user
// (this is actually the path spark on yarn takes) since SPARK_USER is initialized only
// in SparkContext.
val currentUser = Option(System.getProperty("user.name")).
getOrElse(SparkContext.SPARK_UNKNOWN_USER)
if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) {
val ugi = UserGroupInformation.createRemoteUser(user)
ugi.doAs(new PrivilegedExceptionAction[Unit] {
def run: Unit = func()
})
} else {
func()
}
}
/**
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
* subsystems.
*/
def newConfiguration(): Configuration = new Configuration()
// Add any user credentials to the job conf which are necessary for running on a secure Hadoop
// cluster
/**
* Add any user credentials to the job conf which are necessary for running on a secure Hadoop
* cluster.
*/
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }
}
object SparkHadoopUtil {
private val hadoop = {
val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
if (yarnMode) {
try {
Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
} catch {
case th: Throwable => throw new SparkException("Unable to load YARN support", th)
}
} else {
new SparkHadoopUtil
}
}
def get: SparkHadoopUtil = {
hadoop
}
}

View file

@ -37,38 +37,81 @@ import org.apache.spark.deploy.master.Master
/**
* The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description,
* and a listener for cluster events, and calls back the listener when various events occur.
*
* @param masterUrls Each url should look like spark://host:port.
*/
private[spark] class Client(
actorSystem: ActorSystem,
masterUrl: String,
masterUrls: Array[String],
appDescription: ApplicationDescription,
listener: ClientListener)
extends Logging {
val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
var prevMaster: ActorRef = null // set for unwatching, when it fails.
var actor: ActorRef = null
var appId: String = null
var registered = false
var activeMasterUrl: String = null
class ClientActor extends Actor with Logging {
var master: ActorSelection = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
var alreadyDead = false // To avoid calling listener.dead() multiple times
override def preStart() {
logInfo("Connecting to master " + masterUrl)
try {
master = context.actorSelection(Master.toAkkaUrl(masterUrl))
master ! RegisterApplication(appDescription)
registerWithMaster()
} catch {
case e: Exception =>
logError("Failed to connect to master", e)
logWarning("Failed to connect to master", e)
markDisconnected()
context.stop(self)
}
}
def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
actor ! RegisterApplication(appDescription)
}
}
def registerWithMaster() {
tryRegisterAllMasters()
import context.dispatcher
var retries = 0
lazy val retryTimer: Cancellable =
context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
retries += 1
if (registered) {
retryTimer.cancel()
} else if (retries >= REGISTRATION_RETRIES) {
logError("All masters are unresponsive! Giving up.")
markDead()
} else {
tryRegisterAllMasters()
}
}
retryTimer // start timer
}
def changeMaster(url: String) {
activeMasterUrl = url
master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
}
override def receive = {
case RegisteredApplication(appId_) =>
case RegisteredApplication(appId_, masterUrl) =>
context.watch(sender)
prevMaster = sender
appId = appId_
registered = true
changeMaster(masterUrl)
listener.connected(appId)
case ApplicationRemoved(message) =>
@ -89,13 +132,19 @@ private[spark] class Client(
listener.executorRemoved(fullId, message.getOrElse(""), exitStatus)
}
case MasterChanged(masterUrl, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterUrl)
context.unwatch(prevMaster)
changeMaster(masterUrl)
alreadyDisconnected = false
sender ! MasterChangeAcknowledged(appId)
case Terminated(actor_) =>
logError(s"Connection to $actor_ dropped, stopping client")
logWarning(s"Connection to $actor_ failed; waiting for master to reconnect...")
markDisconnected()
context.stop(self)
case StopClient =>
markDisconnected()
markDead()
sender ! true
context.stop(self)
}
@ -109,6 +158,13 @@ private[spark] class Client(
alreadyDisconnected = true
}
}
def markDead() {
if (!alreadyDead) {
listener.dead()
alreadyDead = true
}
}
}
def start() {

View file

@ -27,8 +27,12 @@ package org.apache.spark.deploy.client
private[spark] trait ClientListener {
def connected(appId: String): Unit
/** Disconnection may be a temporary state, as we fail over to a new Master. */
def disconnected(): Unit
/** Dead means that we couldn't find any Masters to connect to, and have given up. */
def dead(): Unit
def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit

View file

@ -33,6 +33,11 @@ private[spark] object TestClient {
System.exit(0)
}
def dead() {
logInfo("Could not connect to master")
System.exit(0)
}
def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
@ -44,7 +49,7 @@ private[spark] object TestClient {
val desc = new ApplicationDescription(
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
val listener = new TestListener
val client = new Client(actorSystem, url, desc, listener)
val client = new Client(actorSystem, Array(url), desc, listener)
client.start()
actorSystem.awaitTermination()
}

View file

@ -29,23 +29,46 @@ private[spark] class ApplicationInfo(
val submitDate: Date,
val driver: ActorRef,
val appUiUrl: String)
{
var state = ApplicationState.WAITING
var executors = new mutable.HashMap[Int, ExecutorInfo]
var coresGranted = 0
var endTime = -1L
val appSource = new ApplicationSource(this)
extends Serializable {
private var nextExecutorId = 0
@transient var state: ApplicationState.Value = _
@transient var executors: mutable.HashMap[Int, ExecutorInfo] = _
@transient var coresGranted: Int = _
@transient var endTime: Long = _
@transient var appSource: ApplicationSource = _
def newExecutorId(): Int = {
val id = nextExecutorId
nextExecutorId += 1
id
@transient private var nextExecutorId: Int = _
init()
private def readObject(in: java.io.ObjectInputStream) : Unit = {
in.defaultReadObject()
init()
}
def addExecutor(worker: WorkerInfo, cores: Int): ExecutorInfo = {
val exec = new ExecutorInfo(newExecutorId(), this, worker, cores, desc.memoryPerSlave)
private def init() {
state = ApplicationState.WAITING
executors = new mutable.HashMap[Int, ExecutorInfo]
coresGranted = 0
endTime = -1L
appSource = new ApplicationSource(this)
nextExecutorId = 0
}
private def newExecutorId(useID: Option[Int] = None): Int = {
useID match {
case Some(id) =>
nextExecutorId = math.max(nextExecutorId, id + 1)
id
case None =>
val id = nextExecutorId
nextExecutorId += 1
id
}
}
def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = {
val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
executors(exec.id) = exec
coresGranted += cores
exec

View file

@ -17,12 +17,11 @@
package org.apache.spark.deploy.master
private[spark] object ApplicationState
extends Enumeration {
private[spark] object ApplicationState extends Enumeration {
type ApplicationState = Value
val WAITING, RUNNING, FINISHED, FAILED = Value
val WAITING, RUNNING, FINISHED, FAILED, UNKNOWN = Value
val MAX_NUM_RETRY = 10
}

View file

@ -17,7 +17,7 @@
package org.apache.spark.deploy.master
import org.apache.spark.deploy.ExecutorState
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
private[spark] class ExecutorInfo(
val id: Int,
@ -28,5 +28,10 @@ private[spark] class ExecutorInfo(
var state = ExecutorState.LAUNCHING
/** Copy all state (non-val) variables from the given on-the-wire ExecutorDescription. */
def copyState(execDesc: ExecutorDescription) {
state = execDesc.state
}
def fullId: String = application.id + "/" + id
}

View file

@ -0,0 +1,90 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.master
import java.io._
import scala.Serializable
import akka.serialization.Serialization
import org.apache.spark.Logging
/**
* Stores data in a single on-disk directory with one file per application and worker.
* Files are deleted when applications and workers are removed.
*
* @param dir Directory to store files. Created if non-existent (but not recursively).
* @param serialization Used to serialize our objects.
*/
private[spark] class FileSystemPersistenceEngine(
val dir: String,
val serialization: Serialization)
extends PersistenceEngine with Logging {
new File(dir).mkdir()
override def addApplication(app: ApplicationInfo) {
val appFile = new File(dir + File.separator + "app_" + app.id)
serializeIntoFile(appFile, app)
}
override def removeApplication(app: ApplicationInfo) {
new File(dir + File.separator + "app_" + app.id).delete()
}
override def addWorker(worker: WorkerInfo) {
val workerFile = new File(dir + File.separator + "worker_" + worker.id)
serializeIntoFile(workerFile, worker)
}
override def removeWorker(worker: WorkerInfo) {
new File(dir + File.separator + "worker_" + worker.id).delete()
}
override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
val sortedFiles = new File(dir).listFiles().sortBy(_.getName)
val appFiles = sortedFiles.filter(_.getName.startsWith("app_"))
val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_"))
val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
(apps, workers)
}
private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
val out = new FileOutputStream(file)
out.write(serialized)
out.close()
}
def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
dis.readFully(fileData)
dis.close()
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}
}

View file

@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.master
import akka.actor.{Actor, ActorRef}
import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
/**
* A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it
* is the only Master serving requests.
* In addition to the API provided, the LeaderElectionAgent will use of the following messages
* to inform the Master of leader changes:
* [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]]
* [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
*/
private[spark] trait LeaderElectionAgent extends Actor {
val masterActor: ActorRef
}
/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent {
override def preStart() {
masterActor ! ElectedLeader
}
override def receive = {
case _ =>
}
}

View file

@ -23,42 +23,39 @@ import java.text.SimpleDateFormat
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.concurrent.Await
import scala.concurrent.duration._
import scala.concurrent.duration.{Duration, FiniteDuration}
import akka.actor._
import akka.pattern.ask
import akka.remote._
import akka.util.Timeout
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{Utils, AkkaUtils}
import akka.util.Timeout
import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed
import org.apache.spark.deploy.DeployMessages.KillExecutor
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
import scala.Some
import org.apache.spark.deploy.DeployMessages.WebUIPortResponse
import org.apache.spark.deploy.DeployMessages.LaunchExecutor
import org.apache.spark.deploy.DeployMessages.RegisteredApplication
import org.apache.spark.deploy.DeployMessages.RegisterWorker
import org.apache.spark.deploy.DeployMessages.ExecutorUpdated
import org.apache.spark.deploy.DeployMessages.MasterStateResponse
import org.apache.spark.deploy.DeployMessages.ExecutorAdded
import org.apache.spark.deploy.DeployMessages.RegisterApplication
import org.apache.spark.deploy.DeployMessages.ApplicationRemoved
import org.apache.spark.deploy.DeployMessages.Heartbeat
import org.apache.spark.deploy.DeployMessages.RegisteredWorker
import akka.actor.Terminated
import akka.serialization.SerializationExtension
import java.util.concurrent.TimeUnit
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
import context.dispatcher
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt
val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt
val RECOVERY_DIR = System.getProperty("spark.deploy.recoveryDirectory", "")
val RECOVERY_MODE = System.getProperty("spark.deploy.recoveryMode", "NONE")
var nextAppNumber = 0
val workers = new HashSet[WorkerInfo]
val idToWorker = new HashMap[String, WorkerInfo]
@ -88,52 +85,114 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
if (envVar != null) envVar else host
}
val masterUrl = "spark://" + host + ":" + port
var masterWebUiUrl: String = _
var state = RecoveryState.STANDBY
var persistenceEngine: PersistenceEngine = _
var leaderElectionAgent: ActorRef = _
// As a temporary workaround before better ways of configuring memory, we allow users to set
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
// among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
override def preStart() {
logInfo("Starting Spark master at spark://" + host + ":" + port)
logInfo("Starting Spark master at " + masterUrl)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.start()
import context.dispatcher
masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get
context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
masterMetricsSystem.registerSource(masterSource)
masterMetricsSystem.start()
applicationMetricsSystem.start()
persistenceEngine = RECOVERY_MODE match {
case "ZOOKEEPER" =>
logInfo("Persisting recovery state to ZooKeeper")
new ZooKeeperPersistenceEngine(SerializationExtension(context.system))
case "FILESYSTEM" =>
logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system))
case _ =>
new BlackHolePersistenceEngine()
}
leaderElectionAgent = RECOVERY_MODE match {
case "ZOOKEEPER" =>
context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl))
case _ =>
context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
}
}
override def preRestart(reason: Throwable, message: Option[Any]) {
super.preRestart(reason, message) // calls postStop()!
logError("Master actor restarted due to exception", reason)
}
override def postStop() {
webUi.stop()
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
persistenceEngine.close()
context.stop(leaderElectionAgent)
}
override def receive = {
case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => {
case ElectedLeader => {
val (storedApps, storedWorkers) = persistenceEngine.readPersistedData()
state = if (storedApps.isEmpty && storedWorkers.isEmpty)
RecoveryState.ALIVE
else
RecoveryState.RECOVERING
logInfo("I have been elected leader! New state: " + state)
if (state == RecoveryState.RECOVERING) {
beginRecovery(storedApps, storedWorkers)
context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() }
}
}
case RevokedLeadership => {
logError("Leadership has been revoked -- master shutting down.")
System.exit(0)
}
case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => {
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
host, workerPort, cores, Utils.megabytesToString(memory)))
if (idToWorker.contains(id)) {
if (state == RecoveryState.STANDBY) {
// ignore, don't send response
} else if (idToWorker.contains(id)) {
sender ! RegisterWorkerFailed("Duplicate worker ID")
} else {
addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress)
context.watch(sender) // This doesn't work with remote actors but helps for testing
sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUi.boundPort.get)
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
registerWorker(worker)
context.watch(sender)
persistenceEngine.addWorker(worker)
sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
schedule()
}
}
case RegisterApplication(description) => {
logInfo("Registering app " + description.name)
val app = addApplication(description, sender)
logInfo("Registered app " + description.name + " with ID " + app.id)
waitingApps += app
context.watch(sender) // This doesn't work with remote actors but helps for testing
sender ! RegisteredApplication(app.id)
schedule()
if (state == RecoveryState.STANDBY) {
// ignore, don't send response
} else {
logInfo("Registering app " + description.name)
val app = createApplication(description, sender)
registerApplication(app)
logInfo("Registered app " + description.name + " with ID " + app.id)
context.watch(sender)
persistenceEngine.addApplication(app)
sender ! RegisteredApplication(app.id, masterUrl)
schedule()
}
}
case ExecutorStateChanged(appId, execId, state, message, exitStatus) => {
@ -173,27 +232,49 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
case MasterChangeAcknowledged(appId) => {
idToApp.get(appId) match {
case Some(app) =>
logInfo("Application has been re-registered: " + appId)
app.state = ApplicationState.WAITING
case None =>
logWarning("Master change ack from unknown app: " + appId)
}
if (canCompleteRecovery) { completeRecovery() }
}
case WorkerSchedulerStateResponse(workerId, executors) => {
idToWorker.get(workerId) match {
case Some(worker) =>
logInfo("Worker has been re-registered: " + workerId)
worker.state = WorkerState.ALIVE
val validExecutors = executors.filter(exec => idToApp.get(exec.appId).isDefined)
for (exec <- validExecutors) {
val app = idToApp.get(exec.appId).get
val execInfo = app.addExecutor(worker, exec.cores, Some(exec.execId))
worker.addExecutor(execInfo)
execInfo.copyState(exec)
}
case None =>
logWarning("Scheduler state from unknown worker: " + workerId)
}
if (canCompleteRecovery) { completeRecovery() }
}
case Terminated(actor) => {
// The disconnected actor could've been either a worker or an app; remove whichever of
// those we have an entry for in the corresponding actor hashmap
actorToWorker.get(actor).foreach(removeWorker)
actorToApp.get(actor).foreach(finishApplication)
}
case DisassociatedEvent(_, address, _) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
}
case AssociationErrorEvent(_, _, address, _) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
}
case RequestMasterState => {
sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray)
sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray,
state)
}
case CheckForWorkerTimeOut => {
@ -205,6 +286,50 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
}
def canCompleteRecovery =
workers.count(_.state == WorkerState.UNKNOWN) == 0 &&
apps.count(_.state == ApplicationState.UNKNOWN) == 0
def beginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo]) {
for (app <- storedApps) {
logInfo("Trying to recover app: " + app.id)
try {
registerApplication(app)
app.state = ApplicationState.UNKNOWN
app.driver ! MasterChanged(masterUrl, masterWebUiUrl)
} catch {
case e: Exception => logInfo("App " + app.id + " had exception on reconnect")
}
}
for (worker <- storedWorkers) {
logInfo("Trying to recover worker: " + worker.id)
try {
registerWorker(worker)
worker.state = WorkerState.UNKNOWN
worker.actor ! MasterChanged(masterUrl, masterWebUiUrl)
} catch {
case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect")
}
}
}
def completeRecovery() {
// Ensure "only-once" recovery semantics using a short synchronization period.
synchronized {
if (state != RecoveryState.RECOVERING) { return }
state = RecoveryState.COMPLETING_RECOVERY
}
// Kill off any workers and apps that didn't respond to us.
workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker)
apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication)
state = RecoveryState.ALIVE
schedule()
logInfo("Recovery complete - resuming operations!")
}
/**
* Can an app use the given worker? True if the worker has enough memory and we haven't already
* launched an executor for the app on it (right now the standalone backend doesn't like having
@ -219,6 +344,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
* every time a new app joins or resource availability changes.
*/
def schedule() {
if (state != RecoveryState.ALIVE) { return }
// Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
// in the queue, then the second app, etc.
if (spreadOutApps) {
@ -266,14 +392,13 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
worker.actor ! LaunchExecutor(
worker.actor ! LaunchExecutor(masterUrl,
exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
exec.application.driver ! ExecutorAdded(
exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
}
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
publicAddress: String): WorkerInfo = {
def registerWorker(worker: WorkerInfo): Unit = {
// There may be one or more refs to dead workers on this same node (w/ different ID's),
// remove them.
workers.filter { w =>
@ -281,12 +406,17 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}.foreach { w =>
workers -= w
}
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
val workerAddress = worker.actor.path.address
if (addressToWorker.contains(workerAddress)) {
logInfo("Attempted to re-register worker at same address: " + workerAddress)
return
}
workers += worker
idToWorker(worker.id) = worker
actorToWorker(sender) = worker
addressToWorker(sender.path.address) = worker
worker
actorToWorker(worker.actor) = worker
addressToWorker(workerAddress) = worker
}
def removeWorker(worker: WorkerInfo) {
@ -301,25 +431,36 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
exec.id, ExecutorState.LOST, Some("worker lost"), None)
exec.application.removeExecutor(exec)
}
persistenceEngine.removeWorker(worker)
}
def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
val now = System.currentTimeMillis()
val date = new Date(now)
val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
}
def registerApplication(app: ApplicationInfo): Unit = {
val appAddress = app.driver.path.address
if (addressToWorker.contains(appAddress)) {
logInfo("Attempted to re-register application at same address: " + appAddress)
return
}
applicationMetricsSystem.registerSource(app.appSource)
apps += app
idToApp(app.id) = app
actorToApp(driver) = app
addressToApp(driver.path.address) = app
actorToApp(app.driver) = app
addressToApp(appAddress) = app
if (firstApp == None) {
firstApp = Some(app)
}
// TODO: What is firstApp?? Can we remove it?
val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray
if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) {
if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= app.desc.memoryPerSlave)) {
logWarning("Could not find any workers with enough memory for " + firstApp.get.id)
}
app
waitingApps += app
}
def finishApplication(app: ApplicationInfo) {
@ -344,13 +485,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
waitingApps -= app
for (exec <- app.executors.values) {
exec.worker.removeExecutor(exec)
exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id)
exec.state = ExecutorState.KILLED
}
app.markFinished(state)
if (state != ApplicationState.FINISHED) {
app.driver ! ApplicationRemoved(state.toString)
}
persistenceEngine.removeApplication(app)
schedule()
}
}
@ -404,8 +546,8 @@ private[spark] object Master {
def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), name = actorName)
val timeoutDuration = Duration.create(
System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
val timeoutDuration: FiniteDuration = Duration.create(
System.getProperty("spark.akka.askTimeout", "10").toLong, TimeUnit.SECONDS)
implicit val timeout = Timeout(timeoutDuration)
val respFuture = actor ? RequestWebUIPort // ask pattern
val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse]

View file

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.master
sealed trait MasterMessages extends Serializable
/** Contains messages seen only by the Master and its associated entities. */
private[master] object MasterMessages {
// LeaderElectionAgent to Master
case object ElectedLeader
case object RevokedLeadership
// Actor System to LeaderElectionAgent
case object CheckLeader
// Actor System to Master
case object CheckForWorkerTimeOut
case class BeginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo])
case object CompleteRecovery
case object RequestWebUIPort
case class WebUIPortResponse(webUIBoundPort: Int)
}

View file

@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.master
/**
* Allows Master to persist any state that is necessary in order to recover from a failure.
* The following semantics are required:
* - addApplication and addWorker are called before completing registration of a new app/worker.
* - removeApplication and removeWorker are called at any time.
* Given these two requirements, we will have all apps and workers persisted, but
* we might not have yet deleted apps or workers that finished (so their liveness must be verified
* during recovery).
*/
private[spark] trait PersistenceEngine {
def addApplication(app: ApplicationInfo)
def removeApplication(app: ApplicationInfo)
def addWorker(worker: WorkerInfo)
def removeWorker(worker: WorkerInfo)
/**
* Returns the persisted data sorted by their respective ids (which implies that they're
* sorted by time of creation).
*/
def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo])
def close() {}
}
private[spark] class BlackHolePersistenceEngine extends PersistenceEngine {
override def addApplication(app: ApplicationInfo) {}
override def removeApplication(app: ApplicationInfo) {}
override def addWorker(worker: WorkerInfo) {}
override def removeWorker(worker: WorkerInfo) {}
override def readPersistedData() = (Nil, Nil)
}

View file

@ -0,0 +1,26 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.master
private[spark] object RecoveryState
extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") {
type MasterState = Value
val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value
}

View file

@ -0,0 +1,203 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.master
import scala.collection.JavaConversions._
import scala.concurrent.ops._
import org.apache.spark.Logging
import org.apache.zookeeper._
import org.apache.zookeeper.data.Stat
import org.apache.zookeeper.Watcher.Event.KeeperState
/**
* Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry
* logic. If the ZooKeeper session expires or otherwise dies, a new ZooKeeper session will be
* created. If ZooKeeper remains down after several retries, the given
* [[org.apache.spark.deploy.master.SparkZooKeeperWatcher SparkZooKeeperWatcher]] will be
* informed via zkDown().
*
* Additionally, all commands sent to ZooKeeper will be retried until they either fail too many
* times or a semantic exception is thrown (e.g.., "node already exists").
*/
private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) extends Logging {
val ZK_URL = System.getProperty("spark.deploy.zookeeper.url", "")
val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE
val ZK_TIMEOUT_MILLIS = 30000
val RETRY_WAIT_MILLIS = 5000
val ZK_CHECK_PERIOD_MILLIS = 10000
val MAX_RECONNECT_ATTEMPTS = 3
private var zk: ZooKeeper = _
private val watcher = new ZooKeeperWatcher()
private var reconnectAttempts = 0
private var closed = false
/** Connect to ZooKeeper to start the session. Must be called before anything else. */
def connect() {
connectToZooKeeper()
new Thread() {
override def run() = sessionMonitorThread()
}.start()
}
def sessionMonitorThread(): Unit = {
while (!closed) {
Thread.sleep(ZK_CHECK_PERIOD_MILLIS)
if (zk.getState != ZooKeeper.States.CONNECTED) {
reconnectAttempts += 1
val attemptsLeft = MAX_RECONNECT_ATTEMPTS - reconnectAttempts
if (attemptsLeft <= 0) {
logError("Could not connect to ZooKeeper: system failure")
zkWatcher.zkDown()
close()
} else {
logWarning("ZooKeeper connection failed, retrying " + attemptsLeft + " more times...")
connectToZooKeeper()
}
}
}
}
def close() {
if (!closed && zk != null) { zk.close() }
closed = true
}
private def connectToZooKeeper() {
if (zk != null) zk.close()
zk = new ZooKeeper(ZK_URL, ZK_TIMEOUT_MILLIS, watcher)
}
/**
* Attempts to maintain a live ZooKeeper exception despite (very) transient failures.
* Mainly useful for handling the natural ZooKeeper session expiration.
*/
private class ZooKeeperWatcher extends Watcher {
def process(event: WatchedEvent) {
if (closed) { return }
event.getState match {
case KeeperState.SyncConnected =>
reconnectAttempts = 0
zkWatcher.zkSessionCreated()
case KeeperState.Expired =>
connectToZooKeeper()
case KeeperState.Disconnected =>
logWarning("ZooKeeper disconnected, will retry...")
}
}
}
def create(path: String, bytes: Array[Byte], createMode: CreateMode): String = {
retry {
zk.create(path, bytes, ZK_ACL, createMode)
}
}
def exists(path: String, watcher: Watcher = null): Stat = {
retry {
zk.exists(path, watcher)
}
}
def getChildren(path: String, watcher: Watcher = null): List[String] = {
retry {
zk.getChildren(path, watcher).toList
}
}
def getData(path: String): Array[Byte] = {
retry {
zk.getData(path, false, null)
}
}
def delete(path: String, version: Int = -1): Unit = {
retry {
zk.delete(path, version)
}
}
/**
* Creates the given directory (non-recursively) if it doesn't exist.
* All znodes are created in PERSISTENT mode with no data.
*/
def mkdir(path: String) {
if (exists(path) == null) {
try {
create(path, "".getBytes, CreateMode.PERSISTENT)
} catch {
case e: Exception =>
// If the exception caused the directory not to be created, bubble it up,
// otherwise ignore it.
if (exists(path) == null) { throw e }
}
}
}
/**
* Recursively creates all directories up to the given one.
* All znodes are created in PERSISTENT mode with no data.
*/
def mkdirRecursive(path: String) {
var fullDir = ""
for (dentry <- path.split("/").tail) {
fullDir += "/" + dentry
mkdir(fullDir)
}
}
/**
* Retries the given function up to 3 times. The assumption is that failure is transient,
* UNLESS it is a semantic exception (i.e., trying to get data from a node that doesn't exist),
* in which case the exception will be thrown without retries.
*
* @param fn Block to execute, possibly multiple times.
*/
def retry[T](fn: => T, n: Int = MAX_RECONNECT_ATTEMPTS): T = {
try {
fn
} catch {
case e: KeeperException.NoNodeException => throw e
case e: KeeperException.NodeExistsException => throw e
case e if n > 0 =>
logError("ZooKeeper exception, " + n + " more retries...", e)
Thread.sleep(RETRY_WAIT_MILLIS)
retry(fn, n-1)
}
}
}
trait SparkZooKeeperWatcher {
/**
* Called whenever a ZK session is created --
* this will occur when we create our first session as well as each time
* the session expires or errors out.
*/
def zkSessionCreated()
/**
* Called if ZK appears to be completely down (i.e., not just a transient error).
* We will no longer attempt to reconnect to ZK, and the SparkZooKeeperSession is considered dead.
*/
def zkDown()
}

View file

@ -22,28 +22,44 @@ import scala.collection.mutable
import org.apache.spark.util.Utils
private[spark] class WorkerInfo(
val id: String,
val host: String,
val port: Int,
val cores: Int,
val memory: Int,
val actor: ActorRef,
val webUiPort: Int,
val publicAddress: String) {
val id: String,
val host: String,
val port: Int,
val cores: Int,
val memory: Int,
val actor: ActorRef,
val webUiPort: Int,
val publicAddress: String)
extends Serializable {
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
var state: WorkerState.Value = WorkerState.ALIVE
var coresUsed = 0
var memoryUsed = 0
@transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // fullId => info
@transient var state: WorkerState.Value = _
@transient var coresUsed: Int = _
@transient var memoryUsed: Int = _
var lastHeartbeat = System.currentTimeMillis()
@transient var lastHeartbeat: Long = _
init()
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
private def readObject(in: java.io.ObjectInputStream) : Unit = {
in.defaultReadObject()
init()
}
private def init() {
executors = new mutable.HashMap
state = WorkerState.ALIVE
coresUsed = 0
memoryUsed = 0
lastHeartbeat = System.currentTimeMillis()
}
def hostPort: String = {
assert (port > 0)
host + ":" + port

View file

@ -20,5 +20,5 @@ package org.apache.spark.deploy.master
private[spark] object WorkerState extends Enumeration {
type WorkerState = Value
val ALIVE, DEAD, DECOMMISSIONED = Value
val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value
}

View file

@ -0,0 +1,136 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.master
import akka.actor.ActorRef
import org.apache.zookeeper._
import org.apache.zookeeper.Watcher.Event.EventType
import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.Logging
private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String)
extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging {
val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
private val watcher = new ZooKeeperWatcher()
private val zk = new SparkZooKeeperSession(this)
private var status = LeadershipStatus.NOT_LEADER
private var myLeaderFile: String = _
private var leaderUrl: String = _
override def preStart() {
logInfo("Starting ZooKeeper LeaderElection agent")
zk.connect()
}
override def zkSessionCreated() {
synchronized {
zk.mkdirRecursive(WORKING_DIR)
myLeaderFile =
zk.create(WORKING_DIR + "/master_", masterUrl.getBytes, CreateMode.EPHEMERAL_SEQUENTIAL)
self ! CheckLeader
}
}
override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
logError("LeaderElectionAgent failed, waiting " + zk.ZK_TIMEOUT_MILLIS + "...", reason)
Thread.sleep(zk.ZK_TIMEOUT_MILLIS)
super.preRestart(reason, message)
}
override def zkDown() {
logError("ZooKeeper down! LeaderElectionAgent shutting down Master.")
System.exit(1)
}
override def postStop() {
zk.close()
}
override def receive = {
case CheckLeader => checkLeader()
}
private class ZooKeeperWatcher extends Watcher {
def process(event: WatchedEvent) {
if (event.getType == EventType.NodeDeleted) {
logInfo("Leader file disappeared, a master is down!")
self ! CheckLeader
}
}
}
/** Uses ZK leader election. Navigates several ZK potholes along the way. */
def checkLeader() {
val masters = zk.getChildren(WORKING_DIR).toList
val leader = masters.sorted.head
val leaderFile = WORKING_DIR + "/" + leader
// Setup a watch for the current leader.
zk.exists(leaderFile, watcher)
try {
leaderUrl = new String(zk.getData(leaderFile))
} catch {
// A NoNodeException may be thrown if old leader died since the start of this method call.
// This is fine -- just check again, since we're guaranteed to see the new values.
case e: KeeperException.NoNodeException =>
logInfo("Leader disappeared while reading it -- finding next leader")
checkLeader()
return
}
// Synchronization used to ensure no interleaving between the creation of a new session and the
// checking of a leader, which could cause us to delete our real leader file erroneously.
synchronized {
val isLeader = myLeaderFile == leaderFile
if (!isLeader && leaderUrl == masterUrl) {
// We found a different master file pointing to this process.
// This can happen in the following two cases:
// (1) The master process was restarted on the same node.
// (2) The ZK server died between creating the node and returning the name of the node.
// For this case, we will end up creating a second file, and MUST explicitly delete the
// first one, since our ZK session is still open.
// Note that this deletion will cause a NodeDeleted event to be fired so we check again for
// leader changes.
assert(leaderFile < myLeaderFile)
logWarning("Cleaning up old ZK master election file that points to this master.")
zk.delete(leaderFile)
} else {
updateLeadershipStatus(isLeader)
}
}
}
def updateLeadershipStatus(isLeader: Boolean) {
if (isLeader && status == LeadershipStatus.NOT_LEADER) {
status = LeadershipStatus.LEADER
masterActor ! ElectedLeader
} else if (!isLeader && status == LeadershipStatus.LEADER) {
status = LeadershipStatus.NOT_LEADER
masterActor ! RevokedLeadership
}
}
private object LeadershipStatus extends Enumeration {
type LeadershipStatus = Value
val LEADER, NOT_LEADER = Value
}
}

View file

@ -0,0 +1,85 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.deploy.master
import org.apache.spark.Logging
import org.apache.zookeeper._
import akka.serialization.Serialization
class ZooKeeperPersistenceEngine(serialization: Serialization)
extends PersistenceEngine
with SparkZooKeeperWatcher
with Logging
{
val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
val zk = new SparkZooKeeperSession(this)
zk.connect()
override def zkSessionCreated() {
zk.mkdirRecursive(WORKING_DIR)
}
override def zkDown() {
logError("PersistenceEngine disconnected from ZooKeeper -- ZK looks down.")
}
override def addApplication(app: ApplicationInfo) {
serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
}
override def removeApplication(app: ApplicationInfo) {
zk.delete(WORKING_DIR + "/app_" + app.id)
}
override def addWorker(worker: WorkerInfo) {
serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker)
}
override def removeWorker(worker: WorkerInfo) {
zk.delete(WORKING_DIR + "/worker_" + worker.id)
}
override def close() {
zk.close()
}
override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted
val appFiles = sortedFiles.filter(_.startsWith("app_"))
val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
(apps, workers)
}
private def serializeIntoFile(path: String, value: AnyRef) {
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
zk.create(path, serialized, CreateMode.PERSISTENT)
}
def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = {
val fileData = zk.getData("/spark/master_status/" + filename)
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}
}

View file

@ -43,7 +43,8 @@ private[spark] class ExecutorRunner(
val workerId: String,
val host: String,
val sparkHome: File,
val workDir: File)
val workDir: File,
var state: ExecutorState.Value)
extends Logging {
val fullId = appId + "/" + execId
@ -83,7 +84,8 @@ private[spark] class ExecutorRunner(
process.destroy()
process.waitFor()
}
worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None)
state = ExecutorState.KILLED
worker ! ExecutorStateChanged(appId, execId, state, None, None)
Runtime.getRuntime.removeShutdownHook(shutdownHook)
}
}
@ -102,7 +104,7 @@ private[spark] class ExecutorRunner(
// SPARK-698: do not call the run.cmd script, as process.destroy()
// fails to kill a process tree on Windows
Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
command.arguments.map(substituteVariables)
(command.arguments ++ Seq(appId)).map(substituteVariables)
}
/**
@ -180,9 +182,9 @@ private[spark] class ExecutorRunner(
// long-lived processes only. However, in the future, we might restart the executor a few
// times on the same machine.
val exitCode = process.waitFor()
state = ExecutorState.FAILED
val message = "Command exited with code " + exitCode
worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message),
Some(exitCode))
worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))
} catch {
case interrupted: InterruptedException =>
logInfo("Runner thread for executor " + fullId + " interrupted")
@ -192,8 +194,9 @@ private[spark] class ExecutorRunner(
if (process != null) {
process.destroy()
}
state = ExecutorState.FAILED
val message = e.getClass + ": " + e.getMessage
worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None)
worker ! ExecutorStateChanged(appId, execId, state, Some(message), None)
}
}
}

View file

@ -25,10 +25,8 @@ import scala.collection.mutable.HashMap
import scala.concurrent.duration._
import akka.actor._
import akka.remote.{RemotingLifecycleEvent, AssociationErrorEvent, DisassociatedEvent}
import org.apache.spark.Logging
import org.apache.spark.deploy.ExecutorState
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.ui.WorkerWebUI
@ -45,16 +43,19 @@ import akka.remote.DisassociatedEvent
import org.apache.spark.deploy.DeployMessages.LaunchExecutor
import org.apache.spark.deploy.DeployMessages.RegisterWorker
/**
* @param masterUrls Each url should look like spark://host:port.
*/
private[spark] class Worker(
host: String,
port: Int,
webUiPort: Int,
cores: Int,
memory: Int,
masterUrl: String,
masterUrls: Array[String],
workDirPath: String = null)
extends Actor with Logging {
import context.dispatcher
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
@ -64,8 +65,19 @@ private[spark] class Worker(
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4
val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
// Index into masterUrls that we're currently trying to register with.
var masterIndex = 0
val masterLock: Object = new Object()
var master: ActorSelection = null
var masterWebUiUrl : String = ""
var prevMaster: ActorRef = null
var activeMasterUrl: String = ""
var activeMasterWebUiUrl : String = ""
@volatile var registered = false
@volatile var connected = false
val workerId = generateWorkerId()
var sparkHome: File = null
var workDir: File = null
@ -105,6 +117,7 @@ private[spark] class Worker(
}
override def preStart() {
assert(!registered)
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
host, port, cores, Utils.megabytesToString(memory)))
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
@ -113,46 +126,99 @@ private[spark] class Worker(
webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
webUi.start()
connectToMaster()
registerWithMaster()
metricsSystem.registerSource(workerSource)
metricsSystem.start()
}
def connectToMaster() {
logInfo("Connecting to master " + masterUrl)
master = context.actorSelection(Master.toAkkaUrl(masterUrl))
master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress)
def changeMaster(url: String, uiUrl: String) {
masterLock.synchronized {
activeMasterUrl = url
activeMasterWebUiUrl = uiUrl
master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
connected = true
}
}
import context.dispatcher
def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get,
publicAddress)
}
}
def registerWithMaster() {
tryRegisterAllMasters()
var retries = 0
lazy val retryTimer: Cancellable =
context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
retries += 1
if (registered) {
retryTimer.cancel()
} else if (retries >= REGISTRATION_RETRIES) {
logError("All masters are unresponsive! Giving up.")
System.exit(1)
} else {
tryRegisterAllMasters()
}
}
retryTimer // start timer
}
override def receive = {
case RegisteredWorker(url) =>
masterWebUiUrl = url
logInfo("Successfully registered with master")
case RegisteredWorker(masterUrl, masterWebUiUrl) =>
logInfo("Successfully registered with master " + masterUrl)
registered = true
context.watch(sender) // remote death watch for master
//TODO: Is heartbeat really necessary akka does it anyway !
context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) {
master ! Heartbeat(workerId)
prevMaster = sender
changeMaster(masterUrl, masterWebUiUrl)
context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
case SendHeartbeat =>
masterLock.synchronized {
if (connected) { master ! Heartbeat(workerId) }
}
case RegisterWorkerFailed(message) =>
logError("Worker registration failed: " + message)
System.exit(1)
case MasterChanged(masterUrl, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterUrl)
context.unwatch(prevMaster)
prevMaster = sender
changeMaster(masterUrl, masterWebUiUrl)
case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(
appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
memoryUsed += memory_
master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None)
val execs = executors.values.
map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
sender ! WorkerSchedulerStateResponse(workerId, execs.toList)
case RegisterWorkerFailed(message) =>
if (!registered) {
logError("Worker registration failed: " + message)
System.exit(1)
}
case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
if (masterUrl != activeMasterUrl) {
logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.")
} else {
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
self, workerId, host, new File(execSparkHome_), workDir, ExecutorState.RUNNING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
memoryUsed += memory_
masterLock.synchronized {
master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
}
}
case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
masterLock.synchronized {
master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
}
val fullId = appId + "/" + execId
if (ExecutorState.isFinished(state)) {
val executor = executors(fullId)
@ -165,14 +231,18 @@ private[spark] class Worker(
memoryUsed -= executor.memory
}
case KillExecutor(appId, execId) =>
val fullId = appId + "/" + execId
executors.get(fullId) match {
case Some(executor) =>
logInfo("Asked to kill executor " + fullId)
executor.kill()
case None =>
logInfo("Asked to kill unknown executor " + fullId)
case KillExecutor(masterUrl, appId, execId) =>
if (masterUrl != activeMasterUrl) {
logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor " + execId)
} else {
val fullId = appId + "/" + execId
executors.get(fullId) match {
case Some(executor) =>
logInfo("Asked to kill executor " + fullId)
executor.kill()
case None =>
logInfo("Asked to kill unknown executor " + fullId)
}
}
case Terminated(actor_) =>
@ -181,17 +251,14 @@ private[spark] class Worker(
case RequestWorkerState => {
sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
finishedExecutors.values.toList, masterUrl, cores, memory,
coresUsed, memoryUsed, masterWebUiUrl)
finishedExecutors.values.toList, activeMasterUrl, cores, memory,
coresUsed, memoryUsed, activeMasterWebUiUrl)
}
}
def masterDisconnected() {
// TODO: It would be nice to try to reconnect to the master, but just shut down for now.
// (Note that if reconnecting we would also need to assign IDs differently.)
logError("Connection to master failed! Shutting down.")
executors.values.foreach(_.kill())
System.exit(1)
logError("Connection to master failed! Waiting for master to reconnect...")
connected = false
}
def generateWorkerId(): String = {
@ -209,17 +276,18 @@ private[spark] object Worker {
def main(argStrings: Array[String]) {
val args = new WorkerArguments(argStrings)
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
args.memory, args.master, args.workDir)
args.memory, args.masters, args.workDir)
actorSystem.awaitTermination()
}
def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None)
: (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
masterUrl, workDir), name = "Worker")
masterUrls, workDir), name = "Worker")
(actorSystem, boundPort)
}

View file

@ -29,7 +29,7 @@ private[spark] class WorkerArguments(args: Array[String]) {
var webUiPort = 8081
var cores = inferDefaultCores()
var memory = inferDefaultMemory()
var master: String = null
var masters: Array[String] = null
var workDir: String = null
// Check for settings in environment variables
@ -86,14 +86,14 @@ private[spark] class WorkerArguments(args: Array[String]) {
printUsageAndExit(0)
case value :: tail =>
if (master != null) { // Two positional arguments were given
if (masters != null) { // Two positional arguments were given
printUsageAndExit(1)
}
master = value
masters = value.stripPrefix("spark://").split(",").map("spark://" + _)
parse(tail)
case Nil =>
if (master == null) { // No positional argument was given
if (masters == null) { // No positional argument was given
printUsageAndExit(1)
}

View file

@ -108,7 +108,7 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
val logText = <node>{Utils.offsetBytes(path, startByte, endByte)}</node>
val linkToMaster = <p><a href={worker.masterWebUiUrl}>Back to Master</a></p>
val linkToMaster = <p><a href={worker.activeMasterWebUiUrl}>Back to Master</a></p>
val range = <span>Bytes {startByte.toString} - {endByte.toString} of {logLength}</span>

View file

@ -24,23 +24,15 @@ import akka.remote._
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{Utils, AkkaUtils}
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisteredExecutor
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.LaunchTask
import akka.remote.DisassociatedEvent
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutor
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutorFailed
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisteredExecutor
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.LaunchTask
import akka.remote.AssociationErrorEvent
import akka.remote.DisassociatedEvent
import akka.actor.Terminated
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutor
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutorFailed
private[spark] class StandaloneExecutorBackend(
private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
executorId: String,
hostPort: String,
@ -75,15 +67,28 @@ private[spark] class StandaloneExecutorBackend(
case LaunchTask(taskDesc) =>
logInfo("Got assigned task " + taskDesc.taskId)
if (executor == null) {
logError("Received launchTask but executor was null")
logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
}
case KillTask(taskId, _) =>
if (executor == null) {
logError("Received KillTask command but executor was null")
System.exit(1)
} else {
executor.killTask(taskId)
}
case Terminated(actor) =>
logError("Driver terminated or disconnected! Shutting down.")
logError(s"Driver $actor terminated or disconnected! Shutting down.")
System.exit(1)
case StopExecutor =>
logInfo("Driver commanded a shutdown")
context.stop(self)
context.system.shutdown()
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
@ -91,7 +96,7 @@ private[spark] class StandaloneExecutorBackend(
}
}
private[spark] object StandaloneExecutorBackend {
private[spark] object CoarseGrainedExecutorBackend {
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
// Debug code
Utils.checkHost(hostname)
@ -103,15 +108,17 @@ private[spark] object StandaloneExecutorBackend {
val sparkHostPort = hostname + ":" + boundPort
System.setProperty("spark.hostPort", sparkHostPort)
actorSystem.actorOf(
Props(classOf[StandaloneExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
name = "Executor")
actorSystem.awaitTermination()
}
def main(args: Array[String]) {
if (args.length < 4) {
//the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]")
//the reason we allow the last appid argument is to make it easy to kill rogue executors
System.err.println(
"Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
"[<appid>]")
System.exit(1)
}
run(args(0), args(1), args(2), args(3).toInt)

View file

@ -25,9 +25,10 @@ import java.util.concurrent._
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import org.apache.spark.scheduler._
import org.apache.spark._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
import org.apache.spark.util.Utils
/**
@ -36,7 +37,8 @@ import org.apache.spark.util.Utils
private[spark] class Executor(
executorId: String,
slaveHostname: String,
properties: Seq[(String, String)])
properties: Seq[(String, String)],
isLocal: Boolean = false)
extends Logging
{
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
@ -73,46 +75,79 @@ private[spark] class Executor(
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
Thread.setDefaultUncaughtExceptionHandler(
new Thread.UncaughtExceptionHandler {
override def uncaughtException(thread: Thread, exception: Throwable) {
try {
logError("Uncaught exception in thread " + thread, exception)
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
Thread.setDefaultUncaughtExceptionHandler(
new Thread.UncaughtExceptionHandler {
override def uncaughtException(thread: Thread, exception: Throwable) {
try {
logError("Uncaught exception in thread " + thread, exception)
// We may have been called from a shutdown hook. If so, we must not call System.exit().
// (If we do, we will deadlock.)
if (!Utils.inShutdown()) {
if (exception.isInstanceOf[OutOfMemoryError]) {
System.exit(ExecutorExitCode.OOM)
} else {
System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
// We may have been called from a shutdown hook. If so, we must not call System.exit().
// (If we do, we will deadlock.)
if (!Utils.inShutdown()) {
if (exception.isInstanceOf[OutOfMemoryError]) {
System.exit(ExecutorExitCode.OOM)
} else {
System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
}
}
} catch {
case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
} catch {
case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
}
}
)
)
}
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
SparkEnv.set(env)
env.metricsSystem.registerSource(executorSource)
private val env = {
if (!isLocal) {
val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
isDriver = false, isLocal = false)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
_env
} else {
SparkEnv.get
}
}
private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = {
env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
}
// Start worker thread pool
val threadPool = new ThreadPoolExecutor(
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
threadPool.execute(new TaskRunner(context, taskId, serializedTask))
val tr = new TaskRunner(context, taskId, serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}
def killTask(taskId: Long) {
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill()
// We remove the task also in the finally block in TaskRunner.run.
// The reason we need to remove it here is because killTask might be called before the task
// is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
// idempotent.
runningTasks.remove(taskId)
}
}
/** Get the Yarn approved local directories. */
@ -124,56 +159,87 @@ private[spark] class Executor(
.getOrElse(Option(System.getenv("LOCAL_DIRS"))
.getOrElse(""))
if (localDirs.isEmpty()) {
if (localDirs.isEmpty) {
throw new Exception("Yarn Local dirs can't be empty")
}
return localDirs
localDirs
}
class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {
override def run() {
@volatile private var killed = false
@volatile private var task: Task[Any] = _
def kill() {
logInfo("Executor is trying to kill task " + taskId)
killed = true
if (task != null) {
task.kill()
}
}
override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () =>
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var attemptedTask: Option[Task[Any]] = None
var taskStart: Long = 0
def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
val startGCTime = getTotalGCTime
def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
val startGCTime = gcTime
try {
SparkEnv.set(env)
Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
logInfo("Executor killed task " + taskId)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
return
}
attemptedTask = Some(task)
logInfo("Its epoch is " + task.epoch)
logDebug("Task " + taskId +"'s epoch is " + task.epoch)
env.mapOutputTracker.updateEpoch(task.epoch)
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
if (task.killed) {
logInfo("Executor killed task " + taskId)
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
return
}
for (m <- task.metrics) {
m.hostname = Utils.localHostName
m.hostname = Utils.localHostName()
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
m.jvmGCTime = getTotalGCTime - startGCTime
m.jvmGCTime = gcTime - startGCTime
}
//TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
// we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
// just change the relevants bytes in the byte buffer
// TODO I'd also like to track the time it takes to serialize the task results, but that is
// huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
// custom serialized format, we could just change the relevants bytes in the byte buffer
val accumUpdates = Accumulators.values
val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
val serializedDirectResult = ser.serialize(directResult)
logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
val serializedResult = {
if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
logInfo("Storing result for " + taskId + " in local BlockManager")
val blockId = "taskresult_" + taskId
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
ser.serialize(new IndirectTaskResult[Any](blockId))
@ -182,12 +248,13 @@ private[spark] class Executor(
serializedDirectResult
}
}
context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
logInfo("Finished task ID " + taskId)
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}
case t: Throwable => {
@ -195,10 +262,10 @@ private[spark] class Executor(
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
m.executorRunTime = serviceTime
m.jvmGCTime = getTotalGCTime - startGCTime
m.jvmGCTime = gcTime - startGCTime
}
val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
// have left some weird state around depending on when the exception was thrown, but on
@ -206,6 +273,8 @@ private[spark] class Executor(
logError("Exception in task ID " + taskId, t)
//System.exit(1)
}
} finally {
runningTasks.remove(taskId)
}
}
}
@ -215,7 +284,7 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): ExecutorURLClassLoader = {
var loader = this.getClass.getClassLoader
val loader = this.getClass.getClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
@ -237,7 +306,7 @@ private[spark] class Executor(
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
return constructor.newInstance(classUri, parent)
constructor.newInstance(classUri, parent)
} catch {
case _: ClassNotFoundException =>
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
@ -245,7 +314,7 @@ private[spark] class Executor(
null
}
} else {
return parent
parent
}
}

View file

@ -18,14 +18,18 @@
package org.apache.spark.executor
import java.nio.ByteBuffer
import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _}
import org.apache.spark.TaskState.TaskState
import com.google.protobuf.ByteString
import org.apache.spark.{Logging}
import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
import org.apache.spark.Logging
import org.apache.spark.TaskState
import org.apache.spark.TaskState.TaskState
import org.apache.spark.util.Utils
private[spark] class MesosExecutorBackend
extends MesosExecutor
with ExecutorBackend
@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend
}
override def killTask(d: ExecutorDriver, t: TaskID) {
logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)")
if (executor == null) {
logError("Received KillTask but executor was null")
} else {
executor.killTask(t.getValue.toLong)
}
}
override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {}

View file

@ -102,4 +102,9 @@ class ShuffleWriteMetrics extends Serializable {
* Number of bytes written for a shuffle
*/
var shuffleBytesWritten: Long = _
/**
* Time spent blocking on writes to disk or buffer cache, in nanoseconds.
*/
var shuffleWriteTime: Long = _
}

View file

@ -79,7 +79,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
implicit val futureExecContext = ExecutionContext.fromExecutor(
Utils.newDaemonCachedThreadPool("Connection manager future execution context"))
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null

View file

@ -20,17 +20,18 @@ package org.apache.spark.network.netty
import io.netty.buffer._
import org.apache.spark.Logging
import org.apache.spark.storage.{TestBlockId, BlockId}
private[spark] class FileHeader (
val fileLen: Int,
val blockId: String) extends Logging {
val blockId: BlockId) extends Logging {
lazy val buffer = {
val buf = Unpooled.buffer()
buf.capacity(FileHeader.HEADER_SIZE)
buf.writeInt(fileLen)
buf.writeInt(blockId.length)
blockId.foreach((x: Char) => buf.writeByte(x))
buf.writeInt(blockId.name.length)
blockId.name.foreach((x: Char) => buf.writeByte(x))
//padding the rest of header
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
@ -57,18 +58,15 @@ private[spark] object FileHeader {
for (i <- 1 to idLength) {
idBuilder += buf.readByte().asInstanceOf[Char]
}
val blockId = idBuilder.toString()
val blockId = BlockId(idBuilder.toString())
new FileHeader(length, blockId)
}
def main (args:Array[String]){
val header = new FileHeader(25,"block_0");
val buf = header.buffer;
val newheader = FileHeader.create(buf);
System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
def main (args:Array[String]) {
val header = new FileHeader(25, TestBlockId("my_block"))
val buf = header.buffer
val newHeader = FileHeader.create(buf)
System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen)
}
}

View file

@ -27,12 +27,13 @@ import org.apache.spark.Logging
import org.apache.spark.network.ConnectionManagerId
import scala.collection.JavaConverters._
import org.apache.spark.storage.BlockId
private[spark] class ShuffleCopier extends Logging {
def getBlock(host: String, port: Int, blockId: String,
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
def getBlock(host: String, port: Int, blockId: BlockId,
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging {
try {
fc.init()
fc.connect(host, port)
fc.sendRequest(blockId)
fc.sendRequest(blockId.name)
fc.waitForClose()
fc.close()
} catch {
@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging {
}
}
def getBlock(cmId: ConnectionManagerId, blockId: String,
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
def getBlock(cmId: ConnectionManagerId, blockId: BlockId,
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
}
def getBlocks(cmId: ConnectionManagerId,
blocks: Seq[(String, Long)],
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
blocks: Seq[(BlockId, Long)],
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
for ((blockId, size) <- blocks) {
getBlock(cmId, blockId, resultCollectCallback)
@ -71,7 +72,7 @@ private[spark] class ShuffleCopier extends Logging {
private[spark] object ShuffleCopier extends Logging {
private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit)
extends FileClientHandler with Logging {
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
@ -79,14 +80,14 @@ private[spark] object ShuffleCopier extends Logging {
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
override def handleError(blockId: String) {
override def handleError(blockId: BlockId) {
if (!isComplete) {
resultCollectCallBack(blockId, -1, null)
}
}
}
def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) {
if (size != -1) {
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
}
@ -99,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging {
}
val host = args(0)
val port = args(1).toInt
val file = args(2)
val blockId = BlockId(args(2))
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
@ -107,12 +108,12 @@ private[spark] object ShuffleCopier extends Logging {
Executors.callable(new Runnable() {
def run() {
val copier = new ShuffleCopier()
copier.getBlock(host, port, file, echoResultCollectCallBack)
copier.getBlock(host, port, blockId, echoResultCollectCallBack)
}
})
}).asJava
copiers.invokeAll(tasks)
copiers.shutdown
copiers.shutdown()
System.exit(0)
}
}

View file

@ -21,6 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
import org.apache.spark.storage.{BlockId, FileSegment}
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@ -53,8 +54,8 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
override def getAbsolutePath(blockId: String): String = {
if (!blockId.startsWith("shuffle_")) {
override def getBlockLocation(blockId: BlockId): FileSegment = {
if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
// Figure out which local directory it hashes to, and which subdirectory in that
@ -62,8 +63,8 @@ private[spark] object ShuffleSender {
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId)
return file.getAbsolutePath
val file = new File(subDir, blockId.name)
return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)

View file

@ -15,6 +15,8 @@
* limitations under the License.
*/
package org.apache
/**
* Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to
* Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection,

View file

@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rdd
import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
import scala.reflect.ClassTag
import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
/**
* A set of asynchronous RDD actions available through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
/**
* Returns a future for counting the number of elements in the RDD.
*/
def countAsync(): FutureAction[Long] = {
val totalCount = new AtomicLong
self.context.submitJob(
self,
(iter: Iterator[T]) => {
var result = 0L
while (iter.hasNext) {
result += 1L
iter.next()
}
result
},
Range(0, self.partitions.size),
(index: Int, data: Long) => totalCount.addAndGet(data),
totalCount.get())
}
/**
* Returns a future for retrieving all elements of this RDD.
*/
def collectAsync(): FutureAction[Seq[T]] = {
val results = new Array[Array[T]](self.partitions.size)
self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size),
(index, data) => results(index) = data, results.flatten.toSeq)
}
/**
* Returns a future for retrieving the first num elements of the RDD.
*/
def takeAsync(num: Int): FutureAction[Seq[T]] = {
val f = new ComplexFutureAction[Seq[T]]
f.run {
val results = new ArrayBuffer[T](num)
val totalParts = self.partitions.length
var partsScanned = 0
while (results.size < num && partsScanned < totalParts) {
// The number of partitions to try in this iteration. It is ok for this number to be
// greater than totalParts because we actually cap it at totalParts in runJob.
var numPartsToTry = 1
if (partsScanned > 0) {
// If we didn't find any rows after the first iteration, just try all partitions next.
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
// by 50%.
if (results.size == 0) {
numPartsToTry = totalParts - 1
} else {
numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
}
}
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
val left = num - results.size
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
val buf = new Array[Array[T]](p.size)
f.runJob(self,
(it: Iterator[T]) => it.take(left).toArray,
p,
(index: Int, data: Array[T]) => buf(index) = data,
Unit)
buf.foreach(results ++= _.take(num - results.size))
partsScanned += numPartsToTry
}
results.toSeq
}
f
}
/**
* Applies a function f to all elements of this RDD.
*/
def foreachAsync(f: T => Unit): FutureAction[Unit] = {
self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size),
(index, data) => Unit, Unit)
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = {
self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size),
(index, data) => Unit, Unit)
}
}

View file

@ -17,16 +17,17 @@
package org.apache.spark.rdd
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
import org.apache.spark.storage.BlockManager
import scala.reflect.ClassTag
private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
import org.apache.spark.storage.{BlockId, BlockManager}
private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition {
val index = idx
}
private[spark]
class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[String])
class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)

View file

@ -19,6 +19,7 @@ package org.apache.spark.rdd
import scala.reflect.ClassTag
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.{NullWritable, BytesWritable}
@ -84,9 +85,9 @@ private[spark] object CheckpointRDD extends Logging {
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
val env = SparkEnv.get
val outputDir = new Path(path)
val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration())
val finalOutputName = splitIdToFile(ctx.splitId)
val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
@ -123,7 +124,7 @@ private[spark] object CheckpointRDD extends Logging {
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
val env = SparkEnv.get
val fs = path.getFileSystem(env.hadoop.newConfiguration())
val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
val fileInputStream = fs.open(path, bufferSize)
val serializer = env.serializer.newInstance()
@ -146,7 +147,7 @@ private[spark] object CheckpointRDD extends Logging {
val sc = new SparkContext(cluster, "CheckpointRDD Test")
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
val path = new Path(hdfsPath, "temp")
val fs = path.getFileSystem(env.hadoop.newConfiguration())
val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")

View file

@ -18,13 +18,12 @@
package org.apache.spark.rdd
import java.io.{ObjectOutputStream, IOException}
import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.util.AppendOnlyMap
private[spark] sealed trait CoGroupSplitDep extends Serializable
@ -105,17 +104,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val split = s.asInstanceOf[CoGroupPartition]
val numRdds = split.deps.size
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
val seq = map.get(k)
if (seq != null) {
seq
} else {
val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
map.put(k, seq)
seq
}
val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any])
}
val getSeq = (k: K) => {
map.changeValue(k, update)
}
val ser = SparkEnv.get.serializerManager.get(serializerClass)
@ -129,12 +125,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
kv => getSeq(kv._1)(depNum) += kv._2
}
}
}
JavaConversions.mapAsScalaMap(map).iterator
new InterruptibleIterator(context, map.iterator)
}
override def clearDependencies() {

View file

@ -27,54 +27,19 @@ import org.apache.hadoop.mapred.RecordReader
import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv,
TaskContext}
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.util.NextIterator
import org.apache.hadoop.conf.{Configuration, Configurable}
/**
* An RDD that reads a file (or multiple files) from Hadoop (e.g. files in HDFS, the local file
* system, or S3).
* This accepts a general, broadcasted Hadoop Configuration because those tend to remain the same
* across multiple reads; the 'path' is the only variable that is different across new JobConfs
* created from the Configuration.
*/
class HadoopFileRDD[K, V](
sc: SparkContext,
path: String,
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
minSplits: Int)
extends HadoopRDD[K, V](sc, broadcastedConf, inputFormatClass, keyClass, valueClass, minSplits) {
override def getJobConf(): JobConf = {
if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
// getJobConf() has been called previously, so there is already a local cache of the JobConf
// needed by this RDD.
return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
} else {
// Create a new JobConf, set the input file/directory paths to read from, and cache the
// JobConf (i.e., in a shared hash map in the slave's JVM process that's accessible through
// HadoopRDD.putCachedMetadata()), so that we only create one copy across multiple
// getJobConf() calls for this RDD in the local process.
// The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
val newJobConf = new JobConf(broadcastedConf.value.value)
FileInputFormat.setInputPaths(newJobConf, path)
HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
return newJobConf
}
}
}
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSplit)
extends Partition {
val inputSplit = new SerializableWritable[InputSplit](s)
override def hashCode(): Int = (41 * (41 + rddId) + idx).toInt
@ -83,11 +48,24 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
}
/**
* A base class that provides core functionality for reading data partitions stored in Hadoop.
* An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
* sources in HBase, or S3).
*
* @param sc The SparkContext to associate the RDD with.
* @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
* variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job.
* Otherwise, a new JobConf will be created on each slave using the enclosed Configuration.
* @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD
* creates.
* @param inputFormatClass Storage format of the data to be read.
* @param keyClass Class of the key associated with the inputFormatClass.
* @param valueClass Class of the value associated with the inputFormatClass.
* @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate.
*/
class HadoopRDD[K, V](
sc: SparkContext,
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
initLocalJobConfFuncOpt: Option[JobConf => Unit],
inputFormatClass: Class[_ <: InputFormat[K, V]],
keyClass: Class[K],
valueClass: Class[V],
@ -105,6 +83,7 @@ class HadoopRDD[K, V](
sc,
sc.broadcast(new SerializableWritable(conf))
.asInstanceOf[Broadcast[SerializableWritable[Configuration]]],
None /* initLocalJobConfFuncOpt */,
inputFormatClass,
keyClass,
valueClass,
@ -130,6 +109,7 @@ class HadoopRDD[K, V](
// local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
// The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
val newJobConf = new JobConf(broadcastedConf.value.value)
initLocalJobConfFuncOpt.map(f => f(newJobConf))
HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
return newJobConf
}
@ -152,6 +132,8 @@ class HadoopRDD[K, V](
override def getPartitions: Array[Partition] = {
val jobConf = getJobConf()
// add the credentials here as this can be called before SparkContext initialized
SparkHadoopUtil.get.addCredentials(jobConf)
val inputFormat = getInputFormat(jobConf)
if (inputFormat.isInstanceOf[Configurable]) {
inputFormat.asInstanceOf[Configurable].setConf(jobConf)
@ -164,38 +146,41 @@ class HadoopRDD[K, V](
array
}
override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
var reader: RecordReader[K, V] = null
override def compute(theSplit: Partition, context: TaskContext) = {
val iter = new NextIterator[(K, V)] {
val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
var reader: RecordReader[K, V] = null
val jobConf = getJobConf()
val inputFormat = getInputFormat(jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
val jobConf = getJobConf()
val inputFormat = getInputFormat(jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback{ () => closeIfNeeded() }
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback{ () => closeIfNeeded() }
val key: K = reader.createKey()
val value: V = reader.createValue()
val key: K = reader.createKey()
val value: V = reader.createValue()
override def getNext() = {
try {
finished = !reader.next(key, value)
} catch {
case eof: EOFException =>
finished = true
override def getNext() = {
try {
finished = !reader.next(key, value)
} catch {
case eof: EOFException =>
finished = true
}
(key, value)
}
(key, value)
}
override def close() {
try {
reader.close()
} catch {
case e: Exception => logWarning("Exception in RecordReader.close()", e)
override def close() {
try {
reader.close()
} catch {
case e: Exception => logWarning("Exception in RecordReader.close()", e)
}
}
}
new InterruptibleIterator[(K, V)](context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {
@ -216,10 +201,10 @@ private[spark] object HadoopRDD {
* The three methods below are helpers for accessing the local map, a property of the SparkEnv of
* the local process.
*/
def getCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.get(key)
def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key)
def containsCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.containsKey(key)
def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key)
def putCachedMetadata(key: String, value: Any) =
SparkEnv.get.hadoop.hadoopJobMetadata.put(key, value)
SparkEnv.get.hadoopJobMetadata.put(key, value)
}

View file

@ -22,14 +22,14 @@ import scala.reflect.ClassTag
/**
* A variant of the MapPartitionsRDD that passes the partition index into the
* closure. This can be used to generate or collect partition specific
* information such as the number of tuples in a partition.
* A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
* TaskContext, the closure can either get access to the interruptible flag or get the index
* of the partition in the RDD.
*/
private[spark]
class MapPartitionsWithIndexRDD[U: ClassTag, T: ClassTag](
class MapPartitionsWithContextRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
f: (Int, Iterator[T]) => Iterator[U],
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean
) extends RDD[U](prev) {
@ -38,5 +38,5 @@ class MapPartitionsWithIndexRDD[U: ClassTag, T: ClassTag](
override val partitioner = if (preservesPartitioning) prev.partitioner else None
override def compute(split: Partition, context: TaskContext) =
f(split.index, firstParent[T].iterator(split, context))
f(context, firstParent[T].iterator(split, context))
}

View file

@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
private[spark]
@ -71,49 +71,52 @@ class NewHadoopRDD[K, V](
result
}
override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
if (format.isInstanceOf[Configurable]) {
format.asInstanceOf[Configurable].setConf(conf)
}
val reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback(() => close())
var havePair = false
var finished = false
override def hasNext: Boolean = {
if (!finished && !havePair) {
finished = !reader.nextKeyValue
havePair = !finished
override def compute(theSplit: Partition, context: TaskContext) = {
val iter = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
if (format.isInstanceOf[Configurable]) {
format.asInstanceOf[Configurable].setConf(conf)
}
!finished
}
val reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
override def next: (K, V) = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback(() => close())
var havePair = false
var finished = false
override def hasNext: Boolean = {
if (!finished && !havePair) {
finished = !reader.nextKeyValue
havePair = !finished
}
!finished
}
havePair = false
return (reader.getCurrentKey, reader.getCurrentValue)
}
private def close() {
try {
reader.close()
} catch {
case e: Exception => logWarning("Exception in RecordReader.close()", e)
override def next(): (K, V) = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
(reader.getCurrentKey, reader.getCurrentValue)
}
private def close() {
try {
reader.close()
} catch {
case e: Exception => logWarning("Exception in RecordReader.close()", e)
}
}
}
new InterruptibleIterator(context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {

View file

@ -85,18 +85,24 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
}
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
self.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
}, preservesPartitioning = true)
} else if (mapSideCombine) {
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
.setSerializer(serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
partitioned.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter))
}, preservesPartitioning = true)
} else {
// Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
values.mapPartitionsWithContext((context, iter) => {
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
}, preservesPartitioning = true)
}
}
@ -565,7 +571,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" <split #> <attempt # = spark task #> */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber)
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
val format = outputFormatClass.newInstance
val committer = format.getOutputCommitter(hadoopContext)
@ -664,7 +670,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
writer.setup(context.stageId, context.splitId, attemptNumber)
writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
var count = 0

View file

@ -96,8 +96,9 @@ private[spark] class ParallelCollectionRDD[T: ClassTag](
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
override def compute(s: Partition, context: TaskContext) =
s.asInstanceOf[ParallelCollectionPartition[T]].iterator
override def compute(s: Partition, context: TaskContext) = {
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
}
override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)

View file

@ -268,6 +268,19 @@ abstract class RDD[T: ClassTag](
def distinct(): RDD[T] = distinct(partitions.size)
/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): RDD[T] = {
coalesce(numPartitions, true)
}
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*
@ -421,26 +434,39 @@ abstract class RDD[T: ClassTag](
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
}
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] =
preservesPartitioning: Boolean = false): RDD[U] = {
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
}
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
}
/**
* Return a new RDD by applying a function to each partition of this RDD. This is a variant of
* mapPartitions that also passes the TaskContext into the closure.
*/
def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
}
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
@ -448,22 +474,23 @@ abstract class RDD[T: ClassTag](
*/
@deprecated("use mapPartitionsWithIndex", "0.7.0")
def mapPartitionsWithSplit[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] =
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
mapPartitionsWithIndex(f, preservesPartitioning)
}
/**
* Maps f over this RDD, where f takes an additional parameter of type A. This
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def mapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false)
(f:(T, A) => U): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val a = constructA(index)
iter.map(t => f(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
def mapWith[A: ClassTag, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = {
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
val a = constructA(context.partitionId)
iter.map(t => f(t, a))
}
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@ -471,13 +498,14 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def flatMapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false)
(f:(T, A) => Seq[U]): RDD[U] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
val a = constructA(index)
iter.flatMap(t => f(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
def flatMapWith[A: ClassTag, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = {
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
val a = constructA(context.partitionId)
iter.flatMap(t => f(t, a))
}
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
}
/**
@ -485,13 +513,12 @@ abstract class RDD[T: ClassTag](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def foreachWith[A: ClassTag](constructA: Int => A)
(f:(T, A) => Unit) {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val a = constructA(index)
iter.map(t => {f(t, a); t})
}
(new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
val a = constructA(context.partitionId)
iter.map(t => {f(t, a); t})
}
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
}
/**
@ -499,13 +526,12 @@ abstract class RDD[T: ClassTag](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
def filterWith[A: ClassTag](constructA: Int => A)
(p:(T, A) => Boolean): RDD[T] = {
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
val a = constructA(index)
iter.filter(t => p(t, a))
}
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
val a = constructA(context.partitionId)
iter.filter(t => p(t, a))
}
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
}
/**
@ -544,16 +570,14 @@ abstract class RDD[T: ClassTag](
* Applies a function f to all elements of this RDD.
*/
def foreach(f: T => Unit) {
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
}
/**
* Applies a function f to each partition of this RDD.
*/
def foreachPartition(f: Iterator[T] => Unit) {
val cleanF = sc.clean(f)
sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
sc.runJob(this, (iter: Iterator[T]) => f(iter))
}
/**
@ -678,6 +702,8 @@ abstract class RDD[T: ClassTag](
*/
def count(): Long = {
sc.runJob(this, (iter: Iterator[T]) => {
// Use a while loop to count the number of elements rather than iter.size because
// iter.size uses a for loop, which is slightly slower in current version of Scala.
var result = 0L
while (iter.hasNext) {
result += 1L

View file

@ -57,7 +57,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics,
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
SparkEnv.get.serializerManager.get(serializerClass))
}

View file

@ -111,7 +111,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
}
case ShuffleCoGroupSplitDep(shuffleId) => {
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
context.taskMetrics, serializer)
context, serializer)
iter.foreach(op)
}
}

View file

@ -19,18 +19,21 @@ package org.apache.spark.scheduler
import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.reflect.ClassTag
import akka.actor._
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.storage.{BlockManager, BlockManagerMaster}
import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
@ -42,76 +45,108 @@ import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
* locations to run each task on, based on the current cache status, and passes these to the
* low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being
* lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are
* not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task
* not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
* a small number of times before cancelling the whole stage.
*
* THREADING: This class runs all its logic in a single thread executing the run() method, to which
* events are submitted using a synchonized queue (eventQueue). The public API methods, such as
* events are submitted using a synchronized queue (eventQueue). The public API methods, such as
* runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods
* should be private.
*/
private[spark]
class DAGScheduler(
taskSched: TaskScheduler,
mapOutputTracker: MapOutputTracker,
mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends TaskSchedulerListener with Logging {
extends Logging {
def this(taskSched: TaskScheduler) {
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
SparkEnv.get.blockManager.master, SparkEnv.get)
}
taskSched.setListener(this)
taskSched.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventQueue.put(BeginEvent(task, taskInfo))
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventProcessActor ! BeginEvent(task, taskInfo)
}
// Called to report that a task has completed and results are being fetched remotely.
def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
eventProcessActor ! GettingResultEvent(task, taskInfo)
}
// Called by TaskScheduler to report task completions or failures.
override def taskEnded(
def taskEnded(
task: Task[_],
reason: TaskEndReason,
result: Any,
accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
}
// Called by TaskScheduler when an executor fails.
override def executorLost(execId: String) {
eventQueue.put(ExecutorLost(execId))
def executorLost(execId: String) {
eventProcessActor ! ExecutorLost(execId)
}
// Called by TaskScheduler when a host is added
override def executorGained(execId: String, host: String) {
eventQueue.put(ExecutorGained(execId, host))
def executorGained(execId: String, host: String) {
eventProcessActor ! ExecutorGained(execId, host)
}
// Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
override def taskSetFailed(taskSet: TaskSet, reason: String) {
eventQueue.put(TaskSetFailed(taskSet, reason))
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
def taskSetFailed(taskSet: TaskSet, reason: String) {
eventProcessActor ! TaskSetFailed(taskSet, reason)
}
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
val RESUBMIT_TIMEOUT = 50L
val RESUBMIT_TIMEOUT = 50.milliseconds
// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
val POLL_TIMEOUT = 10L
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor {
override def preStart() {
context.system.scheduler.schedule(RESUBMIT_TIMEOUT, RESUBMIT_TIMEOUT) {
if (failed.size > 0) {
resubmitFailedStages()
}
}
}
val nextJobId = new AtomicInteger(0)
/**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
*/
def receive = {
case event: DAGSchedulerEvent =>
logDebug("Got event of type " + event.getClass.getName)
val nextStageId = new AtomicInteger(0)
if (!processEvent(event))
submitWaitingStages()
else
context.stop(self)
}
}))
val stageIdToStage = new TimeStampedHashMap[Int, Stage]
private[scheduler] val nextJobId = new AtomicInteger(0)
val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
def numTotalJobs: Int = nextJobId.get()
private val nextStageId = new AtomicInteger(0)
private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
@ -128,6 +163,7 @@ class DAGScheduler(
// stray messages to detect.
val failedEpoch = new HashMap[String, Long]
// stage id to the active job
val idToActiveJob = new HashMap[Int, ActiveJob]
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
@ -139,17 +175,7 @@ class DAGScheduler(
val activeJobs = new HashSet[ActiveJob]
val resultStageToJob = new HashMap[Stage, ActiveJob]
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
// Start a thread to run the DAGScheduler event loop
def start() {
new Thread("DAGScheduler") {
setDaemon(true)
override def run() {
DAGScheduler.this.run()
}
}.start()
}
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup)
def addSparkListener(listener: SparkListener) {
listenerBus.addListener(listener)
@ -157,7 +183,7 @@ class DAGScheduler(
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
cacheLocs(rdd.id) = blockIds.map { id =>
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
@ -179,7 +205,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
@ -192,6 +218,7 @@ class DAGScheduler(
*/
private def newStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int,
callSite: Option[String] = None)
@ -204,9 +231,10 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
val stage =
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
stageToInfos(stage) = StageInfo(stage)
stageToInfos(stage) = new StageInfo(stage)
stage
}
@ -262,54 +290,48 @@ class DAGScheduler(
}
/**
* Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
* JobWaiter whose getResult() method will return the result of the job when it is complete.
*
* The job is assumed to have at least one partition; zero partition jobs should be handled
* without a JobSubmitted event.
* Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
* can be used to block until the the job finishes executing or can be used to cancel the job.
*/
private[scheduler] def prepareJob[T, U: ClassTag](
finalRdd: RDD[T],
def submitJob[T, U](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
properties: Properties = null)
: (JobSubmitted, JobWaiter[U]) =
properties: Properties = null): JobWaiter[U] =
{
assert(partitions.size > 0)
val waiter = new JobWaiter(partitions.size, resultHandler)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
properties)
(toSubmit, waiter)
}
def runJob[T, U: ClassTag](
finalRdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
properties: Properties = null)
{
if (partitions.size == 0) {
return
}
// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = finalRdd.partitions.length
val maxPartitions = rdd.partitions.length
partitions.find(p => p >= maxPartitions).foreach { p =>
throw new IllegalArgumentException(
"Attempting to access a non-existent partition: " + p + ". " +
"Total number of partitions: " + maxPartitions)
"Total number of partitions: " + maxPartitions)
}
val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
eventQueue.put(toSubmit)
val jobId = nextJobId.getAndIncrement()
if (partitions.size == 0) {
return new JobWaiter[U](this, jobId, 0, resultHandler)
}
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
waiter
}
def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
callSite: String,
allowLocal: Boolean,
resultHandler: (Int, U) => Unit,
properties: Properties = null)
{
val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
waiter.awaitResult() match {
case JobSucceeded => {}
case JobFailed(exception: Exception, _) =>
@ -330,19 +352,39 @@ class DAGScheduler(
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
val jobId = nextJobId.getAndIncrement()
eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
listener.awaitResult() // Will throw an exception if the job fails
}
/**
* Cancel a job that is running or waiting in the queue.
*/
def cancelJob(jobId: Int) {
logInfo("Asked to cancel job " + jobId)
eventProcessActor ! JobCancelled(jobId)
}
def cancelJobGroup(groupId: String) {
logInfo("Asked to cancel job group " + groupId)
eventProcessActor ! JobGroupCancelled(groupId)
}
/**
* Cancel all jobs that are running or waiting in the queue.
*/
def cancelAllJobs() {
eventProcessActor ! AllJobsCancelled
}
/**
* Process one event retrieved from the event queue.
* Returns true if we should stop the event loop.
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
val jobId = nextJobId.getAndIncrement()
val finalStage = newStage(finalRDD, None, jobId, Some(callSite))
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@ -361,18 +403,43 @@ class DAGScheduler(
submitStage(finalStage)
}
case JobCancelled(jobId) =>
// Cancel a job: find all the running stages that are linked to this job, and cancel them.
running.filter(_.jobId == jobId).foreach { stage =>
taskSched.cancelTasks(stage.id)
}
case JobGroupCancelled(groupId) =>
// Cancel all jobs belonging to this job group.
// First finds all active jobs with this group id, and then kill stages for them.
val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
.map(_.jobId)
if (!jobIds.isEmpty) {
running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
taskSched.cancelTasks(stage.id)
}
}
case AllJobsCancelled =>
// Cancel all running jobs.
running.foreach { stage =>
taskSched.cancelTasks(stage.id)
}
case ExecutorGained(execId, host) =>
handleExecutorGained(execId, host)
case ExecutorLost(execId) =>
handleExecutorLost(execId)
case begin: BeginEvent =>
listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
case BeginEvent(task, taskInfo) =>
listenerBus.post(SparkListenerTaskStart(task, taskInfo))
case completion: CompletionEvent =>
listenerBus.post(SparkListenerTaskEnd(
completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
case GettingResultEvent(task, taskInfo) =>
listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo))
case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
@ -422,42 +489,6 @@ class DAGScheduler(
}
}
/**
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
* events and responds by launching tasks. This runs in a dedicated thread and receives events
* via the eventQueue.
*/
private def run() {
SparkEnv.set(env)
while (true) {
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
if (event != null) {
logDebug("Got event of type " + event.getClass.getName)
}
this.synchronized { // needed in case other threads makes calls into methods of this class
if (event != null) {
if (processEvent(event)) {
return
}
}
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
// Periodically resubmit failed stages if some map output fetches have failed and we have
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
// the same time, so we want to make sure we've identified all the reduce tasks that depend
// on the failed node.
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
resubmitFailedStages()
} else {
submitWaitingStages()
}
}
}
}
/**
* Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
@ -542,7 +573,7 @@ class DAGScheduler(
// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
@ -563,9 +594,7 @@ class DAGScheduler(
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
if (!stage.submissionTime.isDefined) {
stage.submissionTime = Some(System.currentTimeMillis())
}
stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@ -579,15 +608,20 @@ class DAGScheduler(
*/
private def handleTaskCompletion(event: CompletionEvent) {
val task = event.task
if (!stageIdToStage.contains(task.stageId)) {
// Skip all the actions if the stage has been cancelled.
return
}
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
val serviceTime = stage.submissionTime match {
val serviceTime = stageToInfos(stage).submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
case _ => "Unkown"
case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stage.completionTime = Some(System.currentTimeMillis)
stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(StageCompleted(stageToInfos(stage)))
running -= stage
}
@ -627,7 +661,7 @@ class DAGScheduler(
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
} else {
stage.addOutputLoc(smt.partition, status)
stage.addOutputLoc(smt.partitionId, status)
}
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
markStageAsFinished(stage)
@ -753,14 +787,14 @@ class DAGScheduler(
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
failedStage.completionTime = Some(System.currentTimeMillis())
stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
val error = new SparkException("Job failed: " + reason)
val error = new SparkException("Job aborted: " + reason)
job.listener.jobFailed(error)
listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
idToActiveJob -= resultStage.jobId
@ -823,7 +857,7 @@ class DAGScheduler(
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
// that has any placement preferences. Ideally we would choose based on transfer sizes,
// but this will do for now.
rdd.dependencies.foreach(_ match {
rdd.dependencies.foreach {
case n: NarrowDependency[_] =>
for (inPart <- n.getParents(partition)) {
val locs = getPreferredLocs(n.rdd, inPart)
@ -831,7 +865,7 @@ class DAGScheduler(
return locs
}
case _ =>
})
}
Nil
}
@ -854,7 +888,7 @@ class DAGScheduler(
}
def stop() {
eventQueue.put(StopDAGScheduler)
eventProcessActor ! StopDAGScheduler
metadataCleaner.cancel()
taskSched.stop()
}

View file

@ -31,9 +31,10 @@ import org.apache.spark.executor.TaskMetrics
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
* This greatly simplifies synchronization.
*/
private[spark] sealed trait DAGSchedulerEvent
private[scheduler] sealed trait DAGSchedulerEvent
private[spark] case class JobSubmitted(
private[scheduler] case class JobSubmitted(
jobId: Int,
finalRDD: RDD[_],
func: (TaskContext, Iterator[_]) => _,
partitions: Array[Int],
@ -43,9 +44,19 @@ private[spark] case class JobSubmitted(
properties: Properties = null)
extends DAGSchedulerEvent
private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
private[spark] case class CompletionEvent(
private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
private[scheduler]
case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,
result: Any,
@ -54,10 +65,12 @@ private[spark] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
private[scheduler]
case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
private[scheduler]
case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent

View file

@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar
})
metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] {
override def getValue: Int = dagScheduler.nextJobId.get()
override def getValue: Int = dagScheduler.numTotalJobs
})
metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] {

View file

@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import org.apache.spark.{Logging, SparkEnv}
import org.apache.spark.deploy.SparkHadoopUtil
import scala.collection.immutable.Set
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
import org.apache.hadoop.security.UserGroupInformation
@ -87,9 +88,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
val env = SparkEnv.get
val conf = new JobConf(configuration)
env.hadoop.addCredentials(conf)
SparkHadoopUtil.get.addCredentials(conf)
FileInputFormat.setInputPaths(conf, path)
val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
@ -108,9 +108,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
// This method does not expect failures, since validate has already passed ...
private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
val env = SparkEnv.get
val jobConf = new JobConf(configuration)
env.hadoop.addCredentials(jobConf)
SparkHadoopUtil.get.addCredentials(jobConf)
FileInputFormat.setInputPaths(jobConf, path)
val instance: org.apache.hadoop.mapred.InputFormat[_, _] =

View file

@ -1,292 +1,384 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler
import java.io.PrintWriter
import java.io.File
import java.io.FileNotFoundException
import java.text.SimpleDateFormat
import java.util.{Date, Properties}
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{Map, HashMap, ListBuffer}
import scala.io.Source
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
// Used to record runtime information for each job, including RDD graph
// tasks' start/stop shuffle information and information from outside
class JobLogger(val logDirName: String) extends SparkListener with Logging {
private val logDir =
if (System.getenv("SPARK_LOG_DIR") != null)
System.getenv("SPARK_LOG_DIR")
else
"/tmp/spark"
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
private val stageIDToJobID = new HashMap[Int, Int]
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
createLogDir()
def this() = this(String.valueOf(System.currentTimeMillis()))
def getLogDir = logDir
def getJobIDtoPrintWriter = jobIDToPrintWriter
def getStageIDToJobID = stageIDToJobID
def getJobIDToStages = jobIDToStages
def getEventQueue = eventQueue
// Create a folder for log files, the folder's name is the creation time of the jobLogger
protected def createLogDir() {
val dir = new File(logDir + "/" + logDirName + "/")
if (dir.exists()) {
return
}
if (dir.mkdirs() == false) {
logError("create log directory error:" + logDir + "/" + logDirName + "/")
}
}
// Create a log file for one job, the file name is the jobID
protected def createLogWriter(jobID: Int) {
try{
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
jobIDToPrintWriter += (jobID -> fileWriter)
} catch {
case e: FileNotFoundException => e.printStackTrace()
}
}
// Close log file, and clean the stage relationship in stageIDToJobID
protected def closeLogWriter(jobID: Int) =
jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
fileWriter.close()
jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
stageIDToJobID -= stage.id
})
jobIDToPrintWriter -= jobID
jobIDToStages -= jobID
}
// Write log information to log file, withTime parameter controls whether to recored
// time stamp for the information
protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
var writeInfo = info
if (withTime) {
val date = new Date(System.currentTimeMillis())
writeInfo = DATE_FORMAT.format(date) + ": " +info
}
jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
}
protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
protected def buildJobDep(jobID: Int, stage: Stage) {
if (stage.jobId == jobID) {
jobIDToStages.get(jobID) match {
case Some(stageList) => stageList += stage
case None => val stageList = new ListBuffer[Stage]
stageList += stage
jobIDToStages += (jobID -> stageList)
}
stageIDToJobID += (stage.id -> jobID)
stage.parents.foreach(buildJobDep(jobID, _))
}
}
protected def recordStageDep(jobID: Int) {
def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
var rddList = new ListBuffer[RDD[_]]
rddList += rdd
rdd.dependencies.foreach{ dep => dep match {
case shufDep: ShuffleDependency[_,_] =>
case _ => rddList ++= getRddsInStage(dep.rdd)
}
}
rddList
}
jobIDToStages.get(jobID).foreach {_.foreach { stage =>
var depRddDesc: String = ""
getRddsInStage(stage.rdd).foreach { rdd =>
depRddDesc += rdd.id + ","
}
var depStageDesc: String = ""
stage.parents.foreach { stage =>
depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
}
jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
" STAGE_DEP=" + depStageDesc, false)
}
}
}
// Generate indents and convert to String
protected def indentString(indent: Int) = {
val sb = new StringBuilder()
for (i <- 1 to indent) {
sb.append(" ")
}
sb.toString()
}
protected def getRddName(rdd: RDD[_]) = {
var rddName = rdd.getClass.getName
if (rdd.name != null) {
rddName = rdd.name
}
rddName
}
protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
rdd.dependencies.foreach{ dep => dep match {
case shufDep: ShuffleDependency[_,_] =>
val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
}
}
protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
var stageInfo: String = ""
if (stage.isShuffleMap) {
stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
stage.shuffleDep.get.shuffleId
}else{
stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
}
if (stage.jobId == jobID) {
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
recordRddInStageGraph(jobID, stage.rdd, indent)
stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
} else
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
}
// Record task metrics into job log files
protected def recordTaskMetrics(stageID: Int, status: String,
taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
" START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
val readMetrics =
taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
" REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
case None => ""
}
val writeMetrics =
taskMetrics.shuffleWriteMetrics match {
case Some(metrics) =>
" SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
case None => ""
}
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
stageLogInfo(
stageSubmitted.stage.id,
"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.id, stageSubmitted.taskSize))
}
override def onStageCompleted(stageCompleted: StageCompleted) {
stageLogInfo(
stageCompleted.stageInfo.stage.id,
"STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
}
override def onTaskStart(taskStart: SparkListenerTaskStart) { }
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
val task = taskEnd.task
val taskInfo = taskEnd.taskInfo
var taskStatus = ""
task match {
case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
}
taskEnd.reason match {
case Success => taskStatus += " STATUS=SUCCESS"
recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
case Resubmitted =>
taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
" STAGE_ID=" + task.stageId
stageLogInfo(task.stageId, taskStatus)
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
stageLogInfo(task.stageId, taskStatus)
case OtherFailure(message) =>
taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
" STAGE_ID=" + task.stageId + " INFO=" + message
stageLogInfo(task.stageId, taskStatus)
case _ =>
}
}
override def onJobEnd(jobEnd: SparkListenerJobEnd) {
val job = jobEnd.job
var info = "JOB_ID=" + job.jobId
jobEnd.jobResult match {
case JobSucceeded => info += " STATUS=SUCCESS"
case JobFailed(exception, _) =>
info += " STATUS=FAILED REASON="
exception.getMessage.split("\\s+").foreach(info += _ + "_")
case _ =>
}
jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
closeLogWriter(job.jobId)
}
protected def recordJobProperties(jobID: Int, properties: Properties) {
if(properties != null) {
val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
jobLogInfo(jobID, description, false)
}
}
override def onJobStart(jobStart: SparkListenerJobStart) {
val job = jobStart.job
val properties = jobStart.properties
createLogWriter(job.jobId)
recordJobProperties(job.jobId, properties)
buildJobDep(job.jobId, job.finalStage)
recordStageDep(job.jobId)
recordStageDepGraph(job.jobId, job.finalStage)
jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
}
}
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler
import java.io.{IOException, File, FileNotFoundException, PrintWriter}
import java.text.SimpleDateFormat
import java.util.{Date, Properties}
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.StorageLevel
/**
* A logger class to record runtime information for jobs in Spark. This class outputs one log file
* for each Spark job, containing RDD graph, tasks start/stop, shuffle information.
* JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext
* after the SparkContext is created.
* Note that each JobLogger only works for one SparkContext
* @param logDirName The base directory for the log files.
*/
class JobLogger(val user: String, val logDirName: String)
extends SparkListener with Logging {
def this() = this(System.getProperty("user.name", "<unknown>"),
String.valueOf(System.currentTimeMillis()))
private val logDir =
if (System.getenv("SPARK_LOG_DIR") != null)
System.getenv("SPARK_LOG_DIR")
else
"/tmp/spark-%s".format(user)
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
private val stageIDToJobID = new HashMap[Int, Int]
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
createLogDir()
// The following 5 functions are used only in testing.
private[scheduler] def getLogDir = logDir
private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
private[scheduler] def getStageIDToJobID = stageIDToJobID
private[scheduler] def getJobIDToStages = jobIDToStages
private[scheduler] def getEventQueue = eventQueue
/** Create a folder for log files, the folder's name is the creation time of jobLogger */
protected def createLogDir() {
val dir = new File(logDir + "/" + logDirName + "/")
if (dir.exists()) {
return
}
if (dir.mkdirs() == false) {
// JobLogger should throw a exception rather than continue to construct this object.
throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/")
}
}
/**
* Create a log file for one job
* @param jobID ID of the job
* @exception FileNotFoundException Fail to create log file
*/
protected def createLogWriter(jobID: Int) {
try {
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
jobIDToPrintWriter += (jobID -> fileWriter)
} catch {
case e: FileNotFoundException => e.printStackTrace()
}
}
/**
* Close log file, and clean the stage relationship in stageIDToJobID
* @param jobID ID of the job
*/
protected def closeLogWriter(jobID: Int) {
jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
fileWriter.close()
jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
stageIDToJobID -= stage.id
})
jobIDToPrintWriter -= jobID
jobIDToStages -= jobID
}
}
/**
* Write info into log file
* @param jobID ID of the job
* @param info Info to be recorded
* @param withTime Controls whether to record time stamp before the info, default is true
*/
protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
var writeInfo = info
if (withTime) {
val date = new Date(System.currentTimeMillis())
writeInfo = DATE_FORMAT.format(date) + ": " +info
}
jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
}
/**
* Write info into log file
* @param stageID ID of the stage
* @param info Info to be recorded
* @param withTime Controls whether to record time stamp before the info, default is true
*/
protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) {
stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
}
/**
* Build stage dependency for a job
* @param jobID ID of the job
* @param stage Root stage of the job
*/
protected def buildJobDep(jobID: Int, stage: Stage) {
if (stage.jobId == jobID) {
jobIDToStages.get(jobID) match {
case Some(stageList) => stageList += stage
case None => val stageList = new ListBuffer[Stage]
stageList += stage
jobIDToStages += (jobID -> stageList)
}
stageIDToJobID += (stage.id -> jobID)
stage.parents.foreach(buildJobDep(jobID, _))
}
}
/**
* Record stage dependency and RDD dependency for a stage
* @param jobID Job ID of the stage
*/
protected def recordStageDep(jobID: Int) {
def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
var rddList = new ListBuffer[RDD[_]]
rddList += rdd
rdd.dependencies.foreach {
case shufDep: ShuffleDependency[_, _] =>
case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
}
rddList
}
jobIDToStages.get(jobID).foreach {_.foreach { stage =>
var depRddDesc: String = ""
getRddsInStage(stage.rdd).foreach { rdd =>
depRddDesc += rdd.id + ","
}
var depStageDesc: String = ""
stage.parents.foreach { stage =>
depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
}
jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
" STAGE_DEP=" + depStageDesc, false)
}
}
}
/**
* Generate indents and convert to String
* @param indent Number of indents
* @return string of indents
*/
protected def indentString(indent: Int): String = {
val sb = new StringBuilder()
for (i <- 1 to indent) {
sb.append(" ")
}
sb.toString()
}
/**
* Get RDD's name
* @param rdd Input RDD
* @return String of RDD's name
*/
protected def getRddName(rdd: RDD[_]): String = {
var rddName = rdd.getClass.getSimpleName
if (rdd.name != null) {
rddName = rdd.name
}
rddName
}
/**
* Record RDD dependency graph in a stage
* @param jobID Job ID of the stage
* @param rdd Root RDD of the stage
* @param indent Indent number before info
*/
protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
val rddInfo =
if (rdd.getStorageLevel != StorageLevel.NONE) {
"RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " +
rdd.origin + " " + rdd.generator
} else {
"RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " +
rdd.origin + " " + rdd.generator
}
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
rdd.dependencies.foreach {
case shufDep: ShuffleDependency[_, _] =>
val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
}
/**
* Record stage dependency graph of a job
* @param jobID Job ID of the stage
* @param stage Root stage of the job
* @param indent Indent number before info, default is 0
*/
protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0) {
val stageInfo = if (stage.isShuffleMap) {
"STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
} else {
"STAGE_ID=" + stage.id + " RESULT_STAGE"
}
if (stage.jobId == jobID) {
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
if (!idSet.contains(stage.id)) {
idSet += stage.id
recordRddInStageGraph(jobID, stage.rdd, indent)
stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2))
}
} else {
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
}
}
/**
* Record task metrics into job log files, including execution info and shuffle metrics
* @param stageID Stage ID of the task
* @param status Status info of the task
* @param taskInfo Task description info
* @param taskMetrics Task running metrics
*/
protected def recordTaskMetrics(stageID: Int, status: String,
taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
" START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
val readMetrics = taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
" REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
case None => ""
}
val writeMetrics = taskMetrics.shuffleWriteMetrics match {
case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
case None => ""
}
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
}
/**
* When stage is submitted, record stage submit info
* @param stageSubmitted Stage submitted event
*/
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
}
/**
* When stage is completed, record stage completion status
* @param stageCompleted Stage completed event
*/
override def onStageCompleted(stageCompleted: StageCompleted) {
stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
stageCompleted.stage.stageId))
}
override def onTaskStart(taskStart: SparkListenerTaskStart) { }
/**
* When task ends, record task completion status and metrics
* @param taskEnd Task end event
*/
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
val task = taskEnd.task
val taskInfo = taskEnd.taskInfo
var taskStatus = ""
task match {
case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
}
taskEnd.reason match {
case Success => taskStatus += " STATUS=SUCCESS"
recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskEnd.taskMetrics)
case Resubmitted =>
taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
" STAGE_ID=" + task.stageId
stageLogInfo(task.stageId, taskStatus)
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
stageLogInfo(task.stageId, taskStatus)
case OtherFailure(message) =>
taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
" STAGE_ID=" + task.stageId + " INFO=" + message
stageLogInfo(task.stageId, taskStatus)
case _ =>
}
}
/**
* When job ends, recording job completion status and close log file
* @param jobEnd Job end event
*/
override def onJobEnd(jobEnd: SparkListenerJobEnd) {
val job = jobEnd.job
var info = "JOB_ID=" + job.jobId
jobEnd.jobResult match {
case JobSucceeded => info += " STATUS=SUCCESS"
case JobFailed(exception, _) =>
info += " STATUS=FAILED REASON="
exception.getMessage.split("\\s+").foreach(info += _ + "_")
case _ =>
}
jobLogInfo(job.jobId, info.substring(0, info.length - 1).toUpperCase)
closeLogWriter(job.jobId)
}
/**
* Record job properties into job log file
* @param jobID ID of the job
* @param properties Properties of the job
*/
protected def recordJobProperties(jobID: Int, properties: Properties) {
if(properties != null) {
val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
jobLogInfo(jobID, description, false)
}
}
/**
* When job starts, record job property and stage graph
* @param jobStart Job start event
*/
override def onJobStart(jobStart: SparkListenerJobStart) {
val job = jobStart.job
val properties = jobStart.properties
createLogWriter(job.jobId)
recordJobProperties(job.jobId, properties)
buildJobDep(job.jobId, job.finalStage)
recordStageDep(job.jobId)
recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int])
jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
}
}

View file

@ -17,48 +17,58 @@
package org.apache.spark.scheduler
import scala.collection.mutable.ArrayBuffer
/**
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
* results to the given handler function.
*/
private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
private[spark] class JobWaiter[T](
dagScheduler: DAGScheduler,
jobId: Int,
totalTasks: Int,
resultHandler: (Int, T) => Unit)
extends JobListener {
private var finishedTasks = 0
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
private var jobResult: JobResult = null // If the job is finished, this will be its result
// Is the job as a whole finished (succeeded or failed)?
private var _jobFinished = totalTasks == 0
override def taskSucceeded(index: Int, result: Any) {
synchronized {
if (jobFinished) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
}
resultHandler(index, result.asInstanceOf[T])
finishedTasks += 1
if (finishedTasks == totalTasks) {
jobFinished = true
jobResult = JobSucceeded
this.notifyAll()
}
}
def jobFinished = _jobFinished
// If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
// partition RDDs), we set the jobResult directly to JobSucceeded.
private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
/**
* Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled
* asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it
* will fail this job with a SparkException.
*/
def cancel() {
dagScheduler.cancelJob(jobId)
}
override def jobFailed(exception: Exception) {
synchronized {
if (jobFinished) {
throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
}
jobFinished = true
jobResult = JobFailed(exception, None)
override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
if (_jobFinished) {
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
}
resultHandler(index, result.asInstanceOf[T])
finishedTasks += 1
if (finishedTasks == totalTasks) {
_jobFinished = true
jobResult = JobSucceeded
this.notifyAll()
}
}
override def jobFailed(exception: Exception): Unit = synchronized {
_jobFinished = true
jobResult = JobFailed(exception, None)
this.notifyAll()
}
def awaitResult(): JobResult = synchronized {
while (!jobFinished) {
while (!_jobFinished) {
this.wait()
}
return jobResult

View file

@ -43,7 +43,10 @@ private[spark] class Pool(
var runningTasks = 0
var priority = 0
var stageId = 0
// A pool's stage id is used to break the tie in scheduling.
var stageId = -1
var name = poolName
var parent: Pool = null

View file

@ -23,7 +23,7 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.RDDCheckpointData
import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
private[spark] object ResultTask {
@ -32,23 +32,23 @@ private[spark] object ResultTask {
// expensive on the master node if it needs to launch thousands of tasks.
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues)
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
return old
old
} else {
val out = new ByteArrayOutputStream
val ser = SparkEnv.get.closureSerializer.newInstance
val ser = SparkEnv.get.closureSerializer.newInstance()
val objOut = ser.serializeStream(new GZIPOutputStream(out))
objOut.writeObject(rdd)
objOut.writeObject(func)
objOut.close()
val bytes = out.toByteArray
serializedInfoCache.put(stageId, bytes)
return bytes
bytes
}
}
}
@ -56,11 +56,11 @@ private[spark] object ResultTask {
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
val loader = Thread.currentThread.getContextClassLoader
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
return (rdd, func)
(rdd, func)
}
def clearCache() {
@ -71,29 +71,37 @@ private[spark] object ResultTask {
}
/**
* A task that sends back the output to the driver application.
*
* See [[org.apache.spark.scheduler.Task]] for more information.
*
* @param stageId id of the stage this task belongs to
* @param rdd input to func
* @param func a function to apply on a partition of the RDD
* @param _partitionId index of the number in the RDD
* @param locs preferred task execution locations for locality scheduling
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
* input RDD's partitions).
*/
private[spark] class ResultTask[T, U](
stageId: Int,
var rdd: RDD[T],
var func: (TaskContext, Iterator[T]) => U,
var partition: Int,
_partitionId: Int,
@transient locs: Seq[TaskLocation],
var outputId: Int)
extends Task[U](stageId) with Externalizable {
extends Task[U](stageId, _partitionId) with Externalizable {
def this() = this(0, null, null, 0, null, 0)
var split = if (rdd == null) {
null
} else {
rdd.partitions(partition)
}
var split = if (rdd == null) null else rdd.partitions(partitionId)
@transient private val preferredLocs: Seq[TaskLocation] = {
if (locs == null) Nil else locs.toSet.toSeq
}
override def run(attemptId: Long): U = {
val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
override def runTask(context: TaskContext): U = {
metrics = Some(context.taskMetrics)
try {
func(context, rdd.iterator(split, context))
@ -104,17 +112,17 @@ private[spark] class ResultTask[T, U](
override def preferredLocations: Seq[TaskLocation] = preferredLocs
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
override def writeExternal(out: ObjectOutput) {
RDDCheckpointData.synchronized {
split = rdd.partitions(partition)
split = rdd.partitions(partitionId)
out.writeInt(stageId)
val bytes = ResultTask.serializeInfo(
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
out.writeInt(bytes.length)
out.write(bytes)
out.writeInt(partition)
out.writeInt(partitionId)
out.writeInt(outputId)
out.writeLong(epoch)
out.writeObject(split)
@ -129,7 +137,7 @@ private[spark] class ResultTask[T, U](
val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
rdd = rdd_.asInstanceOf[RDD[T]]
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
partition = in.readInt()
partitionId = in.readInt()
outputId = in.readInt()
epoch = in.readLong()
split = in.readObject().asInstanceOf[Partition]

Some files were not shown because too many files have changed in this diff Show more