Merge branch 'scala210-master' of github.com:colorant/incubator-spark into scala-2.10
Conflicts: core/src/main/scala/org/apache/spark/deploy/client/Client.scala core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
This commit is contained in:
commit
199e9cf02d
|
@ -68,7 +68,7 @@ described below.
|
|||
|
||||
When developing a Spark application, specify the Hadoop version by adding the
|
||||
"hadoop-client" artifact to your project's dependencies. For example, if you're
|
||||
using Hadoop 1.0.1 and build your application using SBT, add this entry to
|
||||
using Hadoop 1.2.1 and build your application using SBT, add this entry to
|
||||
`libraryDependencies`:
|
||||
|
||||
"org.apache.hadoop" % "hadoop-client" % "1.2.1"
|
||||
|
|
|
@ -32,12 +32,26 @@ fi
|
|||
|
||||
# Build up classpath
|
||||
CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf"
|
||||
if [ -f "$FWDIR/RELEASE" ]; then
|
||||
ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
|
||||
|
||||
# First check if we have a dependencies jar. If so, include binary classes with the deps jar
|
||||
if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar ]; then
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
|
||||
|
||||
DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar`
|
||||
CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR"
|
||||
else
|
||||
ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
|
||||
# Else use spark-assembly jar from either RELEASE or assembly directory
|
||||
if [ -f "$FWDIR/RELEASE" ]; then
|
||||
ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
|
||||
else
|
||||
ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
|
||||
fi
|
||||
CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
|
||||
fi
|
||||
CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
|
||||
|
||||
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
|
||||
if [[ $SPARK_TESTING == 1 ]]; then
|
||||
|
|
|
@ -28,7 +28,7 @@
|
|||
# SPARK_SSH_OPTS Options passed to ssh when running remote commands.
|
||||
##
|
||||
|
||||
usage="Usage: slaves.sh [--config confdir] command..."
|
||||
usage="Usage: slaves.sh [--config <conf-dir>] command..."
|
||||
|
||||
# if no args specified, show usage
|
||||
if [ $# -le 0 ]; then
|
||||
|
@ -46,6 +46,23 @@ bin=`cd "$bin"; pwd`
|
|||
# spark-env.sh. Save it here.
|
||||
HOSTLIST=$SPARK_SLAVES
|
||||
|
||||
# Check if --config is passed as an argument. It is an optional parameter.
|
||||
# Exit if the argument is not a directory.
|
||||
if [ "$1" == "--config" ]
|
||||
then
|
||||
shift
|
||||
conf_dir=$1
|
||||
if [ ! -d "$conf_dir" ]
|
||||
then
|
||||
echo "ERROR : $conf_dir is not a directory"
|
||||
echo $usage
|
||||
exit 1
|
||||
else
|
||||
export SPARK_CONF_DIR=$conf_dir
|
||||
fi
|
||||
shift
|
||||
fi
|
||||
|
||||
if [ -f "${SPARK_CONF_DIR}/spark-env.sh" ]; then
|
||||
. "${SPARK_CONF_DIR}/spark-env.sh"
|
||||
fi
|
||||
|
|
|
@ -29,7 +29,7 @@
|
|||
# SPARK_NICENESS The scheduling priority for daemons. Defaults to 0.
|
||||
##
|
||||
|
||||
usage="Usage: spark-daemon.sh [--config <conf-dir>] [--hosts hostlistfile] (start|stop) <spark-command> <spark-instance-number> <args...>"
|
||||
usage="Usage: spark-daemon.sh [--config <conf-dir>] (start|stop) <spark-command> <spark-instance-number> <args...>"
|
||||
|
||||
# if no args specified, show usage
|
||||
if [ $# -le 1 ]; then
|
||||
|
@ -43,6 +43,25 @@ bin=`cd "$bin"; pwd`
|
|||
. "$bin/spark-config.sh"
|
||||
|
||||
# get arguments
|
||||
|
||||
# Check if --config is passed as an argument. It is an optional parameter.
|
||||
# Exit if the argument is not a directory.
|
||||
|
||||
if [ "$1" == "--config" ]
|
||||
then
|
||||
shift
|
||||
conf_dir=$1
|
||||
if [ ! -d "$conf_dir" ]
|
||||
then
|
||||
echo "ERROR : $conf_dir is not a directory"
|
||||
echo $usage
|
||||
exit 1
|
||||
else
|
||||
export SPARK_CONF_DIR=$conf_dir
|
||||
fi
|
||||
shift
|
||||
fi
|
||||
|
||||
startStop=$1
|
||||
shift
|
||||
command=$1
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
|
||||
# Run a Spark command on all slave hosts.
|
||||
|
||||
usage="Usage: spark-daemons.sh [--config confdir] [--hosts hostlistfile] [start|stop] command instance-number args..."
|
||||
usage="Usage: spark-daemons.sh [--config <conf-dir>] [start|stop] command instance-number args..."
|
||||
|
||||
# if no args specified, show usage
|
||||
if [ $# -le 1 ]; then
|
||||
|
|
|
@ -17,8 +17,6 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
# Starts the master on the machine this script is executed on.
|
||||
|
||||
bin=`dirname "$0"`
|
||||
bin=`cd "$bin"; pwd`
|
||||
|
||||
|
|
|
@ -48,6 +48,10 @@
|
|||
<groupId>org.apache.avro</groupId>
|
||||
<artifactId>avro-ipc</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.zookeeper</groupId>
|
||||
<artifactId>zookeeper</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.eclipse.jetty</groupId>
|
||||
<artifactId>jetty-server</artifactId>
|
||||
|
|
|
@ -21,6 +21,7 @@ import io.netty.buffer.ByteBuf;
|
|||
import io.netty.channel.ChannelHandlerContext;
|
||||
import io.netty.channel.ChannelInboundByteHandlerAdapter;
|
||||
|
||||
import org.apache.spark.storage.BlockId;
|
||||
|
||||
abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
|
||||
|
||||
|
@ -33,7 +34,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
|
|||
}
|
||||
|
||||
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
|
||||
public abstract void handleError(String blockId);
|
||||
public abstract void handleError(BlockId blockId);
|
||||
|
||||
@Override
|
||||
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
|
||||
|
|
|
@ -24,6 +24,8 @@ import io.netty.channel.ChannelHandlerContext;
|
|||
import io.netty.channel.ChannelInboundMessageHandlerAdapter;
|
||||
import io.netty.channel.DefaultFileRegion;
|
||||
|
||||
import org.apache.spark.storage.BlockId;
|
||||
import org.apache.spark.storage.FileSegment;
|
||||
|
||||
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
|
||||
|
||||
|
@ -34,41 +36,36 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
|
|||
}
|
||||
|
||||
@Override
|
||||
public void messageReceived(ChannelHandlerContext ctx, String blockId) {
|
||||
String path = pResolver.getAbsolutePath(blockId);
|
||||
// if getFilePath returns null, close the channel
|
||||
if (path == null) {
|
||||
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
|
||||
BlockId blockId = BlockId.apply(blockIdString);
|
||||
FileSegment fileSegment = pResolver.getBlockLocation(blockId);
|
||||
// if getBlockLocation returns null, close the channel
|
||||
if (fileSegment == null) {
|
||||
//ctx.close();
|
||||
return;
|
||||
}
|
||||
File file = new File(path);
|
||||
File file = fileSegment.file();
|
||||
if (file.exists()) {
|
||||
if (!file.isFile()) {
|
||||
//logger.info("Not a file : " + file.getAbsolutePath());
|
||||
ctx.write(new FileHeader(0, blockId).buffer());
|
||||
ctx.flush();
|
||||
return;
|
||||
}
|
||||
long length = file.length();
|
||||
long length = fileSegment.length();
|
||||
if (length > Integer.MAX_VALUE || length <= 0) {
|
||||
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
|
||||
ctx.write(new FileHeader(0, blockId).buffer());
|
||||
ctx.flush();
|
||||
return;
|
||||
}
|
||||
int len = new Long(length).intValue();
|
||||
//logger.info("Sending block "+blockId+" filelen = "+len);
|
||||
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
|
||||
ctx.write((new FileHeader(len, blockId)).buffer());
|
||||
try {
|
||||
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
|
||||
.getChannel(), 0, file.length()));
|
||||
.getChannel(), fileSegment.offset(), fileSegment.length()));
|
||||
} catch (Exception e) {
|
||||
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
|
||||
e.printStackTrace();
|
||||
}
|
||||
} else {
|
||||
//logger.warning("File not found: " + file.getAbsolutePath());
|
||||
ctx.write(new FileHeader(0, blockId).buffer());
|
||||
}
|
||||
ctx.flush();
|
||||
|
|
|
@ -17,13 +17,10 @@
|
|||
|
||||
package org.apache.spark.network.netty;
|
||||
|
||||
import org.apache.spark.storage.BlockId;
|
||||
import org.apache.spark.storage.FileSegment;
|
||||
|
||||
public interface PathResolver {
|
||||
/**
|
||||
* Get the absolute path of the file
|
||||
*
|
||||
* @param fileId
|
||||
* @return the absolute path of file
|
||||
*/
|
||||
public String getAbsolutePath(String fileId);
|
||||
/** Get the file segment in which the given block resides. */
|
||||
public FileSegment getBlockLocation(BlockId blockId);
|
||||
}
|
||||
|
|
|
@ -17,20 +17,29 @@
|
|||
|
||||
package org.apache.hadoop.mapred
|
||||
|
||||
private[apache]
|
||||
trait SparkHadoopMapRedUtil {
|
||||
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
|
||||
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext");
|
||||
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID])
|
||||
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
|
||||
"org.apache.hadoop.mapred.JobContext")
|
||||
val ctor = klass.getDeclaredConstructor(classOf[JobConf],
|
||||
classOf[org.apache.hadoop.mapreduce.JobID])
|
||||
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
|
||||
}
|
||||
|
||||
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = {
|
||||
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext")
|
||||
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
|
||||
"org.apache.hadoop.mapred.TaskAttemptContext")
|
||||
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
|
||||
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
|
||||
}
|
||||
|
||||
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
|
||||
def newTaskAttemptID(
|
||||
jtIdentifier: String,
|
||||
jobId: Int,
|
||||
isMap: Boolean,
|
||||
taskId: Int,
|
||||
attemptId: Int) = {
|
||||
new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,9 +17,10 @@
|
|||
|
||||
package org.apache.hadoop.mapreduce
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import java.lang.{Integer => JInteger, Boolean => JBoolean}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
|
||||
private[apache]
|
||||
trait SparkHadoopMapReduceUtil {
|
||||
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
|
||||
val klass = firstAvailableClass(
|
||||
|
@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil {
|
|||
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
|
||||
}
|
||||
|
||||
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
|
||||
val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID");
|
||||
def newTaskAttemptID(
|
||||
jtIdentifier: String,
|
||||
jobId: Int,
|
||||
isMap: Boolean,
|
||||
taskId: Int,
|
||||
attemptId: Int) = {
|
||||
val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
|
||||
try {
|
||||
// first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN)
|
||||
// First, attempt to use the old-style constructor that takes a boolean isMap
|
||||
// (not available in YARN)
|
||||
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
|
||||
classOf[Int], classOf[Int])
|
||||
ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new
|
||||
JInteger(attemptId)).asInstanceOf[TaskAttemptID]
|
||||
classOf[Int], classOf[Int])
|
||||
ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
|
||||
new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
|
||||
} catch {
|
||||
case exc: NoSuchMethodException => {
|
||||
// failed, look for the new ctor that takes a TaskType (not available in 1.x)
|
||||
val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]]
|
||||
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE")
|
||||
// If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
|
||||
val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
|
||||
.asInstanceOf[Class[Enum[_]]]
|
||||
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
|
||||
taskTypeClass, if(isMap) "MAP" else "REDUCE")
|
||||
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
|
||||
classOf[Int], classOf[Int])
|
||||
ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new
|
||||
JInteger(attemptId)).asInstanceOf[TaskAttemptID]
|
||||
ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
|
||||
new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,43 +17,42 @@
|
|||
|
||||
package org.apache.spark
|
||||
|
||||
import java.util.{HashMap => JHashMap}
|
||||
import org.apache.spark.util.AppendOnlyMap
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
/** A set of functions used to aggregate data.
|
||||
*
|
||||
* @param createCombiner function to create the initial value of the aggregation.
|
||||
* @param mergeValue function to merge a new value into the aggregation result.
|
||||
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
|
||||
*/
|
||||
/**
|
||||
* A set of functions used to aggregate data.
|
||||
*
|
||||
* @param createCombiner function to create the initial value of the aggregation.
|
||||
* @param mergeValue function to merge a new value into the aggregation result.
|
||||
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
|
||||
*/
|
||||
case class Aggregator[K, V, C] (
|
||||
createCombiner: V => C,
|
||||
mergeValue: (C, V) => C,
|
||||
mergeCombiners: (C, C) => C) {
|
||||
|
||||
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]) : Iterator[(K, C)] = {
|
||||
val combiners = new JHashMap[K, C]
|
||||
for (kv <- iter) {
|
||||
val oldC = combiners.get(kv._1)
|
||||
if (oldC == null) {
|
||||
combiners.put(kv._1, createCombiner(kv._2))
|
||||
} else {
|
||||
combiners.put(kv._1, mergeValue(oldC, kv._2))
|
||||
}
|
||||
val combiners = new AppendOnlyMap[K, C]
|
||||
var kv: Product2[K, V] = null
|
||||
val update = (hadValue: Boolean, oldValue: C) => {
|
||||
if (hadValue) mergeValue(oldValue, kv._2) else createCombiner(kv._2)
|
||||
}
|
||||
while (iter.hasNext) {
|
||||
kv = iter.next()
|
||||
combiners.changeValue(kv._1, update)
|
||||
}
|
||||
combiners.iterator
|
||||
}
|
||||
|
||||
def combineCombinersByKey(iter: Iterator[(K, C)]) : Iterator[(K, C)] = {
|
||||
val combiners = new JHashMap[K, C]
|
||||
iter.foreach { case(k, c) =>
|
||||
val oldC = combiners.get(k)
|
||||
if (oldC == null) {
|
||||
combiners.put(k, c)
|
||||
} else {
|
||||
combiners.put(k, mergeCombiners(oldC, c))
|
||||
}
|
||||
val combiners = new AppendOnlyMap[K, C]
|
||||
var kc: (K, C) = null
|
||||
val update = (hadValue: Boolean, oldValue: C) => {
|
||||
if (hadValue) mergeCombiners(oldValue, kc._2) else kc._2
|
||||
}
|
||||
while (iter.hasNext) {
|
||||
kc = iter.next()
|
||||
combiners.changeValue(kc._1, update)
|
||||
}
|
||||
combiners.iterator
|
||||
}
|
||||
|
|
|
@ -22,13 +22,17 @@ import scala.collection.mutable.HashMap
|
|||
|
||||
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
|
||||
import org.apache.spark.serializer.Serializer
|
||||
import org.apache.spark.storage.BlockManagerId
|
||||
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
|
||||
import org.apache.spark.util.CompletionIterator
|
||||
|
||||
|
||||
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
|
||||
|
||||
override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
|
||||
override def fetch[T](
|
||||
shuffleId: Int,
|
||||
reduceId: Int,
|
||||
context: TaskContext,
|
||||
serializer: Serializer)
|
||||
: Iterator[T] =
|
||||
{
|
||||
|
||||
|
@ -45,12 +49,12 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
|
|||
splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
|
||||
}
|
||||
|
||||
val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])] = splitsByAddress.toSeq.map {
|
||||
val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])] = splitsByAddress.toSeq.map {
|
||||
case (address, splits) =>
|
||||
(address, splits.map(s => ("shuffle_%d_%d_%d".format(shuffleId, s._1, reduceId), s._2)))
|
||||
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
|
||||
}
|
||||
|
||||
def unpackBlock(blockPair: (String, Option[Iterator[Any]])) : Iterator[T] = {
|
||||
def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
|
||||
val blockId = blockPair._1
|
||||
val blockOption = blockPair._2
|
||||
blockOption match {
|
||||
|
@ -58,9 +62,8 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
|
|||
block.asInstanceOf[Iterator[T]]
|
||||
}
|
||||
case None => {
|
||||
val regex = "shuffle_([0-9]*)_([0-9]*)_([0-9]*)".r
|
||||
blockId match {
|
||||
case regex(shufId, mapId, _) =>
|
||||
case ShuffleBlockId(shufId, mapId, _) =>
|
||||
val address = statuses(mapId.toInt)._1
|
||||
throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, null)
|
||||
case _ =>
|
||||
|
@ -74,7 +77,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
|
|||
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
|
||||
val itr = blockFetcherItr.flatMap(unpackBlock)
|
||||
|
||||
CompletionIterator[T, Iterator[T]](itr, {
|
||||
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
|
||||
val shuffleMetrics = new ShuffleReadMetrics
|
||||
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
|
||||
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
|
||||
|
@ -83,7 +86,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
|
|||
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
|
||||
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
|
||||
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
|
||||
metrics.shuffleReadMetrics = Some(shuffleMetrics)
|
||||
context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
|
||||
})
|
||||
|
||||
new InterruptibleIterator[T](context, completionIter)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
package org.apache.spark
|
||||
|
||||
import scala.collection.mutable.{ArrayBuffer, HashSet}
|
||||
import org.apache.spark.storage.{BlockManager, StorageLevel}
|
||||
import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, RDDBlockId}
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
|
||||
|
@ -28,17 +28,17 @@ import org.apache.spark.rdd.RDD
|
|||
private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
|
||||
|
||||
/** Keys of RDD splits that are being computed/loaded. */
|
||||
private val loading = new HashSet[String]()
|
||||
private val loading = new HashSet[RDDBlockId]()
|
||||
|
||||
/** Gets or computes an RDD split. Used by RDD.iterator() when an RDD is cached. */
|
||||
def getOrCompute[T](rdd: RDD[T], split: Partition, context: TaskContext, storageLevel: StorageLevel)
|
||||
: Iterator[T] = {
|
||||
val key = "rdd_%d_%d".format(rdd.id, split.index)
|
||||
val key = RDDBlockId(rdd.id, split.index)
|
||||
logDebug("Looking for partition " + key)
|
||||
blockManager.get(key) match {
|
||||
case Some(values) =>
|
||||
// Partition is already materialized, so just return its values
|
||||
return values.asInstanceOf[Iterator[T]]
|
||||
return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
|
||||
|
||||
case None =>
|
||||
// Mark the split as loading (unless someone else marks it first)
|
||||
|
@ -56,7 +56,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
|
|||
// downside of the current code is that threads wait serially if this does happen.
|
||||
blockManager.get(key) match {
|
||||
case Some(values) =>
|
||||
return values.asInstanceOf[Iterator[T]]
|
||||
return new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
|
||||
case None =>
|
||||
logInfo("Whoever was loading %s failed; we'll try it ourselves".format(key))
|
||||
loading.add(key)
|
||||
|
@ -73,7 +73,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
|
|||
if (context.runningLocally) { return computedValues }
|
||||
val elements = new ArrayBuffer[Any]
|
||||
elements ++= computedValues
|
||||
blockManager.put(key, elements, storageLevel, true)
|
||||
blockManager.put(key, elements, storageLevel, tellMaster = true)
|
||||
return elements.iterator.asInstanceOf[Iterator[T]]
|
||||
} finally {
|
||||
loading.synchronized {
|
||||
|
|
250
core/src/main/scala/org/apache/spark/FutureAction.scala
Normal file
250
core/src/main/scala/org/apache/spark/FutureAction.scala
Normal file
|
@ -0,0 +1,250 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark
|
||||
|
||||
import scala.concurrent._
|
||||
import scala.concurrent.duration.Duration
|
||||
import scala.util.Try
|
||||
|
||||
import org.apache.spark.scheduler.{JobSucceeded, JobWaiter}
|
||||
import org.apache.spark.scheduler.JobFailed
|
||||
import org.apache.spark.rdd.RDD
|
||||
|
||||
|
||||
/**
|
||||
* A future for the result of an action. This is an extension of the Scala Future interface to
|
||||
* support cancellation.
|
||||
*/
|
||||
trait FutureAction[T] extends Future[T] {
|
||||
// Note that we redefine methods of the Future trait here explicitly so we can specify a different
|
||||
// documentation (with reference to the word "action").
|
||||
|
||||
/**
|
||||
* Cancels the execution of this action.
|
||||
*/
|
||||
def cancel()
|
||||
|
||||
/**
|
||||
* Blocks until this action completes.
|
||||
* @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
|
||||
* for unbounded waiting, or a finite positive duration
|
||||
* @return this FutureAction
|
||||
*/
|
||||
override def ready(atMost: Duration)(implicit permit: CanAwait): FutureAction.this.type
|
||||
|
||||
/**
|
||||
* Awaits and returns the result (of type T) of this action.
|
||||
* @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf
|
||||
* for unbounded waiting, or a finite positive duration
|
||||
* @throws Exception exception during action execution
|
||||
* @return the result value if the action is completed within the specific maximum wait time
|
||||
*/
|
||||
@throws(classOf[Exception])
|
||||
override def result(atMost: Duration)(implicit permit: CanAwait): T
|
||||
|
||||
/**
|
||||
* When this action is completed, either through an exception, or a value, applies the provided
|
||||
* function.
|
||||
*/
|
||||
def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext)
|
||||
|
||||
/**
|
||||
* Returns whether the action has already been completed with a value or an exception.
|
||||
*/
|
||||
override def isCompleted: Boolean
|
||||
|
||||
/**
|
||||
* The value of this Future.
|
||||
*
|
||||
* If the future is not completed the returned value will be None. If the future is completed
|
||||
* the value will be Some(Success(t)) if it contains a valid result, or Some(Failure(error)) if
|
||||
* it contains an exception.
|
||||
*/
|
||||
override def value: Option[Try[T]]
|
||||
|
||||
/**
|
||||
* Blocks and returns the result of this job.
|
||||
*/
|
||||
@throws(classOf[Exception])
|
||||
def get(): T = Await.result(this, Duration.Inf)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* The future holding the result of an action that triggers a single job. Examples include
|
||||
* count, collect, reduce.
|
||||
*/
|
||||
class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: => T)
|
||||
extends FutureAction[T] {
|
||||
|
||||
override def cancel() {
|
||||
jobWaiter.cancel()
|
||||
}
|
||||
|
||||
override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
|
||||
if (!atMost.isFinite()) {
|
||||
awaitResult()
|
||||
} else {
|
||||
val finishTime = System.currentTimeMillis() + atMost.toMillis
|
||||
while (!isCompleted) {
|
||||
val time = System.currentTimeMillis()
|
||||
if (time >= finishTime) {
|
||||
throw new TimeoutException
|
||||
} else {
|
||||
jobWaiter.wait(finishTime - time)
|
||||
}
|
||||
}
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
@throws(classOf[Exception])
|
||||
override def result(atMost: Duration)(implicit permit: CanAwait): T = {
|
||||
ready(atMost)(permit)
|
||||
awaitResult() match {
|
||||
case scala.util.Success(res) => res
|
||||
case scala.util.Failure(e) => throw e
|
||||
}
|
||||
}
|
||||
|
||||
override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) {
|
||||
executor.execute(new Runnable {
|
||||
override def run() {
|
||||
func(awaitResult())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
override def isCompleted: Boolean = jobWaiter.jobFinished
|
||||
|
||||
override def value: Option[Try[T]] = {
|
||||
if (jobWaiter.jobFinished) {
|
||||
Some(awaitResult())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
private def awaitResult(): Try[T] = {
|
||||
jobWaiter.awaitResult() match {
|
||||
case JobSucceeded => scala.util.Success(resultFunc)
|
||||
case JobFailed(e: Exception, _) => scala.util.Failure(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A FutureAction for actions that could trigger multiple Spark jobs. Examples include take,
|
||||
* takeSample. Cancellation works by setting the cancelled flag to true and interrupting the
|
||||
* action thread if it is being blocked by a job.
|
||||
*/
|
||||
class ComplexFutureAction[T] extends FutureAction[T] {
|
||||
|
||||
// Pointer to the thread that is executing the action. It is set when the action is run.
|
||||
@volatile private var thread: Thread = _
|
||||
|
||||
// A flag indicating whether the future has been cancelled. This is used in case the future
|
||||
// is cancelled before the action was even run (and thus we have no thread to interrupt).
|
||||
@volatile private var _cancelled: Boolean = false
|
||||
|
||||
// A promise used to signal the future.
|
||||
private val p = promise[T]()
|
||||
|
||||
override def cancel(): Unit = this.synchronized {
|
||||
_cancelled = true
|
||||
if (thread != null) {
|
||||
thread.interrupt()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Executes some action enclosed in the closure. To properly enable cancellation, the closure
|
||||
* should use runJob implementation in this promise. See takeAsync for example.
|
||||
*/
|
||||
def run(func: => T)(implicit executor: ExecutionContext): this.type = {
|
||||
scala.concurrent.future {
|
||||
thread = Thread.currentThread
|
||||
try {
|
||||
p.success(func)
|
||||
} catch {
|
||||
case e: Exception => p.failure(e)
|
||||
} finally {
|
||||
thread = null
|
||||
}
|
||||
}
|
||||
this
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs a Spark job. This is a wrapper around the same functionality provided by SparkContext
|
||||
* to enable cancellation.
|
||||
*/
|
||||
def runJob[T, U, R](
|
||||
rdd: RDD[T],
|
||||
processPartition: Iterator[T] => U,
|
||||
partitions: Seq[Int],
|
||||
resultHandler: (Int, U) => Unit,
|
||||
resultFunc: => R) {
|
||||
// If the action hasn't been cancelled yet, submit the job. The check and the submitJob
|
||||
// command need to be in an atomic block.
|
||||
val job = this.synchronized {
|
||||
if (!cancelled) {
|
||||
rdd.context.submitJob(rdd, processPartition, partitions, resultHandler, resultFunc)
|
||||
} else {
|
||||
throw new SparkException("Action has been cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for the job to complete. If the action is cancelled (with an interrupt),
|
||||
// cancel the job and stop the execution. This is not in a synchronized block because
|
||||
// Await.ready eventually waits on the monitor in FutureJob.jobWaiter.
|
||||
try {
|
||||
Await.ready(job, Duration.Inf)
|
||||
} catch {
|
||||
case e: InterruptedException =>
|
||||
job.cancel()
|
||||
throw new SparkException("Action has been cancelled")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns whether the promise has been cancelled.
|
||||
*/
|
||||
def cancelled: Boolean = _cancelled
|
||||
|
||||
@throws(classOf[InterruptedException])
|
||||
@throws(classOf[scala.concurrent.TimeoutException])
|
||||
override def ready(atMost: Duration)(implicit permit: CanAwait): this.type = {
|
||||
p.future.ready(atMost)(permit)
|
||||
this
|
||||
}
|
||||
|
||||
@throws(classOf[Exception])
|
||||
override def result(atMost: Duration)(implicit permit: CanAwait): T = {
|
||||
p.future.result(atMost)(permit)
|
||||
}
|
||||
|
||||
override def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit = {
|
||||
p.future.onComplete(func)(executor)
|
||||
}
|
||||
|
||||
override def isCompleted: Boolean = p.isCompleted
|
||||
|
||||
override def value: Option[Try[T]] = p.future.value
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark
|
||||
|
||||
/**
|
||||
* An iterator that wraps around an existing iterator to provide task killing functionality.
|
||||
* It works by checking the interrupted flag in TaskContext.
|
||||
*/
|
||||
class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
|
||||
extends Iterator[T] {
|
||||
|
||||
def hasNext: Boolean = !context.interrupted && delegate.hasNext
|
||||
|
||||
def next(): T = delegate.next()
|
||||
}
|
|
@ -20,7 +20,6 @@ package org.apache.spark
|
|||
import java.io._
|
||||
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
|
||||
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.mutable.HashSet
|
||||
|
||||
import akka.actor._
|
||||
|
@ -34,7 +33,7 @@ import scala.concurrent.duration._
|
|||
|
||||
import org.apache.spark.scheduler.MapStatus
|
||||
import org.apache.spark.storage.BlockManagerId
|
||||
import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashMap}
|
||||
import org.apache.spark.util.{MetadataCleanerType, Utils, MetadataCleaner, TimeStampedHashMap}
|
||||
|
||||
|
||||
private[spark] sealed trait MapOutputTrackerMessage
|
||||
|
@ -42,11 +41,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String
|
|||
extends MapOutputTrackerMessage
|
||||
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
|
||||
|
||||
private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
|
||||
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
|
||||
extends Actor with Logging {
|
||||
def receive = {
|
||||
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
|
||||
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
|
||||
sender ! tracker.getSerializedLocations(shuffleId)
|
||||
sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
|
||||
|
||||
case StopMapOutputTracker =>
|
||||
logInfo("MapOutputTrackerActor stopped!")
|
||||
|
@ -62,22 +62,19 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
// Set to the MapOutputTrackerActor living on the driver
|
||||
var trackerActor: Either[ActorRef, ActorSelection] = _
|
||||
|
||||
private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
|
||||
// Incremented every time a fetch fails so that client nodes know to clear
|
||||
// their cache of map output locations if this happens.
|
||||
private var epoch: Long = 0
|
||||
private val epochLock = new java.lang.Object
|
||||
protected var epoch: Long = 0
|
||||
protected val epochLock = new java.lang.Object
|
||||
|
||||
// Cache a serialized version of the output statuses for each shuffle to send them out faster
|
||||
var cacheEpoch = epoch
|
||||
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
|
||||
|
||||
val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
|
||||
private val metadataCleaner =
|
||||
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
|
||||
|
||||
// Send a message to the trackerActor and get its result within a default timeout, or
|
||||
// throw a SparkException if this fails.
|
||||
def askTracker(message: Any): Any = {
|
||||
private def askTracker(message: Any): Any = {
|
||||
try {
|
||||
val future = if (trackerActor.isLeft ) {
|
||||
trackerActor.left.get.ask(message)(timeout)
|
||||
|
@ -92,50 +89,12 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
}
|
||||
|
||||
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
|
||||
def communicate(message: Any) {
|
||||
private def communicate(message: Any) {
|
||||
if (askTracker(message) != true) {
|
||||
throw new SparkException("Error reply received from MapOutputTracker")
|
||||
}
|
||||
}
|
||||
|
||||
def registerShuffle(shuffleId: Int, numMaps: Int) {
|
||||
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
|
||||
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
|
||||
}
|
||||
}
|
||||
|
||||
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
|
||||
var array = mapStatuses(shuffleId)
|
||||
array.synchronized {
|
||||
array(mapId) = status
|
||||
}
|
||||
}
|
||||
|
||||
def registerMapOutputs(
|
||||
shuffleId: Int,
|
||||
statuses: Array[MapStatus],
|
||||
changeEpoch: Boolean = false) {
|
||||
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
|
||||
if (changeEpoch) {
|
||||
incrementEpoch()
|
||||
}
|
||||
}
|
||||
|
||||
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
|
||||
var arrayOpt = mapStatuses.get(shuffleId)
|
||||
if (arrayOpt.isDefined && arrayOpt.get != null) {
|
||||
var array = arrayOpt.get
|
||||
array.synchronized {
|
||||
if (array(mapId) != null && array(mapId).location == bmAddress) {
|
||||
array(mapId) = null
|
||||
}
|
||||
}
|
||||
incrementEpoch()
|
||||
} else {
|
||||
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
|
||||
}
|
||||
}
|
||||
|
||||
// Remembers which map output locations are currently being fetched on a worker
|
||||
private val fetching = new HashSet[Int]
|
||||
|
||||
|
@ -174,7 +133,7 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
try {
|
||||
val fetchedBytes =
|
||||
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
|
||||
fetchedStatuses = deserializeStatuses(fetchedBytes)
|
||||
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
|
||||
logInfo("Got the output locations")
|
||||
mapStatuses.put(shuffleId, fetchedStatuses)
|
||||
} finally {
|
||||
|
@ -200,9 +159,8 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
private def cleanup(cleanupTime: Long) {
|
||||
protected def cleanup(cleanupTime: Long) {
|
||||
mapStatuses.clearOldValues(cleanupTime)
|
||||
cachedSerializedStatuses.clearOldValues(cleanupTime)
|
||||
}
|
||||
|
||||
def stop() {
|
||||
|
@ -212,15 +170,7 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
trackerActor = null
|
||||
}
|
||||
|
||||
// Called on master to increment the epoch number
|
||||
def incrementEpoch() {
|
||||
epochLock.synchronized {
|
||||
epoch += 1
|
||||
logDebug("Increasing epoch to " + epoch)
|
||||
}
|
||||
}
|
||||
|
||||
// Called on master or workers to get current epoch number
|
||||
// Called to get current epoch number
|
||||
def getEpoch: Long = {
|
||||
epochLock.synchronized {
|
||||
return epoch
|
||||
|
@ -234,14 +184,62 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
epochLock.synchronized {
|
||||
if (newEpoch > epoch) {
|
||||
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
|
||||
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
|
||||
mapStatuses.clear()
|
||||
epoch = newEpoch
|
||||
mapStatuses.clear()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
|
||||
private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
|
||||
|
||||
// Cache a serialized version of the output statuses for each shuffle to send them out faster
|
||||
private var cacheEpoch = epoch
|
||||
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
|
||||
|
||||
def registerShuffle(shuffleId: Int, numMaps: Int) {
|
||||
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
|
||||
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
|
||||
}
|
||||
}
|
||||
|
||||
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
|
||||
val array = mapStatuses(shuffleId)
|
||||
array.synchronized {
|
||||
array(mapId) = status
|
||||
}
|
||||
}
|
||||
|
||||
def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
|
||||
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
|
||||
if (changeEpoch) {
|
||||
incrementEpoch()
|
||||
}
|
||||
}
|
||||
|
||||
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
|
||||
val arrayOpt = mapStatuses.get(shuffleId)
|
||||
if (arrayOpt.isDefined && arrayOpt.get != null) {
|
||||
val array = arrayOpt.get
|
||||
array.synchronized {
|
||||
if (array(mapId) != null && array(mapId).location == bmAddress) {
|
||||
array(mapId) = null
|
||||
}
|
||||
}
|
||||
incrementEpoch()
|
||||
} else {
|
||||
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
|
||||
}
|
||||
}
|
||||
|
||||
def incrementEpoch() {
|
||||
epochLock.synchronized {
|
||||
epoch += 1
|
||||
logDebug("Increasing epoch to " + epoch)
|
||||
}
|
||||
}
|
||||
|
||||
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
|
||||
var statuses: Array[MapStatus] = null
|
||||
var epochGotten: Long = -1
|
||||
epochLock.synchronized {
|
||||
|
@ -259,7 +257,7 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
}
|
||||
// If we got here, we failed to find the serialized locations in the cache, so we pulled
|
||||
// out a snapshot of the locations as "locs"; let's serialize and return that
|
||||
val bytes = serializeStatuses(statuses)
|
||||
val bytes = MapOutputTracker.serializeMapStatuses(statuses)
|
||||
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
|
||||
// Add them into the table only if the epoch hasn't changed while we were working
|
||||
epochLock.synchronized {
|
||||
|
@ -267,13 +265,31 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
cachedSerializedStatuses(shuffleId) = bytes
|
||||
}
|
||||
}
|
||||
return bytes
|
||||
bytes
|
||||
}
|
||||
|
||||
protected override def cleanup(cleanupTime: Long) {
|
||||
super.cleanup(cleanupTime)
|
||||
cachedSerializedStatuses.clearOldValues(cleanupTime)
|
||||
}
|
||||
|
||||
override def stop() {
|
||||
super.stop()
|
||||
cachedSerializedStatuses.clear()
|
||||
}
|
||||
|
||||
override def updateEpoch(newEpoch: Long) {
|
||||
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] object MapOutputTracker {
|
||||
private val LOG_BASE = 1.1
|
||||
|
||||
// Serialize an array of map output locations into an efficient byte format so that we can send
|
||||
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
|
||||
// generally be pretty compressible because many map outputs will be on the same hostname.
|
||||
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
|
||||
def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
|
||||
val out = new ByteArrayOutputStream
|
||||
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
|
||||
// Since statuses can be modified in parallel, sync on it
|
||||
|
@ -284,18 +300,11 @@ private[spark] class MapOutputTracker extends Logging {
|
|||
out.toByteArray
|
||||
}
|
||||
|
||||
// Opposite of serializeStatuses.
|
||||
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
|
||||
// Opposite of serializeMapStatuses.
|
||||
def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
|
||||
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
|
||||
objIn.readObject().
|
||||
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
|
||||
// comment this out - nulls could be due to missing location ?
|
||||
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
|
||||
objIn.readObject().asInstanceOf[Array[MapStatus]]
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] object MapOutputTracker {
|
||||
private val LOG_BASE = 1.1
|
||||
|
||||
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
|
||||
// any of the statuses is null (indicating a missing location due to a failed mapper),
|
||||
|
|
|
@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher {
|
|||
* Fetch the shuffle outputs for a given ShuffleDependency.
|
||||
* @return An iterator over the elements of the fetched shuffle outputs.
|
||||
*/
|
||||
def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
|
||||
def fetch[T](
|
||||
shuffleId: Int,
|
||||
reduceId: Int,
|
||||
context: TaskContext,
|
||||
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
|
||||
|
||||
/** Stop the fetcher */
|
||||
|
|
|
@ -24,7 +24,7 @@ import java.util.concurrent.atomic.AtomicInteger
|
|||
|
||||
import scala.collection.Map
|
||||
import scala.collection.generic.Growable
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.collection.JavaConverters._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.reflect.{ ClassTag, classTag}
|
||||
|
@ -53,21 +53,19 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
|
|||
|
||||
import org.apache.mesos.MesosNativeLibrary
|
||||
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.deploy.LocalSparkCluster
|
||||
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
|
||||
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
|
||||
import org.apache.spark.rdd._
|
||||
import org.apache.spark.scheduler._
|
||||
import org.apache.spark.scheduler.cluster.{StandaloneSchedulerBackend, SparkDeploySchedulerBackend,
|
||||
ClusterScheduler}
|
||||
import org.apache.spark.scheduler.local.LocalScheduler
|
||||
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
|
||||
SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend}
|
||||
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
|
||||
import org.apache.spark.storage.{StorageUtils, BlockManagerSource}
|
||||
import org.apache.spark.ui.SparkUI
|
||||
import org.apache.spark.util.{ClosureCleaner, Utils, MetadataCleaner, TimeStampedHashMap}
|
||||
import org.apache.spark.scheduler.local.LocalScheduler
|
||||
import org.apache.spark.scheduler.StageInfo
|
||||
import org.apache.spark.storage.RDDInfo
|
||||
import org.apache.spark.storage.StorageStatus
|
||||
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
|
||||
import org.apache.spark.ui.SparkUI
|
||||
import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
|
||||
TimeStampedHashMap, Utils}
|
||||
|
||||
/**
|
||||
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
|
||||
|
@ -121,9 +119,9 @@ class SparkContext(
|
|||
|
||||
// Keeps track of all persisted RDDs
|
||||
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
|
||||
private[spark] val metadataCleaner = new MetadataCleaner("SparkContext", this.cleanup)
|
||||
private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup)
|
||||
|
||||
// Initalize the Spark UI
|
||||
// Initialize the Spark UI
|
||||
private[spark] val ui = new SparkUI(this)
|
||||
ui.bind()
|
||||
|
||||
|
@ -149,6 +147,14 @@ class SparkContext(
|
|||
executorEnvs ++= environment
|
||||
}
|
||||
|
||||
// Set SPARK_USER for user who is running SparkContext.
|
||||
val sparkUser = Option {
|
||||
Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER"))
|
||||
}.getOrElse {
|
||||
SparkContext.SPARK_UNKNOWN_USER
|
||||
}
|
||||
executorEnvs("SPARK_USER") = sparkUser
|
||||
|
||||
// Create and start the scheduler
|
||||
private[spark] var taskScheduler: TaskScheduler = {
|
||||
// Regular expression used for local[N] master format
|
||||
|
@ -158,9 +164,11 @@ class SparkContext(
|
|||
// Regular expression for simulating a Spark cluster of [N, cores, memory] locally
|
||||
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
|
||||
// Regular expression for connecting to Spark deploy clusters
|
||||
val SPARK_REGEX = """(spark://.*)""".r
|
||||
//Regular expression for connection to Mesos cluster
|
||||
val MESOS_REGEX = """(mesos://.*)""".r
|
||||
val SPARK_REGEX = """spark://(.*)""".r
|
||||
// Regular expression for connection to Mesos cluster
|
||||
val MESOS_REGEX = """mesos://(.*)""".r
|
||||
// Regular expression for connection to Simr cluster
|
||||
val SIMR_REGEX = """simr://(.*)""".r
|
||||
|
||||
master match {
|
||||
case "local" =>
|
||||
|
@ -174,7 +182,14 @@ class SparkContext(
|
|||
|
||||
case SPARK_REGEX(sparkUrl) =>
|
||||
val scheduler = new ClusterScheduler(this)
|
||||
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
|
||||
val masterUrls = sparkUrl.split(",").map("spark://" + _)
|
||||
val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
|
||||
scheduler.initialize(backend)
|
||||
scheduler
|
||||
|
||||
case SIMR_REGEX(simrUrl) =>
|
||||
val scheduler = new ClusterScheduler(this)
|
||||
val backend = new SimrSchedulerBackend(scheduler, this, simrUrl)
|
||||
scheduler.initialize(backend)
|
||||
scheduler
|
||||
|
||||
|
@ -190,8 +205,8 @@ class SparkContext(
|
|||
val scheduler = new ClusterScheduler(this)
|
||||
val localCluster = new LocalSparkCluster(
|
||||
numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
|
||||
val sparkUrl = localCluster.start()
|
||||
val backend = new SparkDeploySchedulerBackend(scheduler, this, sparkUrl, appName)
|
||||
val masterUrls = localCluster.start()
|
||||
val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
|
||||
scheduler.initialize(backend)
|
||||
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
|
||||
localCluster.stop()
|
||||
|
@ -210,38 +225,36 @@ class SparkContext(
|
|||
throw new SparkException("YARN mode not available ?", th)
|
||||
}
|
||||
}
|
||||
val backend = new StandaloneSchedulerBackend(scheduler, this.env.actorSystem)
|
||||
val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem)
|
||||
scheduler.initialize(backend)
|
||||
scheduler
|
||||
|
||||
case MESOS_REGEX(mesosUrl) =>
|
||||
MesosNativeLibrary.load()
|
||||
val scheduler = new ClusterScheduler(this)
|
||||
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
|
||||
val backend = if (coarseGrained) {
|
||||
new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
|
||||
} else {
|
||||
new MesosSchedulerBackend(scheduler, this, mesosUrl, appName)
|
||||
}
|
||||
scheduler.initialize(backend)
|
||||
scheduler
|
||||
|
||||
case _ =>
|
||||
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
|
||||
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
|
||||
}
|
||||
MesosNativeLibrary.load()
|
||||
val scheduler = new ClusterScheduler(this)
|
||||
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
|
||||
val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
|
||||
val backend = if (coarseGrained) {
|
||||
new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
|
||||
} else {
|
||||
new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
|
||||
}
|
||||
scheduler.initialize(backend)
|
||||
scheduler
|
||||
throw new SparkException("Could not parse Master URL: '" + master + "'")
|
||||
}
|
||||
}
|
||||
taskScheduler.start()
|
||||
|
||||
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
|
||||
dagScheduler.start()
|
||||
|
||||
ui.start()
|
||||
|
||||
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
|
||||
val hadoopConfiguration = {
|
||||
val env = SparkEnv.get
|
||||
val conf = env.hadoop.newConfiguration()
|
||||
val conf = SparkHadoopUtil.get.newConfiguration()
|
||||
// Explicitly check for S3 environment variables
|
||||
if (System.getenv("AWS_ACCESS_KEY_ID") != null &&
|
||||
System.getenv("AWS_SECRET_ACCESS_KEY") != null) {
|
||||
|
@ -251,8 +264,10 @@ class SparkContext(
|
|||
conf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY"))
|
||||
}
|
||||
// Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar"
|
||||
for (key <- System.getProperties.toMap[String, String].keys if key.startsWith("spark.hadoop.")) {
|
||||
conf.set(key.substring("spark.hadoop.".length), System.getProperty(key))
|
||||
Utils.getSystemProperties.foreach { case (key, value) =>
|
||||
if (key.startsWith("spark.hadoop.")) {
|
||||
conf.set(key.substring("spark.hadoop.".length), value)
|
||||
}
|
||||
}
|
||||
val bufferSize = System.getProperty("spark.buffer.size", "65536")
|
||||
conf.set("io.file.buffer.size", bufferSize)
|
||||
|
@ -266,6 +281,12 @@ class SparkContext(
|
|||
override protected def childValue(parent: Properties): Properties = new Properties(parent)
|
||||
}
|
||||
|
||||
private[spark] def getLocalProperties(): Properties = localProperties.get()
|
||||
|
||||
private[spark] def setLocalProperties(props: Properties) {
|
||||
localProperties.set(props)
|
||||
}
|
||||
|
||||
def initLocalProperties() {
|
||||
localProperties.set(new Properties())
|
||||
}
|
||||
|
@ -285,15 +306,46 @@ class SparkContext(
|
|||
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
|
||||
|
||||
/** Set a human readable description of the current job. */
|
||||
@deprecated("use setJobGroup", "0.8.1")
|
||||
def setJobDescription(value: String) {
|
||||
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Assigns a group id to all the jobs started by this thread until the group id is set to a
|
||||
* different value or cleared.
|
||||
*
|
||||
* Often, a unit of execution in an application consists of multiple Spark actions or jobs.
|
||||
* Application programmers can use this method to group all those jobs together and give a
|
||||
* group description. Once set, the Spark web UI will associate such jobs with this group.
|
||||
*
|
||||
* The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all
|
||||
* running jobs in this group. For example,
|
||||
* {{{
|
||||
* // In the main thread:
|
||||
* sc.setJobGroup("some_job_to_cancel", "some job description")
|
||||
* sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
|
||||
*
|
||||
* // In a separate thread:
|
||||
* sc.cancelJobGroup("some_job_to_cancel")
|
||||
* }}}
|
||||
*/
|
||||
def setJobGroup(groupId: String, description: String) {
|
||||
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
|
||||
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
|
||||
}
|
||||
|
||||
/** Clear the job group id and its description. */
|
||||
def clearJobGroup() {
|
||||
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
|
||||
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
|
||||
}
|
||||
|
||||
// Post init
|
||||
taskScheduler.postStartHook()
|
||||
|
||||
val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
|
||||
val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
|
||||
private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
|
||||
private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
|
||||
|
||||
def initDriverMetrics() {
|
||||
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
|
||||
|
@ -332,7 +384,7 @@ class SparkContext(
|
|||
}
|
||||
|
||||
/**
|
||||
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf giving its InputFormat and any
|
||||
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and any
|
||||
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
|
||||
* etc).
|
||||
*/
|
||||
|
@ -344,7 +396,7 @@ class SparkContext(
|
|||
minSplits: Int = defaultMinSplits
|
||||
): RDD[(K, V)] = {
|
||||
// Add necessary security credentials to the JobConf before broadcasting it.
|
||||
SparkEnv.get.hadoop.addCredentials(conf)
|
||||
SparkHadoopUtil.get.addCredentials(conf)
|
||||
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minSplits)
|
||||
}
|
||||
|
||||
|
@ -358,24 +410,15 @@ class SparkContext(
|
|||
): RDD[(K, V)] = {
|
||||
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
|
||||
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
|
||||
hadoopFile(path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get an RDD for a Hadoop file with an arbitray InputFormat. Accept a Hadoop Configuration
|
||||
* that has already been broadcast, assuming that it's safe to use it to construct a
|
||||
* HadoopFileRDD (i.e., except for file 'path', all other configuration properties can be resued).
|
||||
*/
|
||||
def hadoopFile[K, V](
|
||||
path: String,
|
||||
confBroadcast: Broadcast[SerializableWritable[Configuration]],
|
||||
inputFormatClass: Class[_ <: InputFormat[K, V]],
|
||||
keyClass: Class[K],
|
||||
valueClass: Class[V],
|
||||
minSplits: Int
|
||||
): RDD[(K, V)] = {
|
||||
new HadoopFileRDD(
|
||||
this, path, confBroadcast, inputFormatClass, keyClass, valueClass, minSplits)
|
||||
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
|
||||
new HadoopRDD(
|
||||
this,
|
||||
confBroadcast,
|
||||
Some(setInputPathsFunc),
|
||||
inputFormatClass,
|
||||
keyClass,
|
||||
valueClass,
|
||||
minSplits)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -563,7 +606,8 @@ class SparkContext(
|
|||
val uri = new URI(path)
|
||||
val key = uri.getScheme match {
|
||||
case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
|
||||
case _ => path
|
||||
case "local" => "file:" + uri.getPath
|
||||
case _ => path
|
||||
}
|
||||
addedFiles(key) = System.currentTimeMillis
|
||||
|
||||
|
@ -657,12 +701,11 @@ class SparkContext(
|
|||
/**
|
||||
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
|
||||
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
|
||||
* filesystems), or an HTTP, HTTPS or FTP URI.
|
||||
* filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node.
|
||||
*/
|
||||
def addJar(path: String) {
|
||||
if (path == null) {
|
||||
logWarning("null specified as parameter to addJar",
|
||||
new SparkException("null specified as parameter to addJar"))
|
||||
logWarning("null specified as parameter to addJar")
|
||||
} else {
|
||||
var key = ""
|
||||
if (path.contains("\\")) {
|
||||
|
@ -671,12 +714,27 @@ class SparkContext(
|
|||
} else {
|
||||
val uri = new URI(path)
|
||||
key = uri.getScheme match {
|
||||
// A JAR file which exists only on the driver node
|
||||
case null | "file" =>
|
||||
if (env.hadoop.isYarnMode()) {
|
||||
logWarning("local jar specified as parameter to addJar under Yarn mode")
|
||||
return
|
||||
if (SparkHadoopUtil.get.isYarnMode()) {
|
||||
// In order for this to work on yarn the user must specify the --addjars option to
|
||||
// the client to upload the file into the distributed cache to make it show up in the
|
||||
// current working directory.
|
||||
val fileName = new Path(uri.getPath).getName()
|
||||
try {
|
||||
env.httpFileServer.addJar(new File(fileName))
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logError("Error adding jar (" + e + "), was the --addJars option used?")
|
||||
throw e
|
||||
}
|
||||
}
|
||||
} else {
|
||||
env.httpFileServer.addJar(new File(uri.getPath))
|
||||
}
|
||||
env.httpFileServer.addJar(new File(uri.getPath))
|
||||
// A JAR file which exists locally on every worker node
|
||||
case "local" =>
|
||||
"file:" + uri.getPath
|
||||
case _ =>
|
||||
path
|
||||
}
|
||||
|
@ -750,13 +808,13 @@ class SparkContext(
|
|||
allowLocal: Boolean,
|
||||
resultHandler: (Int, U) => Unit) {
|
||||
val callSite = Utils.formatSparkCallSite
|
||||
val cleanedFunc = clean(func)
|
||||
logInfo("Starting job: " + callSite)
|
||||
val start = System.nanoTime
|
||||
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler,
|
||||
localProperties.get)
|
||||
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
|
||||
resultHandler, localProperties.get)
|
||||
logInfo("Job finished: " + callSite + ", took " + (System.nanoTime - start) / 1e9 + " s")
|
||||
rdd.doCheckpoint()
|
||||
result
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -842,6 +900,42 @@ class SparkContext(
|
|||
result
|
||||
}
|
||||
|
||||
/**
|
||||
* Submit a job for execution and return a FutureJob holding the result.
|
||||
*/
|
||||
def submitJob[T, U, R](
|
||||
rdd: RDD[T],
|
||||
processPartition: Iterator[T] => U,
|
||||
partitions: Seq[Int],
|
||||
resultHandler: (Int, U) => Unit,
|
||||
resultFunc: => R): SimpleFutureAction[R] =
|
||||
{
|
||||
val cleanF = clean(processPartition)
|
||||
val callSite = Utils.formatSparkCallSite
|
||||
val waiter = dagScheduler.submitJob(
|
||||
rdd,
|
||||
(context: TaskContext, iter: Iterator[T]) => cleanF(iter),
|
||||
partitions,
|
||||
callSite,
|
||||
allowLocal = false,
|
||||
resultHandler,
|
||||
localProperties.get)
|
||||
new SimpleFutureAction(waiter, resultFunc)
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]]
|
||||
* for more information.
|
||||
*/
|
||||
def cancelJobGroup(groupId: String) {
|
||||
dagScheduler.cancelJobGroup(groupId)
|
||||
}
|
||||
|
||||
/** Cancel all jobs that have been scheduled or are running. */
|
||||
def cancelAllJobs() {
|
||||
dagScheduler.cancelAllJobs()
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean a closure to make it ready to serialized and send to tasks
|
||||
* (removes unreferenced variables in $outer's, updates REPL variables)
|
||||
|
@ -859,9 +953,8 @@ class SparkContext(
|
|||
* prevent accidental overriding of checkpoint files in the existing directory.
|
||||
*/
|
||||
def setCheckpointDir(dir: String, useExisting: Boolean = false) {
|
||||
val env = SparkEnv.get
|
||||
val path = new Path(dir)
|
||||
val fs = path.getFileSystem(env.hadoop.newConfiguration())
|
||||
val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
|
||||
if (!useExisting) {
|
||||
if (fs.exists(path)) {
|
||||
throw new Exception("Checkpoint directory '" + path + "' already exists.")
|
||||
|
@ -898,7 +991,12 @@ class SparkContext(
|
|||
* various Spark features.
|
||||
*/
|
||||
object SparkContext {
|
||||
val SPARK_JOB_DESCRIPTION = "spark.job.description"
|
||||
|
||||
private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
|
||||
|
||||
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
|
||||
|
||||
private[spark] val SPARK_UNKNOWN_USER = "<unknown>"
|
||||
|
||||
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
|
||||
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
|
||||
|
@ -925,6 +1023,8 @@ object SparkContext {
|
|||
implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
|
||||
new PairRDDFunctions(rdd)
|
||||
|
||||
implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
|
||||
|
||||
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
|
||||
rdd: RDD[(K, V)]) =
|
||||
new SequenceFileRDDFunctions(rdd)
|
||||
|
|
|
@ -25,13 +25,13 @@ import akka.remote.RemoteActorRefProvider
|
|||
|
||||
import org.apache.spark.broadcast.BroadcastManager
|
||||
import org.apache.spark.metrics.MetricsSystem
|
||||
import org.apache.spark.deploy.SparkHadoopUtil
|
||||
import org.apache.spark.storage.{BlockManagerMasterActor, BlockManager, BlockManagerMaster}
|
||||
import org.apache.spark.network.ConnectionManager
|
||||
import org.apache.spark.serializer.{Serializer, SerializerManager}
|
||||
import org.apache.spark.util.{Utils, AkkaUtils}
|
||||
import org.apache.spark.api.python.PythonWorkerFactory
|
||||
|
||||
import com.google.common.collect.MapMaker
|
||||
|
||||
/**
|
||||
* Holds all the runtime environment objects for a running Spark instance (either master or worker),
|
||||
|
@ -58,18 +58,9 @@ class SparkEnv (
|
|||
|
||||
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
|
||||
|
||||
val hadoop = {
|
||||
val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
|
||||
if(yarnMode) {
|
||||
try {
|
||||
Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
|
||||
} catch {
|
||||
case th: Throwable => throw new SparkException("Unable to load YARN support", th)
|
||||
}
|
||||
} else {
|
||||
new SparkHadoopUtil
|
||||
}
|
||||
}
|
||||
// A general, soft-reference map for metadata needed during HadoopRDD split computation
|
||||
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
|
||||
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
|
||||
|
||||
def stop() {
|
||||
pythonWorkers.foreach { case(key, worker) => worker.stop() }
|
||||
|
@ -188,10 +179,14 @@ object SparkEnv extends Logging {
|
|||
|
||||
// Have to assign trackerActor after initialization as MapOutputTrackerActor
|
||||
// requires the MapOutputTracker itself
|
||||
val mapOutputTracker = new MapOutputTracker()
|
||||
val mapOutputTracker = if (isDriver) {
|
||||
new MapOutputTrackerMaster()
|
||||
} else {
|
||||
new MapOutputTracker()
|
||||
}
|
||||
mapOutputTracker.trackerActor = registerOrLookup(
|
||||
"MapOutputTracker",
|
||||
new MapOutputTrackerActor(mapOutputTracker))
|
||||
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
|
||||
|
||||
val shuffleFetcher = instantiateClass[ShuffleFetcher](
|
||||
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
|
||||
|
|
|
@ -17,14 +17,14 @@
|
|||
|
||||
package org.apache.hadoop.mapred
|
||||
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import java.io.IOException
|
||||
import java.text.SimpleDateFormat
|
||||
import java.text.NumberFormat
|
||||
import java.io.IOException
|
||||
import java.util.Date
|
||||
|
||||
import org.apache.hadoop.fs.FileSystem
|
||||
import org.apache.hadoop.fs.Path
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.SerializableWritable
|
||||
|
||||
|
@ -36,7 +36,11 @@ import org.apache.spark.SerializableWritable
|
|||
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
|
||||
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
|
||||
*/
|
||||
class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkHadoopMapRedUtil with Serializable {
|
||||
private[apache]
|
||||
class SparkHadoopWriter(@transient jobConf: JobConf)
|
||||
extends Logging
|
||||
with SparkHadoopMapRedUtil
|
||||
with Serializable {
|
||||
|
||||
private val now = new Date()
|
||||
private val conf = new SerializableWritable(jobConf)
|
||||
|
@ -83,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
|
|||
}
|
||||
|
||||
getOutputCommitter().setupTask(getTaskContext())
|
||||
writer = getOutputFormat().getRecordWriter(
|
||||
fs, conf.value, outputName, Reporter.NULL)
|
||||
writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
|
||||
}
|
||||
|
||||
def write(key: AnyRef, value: AnyRef) {
|
||||
if (writer!=null) {
|
||||
//println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
|
||||
if (writer != null) {
|
||||
writer.write(key, value)
|
||||
} else {
|
||||
throw new IOException("Writer is null, open() has not been called")
|
||||
|
@ -179,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf) extends Logging with SparkH
|
|||
}
|
||||
}
|
||||
|
||||
private[apache]
|
||||
object SparkHadoopWriter {
|
||||
def createJobID(time: Date, id: Int): JobID = {
|
||||
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
|
||||
|
|
|
@ -17,21 +17,30 @@
|
|||
|
||||
package org.apache.spark
|
||||
|
||||
import executor.TaskMetrics
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.executor.TaskMetrics
|
||||
|
||||
class TaskContext(
|
||||
val stageId: Int,
|
||||
val splitId: Int,
|
||||
val partitionId: Int,
|
||||
val attemptId: Long,
|
||||
val runningLocally: Boolean = false,
|
||||
val taskMetrics: TaskMetrics = TaskMetrics.empty()
|
||||
@volatile var interrupted: Boolean = false,
|
||||
private[spark] val taskMetrics: TaskMetrics = TaskMetrics.empty()
|
||||
) extends Serializable {
|
||||
|
||||
@transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
|
||||
@deprecated("use partitionId", "0.8.1")
|
||||
def splitId = partitionId
|
||||
|
||||
// Add a callback function to be executed on task completion. An example use
|
||||
// is for HadoopRDD to register a callback to close the input stream.
|
||||
// List of callback functions to execute when the task completes.
|
||||
@transient private val onCompleteCallbacks = new ArrayBuffer[() => Unit]
|
||||
|
||||
/**
|
||||
* Add a callback function to be executed on task completion. An example use
|
||||
* is for HadoopRDD to register a callback to close the input stream.
|
||||
* @param f Callback function.
|
||||
*/
|
||||
def addOnCompleteCallback(f: () => Unit) {
|
||||
onCompleteCallbacks += f
|
||||
}
|
||||
|
|
|
@ -52,4 +52,6 @@ private[spark] case class ExceptionFailure(
|
|||
*/
|
||||
private[spark] case object TaskResultLost extends TaskEndReason
|
||||
|
||||
private[spark] case object TaskKilled extends TaskEndReason
|
||||
|
||||
private[spark] case class OtherFailure(message: String) extends TaskEndReason
|
||||
|
|
|
@ -51,6 +51,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
|
|||
*/
|
||||
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
* This method blocks until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist())
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
*
|
||||
* @param blocking Whether to block until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking))
|
||||
|
||||
// first() has to be overriden here in order for its return type to be Double instead of Object.
|
||||
override def first(): Double = srdd.first()
|
||||
|
||||
|
@ -83,6 +96,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
|
|||
def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD =
|
||||
fromRDD(srdd.coalesce(numPartitions, shuffle))
|
||||
|
||||
/**
|
||||
* Return a new RDD that has exactly numPartitions partitions.
|
||||
*
|
||||
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
|
||||
* a shuffle to redistribute data.
|
||||
*
|
||||
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
|
||||
* which can avoid performing a shuffle.
|
||||
*/
|
||||
def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions))
|
||||
|
||||
/**
|
||||
* Return an RDD with the elements from `this` that are not in `other`.
|
||||
*
|
||||
|
|
|
@ -66,6 +66,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
|
|||
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
|
||||
new JavaPairRDD[K, V](rdd.persist(newLevel))
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
* This method blocks until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist())
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
*
|
||||
* @param blocking Whether to block until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking))
|
||||
|
||||
// Transformations (return a new RDD)
|
||||
|
||||
/**
|
||||
|
@ -95,6 +108,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
|
|||
def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] =
|
||||
fromRDD(rdd.coalesce(numPartitions, shuffle))
|
||||
|
||||
/**
|
||||
* Return a new RDD that has exactly numPartitions partitions.
|
||||
*
|
||||
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
|
||||
* a shuffle to redistribute data.
|
||||
*
|
||||
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
|
||||
* which can avoid performing a shuffle.
|
||||
*/
|
||||
def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions))
|
||||
|
||||
/**
|
||||
* Return a sampled subset of this RDD.
|
||||
*/
|
||||
|
@ -599,4 +623,15 @@ object JavaPairRDD {
|
|||
new JavaPairRDD[K, V](rdd)
|
||||
|
||||
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
|
||||
|
||||
|
||||
/** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
|
||||
def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
|
||||
implicit val cmk: ClassTag[K] =
|
||||
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
|
||||
implicit val cmv: ClassTag[V] =
|
||||
implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
|
||||
new JavaPairRDD[K, V](rdd.rdd)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -43,9 +43,17 @@ JavaRDDLike[T, JavaRDD[T]] {
|
|||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
* This method blocks until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
|
||||
|
||||
/**
|
||||
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
|
||||
*
|
||||
* @param blocking Whether to block until all blocks are deleted.
|
||||
*/
|
||||
def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking))
|
||||
|
||||
// Transformations (return a new RDD)
|
||||
|
||||
/**
|
||||
|
@ -75,6 +83,17 @@ JavaRDDLike[T, JavaRDD[T]] {
|
|||
def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] =
|
||||
rdd.coalesce(numPartitions, shuffle)
|
||||
|
||||
/**
|
||||
* Return a new RDD that has exactly numPartitions partitions.
|
||||
*
|
||||
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
|
||||
* a shuffle to redistribute data.
|
||||
*
|
||||
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
|
||||
* which can avoid performing a shuffle.
|
||||
*/
|
||||
def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions)
|
||||
|
||||
/**
|
||||
* Return a sampled subset of this RDD.
|
||||
*/
|
||||
|
|
|
@ -18,8 +18,6 @@
|
|||
package org.apache.spark.api.java.function;
|
||||
|
||||
|
||||
import scala.runtime.AbstractFunction1;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
|
@ -27,11 +25,7 @@ import java.io.Serializable;
|
|||
*/
|
||||
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
|
||||
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
|
||||
public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
|
||||
public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>>
|
||||
implements Serializable {
|
||||
|
||||
public abstract Iterable<Double> call(T t);
|
||||
|
||||
@Override
|
||||
public final Iterable<Double> apply(T t) { return call(t); }
|
||||
// Intentionally left blank
|
||||
}
|
||||
|
|
|
@ -27,6 +27,5 @@ import java.io.Serializable;
|
|||
// are overloaded for both Function and DoubleFunction.
|
||||
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
|
||||
implements Serializable {
|
||||
|
||||
public abstract Double call(T t) throws Exception;
|
||||
// Intentionally left blank
|
||||
}
|
||||
|
|
|
@ -23,8 +23,5 @@ import scala.reflect.ClassTag
|
|||
* A function that returns zero or more output records from each input record.
|
||||
*/
|
||||
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
|
||||
@throws(classOf[Exception])
|
||||
def call(x: T) : java.lang.Iterable[R]
|
||||
|
||||
def elementType() : ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]]
|
||||
}
|
||||
|
|
|
@ -23,8 +23,5 @@ import scala.reflect.ClassTag
|
|||
* A function that takes two inputs and returns zero or more output records.
|
||||
*/
|
||||
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
|
||||
@throws(classOf[Exception])
|
||||
def call(a: A, b:B) : java.lang.Iterable[C]
|
||||
|
||||
def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]]
|
||||
}
|
||||
|
|
|
@ -29,10 +29,8 @@ import java.io.Serializable;
|
|||
* when mapping RDDs of other types.
|
||||
*/
|
||||
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
|
||||
public abstract R call(T t) throws Exception;
|
||||
|
||||
public ClassTag<R> returnType() {
|
||||
return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
|
||||
return ClassTag$.MODULE$.apply(Object.class);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -28,8 +28,6 @@ import java.io.Serializable;
|
|||
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
|
||||
implements Serializable {
|
||||
|
||||
public abstract R call(T1 t1, T2 t2) throws Exception;
|
||||
|
||||
public ClassTag<R> returnType() {
|
||||
return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.java.function;
|
||||
|
||||
import scala.reflect.ClassTag;
|
||||
import scala.reflect.ClassTag$;
|
||||
import scala.runtime.AbstractFunction2;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
|
||||
*/
|
||||
public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R>
|
||||
implements Serializable {
|
||||
|
||||
public ClassTag<R> returnType() {
|
||||
return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
|
||||
}
|
||||
}
|
||||
|
|
@ -33,8 +33,6 @@ public abstract class PairFlatMapFunction<T, K, V>
|
|||
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
|
||||
implements Serializable {
|
||||
|
||||
public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
|
||||
|
||||
public ClassTag<K> keyType() {
|
||||
return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
|
||||
}
|
||||
|
|
|
@ -28,12 +28,9 @@ import java.io.Serializable;
|
|||
*/
|
||||
// PairFunction does not extend Function because some UDF functions, like map,
|
||||
// are overloaded for both Function and PairFunction.
|
||||
public abstract class PairFunction<T, K, V>
|
||||
extends WrappedFunction1<T, Tuple2<K, V>>
|
||||
public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
|
||||
implements Serializable {
|
||||
|
||||
public abstract Tuple2<K, V> call(T t) throws Exception;
|
||||
|
||||
public ClassTag<K> keyType() {
|
||||
return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.api.java.function
|
||||
|
||||
import scala.runtime.AbstractFunction3
|
||||
|
||||
/**
|
||||
* Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the
|
||||
* apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply
|
||||
* isn't marked to allow that).
|
||||
*/
|
||||
private[spark] abstract class WrappedFunction3[T1, T2, T3, R]
|
||||
extends AbstractFunction3[T1, T2, T3, R] {
|
||||
@throws(classOf[Exception])
|
||||
def call(t1: T1, t2: T2, t3: T3): R
|
||||
|
||||
final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3)
|
||||
}
|
||||
|
|
@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[
|
|||
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
|
||||
* collects a list of pickled strings that we pass to Python through a socket.
|
||||
*/
|
||||
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
|
||||
private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
|
||||
extends AccumulatorParam[JList[Array[Byte]]] {
|
||||
|
||||
Utils.checkHost(serverHost, "Expected hostname")
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -25,16 +25,15 @@ import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
|
|||
|
||||
import org.apache.spark.{HttpServer, Logging, SparkEnv}
|
||||
import org.apache.spark.io.CompressionCodec
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.util.{Utils, MetadataCleaner, TimeStampedHashSet}
|
||||
|
||||
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
|
||||
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils}
|
||||
|
||||
private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
|
||||
extends Broadcast[T](id) with Logging with Serializable {
|
||||
|
||||
def value = value_
|
||||
|
||||
def blockId: String = "broadcast_" + id
|
||||
def blockId = BroadcastBlockId(id)
|
||||
|
||||
HttpBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
|
||||
|
@ -82,7 +81,7 @@ private object HttpBroadcast extends Logging {
|
|||
private var server: HttpServer = null
|
||||
|
||||
private val files = new TimeStampedHashSet[String]
|
||||
private val cleaner = new MetadataCleaner("HttpBroadcast", cleanup)
|
||||
private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup)
|
||||
|
||||
private lazy val compressionCodec = CompressionCodec.createCodec()
|
||||
|
||||
|
@ -121,7 +120,7 @@ private object HttpBroadcast extends Logging {
|
|||
}
|
||||
|
||||
def write(id: Long, value: Any) {
|
||||
val file = new File(broadcastDir, "broadcast-" + id)
|
||||
val file = new File(broadcastDir, BroadcastBlockId(id).name)
|
||||
val out: OutputStream = {
|
||||
if (compress) {
|
||||
compressionCodec.compressedOutputStream(new FileOutputStream(file))
|
||||
|
@ -137,7 +136,7 @@ private object HttpBroadcast extends Logging {
|
|||
}
|
||||
|
||||
def read[T](id: Long): T = {
|
||||
val url = serverUri + "/broadcast-" + id
|
||||
val url = serverUri + "/" + BroadcastBlockId(id).name
|
||||
val in = {
|
||||
if (compress) {
|
||||
compressionCodec.compressedInputStream(new URL(url).openStream())
|
||||
|
|
|
@ -1,410 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.broadcast
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.Random
|
||||
|
||||
import scala.collection.mutable.Map
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
private object MultiTracker
|
||||
extends Logging {
|
||||
|
||||
// Tracker Messages
|
||||
val REGISTER_BROADCAST_TRACKER = 0
|
||||
val UNREGISTER_BROADCAST_TRACKER = 1
|
||||
val FIND_BROADCAST_TRACKER = 2
|
||||
|
||||
// Map to keep track of guides of ongoing broadcasts
|
||||
var valueToGuideMap = Map[Long, SourceInfo]()
|
||||
|
||||
// Random number generator
|
||||
var ranGen = new Random
|
||||
|
||||
private var initialized = false
|
||||
private var _isDriver = false
|
||||
|
||||
private var stopBroadcast = false
|
||||
|
||||
private var trackMV: TrackMultipleValues = null
|
||||
|
||||
def initialize(__isDriver: Boolean) {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
_isDriver = __isDriver
|
||||
|
||||
if (isDriver) {
|
||||
trackMV = new TrackMultipleValues
|
||||
trackMV.setDaemon(true)
|
||||
trackMV.start()
|
||||
|
||||
// Set DriverHostAddress to the driver's IP address for the slaves to read
|
||||
System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
|
||||
}
|
||||
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def stop() {
|
||||
stopBroadcast = true
|
||||
}
|
||||
|
||||
// Load common parameters
|
||||
private var DriverHostAddress_ = System.getProperty(
|
||||
"spark.MultiTracker.DriverHostAddress", "")
|
||||
private var DriverTrackerPort_ = System.getProperty(
|
||||
"spark.broadcast.driverTrackerPort", "11111").toInt
|
||||
private var BlockSize_ = System.getProperty(
|
||||
"spark.broadcast.blockSize", "4096").toInt * 1024
|
||||
private var MaxRetryCount_ = System.getProperty(
|
||||
"spark.broadcast.maxRetryCount", "2").toInt
|
||||
|
||||
private var TrackerSocketTimeout_ = System.getProperty(
|
||||
"spark.broadcast.trackerSocketTimeout", "50000").toInt
|
||||
private var ServerSocketTimeout_ = System.getProperty(
|
||||
"spark.broadcast.serverSocketTimeout", "10000").toInt
|
||||
|
||||
private var MinKnockInterval_ = System.getProperty(
|
||||
"spark.broadcast.minKnockInterval", "500").toInt
|
||||
private var MaxKnockInterval_ = System.getProperty(
|
||||
"spark.broadcast.maxKnockInterval", "999").toInt
|
||||
|
||||
// Load TreeBroadcast config params
|
||||
private var MaxDegree_ = System.getProperty(
|
||||
"spark.broadcast.maxDegree", "2").toInt
|
||||
|
||||
// Load BitTorrentBroadcast config params
|
||||
private var MaxPeersInGuideResponse_ = System.getProperty(
|
||||
"spark.broadcast.maxPeersInGuideResponse", "4").toInt
|
||||
|
||||
private var MaxChatSlots_ = System.getProperty(
|
||||
"spark.broadcast.maxChatSlots", "4").toInt
|
||||
private var MaxChatTime_ = System.getProperty(
|
||||
"spark.broadcast.maxChatTime", "500").toInt
|
||||
private var MaxChatBlocks_ = System.getProperty(
|
||||
"spark.broadcast.maxChatBlocks", "1024").toInt
|
||||
|
||||
private var EndGameFraction_ = System.getProperty(
|
||||
"spark.broadcast.endGameFraction", "0.95").toDouble
|
||||
|
||||
def isDriver = _isDriver
|
||||
|
||||
// Common config params
|
||||
def DriverHostAddress = DriverHostAddress_
|
||||
def DriverTrackerPort = DriverTrackerPort_
|
||||
def BlockSize = BlockSize_
|
||||
def MaxRetryCount = MaxRetryCount_
|
||||
|
||||
def TrackerSocketTimeout = TrackerSocketTimeout_
|
||||
def ServerSocketTimeout = ServerSocketTimeout_
|
||||
|
||||
def MinKnockInterval = MinKnockInterval_
|
||||
def MaxKnockInterval = MaxKnockInterval_
|
||||
|
||||
// TreeBroadcast configs
|
||||
def MaxDegree = MaxDegree_
|
||||
|
||||
// BitTorrentBroadcast configs
|
||||
def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
|
||||
|
||||
def MaxChatSlots = MaxChatSlots_
|
||||
def MaxChatTime = MaxChatTime_
|
||||
def MaxChatBlocks = MaxChatBlocks_
|
||||
|
||||
def EndGameFraction = EndGameFraction_
|
||||
|
||||
class TrackMultipleValues
|
||||
extends Thread with Logging {
|
||||
override def run() {
|
||||
var threadPool = Utils.newDaemonCachedThreadPool()
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket(DriverTrackerPort)
|
||||
logInfo("TrackMultipleValues started at " + serverSocket)
|
||||
|
||||
try {
|
||||
while (!stopBroadcast) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout(TrackerSocketTimeout)
|
||||
clientSocket = serverSocket.accept()
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
if (stopBroadcast) {
|
||||
logInfo("Stopping TrackMultipleValues...")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (clientSocket != null) {
|
||||
try {
|
||||
threadPool.execute(new Thread {
|
||||
override def run() {
|
||||
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
|
||||
oos.flush()
|
||||
val ois = new ObjectInputStream(clientSocket.getInputStream)
|
||||
|
||||
try {
|
||||
// First, read message type
|
||||
val messageType = ois.readObject.asInstanceOf[Int]
|
||||
|
||||
if (messageType == REGISTER_BROADCAST_TRACKER) {
|
||||
// Receive Long
|
||||
val id = ois.readObject.asInstanceOf[Long]
|
||||
// Receive hostAddress and listenPort
|
||||
val gInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
// Add to the map
|
||||
valueToGuideMap.synchronized {
|
||||
valueToGuideMap += (id -> gInfo)
|
||||
}
|
||||
|
||||
logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
|
||||
|
||||
// Send dummy ACK
|
||||
oos.writeObject(-1)
|
||||
oos.flush()
|
||||
} else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
|
||||
// Receive Long
|
||||
val id = ois.readObject.asInstanceOf[Long]
|
||||
|
||||
// Remove from the map
|
||||
valueToGuideMap.synchronized {
|
||||
valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
|
||||
}
|
||||
|
||||
logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
|
||||
|
||||
// Send dummy ACK
|
||||
oos.writeObject(-1)
|
||||
oos.flush()
|
||||
} else if (messageType == FIND_BROADCAST_TRACKER) {
|
||||
// Receive Long
|
||||
val id = ois.readObject.asInstanceOf[Long]
|
||||
|
||||
var gInfo =
|
||||
if (valueToGuideMap.contains(id)) valueToGuideMap(id)
|
||||
else SourceInfo("", SourceInfo.TxNotStartedRetry)
|
||||
|
||||
logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
|
||||
|
||||
// Send reply back
|
||||
oos.writeObject(gInfo)
|
||||
oos.flush()
|
||||
} else {
|
||||
throw new SparkException("Undefined messageType at TrackMultipleValues")
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logError("TrackMultipleValues had a " + e)
|
||||
}
|
||||
} finally {
|
||||
ois.close()
|
||||
oos.close()
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
})
|
||||
} catch {
|
||||
// In failure, close socket here; else, client thread will close
|
||||
case ioe: IOException => clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
serverSocket.close()
|
||||
}
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown()
|
||||
}
|
||||
}
|
||||
|
||||
def getGuideInfo(variableLong: Long): SourceInfo = {
|
||||
var clientSocketToTracker: Socket = null
|
||||
var oosTracker: ObjectOutputStream = null
|
||||
var oisTracker: ObjectInputStream = null
|
||||
|
||||
var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
|
||||
|
||||
var retriesLeft = MultiTracker.MaxRetryCount
|
||||
do {
|
||||
try {
|
||||
// Connect to the tracker to find out GuideInfo
|
||||
clientSocketToTracker =
|
||||
new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
|
||||
oosTracker =
|
||||
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
|
||||
oosTracker.flush()
|
||||
oisTracker =
|
||||
new ObjectInputStream(clientSocketToTracker.getInputStream)
|
||||
|
||||
// Send messageType/intention
|
||||
oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
|
||||
oosTracker.flush()
|
||||
|
||||
// Send Long and receive GuideInfo
|
||||
oosTracker.writeObject(variableLong)
|
||||
oosTracker.flush()
|
||||
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
|
||||
} catch {
|
||||
case e: Exception => logError("getGuideInfo had a " + e)
|
||||
} finally {
|
||||
if (oisTracker != null) {
|
||||
oisTracker.close()
|
||||
}
|
||||
if (oosTracker != null) {
|
||||
oosTracker.close()
|
||||
}
|
||||
if (clientSocketToTracker != null) {
|
||||
clientSocketToTracker.close()
|
||||
}
|
||||
}
|
||||
|
||||
Thread.sleep(MultiTracker.ranGen.nextInt(
|
||||
MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
|
||||
MultiTracker.MinKnockInterval)
|
||||
|
||||
retriesLeft -= 1
|
||||
} while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
|
||||
|
||||
logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
|
||||
return gInfo
|
||||
}
|
||||
|
||||
def registerBroadcast(id: Long, gInfo: SourceInfo) {
|
||||
val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
|
||||
val oosST = new ObjectOutputStream(socket.getOutputStream)
|
||||
oosST.flush()
|
||||
val oisST = new ObjectInputStream(socket.getInputStream)
|
||||
|
||||
// Send messageType/intention
|
||||
oosST.writeObject(REGISTER_BROADCAST_TRACKER)
|
||||
oosST.flush()
|
||||
|
||||
// Send Long of this broadcast
|
||||
oosST.writeObject(id)
|
||||
oosST.flush()
|
||||
|
||||
// Send this tracker's information
|
||||
oosST.writeObject(gInfo)
|
||||
oosST.flush()
|
||||
|
||||
// Receive ACK and throw it away
|
||||
oisST.readObject.asInstanceOf[Int]
|
||||
|
||||
// Shut stuff down
|
||||
oisST.close()
|
||||
oosST.close()
|
||||
socket.close()
|
||||
}
|
||||
|
||||
def unregisterBroadcast(id: Long) {
|
||||
val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
|
||||
val oosST = new ObjectOutputStream(socket.getOutputStream)
|
||||
oosST.flush()
|
||||
val oisST = new ObjectInputStream(socket.getInputStream)
|
||||
|
||||
// Send messageType/intention
|
||||
oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
|
||||
oosST.flush()
|
||||
|
||||
// Send Long of this broadcast
|
||||
oosST.writeObject(id)
|
||||
oosST.flush()
|
||||
|
||||
// Receive ACK and throw it away
|
||||
oisST.readObject.asInstanceOf[Int]
|
||||
|
||||
// Shut stuff down
|
||||
oisST.close()
|
||||
oosST.close()
|
||||
socket.close()
|
||||
}
|
||||
|
||||
// Helper method to convert an object to Array[BroadcastBlock]
|
||||
def blockifyObject[IN](obj: IN): VariableInfo = {
|
||||
val baos = new ByteArrayOutputStream
|
||||
val oos = new ObjectOutputStream(baos)
|
||||
oos.writeObject(obj)
|
||||
oos.close()
|
||||
baos.close()
|
||||
val byteArray = baos.toByteArray
|
||||
val bais = new ByteArrayInputStream(byteArray)
|
||||
|
||||
var blockNum = (byteArray.length / BlockSize)
|
||||
if (byteArray.length % BlockSize != 0)
|
||||
blockNum += 1
|
||||
|
||||
var retVal = new Array[BroadcastBlock](blockNum)
|
||||
var blockID = 0
|
||||
|
||||
for (i <- 0 until (byteArray.length, BlockSize)) {
|
||||
val thisBlockSize = math.min(BlockSize, byteArray.length - i)
|
||||
var tempByteArray = new Array[Byte](thisBlockSize)
|
||||
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
|
||||
|
||||
retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
|
||||
blockID += 1
|
||||
}
|
||||
bais.close()
|
||||
|
||||
var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
|
||||
variableInfo.hasBlocks = blockNum
|
||||
|
||||
return variableInfo
|
||||
}
|
||||
|
||||
// Helper method to convert Array[BroadcastBlock] to object
|
||||
def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
|
||||
totalBytes: Int,
|
||||
totalBlocks: Int): OUT = {
|
||||
|
||||
var retByteArray = new Array[Byte](totalBytes)
|
||||
for (i <- 0 until totalBlocks) {
|
||||
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
|
||||
i * BlockSize, arrayOfBlocks(i).byteArray.length)
|
||||
}
|
||||
byteArrayToObject(retByteArray)
|
||||
}
|
||||
|
||||
private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
|
||||
val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
|
||||
override def resolveClass(desc: ObjectStreamClass) =
|
||||
Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
|
||||
}
|
||||
val retVal = in.readObject.asInstanceOf[OUT]
|
||||
in.close()
|
||||
return retVal
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
|
||||
extends Serializable
|
||||
|
||||
private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
|
||||
totalBlocks: Int,
|
||||
totalBytes: Int)
|
||||
extends Serializable {
|
||||
@transient var hasBlocks = 0
|
||||
}
|
|
@ -1,54 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.broadcast
|
||||
|
||||
import java.util.BitSet
|
||||
|
||||
import org.apache.spark._
|
||||
|
||||
/**
|
||||
* Used to keep and pass around information of peers involved in a broadcast
|
||||
*/
|
||||
private[spark] case class SourceInfo (hostAddress: String,
|
||||
listenPort: Int,
|
||||
totalBlocks: Int = SourceInfo.UnusedParam,
|
||||
totalBytes: Int = SourceInfo.UnusedParam)
|
||||
extends Comparable[SourceInfo] with Logging {
|
||||
|
||||
var currentLeechers = 0
|
||||
var receptionFailed = false
|
||||
|
||||
var hasBlocks = 0
|
||||
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
|
||||
|
||||
// Ascending sort based on leecher count
|
||||
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper Object of SourceInfo for its constants
|
||||
*/
|
||||
private[spark] object SourceInfo {
|
||||
// Broadcast has not started yet! Should never happen.
|
||||
val TxNotStartedRetry = -1
|
||||
// Broadcast has already finished. Try default mechanism.
|
||||
val TxOverGoToDefault = -3
|
||||
// Other constants
|
||||
val StopBroadcast = -2
|
||||
val UnusedParam = 0
|
||||
}
|
|
@ -0,0 +1,247 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.broadcast
|
||||
|
||||
import java.io._
|
||||
|
||||
import scala.math
|
||||
import scala.util.Random
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
||||
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
|
||||
extends Broadcast[T](id) with Logging with Serializable {
|
||||
|
||||
def value = value_
|
||||
|
||||
def broadcastId = BroadcastBlockId(id)
|
||||
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
|
||||
}
|
||||
|
||||
@transient var arrayOfBlocks: Array[TorrentBlock] = null
|
||||
@transient var totalBlocks = -1
|
||||
@transient var totalBytes = -1
|
||||
@transient var hasBlocks = 0
|
||||
|
||||
if (!isLocal) {
|
||||
sendBroadcast()
|
||||
}
|
||||
|
||||
def sendBroadcast() {
|
||||
var tInfo = TorrentBroadcast.blockifyObject(value_)
|
||||
|
||||
totalBlocks = tInfo.totalBlocks
|
||||
totalBytes = tInfo.totalBytes
|
||||
hasBlocks = tInfo.totalBlocks
|
||||
|
||||
// Store meta-info
|
||||
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
|
||||
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.putSingle(
|
||||
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
|
||||
}
|
||||
|
||||
// Store individual pieces
|
||||
for (i <- 0 until totalBlocks) {
|
||||
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.putSingle(
|
||||
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Called by JVM when deserializing an object
|
||||
private def readObject(in: ObjectInputStream) {
|
||||
in.defaultReadObject()
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.getSingle(broadcastId) match {
|
||||
case Some(x) =>
|
||||
value_ = x.asInstanceOf[T]
|
||||
|
||||
case None =>
|
||||
val start = System.nanoTime
|
||||
logInfo("Started reading broadcast variable " + id)
|
||||
|
||||
// Initialize @transient variables that will receive garbage values from the master.
|
||||
resetWorkerVariables()
|
||||
|
||||
if (receiveBroadcast(id)) {
|
||||
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
|
||||
|
||||
// Store the merged copy in cache so that the next worker doesn't need to rebuild it.
|
||||
// This creates a tradeoff between memory usage and latency.
|
||||
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
|
||||
SparkEnv.get.blockManager.putSingle(
|
||||
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
|
||||
|
||||
// Remove arrayOfBlocks from memory once value_ is on local cache
|
||||
resetWorkerVariables()
|
||||
} else {
|
||||
logError("Reading broadcast variable " + id + " failed")
|
||||
}
|
||||
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
logInfo("Reading broadcast variable " + id + " took " + time + " s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def resetWorkerVariables() {
|
||||
arrayOfBlocks = null
|
||||
totalBytes = -1
|
||||
totalBlocks = -1
|
||||
hasBlocks = 0
|
||||
}
|
||||
|
||||
def receiveBroadcast(variableID: Long): Boolean = {
|
||||
// Receive meta-info
|
||||
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
|
||||
var attemptId = 10
|
||||
while (attemptId > 0 && totalBlocks == -1) {
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.getSingle(metaId) match {
|
||||
case Some(x) =>
|
||||
val tInfo = x.asInstanceOf[TorrentInfo]
|
||||
totalBlocks = tInfo.totalBlocks
|
||||
totalBytes = tInfo.totalBytes
|
||||
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
|
||||
hasBlocks = 0
|
||||
|
||||
case None =>
|
||||
Thread.sleep(500)
|
||||
}
|
||||
}
|
||||
attemptId -= 1
|
||||
}
|
||||
if (totalBlocks == -1) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Receive actual blocks
|
||||
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
|
||||
for (pid <- recvOrder) {
|
||||
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
|
||||
TorrentBroadcast.synchronized {
|
||||
SparkEnv.get.blockManager.getSingle(pieceId) match {
|
||||
case Some(x) =>
|
||||
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
|
||||
hasBlocks += 1
|
||||
SparkEnv.get.blockManager.putSingle(
|
||||
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
|
||||
|
||||
case None =>
|
||||
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(hasBlocks == totalBlocks)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private object TorrentBroadcast
|
||||
extends Logging {
|
||||
|
||||
private var initialized = false
|
||||
|
||||
def initialize(_isDriver: Boolean) {
|
||||
synchronized {
|
||||
if (!initialized) {
|
||||
initialized = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def stop() {
|
||||
initialized = false
|
||||
}
|
||||
|
||||
val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
|
||||
|
||||
def blockifyObject[T](obj: T): TorrentInfo = {
|
||||
val byteArray = Utils.serialize[T](obj)
|
||||
val bais = new ByteArrayInputStream(byteArray)
|
||||
|
||||
var blockNum = (byteArray.length / BLOCK_SIZE)
|
||||
if (byteArray.length % BLOCK_SIZE != 0)
|
||||
blockNum += 1
|
||||
|
||||
var retVal = new Array[TorrentBlock](blockNum)
|
||||
var blockID = 0
|
||||
|
||||
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
|
||||
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
|
||||
var tempByteArray = new Array[Byte](thisBlockSize)
|
||||
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
|
||||
|
||||
retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
|
||||
blockID += 1
|
||||
}
|
||||
bais.close()
|
||||
|
||||
var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
|
||||
tInfo.hasBlocks = blockNum
|
||||
|
||||
return tInfo
|
||||
}
|
||||
|
||||
def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
|
||||
totalBytes: Int,
|
||||
totalBlocks: Int): T = {
|
||||
var retByteArray = new Array[Byte](totalBytes)
|
||||
for (i <- 0 until totalBlocks) {
|
||||
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
|
||||
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
|
||||
}
|
||||
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
private[spark] case class TorrentBlock(
|
||||
blockID: Int,
|
||||
byteArray: Array[Byte])
|
||||
extends Serializable
|
||||
|
||||
private[spark] case class TorrentInfo(
|
||||
@transient arrayOfBlocks : Array[TorrentBlock],
|
||||
totalBlocks: Int,
|
||||
totalBytes: Int)
|
||||
extends Serializable {
|
||||
|
||||
@transient var hasBlocks = 0
|
||||
}
|
||||
|
||||
private[spark] class TorrentBroadcastFactory
|
||||
extends BroadcastFactory {
|
||||
|
||||
def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
|
||||
|
||||
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
|
||||
new TorrentBroadcast[T](value_, isLocal, id)
|
||||
|
||||
def stop() { TorrentBroadcast.stop() }
|
||||
}
|
|
@ -1,603 +0,0 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.broadcast
|
||||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.{Comparator, Random, UUID}
|
||||
|
||||
import scala.collection.mutable.{ListBuffer, Map, Set}
|
||||
import scala.math
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
|
||||
extends Broadcast[T](id) with Logging with Serializable {
|
||||
|
||||
def value = value_
|
||||
|
||||
def blockId = "broadcast_" + id
|
||||
|
||||
MultiTracker.synchronized {
|
||||
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
|
||||
}
|
||||
|
||||
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
|
||||
@transient var totalBytes = -1
|
||||
@transient var totalBlocks = -1
|
||||
@transient var hasBlocks = 0
|
||||
|
||||
@transient var listenPortLock = new Object
|
||||
@transient var guidePortLock = new Object
|
||||
@transient var totalBlocksLock = new Object
|
||||
@transient var hasBlocksLock = new Object
|
||||
|
||||
@transient var listOfSources = ListBuffer[SourceInfo]()
|
||||
|
||||
@transient var serveMR: ServeMultipleRequests = null
|
||||
@transient var guideMR: GuideMultipleRequests = null
|
||||
|
||||
@transient var hostAddress = Utils.localIpAddress
|
||||
@transient var listenPort = -1
|
||||
@transient var guidePort = -1
|
||||
|
||||
@transient var stopBroadcast = false
|
||||
|
||||
// Must call this after all the variables have been created/initialized
|
||||
if (!isLocal) {
|
||||
sendBroadcast()
|
||||
}
|
||||
|
||||
def sendBroadcast() {
|
||||
logInfo("Local host address: " + hostAddress)
|
||||
|
||||
// Create a variableInfo object and store it in valueInfos
|
||||
var variableInfo = MultiTracker.blockifyObject(value_)
|
||||
|
||||
// Prepare the value being broadcasted
|
||||
arrayOfBlocks = variableInfo.arrayOfBlocks
|
||||
totalBytes = variableInfo.totalBytes
|
||||
totalBlocks = variableInfo.totalBlocks
|
||||
hasBlocks = variableInfo.totalBlocks
|
||||
|
||||
guideMR = new GuideMultipleRequests
|
||||
guideMR.setDaemon(true)
|
||||
guideMR.start()
|
||||
logInfo("GuideMultipleRequests started...")
|
||||
|
||||
// Must always come AFTER guideMR is created
|
||||
while (guidePort == -1) {
|
||||
guidePortLock.synchronized { guidePortLock.wait() }
|
||||
}
|
||||
|
||||
serveMR = new ServeMultipleRequests
|
||||
serveMR.setDaemon(true)
|
||||
serveMR.start()
|
||||
logInfo("ServeMultipleRequests started...")
|
||||
|
||||
// Must always come AFTER serveMR is created
|
||||
while (listenPort == -1) {
|
||||
listenPortLock.synchronized { listenPortLock.wait() }
|
||||
}
|
||||
|
||||
// Must always come AFTER listenPort is created
|
||||
val masterSource =
|
||||
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
|
||||
listOfSources += masterSource
|
||||
|
||||
// Register with the Tracker
|
||||
MultiTracker.registerBroadcast(id,
|
||||
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
|
||||
}
|
||||
|
||||
private def readObject(in: ObjectInputStream) {
|
||||
in.defaultReadObject()
|
||||
MultiTracker.synchronized {
|
||||
SparkEnv.get.blockManager.getSingle(blockId) match {
|
||||
case Some(x) =>
|
||||
value_ = x.asInstanceOf[T]
|
||||
|
||||
case None =>
|
||||
logInfo("Started reading broadcast variable " + id)
|
||||
// Initializing everything because Driver will only send null/0 values
|
||||
// Only the 1st worker in a node can be here. Others will get from cache
|
||||
initializeWorkerVariables()
|
||||
|
||||
logInfo("Local host address: " + hostAddress)
|
||||
|
||||
serveMR = new ServeMultipleRequests
|
||||
serveMR.setDaemon(true)
|
||||
serveMR.start()
|
||||
logInfo("ServeMultipleRequests started...")
|
||||
|
||||
val start = System.nanoTime
|
||||
|
||||
val receptionSucceeded = receiveBroadcast(id)
|
||||
if (receptionSucceeded) {
|
||||
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
|
||||
SparkEnv.get.blockManager.putSingle(
|
||||
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
|
||||
} else {
|
||||
logError("Reading broadcast variable " + id + " failed")
|
||||
}
|
||||
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
logInfo("Reading broadcast variable " + id + " took " + time + " s")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def initializeWorkerVariables() {
|
||||
arrayOfBlocks = null
|
||||
totalBytes = -1
|
||||
totalBlocks = -1
|
||||
hasBlocks = 0
|
||||
|
||||
listenPortLock = new Object
|
||||
totalBlocksLock = new Object
|
||||
hasBlocksLock = new Object
|
||||
|
||||
serveMR = null
|
||||
|
||||
hostAddress = Utils.localIpAddress
|
||||
listenPort = -1
|
||||
|
||||
stopBroadcast = false
|
||||
}
|
||||
|
||||
def receiveBroadcast(variableID: Long): Boolean = {
|
||||
val gInfo = MultiTracker.getGuideInfo(variableID)
|
||||
|
||||
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Wait until hostAddress and listenPort are created by the
|
||||
// ServeMultipleRequests thread
|
||||
while (listenPort == -1) {
|
||||
listenPortLock.synchronized { listenPortLock.wait() }
|
||||
}
|
||||
|
||||
var clientSocketToDriver: Socket = null
|
||||
var oosDriver: ObjectOutputStream = null
|
||||
var oisDriver: ObjectInputStream = null
|
||||
|
||||
// Connect and receive broadcast from the specified source, retrying the
|
||||
// specified number of times in case of failures
|
||||
var retriesLeft = MultiTracker.MaxRetryCount
|
||||
do {
|
||||
// Connect to Driver and send this worker's Information
|
||||
clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
|
||||
oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
|
||||
oosDriver.flush()
|
||||
oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
|
||||
|
||||
logDebug("Connected to Driver's guiding object")
|
||||
|
||||
// Send local source information
|
||||
oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
|
||||
oosDriver.flush()
|
||||
|
||||
// Receive source information from Driver
|
||||
var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
|
||||
totalBlocks = sourceInfo.totalBlocks
|
||||
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
|
||||
totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
|
||||
totalBytes = sourceInfo.totalBytes
|
||||
|
||||
logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
|
||||
|
||||
val start = System.nanoTime
|
||||
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
|
||||
val time = (System.nanoTime - start) / 1e9
|
||||
|
||||
// Updating some statistics in sourceInfo. Driver will be using them later
|
||||
if (!receptionSucceeded) {
|
||||
sourceInfo.receptionFailed = true
|
||||
}
|
||||
|
||||
// Send back statistics to the Driver
|
||||
oosDriver.writeObject(sourceInfo)
|
||||
|
||||
if (oisDriver != null) {
|
||||
oisDriver.close()
|
||||
}
|
||||
if (oosDriver != null) {
|
||||
oosDriver.close()
|
||||
}
|
||||
if (clientSocketToDriver != null) {
|
||||
clientSocketToDriver.close()
|
||||
}
|
||||
|
||||
retriesLeft -= 1
|
||||
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
|
||||
|
||||
return (hasBlocks == totalBlocks)
|
||||
}
|
||||
|
||||
/**
|
||||
* Tries to receive broadcast from the source and returns Boolean status.
|
||||
* This might be called multiple times to retry a defined number of times.
|
||||
*/
|
||||
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
|
||||
var clientSocketToSource: Socket = null
|
||||
var oosSource: ObjectOutputStream = null
|
||||
var oisSource: ObjectInputStream = null
|
||||
|
||||
var receptionSucceeded = false
|
||||
try {
|
||||
// Connect to the source to get the object itself
|
||||
clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
|
||||
oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
|
||||
oosSource.flush()
|
||||
oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
|
||||
|
||||
logDebug("Inside receiveSingleTransmission")
|
||||
logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
|
||||
|
||||
// Send the range
|
||||
oosSource.writeObject((hasBlocks, totalBlocks))
|
||||
oosSource.flush()
|
||||
|
||||
for (i <- hasBlocks until totalBlocks) {
|
||||
val recvStartTime = System.currentTimeMillis
|
||||
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
|
||||
val receptionTime = (System.currentTimeMillis - recvStartTime)
|
||||
|
||||
logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
|
||||
|
||||
arrayOfBlocks(hasBlocks) = bcBlock
|
||||
hasBlocks += 1
|
||||
|
||||
// Set to true if at least one block is received
|
||||
receptionSucceeded = true
|
||||
hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => logError("receiveSingleTransmission had a " + e)
|
||||
} finally {
|
||||
if (oisSource != null) {
|
||||
oisSource.close()
|
||||
}
|
||||
if (oosSource != null) {
|
||||
oosSource.close()
|
||||
}
|
||||
if (clientSocketToSource != null) {
|
||||
clientSocketToSource.close()
|
||||
}
|
||||
}
|
||||
|
||||
return receptionSucceeded
|
||||
}
|
||||
|
||||
class GuideMultipleRequests
|
||||
extends Thread with Logging {
|
||||
// Keep track of sources that have completed reception
|
||||
private var setOfCompletedSources = Set[SourceInfo]()
|
||||
|
||||
override def run() {
|
||||
var threadPool = Utils.newDaemonCachedThreadPool()
|
||||
var serverSocket: ServerSocket = null
|
||||
|
||||
serverSocket = new ServerSocket(0)
|
||||
guidePort = serverSocket.getLocalPort
|
||||
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
|
||||
|
||||
guidePortLock.synchronized { guidePortLock.notifyAll() }
|
||||
|
||||
try {
|
||||
while (!stopBroadcast) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
// Stop broadcast if at least one worker has connected and
|
||||
// everyone connected so far are done. Comparing with
|
||||
// listOfSources.size - 1, because it includes the Guide itself
|
||||
listOfSources.synchronized {
|
||||
setOfCompletedSources.synchronized {
|
||||
if (listOfSources.size > 1 &&
|
||||
setOfCompletedSources.size == listOfSources.size - 1) {
|
||||
stopBroadcast = true
|
||||
logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if (clientSocket != null) {
|
||||
logDebug("Guide: Accepted new client connection: " + clientSocket)
|
||||
try {
|
||||
threadPool.execute(new GuideSingleRequest(clientSocket))
|
||||
} catch {
|
||||
// In failure, close() the socket here; else, the thread will close() it
|
||||
case ioe: IOException => clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logInfo("Sending stopBroadcast notifications...")
|
||||
sendStopBroadcastNotifications
|
||||
|
||||
MultiTracker.unregisterBroadcast(id)
|
||||
} finally {
|
||||
if (serverSocket != null) {
|
||||
logInfo("GuideMultipleRequests now stopping...")
|
||||
serverSocket.close()
|
||||
}
|
||||
}
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown()
|
||||
}
|
||||
|
||||
private def sendStopBroadcastNotifications() {
|
||||
listOfSources.synchronized {
|
||||
var listIter = listOfSources.iterator
|
||||
while (listIter.hasNext) {
|
||||
var sourceInfo = listIter.next
|
||||
|
||||
var guideSocketToSource: Socket = null
|
||||
var gosSource: ObjectOutputStream = null
|
||||
var gisSource: ObjectInputStream = null
|
||||
|
||||
try {
|
||||
// Connect to the source
|
||||
guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
|
||||
gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
|
||||
gosSource.flush()
|
||||
gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
|
||||
|
||||
// Send stopBroadcast signal
|
||||
gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
|
||||
gosSource.flush()
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
logError("sendStopBroadcastNotifications had a " + e)
|
||||
}
|
||||
} finally {
|
||||
if (gisSource != null) {
|
||||
gisSource.close()
|
||||
}
|
||||
if (gosSource != null) {
|
||||
gosSource.close()
|
||||
}
|
||||
if (guideSocketToSource != null) {
|
||||
guideSocketToSource.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class GuideSingleRequest(val clientSocket: Socket)
|
||||
extends Thread with Logging {
|
||||
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
|
||||
oos.flush()
|
||||
private val ois = new ObjectInputStream(clientSocket.getInputStream)
|
||||
|
||||
private var selectedSourceInfo: SourceInfo = null
|
||||
private var thisWorkerInfo:SourceInfo = null
|
||||
|
||||
override def run() {
|
||||
try {
|
||||
logInfo("new GuideSingleRequest is running")
|
||||
// Connecting worker is sending in its hostAddress and listenPort it will
|
||||
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
|
||||
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
listOfSources.synchronized {
|
||||
// Select a suitable source and send it back to the worker
|
||||
selectedSourceInfo = selectSuitableSource(sourceInfo)
|
||||
logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
|
||||
oos.writeObject(selectedSourceInfo)
|
||||
oos.flush()
|
||||
|
||||
// Add this new (if it can finish) source to the list of sources
|
||||
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
|
||||
sourceInfo.listenPort, totalBlocks, totalBytes)
|
||||
logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
|
||||
listOfSources += thisWorkerInfo
|
||||
}
|
||||
|
||||
// Wait till the whole transfer is done. Then receive and update source
|
||||
// statistics in listOfSources
|
||||
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
|
||||
|
||||
listOfSources.synchronized {
|
||||
// This should work since SourceInfo is a case class
|
||||
assert(listOfSources.contains(selectedSourceInfo))
|
||||
|
||||
// Remove first
|
||||
// (Currently removing a source based on just one failure notification!)
|
||||
listOfSources = listOfSources - selectedSourceInfo
|
||||
|
||||
// Update sourceInfo and put it back in, IF reception succeeded
|
||||
if (!sourceInfo.receptionFailed) {
|
||||
// Add thisWorkerInfo to sources that have completed reception
|
||||
setOfCompletedSources.synchronized {
|
||||
setOfCompletedSources += thisWorkerInfo
|
||||
}
|
||||
|
||||
// Update leecher count and put it back in
|
||||
selectedSourceInfo.currentLeechers -= 1
|
||||
listOfSources += selectedSourceInfo
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => {
|
||||
// Remove failed worker from listOfSources and update leecherCount of
|
||||
// corresponding source worker
|
||||
listOfSources.synchronized {
|
||||
if (selectedSourceInfo != null) {
|
||||
// Remove first
|
||||
listOfSources = listOfSources - selectedSourceInfo
|
||||
// Update leecher count and put it back in
|
||||
selectedSourceInfo.currentLeechers -= 1
|
||||
listOfSources += selectedSourceInfo
|
||||
}
|
||||
|
||||
// Remove thisWorkerInfo
|
||||
if (listOfSources != null) {
|
||||
listOfSources = listOfSources - thisWorkerInfo
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
logInfo("GuideSingleRequest is closing streams and sockets")
|
||||
ois.close()
|
||||
oos.close()
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
|
||||
// Assuming the caller to have a synchronized block on listOfSources
|
||||
// Select one with the most leechers. This will level-wise fill the tree
|
||||
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
|
||||
var maxLeechers = -1
|
||||
var selectedSource: SourceInfo = null
|
||||
|
||||
listOfSources.foreach { source =>
|
||||
if ((source.hostAddress != skipSourceInfo.hostAddress ||
|
||||
source.listenPort != skipSourceInfo.listenPort) &&
|
||||
source.currentLeechers < MultiTracker.MaxDegree &&
|
||||
source.currentLeechers > maxLeechers) {
|
||||
selectedSource = source
|
||||
maxLeechers = source.currentLeechers
|
||||
}
|
||||
}
|
||||
|
||||
// Update leecher count
|
||||
selectedSource.currentLeechers += 1
|
||||
return selectedSource
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class ServeMultipleRequests
|
||||
extends Thread with Logging {
|
||||
|
||||
var threadPool = Utils.newDaemonCachedThreadPool()
|
||||
|
||||
override def run() {
|
||||
var serverSocket = new ServerSocket(0)
|
||||
listenPort = serverSocket.getLocalPort
|
||||
|
||||
logInfo("ServeMultipleRequests started with " + serverSocket)
|
||||
|
||||
listenPortLock.synchronized { listenPortLock.notifyAll() }
|
||||
|
||||
try {
|
||||
while (!stopBroadcast) {
|
||||
var clientSocket: Socket = null
|
||||
try {
|
||||
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
|
||||
clientSocket = serverSocket.accept
|
||||
} catch {
|
||||
case e: Exception => { }
|
||||
}
|
||||
|
||||
if (clientSocket != null) {
|
||||
logDebug("Serve: Accepted new client connection: " + clientSocket)
|
||||
try {
|
||||
threadPool.execute(new ServeSingleRequest(clientSocket))
|
||||
} catch {
|
||||
// In failure, close socket here; else, the thread will close it
|
||||
case ioe: IOException => clientSocket.close()
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (serverSocket != null) {
|
||||
logInfo("ServeMultipleRequests now stopping...")
|
||||
serverSocket.close()
|
||||
}
|
||||
}
|
||||
// Shutdown the thread pool
|
||||
threadPool.shutdown()
|
||||
}
|
||||
|
||||
class ServeSingleRequest(val clientSocket: Socket)
|
||||
extends Thread with Logging {
|
||||
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
|
||||
oos.flush()
|
||||
private val ois = new ObjectInputStream(clientSocket.getInputStream)
|
||||
|
||||
private var sendFrom = 0
|
||||
private var sendUntil = totalBlocks
|
||||
|
||||
override def run() {
|
||||
try {
|
||||
logInfo("new ServeSingleRequest is running")
|
||||
|
||||
// Receive range to send
|
||||
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
|
||||
sendFrom = rangeToSend._1
|
||||
sendUntil = rangeToSend._2
|
||||
|
||||
// If not a valid range, stop broadcast
|
||||
if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
|
||||
stopBroadcast = true
|
||||
} else {
|
||||
sendObject
|
||||
}
|
||||
} catch {
|
||||
case e: Exception => logError("ServeSingleRequest had a " + e)
|
||||
} finally {
|
||||
logInfo("ServeSingleRequest is closing streams and sockets")
|
||||
ois.close()
|
||||
oos.close()
|
||||
clientSocket.close()
|
||||
}
|
||||
}
|
||||
|
||||
private def sendObject() {
|
||||
// Wait till receiving the SourceInfo from Driver
|
||||
while (totalBlocks == -1) {
|
||||
totalBlocksLock.synchronized { totalBlocksLock.wait() }
|
||||
}
|
||||
|
||||
for (i <- sendFrom until sendUntil) {
|
||||
while (i == hasBlocks) {
|
||||
hasBlocksLock.synchronized { hasBlocksLock.wait() }
|
||||
}
|
||||
try {
|
||||
oos.writeObject(arrayOfBlocks(i))
|
||||
oos.flush()
|
||||
} catch {
|
||||
case e: Exception => logError("sendObject had a " + e)
|
||||
}
|
||||
logDebug("Sent block: " + i + " to " + clientSocket)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class TreeBroadcastFactory
|
||||
extends BroadcastFactory {
|
||||
def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
|
||||
|
||||
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
|
||||
new TreeBroadcast[T](value_, isLocal, id)
|
||||
|
||||
def stop() { MultiTracker.stop() }
|
||||
}
|
|
@ -21,12 +21,14 @@ import scala.collection.immutable.List
|
|||
|
||||
import org.apache.spark.deploy.ExecutorState.ExecutorState
|
||||
import org.apache.spark.deploy.master.{WorkerInfo, ApplicationInfo}
|
||||
import org.apache.spark.deploy.master.RecoveryState.MasterState
|
||||
import org.apache.spark.deploy.worker.ExecutorRunner
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
||||
private[deploy] sealed trait DeployMessage extends Serializable
|
||||
|
||||
/** Contains messages sent between Scheduler actor nodes. */
|
||||
private[deploy] object DeployMessages {
|
||||
|
||||
// Worker to Master
|
||||
|
@ -52,17 +54,20 @@ private[deploy] object DeployMessages {
|
|||
exitStatus: Option[Int])
|
||||
extends DeployMessage
|
||||
|
||||
case class WorkerSchedulerStateResponse(id: String, executors: List[ExecutorDescription])
|
||||
|
||||
case class Heartbeat(workerId: String) extends DeployMessage
|
||||
|
||||
// Master to Worker
|
||||
|
||||
case class RegisteredWorker(masterWebUiUrl: String) extends DeployMessage
|
||||
case class RegisteredWorker(masterUrl: String, masterWebUiUrl: String) extends DeployMessage
|
||||
|
||||
case class RegisterWorkerFailed(message: String) extends DeployMessage
|
||||
|
||||
case class KillExecutor(appId: String, execId: Int) extends DeployMessage
|
||||
case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage
|
||||
|
||||
case class LaunchExecutor(
|
||||
masterUrl: String,
|
||||
appId: String,
|
||||
execId: Int,
|
||||
appDesc: ApplicationDescription,
|
||||
|
@ -76,9 +81,11 @@ private[deploy] object DeployMessages {
|
|||
case class RegisterApplication(appDescription: ApplicationDescription)
|
||||
extends DeployMessage
|
||||
|
||||
case class MasterChangeAcknowledged(appId: String)
|
||||
|
||||
// Master to Client
|
||||
|
||||
case class RegisteredApplication(appId: String) extends DeployMessage
|
||||
case class RegisteredApplication(appId: String, masterUrl: String) extends DeployMessage
|
||||
|
||||
// TODO(matei): replace hostPort with host
|
||||
case class ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) {
|
||||
|
@ -94,6 +101,10 @@ private[deploy] object DeployMessages {
|
|||
|
||||
case object StopClient
|
||||
|
||||
// Master to Worker & Client
|
||||
|
||||
case class MasterChanged(masterUrl: String, masterWebUiUrl: String)
|
||||
|
||||
// MasterWebUI To Master
|
||||
|
||||
case object RequestMasterState
|
||||
|
@ -101,7 +112,8 @@ private[deploy] object DeployMessages {
|
|||
// Master to MasterWebUI
|
||||
|
||||
case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo],
|
||||
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo]) {
|
||||
activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo],
|
||||
status: MasterState) {
|
||||
|
||||
Utils.checkHost(host, "Required hostname")
|
||||
assert (port > 0)
|
||||
|
@ -123,12 +135,7 @@ private[deploy] object DeployMessages {
|
|||
assert (port > 0)
|
||||
}
|
||||
|
||||
// Actor System to Master
|
||||
|
||||
case object CheckForWorkerTimeOut
|
||||
|
||||
case object RequestWebUIPort
|
||||
|
||||
case class WebUIPortResponse(webUIBoundPort: Int)
|
||||
// Actor System to Worker
|
||||
|
||||
case object SendHeartbeat
|
||||
}
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy
|
||||
|
||||
/**
|
||||
* Used to send state on-the-wire about Executors from Worker to Master.
|
||||
* This state is sufficient for the Master to reconstruct its internal data structures during
|
||||
* failover.
|
||||
*/
|
||||
private[spark] class ExecutorDescription(
|
||||
val appId: String,
|
||||
val execId: Int,
|
||||
val cores: Int,
|
||||
val state: ExecutorState.Value)
|
||||
extends Serializable {
|
||||
|
||||
override def toString: String =
|
||||
"ExecutorState(appId=%s, execId=%d, cores=%d, state=%s)".format(appId, execId, cores, state)
|
||||
}
|
|
@ -0,0 +1,420 @@
|
|||
/*
|
||||
*
|
||||
* * Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* * contributor license agreements. See the NOTICE file distributed with
|
||||
* * this work for additional information regarding copyright ownership.
|
||||
* * The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* * (the "License"); you may not use this file except in compliance with
|
||||
* * the License. You may obtain a copy of the License at
|
||||
* *
|
||||
* * http://www.apache.org/licenses/LICENSE-2.0
|
||||
* *
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* * See the License for the specific language governing permissions and
|
||||
* * limitations under the License.
|
||||
*
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy
|
||||
|
||||
import java.io._
|
||||
import java.net.URL
|
||||
import java.util.concurrent.TimeoutException
|
||||
|
||||
import scala.concurrent.{Await, future, promise}
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
import scala.collection.mutable.ListBuffer
|
||||
import scala.sys.process._
|
||||
|
||||
import net.liftweb.json.JsonParser
|
||||
|
||||
import org.apache.spark.{Logging, SparkContext}
|
||||
import org.apache.spark.deploy.master.RecoveryState
|
||||
|
||||
/**
|
||||
* This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master.
|
||||
* In order to mimic a real distributed cluster more closely, Docker is used.
|
||||
* Execute using
|
||||
* ./spark-class org.apache.spark.deploy.FaultToleranceTest
|
||||
*
|
||||
* Make sure that that the environment includes the following properties in SPARK_DAEMON_JAVA_OPTS:
|
||||
* - spark.deploy.recoveryMode=ZOOKEEPER
|
||||
* - spark.deploy.zookeeper.url=172.17.42.1:2181
|
||||
* Note that 172.17.42.1 is the default docker ip for the host and 2181 is the default ZK port.
|
||||
*
|
||||
* Unfortunately, due to the Docker dependency this suite cannot be run automatically without a
|
||||
* working installation of Docker. In addition to having Docker, the following are assumed:
|
||||
* - Docker can run without sudo (see http://docs.docker.io/en/latest/use/basics/)
|
||||
* - The docker images tagged spark-test-master and spark-test-worker are built from the
|
||||
* docker/ directory. Run 'docker/spark-test/build' to generate these.
|
||||
*/
|
||||
private[spark] object FaultToleranceTest extends App with Logging {
|
||||
val masters = ListBuffer[TestMasterInfo]()
|
||||
val workers = ListBuffer[TestWorkerInfo]()
|
||||
var sc: SparkContext = _
|
||||
|
||||
var numPassed = 0
|
||||
var numFailed = 0
|
||||
|
||||
val sparkHome = System.getenv("SPARK_HOME")
|
||||
assertTrue(sparkHome != null, "Run with a valid SPARK_HOME")
|
||||
|
||||
val containerSparkHome = "/opt/spark"
|
||||
val dockerMountDir = "%s:%s".format(sparkHome, containerSparkHome)
|
||||
|
||||
System.setProperty("spark.driver.host", "172.17.42.1") // default docker host ip
|
||||
|
||||
def afterEach() {
|
||||
if (sc != null) {
|
||||
sc.stop()
|
||||
sc = null
|
||||
}
|
||||
terminateCluster()
|
||||
}
|
||||
|
||||
test("sanity-basic") {
|
||||
addMasters(1)
|
||||
addWorkers(1)
|
||||
createClient()
|
||||
assertValidClusterState()
|
||||
}
|
||||
|
||||
test("sanity-many-masters") {
|
||||
addMasters(3)
|
||||
addWorkers(3)
|
||||
createClient()
|
||||
assertValidClusterState()
|
||||
}
|
||||
|
||||
test("single-master-halt") {
|
||||
addMasters(3)
|
||||
addWorkers(2)
|
||||
createClient()
|
||||
assertValidClusterState()
|
||||
|
||||
killLeader()
|
||||
delay(30 seconds)
|
||||
assertValidClusterState()
|
||||
createClient()
|
||||
assertValidClusterState()
|
||||
}
|
||||
|
||||
test("single-master-restart") {
|
||||
addMasters(1)
|
||||
addWorkers(2)
|
||||
createClient()
|
||||
assertValidClusterState()
|
||||
|
||||
killLeader()
|
||||
addMasters(1)
|
||||
delay(30 seconds)
|
||||
assertValidClusterState()
|
||||
|
||||
killLeader()
|
||||
addMasters(1)
|
||||
delay(30 seconds)
|
||||
assertValidClusterState()
|
||||
}
|
||||
|
||||
test("cluster-failure") {
|
||||
addMasters(2)
|
||||
addWorkers(2)
|
||||
createClient()
|
||||
assertValidClusterState()
|
||||
|
||||
terminateCluster()
|
||||
addMasters(2)
|
||||
addWorkers(2)
|
||||
assertValidClusterState()
|
||||
}
|
||||
|
||||
test("all-but-standby-failure") {
|
||||
addMasters(2)
|
||||
addWorkers(2)
|
||||
createClient()
|
||||
assertValidClusterState()
|
||||
|
||||
killLeader()
|
||||
workers.foreach(_.kill())
|
||||
workers.clear()
|
||||
delay(30 seconds)
|
||||
addWorkers(2)
|
||||
assertValidClusterState()
|
||||
}
|
||||
|
||||
test("rolling-outage") {
|
||||
addMasters(1)
|
||||
delay()
|
||||
addMasters(1)
|
||||
delay()
|
||||
addMasters(1)
|
||||
addWorkers(2)
|
||||
createClient()
|
||||
assertValidClusterState()
|
||||
assertTrue(getLeader == masters.head)
|
||||
|
||||
(1 to 3).foreach { _ =>
|
||||
killLeader()
|
||||
delay(30 seconds)
|
||||
assertValidClusterState()
|
||||
assertTrue(getLeader == masters.head)
|
||||
addMasters(1)
|
||||
}
|
||||
}
|
||||
|
||||
def test(name: String)(fn: => Unit) {
|
||||
try {
|
||||
fn
|
||||
numPassed += 1
|
||||
logInfo("Passed: " + name)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
numFailed += 1
|
||||
logError("FAILED: " + name, e)
|
||||
}
|
||||
afterEach()
|
||||
}
|
||||
|
||||
def addMasters(num: Int) {
|
||||
(1 to num).foreach { _ => masters += SparkDocker.startMaster(dockerMountDir) }
|
||||
}
|
||||
|
||||
def addWorkers(num: Int) {
|
||||
val masterUrls = getMasterUrls(masters)
|
||||
(1 to num).foreach { _ => workers += SparkDocker.startWorker(dockerMountDir, masterUrls) }
|
||||
}
|
||||
|
||||
/** Creates a SparkContext, which constructs a Client to interact with our cluster. */
|
||||
def createClient() = {
|
||||
if (sc != null) { sc.stop() }
|
||||
// Counter-hack: Because of a hack in SparkEnv#createFromSystemProperties() that changes this
|
||||
// property, we need to reset it.
|
||||
System.setProperty("spark.driver.port", "0")
|
||||
sc = new SparkContext(getMasterUrls(masters), "fault-tolerance", containerSparkHome)
|
||||
}
|
||||
|
||||
def getMasterUrls(masters: Seq[TestMasterInfo]): String = {
|
||||
"spark://" + masters.map(master => master.ip + ":7077").mkString(",")
|
||||
}
|
||||
|
||||
def getLeader: TestMasterInfo = {
|
||||
val leaders = masters.filter(_.state == RecoveryState.ALIVE)
|
||||
assertTrue(leaders.size == 1)
|
||||
leaders(0)
|
||||
}
|
||||
|
||||
def killLeader(): Unit = {
|
||||
masters.foreach(_.readState())
|
||||
val leader = getLeader
|
||||
masters -= leader
|
||||
leader.kill()
|
||||
}
|
||||
|
||||
def delay(secs: Duration = 5.seconds) = Thread.sleep(secs.toMillis)
|
||||
|
||||
def terminateCluster() {
|
||||
masters.foreach(_.kill())
|
||||
workers.foreach(_.kill())
|
||||
masters.clear()
|
||||
workers.clear()
|
||||
}
|
||||
|
||||
/** This includes Client retry logic, so it may take a while if the cluster is recovering. */
|
||||
def assertUsable() = {
|
||||
val f = future {
|
||||
try {
|
||||
val res = sc.parallelize(0 until 10).collect()
|
||||
assertTrue(res.toList == (0 until 10))
|
||||
true
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("assertUsable() had exception", e)
|
||||
e.printStackTrace()
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
// Avoid waiting indefinitely (e.g., we could register but get no executors).
|
||||
assertTrue(Await.result(f, 120 seconds))
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts that the cluster is usable and that the expected masters and workers
|
||||
* are all alive in a proper configuration (e.g., only one leader).
|
||||
*/
|
||||
def assertValidClusterState() = {
|
||||
assertUsable()
|
||||
var numAlive = 0
|
||||
var numStandby = 0
|
||||
var numLiveApps = 0
|
||||
var liveWorkerIPs: Seq[String] = List()
|
||||
|
||||
def stateValid(): Boolean = {
|
||||
(workers.map(_.ip) -- liveWorkerIPs).isEmpty &&
|
||||
numAlive == 1 && numStandby == masters.size - 1 && numLiveApps >= 1
|
||||
}
|
||||
|
||||
val f = future {
|
||||
try {
|
||||
while (!stateValid()) {
|
||||
Thread.sleep(1000)
|
||||
|
||||
numAlive = 0
|
||||
numStandby = 0
|
||||
numLiveApps = 0
|
||||
|
||||
masters.foreach(_.readState())
|
||||
|
||||
for (master <- masters) {
|
||||
master.state match {
|
||||
case RecoveryState.ALIVE =>
|
||||
numAlive += 1
|
||||
liveWorkerIPs = master.liveWorkerIPs
|
||||
case RecoveryState.STANDBY =>
|
||||
numStandby += 1
|
||||
case _ => // ignore
|
||||
}
|
||||
|
||||
numLiveApps += master.numLiveApps
|
||||
}
|
||||
}
|
||||
true
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("assertValidClusterState() had exception", e)
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
assertTrue(Await.result(f, 120 seconds))
|
||||
} catch {
|
||||
case e: TimeoutException =>
|
||||
logError("Master states: " + masters.map(_.state))
|
||||
logError("Num apps: " + numLiveApps)
|
||||
logError("IPs expected: " + workers.map(_.ip) + " / found: " + liveWorkerIPs)
|
||||
throw new RuntimeException("Failed to get into acceptable cluster state after 2 min.", e)
|
||||
}
|
||||
}
|
||||
|
||||
def assertTrue(bool: Boolean, message: String = "") {
|
||||
if (!bool) {
|
||||
throw new IllegalStateException("Assertion failed: " + message)
|
||||
}
|
||||
}
|
||||
|
||||
logInfo("Ran %s tests, %s passed and %s failed".format(numPassed+numFailed, numPassed, numFailed))
|
||||
}
|
||||
|
||||
private[spark] class TestMasterInfo(val ip: String, val dockerId: DockerId, val logFile: File)
|
||||
extends Logging {
|
||||
|
||||
implicit val formats = net.liftweb.json.DefaultFormats
|
||||
var state: RecoveryState.Value = _
|
||||
var liveWorkerIPs: List[String] = _
|
||||
var numLiveApps = 0
|
||||
|
||||
logDebug("Created master: " + this)
|
||||
|
||||
def readState() {
|
||||
try {
|
||||
val masterStream = new InputStreamReader(new URL("http://%s:8080/json".format(ip)).openStream)
|
||||
val json = JsonParser.parse(masterStream, closeAutomatically = true)
|
||||
|
||||
val workers = json \ "workers"
|
||||
val liveWorkers = workers.children.filter(w => (w \ "state").extract[String] == "ALIVE")
|
||||
liveWorkerIPs = liveWorkers.map(w => (w \ "host").extract[String])
|
||||
|
||||
numLiveApps = (json \ "activeapps").children.size
|
||||
|
||||
val status = json \\ "status"
|
||||
val stateString = status.extract[String]
|
||||
state = RecoveryState.values.filter(state => state.toString == stateString).head
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
// ignore, no state update
|
||||
logWarning("Exception", e)
|
||||
}
|
||||
}
|
||||
|
||||
def kill() { Docker.kill(dockerId) }
|
||||
|
||||
override def toString: String =
|
||||
"[ip=%s, id=%s, logFile=%s, state=%s]".
|
||||
format(ip, dockerId.id, logFile.getAbsolutePath, state)
|
||||
}
|
||||
|
||||
private[spark] class TestWorkerInfo(val ip: String, val dockerId: DockerId, val logFile: File)
|
||||
extends Logging {
|
||||
|
||||
implicit val formats = net.liftweb.json.DefaultFormats
|
||||
|
||||
logDebug("Created worker: " + this)
|
||||
|
||||
def kill() { Docker.kill(dockerId) }
|
||||
|
||||
override def toString: String =
|
||||
"[ip=%s, id=%s, logFile=%s]".format(ip, dockerId, logFile.getAbsolutePath)
|
||||
}
|
||||
|
||||
private[spark] object SparkDocker {
|
||||
def startMaster(mountDir: String): TestMasterInfo = {
|
||||
val cmd = Docker.makeRunCmd("spark-test-master", mountDir = mountDir)
|
||||
val (ip, id, outFile) = startNode(cmd)
|
||||
new TestMasterInfo(ip, id, outFile)
|
||||
}
|
||||
|
||||
def startWorker(mountDir: String, masters: String): TestWorkerInfo = {
|
||||
val cmd = Docker.makeRunCmd("spark-test-worker", args = masters, mountDir = mountDir)
|
||||
val (ip, id, outFile) = startNode(cmd)
|
||||
new TestWorkerInfo(ip, id, outFile)
|
||||
}
|
||||
|
||||
private def startNode(dockerCmd: ProcessBuilder) : (String, DockerId, File) = {
|
||||
val ipPromise = promise[String]()
|
||||
val outFile = File.createTempFile("fault-tolerance-test", "")
|
||||
outFile.deleteOnExit()
|
||||
val outStream: FileWriter = new FileWriter(outFile)
|
||||
def findIpAndLog(line: String): Unit = {
|
||||
if (line.startsWith("CONTAINER_IP=")) {
|
||||
val ip = line.split("=")(1)
|
||||
ipPromise.success(ip)
|
||||
}
|
||||
|
||||
outStream.write(line + "\n")
|
||||
outStream.flush()
|
||||
}
|
||||
|
||||
dockerCmd.run(ProcessLogger(findIpAndLog _))
|
||||
val ip = Await.result(ipPromise.future, 30 seconds)
|
||||
val dockerId = Docker.getLastProcessId
|
||||
(ip, dockerId, outFile)
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class DockerId(val id: String) {
|
||||
override def toString = id
|
||||
}
|
||||
|
||||
private[spark] object Docker extends Logging {
|
||||
def makeRunCmd(imageTag: String, args: String = "", mountDir: String = ""): ProcessBuilder = {
|
||||
val mountCmd = if (mountDir != "") { " -v " + mountDir } else ""
|
||||
|
||||
val cmd = "docker run %s %s %s".format(mountCmd, imageTag, args)
|
||||
logDebug("Run command: " + cmd)
|
||||
cmd
|
||||
}
|
||||
|
||||
def kill(dockerId: DockerId) : Unit = {
|
||||
"docker kill %s".format(dockerId.id).!
|
||||
}
|
||||
|
||||
def getLastProcessId: DockerId = {
|
||||
var id: String = null
|
||||
"docker ps -l -q".!(ProcessLogger(line => id = line))
|
||||
new DockerId(id)
|
||||
}
|
||||
}
|
|
@ -72,7 +72,8 @@ private[spark] object JsonProtocol {
|
|||
("memory" -> obj.workers.map(_.memory).sum) ~
|
||||
("memoryused" -> obj.workers.map(_.memoryUsed).sum) ~
|
||||
("activeapps" -> obj.activeApps.toList.map(writeApplicationInfo)) ~
|
||||
("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo))
|
||||
("completedapps" -> obj.completedApps.toList.map(writeApplicationInfo)) ~
|
||||
("status" -> obj.status.toString)
|
||||
}
|
||||
|
||||
def writeWorkerState(obj: WorkerStateResponse) = {
|
||||
|
|
|
@ -39,22 +39,23 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
|
|||
private val masterActorSystems = ArrayBuffer[ActorSystem]()
|
||||
private val workerActorSystems = ArrayBuffer[ActorSystem]()
|
||||
|
||||
def start(): String = {
|
||||
def start(): Array[String] = {
|
||||
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
|
||||
|
||||
/* Start the Master */
|
||||
val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0)
|
||||
masterActorSystems += masterSystem
|
||||
val masterUrl = "spark://" + localHostname + ":" + masterPort
|
||||
val masters = Array(masterUrl)
|
||||
|
||||
/* Start the Workers */
|
||||
for (workerNum <- 1 to numWorkers) {
|
||||
val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker,
|
||||
memoryPerWorker, masterUrl, null, Some(workerNum))
|
||||
memoryPerWorker, masters, null, Some(workerNum))
|
||||
workerActorSystems += workerSystem
|
||||
}
|
||||
|
||||
return masterUrl
|
||||
return masters
|
||||
}
|
||||
|
||||
def stop() {
|
||||
|
|
|
@ -17,28 +17,70 @@
|
|||
|
||||
package org.apache.spark.deploy
|
||||
|
||||
import com.google.common.collect.MapMaker
|
||||
import java.security.PrivilegedExceptionAction
|
||||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.mapred.JobConf
|
||||
import org.apache.hadoop.security.UserGroupInformation
|
||||
|
||||
import org.apache.spark.{SparkContext, SparkException}
|
||||
|
||||
/**
|
||||
* Contains util methods to interact with Hadoop from spark.
|
||||
* Contains util methods to interact with Hadoop from Spark.
|
||||
*/
|
||||
private[spark]
|
||||
class SparkHadoopUtil {
|
||||
// A general, soft-reference map for metadata needed during HadoopRDD split computation
|
||||
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
|
||||
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
|
||||
val conf = newConfiguration()
|
||||
UserGroupInformation.setConfiguration(conf)
|
||||
|
||||
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop
|
||||
// subsystems
|
||||
def runAsUser(user: String)(func: () => Unit) {
|
||||
// if we are already running as the user intended there is no reason to do the doAs. It
|
||||
// will actually break secure HDFS access as it doesn't fill in the credentials. Also if
|
||||
// the user is UNKNOWN then we shouldn't be creating a remote unknown user
|
||||
// (this is actually the path spark on yarn takes) since SPARK_USER is initialized only
|
||||
// in SparkContext.
|
||||
val currentUser = Option(System.getProperty("user.name")).
|
||||
getOrElse(SparkContext.SPARK_UNKNOWN_USER)
|
||||
if (user != SparkContext.SPARK_UNKNOWN_USER && currentUser != user) {
|
||||
val ugi = UserGroupInformation.createRemoteUser(user)
|
||||
ugi.doAs(new PrivilegedExceptionAction[Unit] {
|
||||
def run: Unit = func()
|
||||
})
|
||||
} else {
|
||||
func()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
|
||||
* subsystems.
|
||||
*/
|
||||
def newConfiguration(): Configuration = new Configuration()
|
||||
|
||||
// Add any user credentials to the job conf which are necessary for running on a secure Hadoop
|
||||
// cluster
|
||||
/**
|
||||
* Add any user credentials to the job conf which are necessary for running on a secure Hadoop
|
||||
* cluster.
|
||||
*/
|
||||
def addCredentials(conf: JobConf) {}
|
||||
|
||||
def isYarnMode(): Boolean = { false }
|
||||
|
||||
}
|
||||
|
||||
object SparkHadoopUtil {
|
||||
private val hadoop = {
|
||||
val yarnMode = java.lang.Boolean.valueOf(System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))
|
||||
if (yarnMode) {
|
||||
try {
|
||||
Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil]
|
||||
} catch {
|
||||
case th: Throwable => throw new SparkException("Unable to load YARN support", th)
|
||||
}
|
||||
} else {
|
||||
new SparkHadoopUtil
|
||||
}
|
||||
}
|
||||
|
||||
def get: SparkHadoopUtil = {
|
||||
hadoop
|
||||
}
|
||||
}
|
||||
|
|
|
@ -37,38 +37,81 @@ import org.apache.spark.deploy.master.Master
|
|||
/**
|
||||
* The main class used to talk to a Spark deploy cluster. Takes a master URL, an app description,
|
||||
* and a listener for cluster events, and calls back the listener when various events occur.
|
||||
*
|
||||
* @param masterUrls Each url should look like spark://host:port.
|
||||
*/
|
||||
private[spark] class Client(
|
||||
actorSystem: ActorSystem,
|
||||
masterUrl: String,
|
||||
masterUrls: Array[String],
|
||||
appDescription: ApplicationDescription,
|
||||
listener: ClientListener)
|
||||
extends Logging {
|
||||
|
||||
val REGISTRATION_TIMEOUT = 20.seconds
|
||||
val REGISTRATION_RETRIES = 3
|
||||
|
||||
var prevMaster: ActorRef = null // set for unwatching, when it fails.
|
||||
var actor: ActorRef = null
|
||||
var appId: String = null
|
||||
var registered = false
|
||||
var activeMasterUrl: String = null
|
||||
|
||||
class ClientActor extends Actor with Logging {
|
||||
var master: ActorSelection = null
|
||||
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
|
||||
var alreadyDead = false // To avoid calling listener.dead() multiple times
|
||||
|
||||
override def preStart() {
|
||||
logInfo("Connecting to master " + masterUrl)
|
||||
try {
|
||||
master = context.actorSelection(Master.toAkkaUrl(masterUrl))
|
||||
master ! RegisterApplication(appDescription)
|
||||
registerWithMaster()
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError("Failed to connect to master", e)
|
||||
logWarning("Failed to connect to master", e)
|
||||
markDisconnected()
|
||||
context.stop(self)
|
||||
}
|
||||
}
|
||||
|
||||
def tryRegisterAllMasters() {
|
||||
for (masterUrl <- masterUrls) {
|
||||
logInfo("Connecting to master " + masterUrl + "...")
|
||||
val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
|
||||
actor ! RegisterApplication(appDescription)
|
||||
}
|
||||
}
|
||||
|
||||
def registerWithMaster() {
|
||||
tryRegisterAllMasters()
|
||||
|
||||
import context.dispatcher
|
||||
var retries = 0
|
||||
lazy val retryTimer: Cancellable =
|
||||
context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
|
||||
retries += 1
|
||||
if (registered) {
|
||||
retryTimer.cancel()
|
||||
} else if (retries >= REGISTRATION_RETRIES) {
|
||||
logError("All masters are unresponsive! Giving up.")
|
||||
markDead()
|
||||
} else {
|
||||
tryRegisterAllMasters()
|
||||
}
|
||||
}
|
||||
retryTimer // start timer
|
||||
}
|
||||
|
||||
def changeMaster(url: String) {
|
||||
activeMasterUrl = url
|
||||
master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
|
||||
}
|
||||
|
||||
override def receive = {
|
||||
case RegisteredApplication(appId_) =>
|
||||
case RegisteredApplication(appId_, masterUrl) =>
|
||||
context.watch(sender)
|
||||
prevMaster = sender
|
||||
appId = appId_
|
||||
registered = true
|
||||
changeMaster(masterUrl)
|
||||
listener.connected(appId)
|
||||
|
||||
case ApplicationRemoved(message) =>
|
||||
|
@ -89,13 +132,19 @@ private[spark] class Client(
|
|||
listener.executorRemoved(fullId, message.getOrElse(""), exitStatus)
|
||||
}
|
||||
|
||||
case MasterChanged(masterUrl, masterWebUiUrl) =>
|
||||
logInfo("Master has changed, new master is at " + masterUrl)
|
||||
context.unwatch(prevMaster)
|
||||
changeMaster(masterUrl)
|
||||
alreadyDisconnected = false
|
||||
sender ! MasterChangeAcknowledged(appId)
|
||||
|
||||
case Terminated(actor_) =>
|
||||
logError(s"Connection to $actor_ dropped, stopping client")
|
||||
logWarning(s"Connection to $actor_ failed; waiting for master to reconnect...")
|
||||
markDisconnected()
|
||||
context.stop(self)
|
||||
|
||||
case StopClient =>
|
||||
markDisconnected()
|
||||
markDead()
|
||||
sender ! true
|
||||
context.stop(self)
|
||||
}
|
||||
|
@ -109,6 +158,13 @@ private[spark] class Client(
|
|||
alreadyDisconnected = true
|
||||
}
|
||||
}
|
||||
|
||||
def markDead() {
|
||||
if (!alreadyDead) {
|
||||
listener.dead()
|
||||
alreadyDead = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def start() {
|
||||
|
|
|
@ -27,8 +27,12 @@ package org.apache.spark.deploy.client
|
|||
private[spark] trait ClientListener {
|
||||
def connected(appId: String): Unit
|
||||
|
||||
/** Disconnection may be a temporary state, as we fail over to a new Master. */
|
||||
def disconnected(): Unit
|
||||
|
||||
/** Dead means that we couldn't find any Masters to connect to, and have given up. */
|
||||
def dead(): Unit
|
||||
|
||||
def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit
|
||||
|
||||
def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit
|
||||
|
|
|
@ -33,6 +33,11 @@ private[spark] object TestClient {
|
|||
System.exit(0)
|
||||
}
|
||||
|
||||
def dead() {
|
||||
logInfo("Could not connect to master")
|
||||
System.exit(0)
|
||||
}
|
||||
|
||||
def executorAdded(id: String, workerId: String, hostPort: String, cores: Int, memory: Int) {}
|
||||
|
||||
def executorRemoved(id: String, message: String, exitStatus: Option[Int]) {}
|
||||
|
@ -44,7 +49,7 @@ private[spark] object TestClient {
|
|||
val desc = new ApplicationDescription(
|
||||
"TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored")
|
||||
val listener = new TestListener
|
||||
val client = new Client(actorSystem, url, desc, listener)
|
||||
val client = new Client(actorSystem, Array(url), desc, listener)
|
||||
client.start()
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
|
|
@ -29,23 +29,46 @@ private[spark] class ApplicationInfo(
|
|||
val submitDate: Date,
|
||||
val driver: ActorRef,
|
||||
val appUiUrl: String)
|
||||
{
|
||||
var state = ApplicationState.WAITING
|
||||
var executors = new mutable.HashMap[Int, ExecutorInfo]
|
||||
var coresGranted = 0
|
||||
var endTime = -1L
|
||||
val appSource = new ApplicationSource(this)
|
||||
extends Serializable {
|
||||
|
||||
private var nextExecutorId = 0
|
||||
@transient var state: ApplicationState.Value = _
|
||||
@transient var executors: mutable.HashMap[Int, ExecutorInfo] = _
|
||||
@transient var coresGranted: Int = _
|
||||
@transient var endTime: Long = _
|
||||
@transient var appSource: ApplicationSource = _
|
||||
|
||||
def newExecutorId(): Int = {
|
||||
val id = nextExecutorId
|
||||
nextExecutorId += 1
|
||||
id
|
||||
@transient private var nextExecutorId: Int = _
|
||||
|
||||
init()
|
||||
|
||||
private def readObject(in: java.io.ObjectInputStream) : Unit = {
|
||||
in.defaultReadObject()
|
||||
init()
|
||||
}
|
||||
|
||||
def addExecutor(worker: WorkerInfo, cores: Int): ExecutorInfo = {
|
||||
val exec = new ExecutorInfo(newExecutorId(), this, worker, cores, desc.memoryPerSlave)
|
||||
private def init() {
|
||||
state = ApplicationState.WAITING
|
||||
executors = new mutable.HashMap[Int, ExecutorInfo]
|
||||
coresGranted = 0
|
||||
endTime = -1L
|
||||
appSource = new ApplicationSource(this)
|
||||
nextExecutorId = 0
|
||||
}
|
||||
|
||||
private def newExecutorId(useID: Option[Int] = None): Int = {
|
||||
useID match {
|
||||
case Some(id) =>
|
||||
nextExecutorId = math.max(nextExecutorId, id + 1)
|
||||
id
|
||||
case None =>
|
||||
val id = nextExecutorId
|
||||
nextExecutorId += 1
|
||||
id
|
||||
}
|
||||
}
|
||||
|
||||
def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = {
|
||||
val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
|
||||
executors(exec.id) = exec
|
||||
coresGranted += cores
|
||||
exec
|
||||
|
|
|
@ -17,12 +17,11 @@
|
|||
|
||||
package org.apache.spark.deploy.master
|
||||
|
||||
private[spark] object ApplicationState
|
||||
extends Enumeration {
|
||||
private[spark] object ApplicationState extends Enumeration {
|
||||
|
||||
type ApplicationState = Value
|
||||
|
||||
val WAITING, RUNNING, FINISHED, FAILED = Value
|
||||
val WAITING, RUNNING, FINISHED, FAILED, UNKNOWN = Value
|
||||
|
||||
val MAX_NUM_RETRY = 10
|
||||
}
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
package org.apache.spark.deploy.master
|
||||
|
||||
import org.apache.spark.deploy.ExecutorState
|
||||
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
|
||||
|
||||
private[spark] class ExecutorInfo(
|
||||
val id: Int,
|
||||
|
@ -28,5 +28,10 @@ private[spark] class ExecutorInfo(
|
|||
|
||||
var state = ExecutorState.LAUNCHING
|
||||
|
||||
/** Copy all state (non-val) variables from the given on-the-wire ExecutorDescription. */
|
||||
def copyState(execDesc: ExecutorDescription) {
|
||||
state = execDesc.state
|
||||
}
|
||||
|
||||
def fullId: String = application.id + "/" + id
|
||||
}
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy.master
|
||||
|
||||
import java.io._
|
||||
|
||||
import scala.Serializable
|
||||
|
||||
import akka.serialization.Serialization
|
||||
import org.apache.spark.Logging
|
||||
|
||||
/**
|
||||
* Stores data in a single on-disk directory with one file per application and worker.
|
||||
* Files are deleted when applications and workers are removed.
|
||||
*
|
||||
* @param dir Directory to store files. Created if non-existent (but not recursively).
|
||||
* @param serialization Used to serialize our objects.
|
||||
*/
|
||||
private[spark] class FileSystemPersistenceEngine(
|
||||
val dir: String,
|
||||
val serialization: Serialization)
|
||||
extends PersistenceEngine with Logging {
|
||||
|
||||
new File(dir).mkdir()
|
||||
|
||||
override def addApplication(app: ApplicationInfo) {
|
||||
val appFile = new File(dir + File.separator + "app_" + app.id)
|
||||
serializeIntoFile(appFile, app)
|
||||
}
|
||||
|
||||
override def removeApplication(app: ApplicationInfo) {
|
||||
new File(dir + File.separator + "app_" + app.id).delete()
|
||||
}
|
||||
|
||||
override def addWorker(worker: WorkerInfo) {
|
||||
val workerFile = new File(dir + File.separator + "worker_" + worker.id)
|
||||
serializeIntoFile(workerFile, worker)
|
||||
}
|
||||
|
||||
override def removeWorker(worker: WorkerInfo) {
|
||||
new File(dir + File.separator + "worker_" + worker.id).delete()
|
||||
}
|
||||
|
||||
override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
|
||||
val sortedFiles = new File(dir).listFiles().sortBy(_.getName)
|
||||
val appFiles = sortedFiles.filter(_.getName.startsWith("app_"))
|
||||
val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
|
||||
val workerFiles = sortedFiles.filter(_.getName.startsWith("worker_"))
|
||||
val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
|
||||
(apps, workers)
|
||||
}
|
||||
|
||||
private def serializeIntoFile(file: File, value: AnyRef) {
|
||||
val created = file.createNewFile()
|
||||
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
|
||||
|
||||
val serializer = serialization.findSerializerFor(value)
|
||||
val serialized = serializer.toBinary(value)
|
||||
|
||||
val out = new FileOutputStream(file)
|
||||
out.write(serialized)
|
||||
out.close()
|
||||
}
|
||||
|
||||
def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
|
||||
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
|
||||
val dis = new DataInputStream(new FileInputStream(file))
|
||||
dis.readFully(fileData)
|
||||
dis.close()
|
||||
|
||||
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
|
||||
val serializer = serialization.serializerFor(clazz)
|
||||
serializer.fromBinary(fileData).asInstanceOf[T]
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy.master
|
||||
|
||||
import akka.actor.{Actor, ActorRef}
|
||||
|
||||
import org.apache.spark.deploy.master.MasterMessages.ElectedLeader
|
||||
|
||||
/**
|
||||
* A LeaderElectionAgent keeps track of whether the current Master is the leader, meaning it
|
||||
* is the only Master serving requests.
|
||||
* In addition to the API provided, the LeaderElectionAgent will use of the following messages
|
||||
* to inform the Master of leader changes:
|
||||
* [[org.apache.spark.deploy.master.MasterMessages.ElectedLeader ElectedLeader]]
|
||||
* [[org.apache.spark.deploy.master.MasterMessages.RevokedLeadership RevokedLeadership]]
|
||||
*/
|
||||
private[spark] trait LeaderElectionAgent extends Actor {
|
||||
val masterActor: ActorRef
|
||||
}
|
||||
|
||||
/** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */
|
||||
private[spark] class MonarchyLeaderAgent(val masterActor: ActorRef) extends LeaderElectionAgent {
|
||||
override def preStart() {
|
||||
masterActor ! ElectedLeader
|
||||
}
|
||||
|
||||
override def receive = {
|
||||
case _ =>
|
||||
}
|
||||
}
|
|
@ -23,41 +23,38 @@ import java.text.SimpleDateFormat
|
|||
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
|
||||
import scala.concurrent.Await
|
||||
import scala.concurrent.duration._
|
||||
import scala.concurrent.duration.{Duration, FiniteDuration}
|
||||
|
||||
import akka.actor._
|
||||
import akka.pattern.ask
|
||||
import akka.remote._
|
||||
import akka.util.Timeout
|
||||
|
||||
import org.apache.spark.{Logging, SparkException}
|
||||
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
|
||||
import org.apache.spark.deploy.DeployMessages._
|
||||
import org.apache.spark.deploy.master.MasterMessages._
|
||||
import org.apache.spark.deploy.master.ui.MasterWebUI
|
||||
import org.apache.spark.metrics.MetricsSystem
|
||||
import org.apache.spark.util.{Utils, AkkaUtils}
|
||||
import akka.util.Timeout
|
||||
import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed
|
||||
import org.apache.spark.deploy.DeployMessages.KillExecutor
|
||||
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
|
||||
import scala.Some
|
||||
import org.apache.spark.deploy.DeployMessages.WebUIPortResponse
|
||||
import org.apache.spark.deploy.DeployMessages.LaunchExecutor
|
||||
import org.apache.spark.deploy.DeployMessages.RegisteredApplication
|
||||
import org.apache.spark.deploy.DeployMessages.RegisterWorker
|
||||
import org.apache.spark.deploy.DeployMessages.ExecutorUpdated
|
||||
import org.apache.spark.deploy.DeployMessages.MasterStateResponse
|
||||
import org.apache.spark.deploy.DeployMessages.ExecutorAdded
|
||||
import org.apache.spark.deploy.DeployMessages.RegisterApplication
|
||||
import org.apache.spark.deploy.DeployMessages.ApplicationRemoved
|
||||
import org.apache.spark.deploy.DeployMessages.Heartbeat
|
||||
import org.apache.spark.deploy.DeployMessages.RegisteredWorker
|
||||
import akka.actor.Terminated
|
||||
import akka.serialization.SerializationExtension
|
||||
import java.util.concurrent.TimeUnit
|
||||
|
||||
|
||||
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
|
||||
import context.dispatcher
|
||||
|
||||
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
|
||||
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
|
||||
val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt
|
||||
val REAPER_ITERATIONS = System.getProperty("spark.dead.worker.persistence", "15").toInt
|
||||
val RECOVERY_DIR = System.getProperty("spark.deploy.recoveryDirectory", "")
|
||||
val RECOVERY_MODE = System.getProperty("spark.deploy.recoveryMode", "NONE")
|
||||
|
||||
var nextAppNumber = 0
|
||||
val workers = new HashSet[WorkerInfo]
|
||||
|
@ -88,52 +85,114 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
|
|||
if (envVar != null) envVar else host
|
||||
}
|
||||
|
||||
val masterUrl = "spark://" + host + ":" + port
|
||||
var masterWebUiUrl: String = _
|
||||
|
||||
var state = RecoveryState.STANDBY
|
||||
|
||||
var persistenceEngine: PersistenceEngine = _
|
||||
|
||||
var leaderElectionAgent: ActorRef = _
|
||||
|
||||
// As a temporary workaround before better ways of configuring memory, we allow users to set
|
||||
// a flag that will perform round-robin scheduling across the nodes (spreading out each app
|
||||
// among all the nodes) instead of trying to consolidate each app onto a small # of nodes.
|
||||
val spreadOutApps = System.getProperty("spark.deploy.spreadOut", "true").toBoolean
|
||||
|
||||
override def preStart() {
|
||||
logInfo("Starting Spark master at spark://" + host + ":" + port)
|
||||
logInfo("Starting Spark master at " + masterUrl)
|
||||
// Listen for remote client disconnection events, since they don't go through Akka's watch()
|
||||
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
|
||||
webUi.start()
|
||||
import context.dispatcher
|
||||
masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get
|
||||
context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
|
||||
|
||||
masterMetricsSystem.registerSource(masterSource)
|
||||
masterMetricsSystem.start()
|
||||
applicationMetricsSystem.start()
|
||||
|
||||
persistenceEngine = RECOVERY_MODE match {
|
||||
case "ZOOKEEPER" =>
|
||||
logInfo("Persisting recovery state to ZooKeeper")
|
||||
new ZooKeeperPersistenceEngine(SerializationExtension(context.system))
|
||||
case "FILESYSTEM" =>
|
||||
logInfo("Persisting recovery state to directory: " + RECOVERY_DIR)
|
||||
new FileSystemPersistenceEngine(RECOVERY_DIR, SerializationExtension(context.system))
|
||||
case _ =>
|
||||
new BlackHolePersistenceEngine()
|
||||
}
|
||||
|
||||
leaderElectionAgent = RECOVERY_MODE match {
|
||||
case "ZOOKEEPER" =>
|
||||
context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl))
|
||||
case _ =>
|
||||
context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
|
||||
}
|
||||
}
|
||||
|
||||
override def preRestart(reason: Throwable, message: Option[Any]) {
|
||||
super.preRestart(reason, message) // calls postStop()!
|
||||
logError("Master actor restarted due to exception", reason)
|
||||
}
|
||||
|
||||
override def postStop() {
|
||||
webUi.stop()
|
||||
masterMetricsSystem.stop()
|
||||
applicationMetricsSystem.stop()
|
||||
persistenceEngine.close()
|
||||
context.stop(leaderElectionAgent)
|
||||
}
|
||||
|
||||
override def receive = {
|
||||
case RegisterWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress) => {
|
||||
case ElectedLeader => {
|
||||
val (storedApps, storedWorkers) = persistenceEngine.readPersistedData()
|
||||
state = if (storedApps.isEmpty && storedWorkers.isEmpty)
|
||||
RecoveryState.ALIVE
|
||||
else
|
||||
RecoveryState.RECOVERING
|
||||
|
||||
logInfo("I have been elected leader! New state: " + state)
|
||||
|
||||
if (state == RecoveryState.RECOVERING) {
|
||||
beginRecovery(storedApps, storedWorkers)
|
||||
context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() }
|
||||
}
|
||||
}
|
||||
|
||||
case RevokedLeadership => {
|
||||
logError("Leadership has been revoked -- master shutting down.")
|
||||
System.exit(0)
|
||||
}
|
||||
|
||||
case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => {
|
||||
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
|
||||
host, workerPort, cores, Utils.megabytesToString(memory)))
|
||||
if (idToWorker.contains(id)) {
|
||||
if (state == RecoveryState.STANDBY) {
|
||||
// ignore, don't send response
|
||||
} else if (idToWorker.contains(id)) {
|
||||
sender ! RegisterWorkerFailed("Duplicate worker ID")
|
||||
} else {
|
||||
addWorker(id, host, workerPort, cores, memory, worker_webUiPort, publicAddress)
|
||||
context.watch(sender) // This doesn't work with remote actors but helps for testing
|
||||
sender ! RegisteredWorker("http://" + masterPublicAddress + ":" + webUi.boundPort.get)
|
||||
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
|
||||
registerWorker(worker)
|
||||
context.watch(sender)
|
||||
persistenceEngine.addWorker(worker)
|
||||
sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
|
||||
schedule()
|
||||
}
|
||||
}
|
||||
|
||||
case RegisterApplication(description) => {
|
||||
logInfo("Registering app " + description.name)
|
||||
val app = addApplication(description, sender)
|
||||
logInfo("Registered app " + description.name + " with ID " + app.id)
|
||||
waitingApps += app
|
||||
context.watch(sender) // This doesn't work with remote actors but helps for testing
|
||||
sender ! RegisteredApplication(app.id)
|
||||
schedule()
|
||||
if (state == RecoveryState.STANDBY) {
|
||||
// ignore, don't send response
|
||||
} else {
|
||||
logInfo("Registering app " + description.name)
|
||||
val app = createApplication(description, sender)
|
||||
registerApplication(app)
|
||||
logInfo("Registered app " + description.name + " with ID " + app.id)
|
||||
context.watch(sender)
|
||||
persistenceEngine.addApplication(app)
|
||||
sender ! RegisteredApplication(app.id, masterUrl)
|
||||
schedule()
|
||||
}
|
||||
}
|
||||
|
||||
case ExecutorStateChanged(appId, execId, state, message, exitStatus) => {
|
||||
|
@ -173,27 +232,49 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
|
|||
}
|
||||
}
|
||||
|
||||
case MasterChangeAcknowledged(appId) => {
|
||||
idToApp.get(appId) match {
|
||||
case Some(app) =>
|
||||
logInfo("Application has been re-registered: " + appId)
|
||||
app.state = ApplicationState.WAITING
|
||||
case None =>
|
||||
logWarning("Master change ack from unknown app: " + appId)
|
||||
}
|
||||
|
||||
if (canCompleteRecovery) { completeRecovery() }
|
||||
}
|
||||
|
||||
case WorkerSchedulerStateResponse(workerId, executors) => {
|
||||
idToWorker.get(workerId) match {
|
||||
case Some(worker) =>
|
||||
logInfo("Worker has been re-registered: " + workerId)
|
||||
worker.state = WorkerState.ALIVE
|
||||
|
||||
val validExecutors = executors.filter(exec => idToApp.get(exec.appId).isDefined)
|
||||
for (exec <- validExecutors) {
|
||||
val app = idToApp.get(exec.appId).get
|
||||
val execInfo = app.addExecutor(worker, exec.cores, Some(exec.execId))
|
||||
worker.addExecutor(execInfo)
|
||||
execInfo.copyState(exec)
|
||||
}
|
||||
case None =>
|
||||
logWarning("Scheduler state from unknown worker: " + workerId)
|
||||
}
|
||||
|
||||
if (canCompleteRecovery) { completeRecovery() }
|
||||
}
|
||||
|
||||
case Terminated(actor) => {
|
||||
// The disconnected actor could've been either a worker or an app; remove whichever of
|
||||
// those we have an entry for in the corresponding actor hashmap
|
||||
actorToWorker.get(actor).foreach(removeWorker)
|
||||
actorToApp.get(actor).foreach(finishApplication)
|
||||
}
|
||||
|
||||
case DisassociatedEvent(_, address, _) => {
|
||||
// The disconnected client could've been either a worker or an app; remove whichever it was
|
||||
addressToWorker.get(address).foreach(removeWorker)
|
||||
addressToApp.get(address).foreach(finishApplication)
|
||||
}
|
||||
|
||||
case AssociationErrorEvent(_, _, address, _) => {
|
||||
// The disconnected client could've been either a worker or an app; remove whichever it was
|
||||
addressToWorker.get(address).foreach(removeWorker)
|
||||
addressToApp.get(address).foreach(finishApplication)
|
||||
if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
|
||||
}
|
||||
|
||||
case RequestMasterState => {
|
||||
sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray)
|
||||
sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray,
|
||||
state)
|
||||
}
|
||||
|
||||
case CheckForWorkerTimeOut => {
|
||||
|
@ -205,6 +286,50 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
|
|||
}
|
||||
}
|
||||
|
||||
def canCompleteRecovery =
|
||||
workers.count(_.state == WorkerState.UNKNOWN) == 0 &&
|
||||
apps.count(_.state == ApplicationState.UNKNOWN) == 0
|
||||
|
||||
def beginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo]) {
|
||||
for (app <- storedApps) {
|
||||
logInfo("Trying to recover app: " + app.id)
|
||||
try {
|
||||
registerApplication(app)
|
||||
app.state = ApplicationState.UNKNOWN
|
||||
app.driver ! MasterChanged(masterUrl, masterWebUiUrl)
|
||||
} catch {
|
||||
case e: Exception => logInfo("App " + app.id + " had exception on reconnect")
|
||||
}
|
||||
}
|
||||
|
||||
for (worker <- storedWorkers) {
|
||||
logInfo("Trying to recover worker: " + worker.id)
|
||||
try {
|
||||
registerWorker(worker)
|
||||
worker.state = WorkerState.UNKNOWN
|
||||
worker.actor ! MasterChanged(masterUrl, masterWebUiUrl)
|
||||
} catch {
|
||||
case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def completeRecovery() {
|
||||
// Ensure "only-once" recovery semantics using a short synchronization period.
|
||||
synchronized {
|
||||
if (state != RecoveryState.RECOVERING) { return }
|
||||
state = RecoveryState.COMPLETING_RECOVERY
|
||||
}
|
||||
|
||||
// Kill off any workers and apps that didn't respond to us.
|
||||
workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker)
|
||||
apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication)
|
||||
|
||||
state = RecoveryState.ALIVE
|
||||
schedule()
|
||||
logInfo("Recovery complete - resuming operations!")
|
||||
}
|
||||
|
||||
/**
|
||||
* Can an app use the given worker? True if the worker has enough memory and we haven't already
|
||||
* launched an executor for the app on it (right now the standalone backend doesn't like having
|
||||
|
@ -219,6 +344,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
|
|||
* every time a new app joins or resource availability changes.
|
||||
*/
|
||||
def schedule() {
|
||||
if (state != RecoveryState.ALIVE) { return }
|
||||
// Right now this is a very simple FIFO scheduler. We keep trying to fit in the first app
|
||||
// in the queue, then the second app, etc.
|
||||
if (spreadOutApps) {
|
||||
|
@ -266,14 +392,13 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
|
|||
def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo, sparkHome: String) {
|
||||
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
|
||||
worker.addExecutor(exec)
|
||||
worker.actor ! LaunchExecutor(
|
||||
worker.actor ! LaunchExecutor(masterUrl,
|
||||
exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory, sparkHome)
|
||||
exec.application.driver ! ExecutorAdded(
|
||||
exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)
|
||||
}
|
||||
|
||||
def addWorker(id: String, host: String, port: Int, cores: Int, memory: Int, webUiPort: Int,
|
||||
publicAddress: String): WorkerInfo = {
|
||||
def registerWorker(worker: WorkerInfo): Unit = {
|
||||
// There may be one or more refs to dead workers on this same node (w/ different ID's),
|
||||
// remove them.
|
||||
workers.filter { w =>
|
||||
|
@ -281,12 +406,17 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
|
|||
}.foreach { w =>
|
||||
workers -= w
|
||||
}
|
||||
val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
|
||||
|
||||
val workerAddress = worker.actor.path.address
|
||||
if (addressToWorker.contains(workerAddress)) {
|
||||
logInfo("Attempted to re-register worker at same address: " + workerAddress)
|
||||
return
|
||||
}
|
||||
|
||||
workers += worker
|
||||
idToWorker(worker.id) = worker
|
||||
actorToWorker(sender) = worker
|
||||
addressToWorker(sender.path.address) = worker
|
||||
worker
|
||||
actorToWorker(worker.actor) = worker
|
||||
addressToWorker(workerAddress) = worker
|
||||
}
|
||||
|
||||
def removeWorker(worker: WorkerInfo) {
|
||||
|
@ -301,25 +431,36 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
|
|||
exec.id, ExecutorState.LOST, Some("worker lost"), None)
|
||||
exec.application.removeExecutor(exec)
|
||||
}
|
||||
persistenceEngine.removeWorker(worker)
|
||||
}
|
||||
|
||||
def addApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
|
||||
def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = {
|
||||
val now = System.currentTimeMillis()
|
||||
val date = new Date(now)
|
||||
val app = new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
|
||||
new ApplicationInfo(now, newApplicationId(date), desc, date, driver, desc.appUiUrl)
|
||||
}
|
||||
|
||||
def registerApplication(app: ApplicationInfo): Unit = {
|
||||
val appAddress = app.driver.path.address
|
||||
if (addressToWorker.contains(appAddress)) {
|
||||
logInfo("Attempted to re-register application at same address: " + appAddress)
|
||||
return
|
||||
}
|
||||
|
||||
applicationMetricsSystem.registerSource(app.appSource)
|
||||
apps += app
|
||||
idToApp(app.id) = app
|
||||
actorToApp(driver) = app
|
||||
addressToApp(driver.path.address) = app
|
||||
actorToApp(app.driver) = app
|
||||
addressToApp(appAddress) = app
|
||||
if (firstApp == None) {
|
||||
firstApp = Some(app)
|
||||
}
|
||||
// TODO: What is firstApp?? Can we remove it?
|
||||
val workersAlive = workers.filter(_.state == WorkerState.ALIVE).toArray
|
||||
if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= desc.memoryPerSlave)) {
|
||||
if (workersAlive.size > 0 && !workersAlive.exists(_.memoryFree >= app.desc.memoryPerSlave)) {
|
||||
logWarning("Could not find any workers with enough memory for " + firstApp.get.id)
|
||||
}
|
||||
app
|
||||
waitingApps += app
|
||||
}
|
||||
|
||||
def finishApplication(app: ApplicationInfo) {
|
||||
|
@ -344,13 +485,14 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
|
|||
waitingApps -= app
|
||||
for (exec <- app.executors.values) {
|
||||
exec.worker.removeExecutor(exec)
|
||||
exec.worker.actor ! KillExecutor(exec.application.id, exec.id)
|
||||
exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id)
|
||||
exec.state = ExecutorState.KILLED
|
||||
}
|
||||
app.markFinished(state)
|
||||
if (state != ApplicationState.FINISHED) {
|
||||
app.driver ! ApplicationRemoved(state.toString)
|
||||
}
|
||||
persistenceEngine.removeApplication(app)
|
||||
schedule()
|
||||
}
|
||||
}
|
||||
|
@ -404,8 +546,8 @@ private[spark] object Master {
|
|||
def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = {
|
||||
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
|
||||
val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), name = actorName)
|
||||
val timeoutDuration = Duration.create(
|
||||
System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
|
||||
val timeoutDuration: FiniteDuration = Duration.create(
|
||||
System.getProperty("spark.akka.askTimeout", "10").toLong, TimeUnit.SECONDS)
|
||||
implicit val timeout = Timeout(timeoutDuration)
|
||||
val respFuture = actor ? RequestWebUIPort // ask pattern
|
||||
val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse]
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy.master
|
||||
|
||||
sealed trait MasterMessages extends Serializable
|
||||
|
||||
/** Contains messages seen only by the Master and its associated entities. */
|
||||
private[master] object MasterMessages {
|
||||
|
||||
// LeaderElectionAgent to Master
|
||||
|
||||
case object ElectedLeader
|
||||
|
||||
case object RevokedLeadership
|
||||
|
||||
// Actor System to LeaderElectionAgent
|
||||
|
||||
case object CheckLeader
|
||||
|
||||
// Actor System to Master
|
||||
|
||||
case object CheckForWorkerTimeOut
|
||||
|
||||
case class BeginRecovery(storedApps: Seq[ApplicationInfo], storedWorkers: Seq[WorkerInfo])
|
||||
|
||||
case object CompleteRecovery
|
||||
|
||||
case object RequestWebUIPort
|
||||
|
||||
case class WebUIPortResponse(webUIBoundPort: Int)
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy.master
|
||||
|
||||
/**
|
||||
* Allows Master to persist any state that is necessary in order to recover from a failure.
|
||||
* The following semantics are required:
|
||||
* - addApplication and addWorker are called before completing registration of a new app/worker.
|
||||
* - removeApplication and removeWorker are called at any time.
|
||||
* Given these two requirements, we will have all apps and workers persisted, but
|
||||
* we might not have yet deleted apps or workers that finished (so their liveness must be verified
|
||||
* during recovery).
|
||||
*/
|
||||
private[spark] trait PersistenceEngine {
|
||||
def addApplication(app: ApplicationInfo)
|
||||
|
||||
def removeApplication(app: ApplicationInfo)
|
||||
|
||||
def addWorker(worker: WorkerInfo)
|
||||
|
||||
def removeWorker(worker: WorkerInfo)
|
||||
|
||||
/**
|
||||
* Returns the persisted data sorted by their respective ids (which implies that they're
|
||||
* sorted by time of creation).
|
||||
*/
|
||||
def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo])
|
||||
|
||||
def close() {}
|
||||
}
|
||||
|
||||
private[spark] class BlackHolePersistenceEngine extends PersistenceEngine {
|
||||
override def addApplication(app: ApplicationInfo) {}
|
||||
override def removeApplication(app: ApplicationInfo) {}
|
||||
override def addWorker(worker: WorkerInfo) {}
|
||||
override def removeWorker(worker: WorkerInfo) {}
|
||||
override def readPersistedData() = (Nil, Nil)
|
||||
}
|
|
@ -0,0 +1,26 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy.master
|
||||
|
||||
private[spark] object RecoveryState
|
||||
extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") {
|
||||
|
||||
type MasterState = Value
|
||||
|
||||
val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value
|
||||
}
|
|
@ -0,0 +1,203 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy.master
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.concurrent.ops._
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.zookeeper._
|
||||
import org.apache.zookeeper.data.Stat
|
||||
import org.apache.zookeeper.Watcher.Event.KeeperState
|
||||
|
||||
/**
|
||||
* Provides a Scala-side interface to the standard ZooKeeper client, with the addition of retry
|
||||
* logic. If the ZooKeeper session expires or otherwise dies, a new ZooKeeper session will be
|
||||
* created. If ZooKeeper remains down after several retries, the given
|
||||
* [[org.apache.spark.deploy.master.SparkZooKeeperWatcher SparkZooKeeperWatcher]] will be
|
||||
* informed via zkDown().
|
||||
*
|
||||
* Additionally, all commands sent to ZooKeeper will be retried until they either fail too many
|
||||
* times or a semantic exception is thrown (e.g.., "node already exists").
|
||||
*/
|
||||
private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher) extends Logging {
|
||||
val ZK_URL = System.getProperty("spark.deploy.zookeeper.url", "")
|
||||
|
||||
val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE
|
||||
val ZK_TIMEOUT_MILLIS = 30000
|
||||
val RETRY_WAIT_MILLIS = 5000
|
||||
val ZK_CHECK_PERIOD_MILLIS = 10000
|
||||
val MAX_RECONNECT_ATTEMPTS = 3
|
||||
|
||||
private var zk: ZooKeeper = _
|
||||
|
||||
private val watcher = new ZooKeeperWatcher()
|
||||
private var reconnectAttempts = 0
|
||||
private var closed = false
|
||||
|
||||
/** Connect to ZooKeeper to start the session. Must be called before anything else. */
|
||||
def connect() {
|
||||
connectToZooKeeper()
|
||||
|
||||
new Thread() {
|
||||
override def run() = sessionMonitorThread()
|
||||
}.start()
|
||||
}
|
||||
|
||||
def sessionMonitorThread(): Unit = {
|
||||
while (!closed) {
|
||||
Thread.sleep(ZK_CHECK_PERIOD_MILLIS)
|
||||
if (zk.getState != ZooKeeper.States.CONNECTED) {
|
||||
reconnectAttempts += 1
|
||||
val attemptsLeft = MAX_RECONNECT_ATTEMPTS - reconnectAttempts
|
||||
if (attemptsLeft <= 0) {
|
||||
logError("Could not connect to ZooKeeper: system failure")
|
||||
zkWatcher.zkDown()
|
||||
close()
|
||||
} else {
|
||||
logWarning("ZooKeeper connection failed, retrying " + attemptsLeft + " more times...")
|
||||
connectToZooKeeper()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def close() {
|
||||
if (!closed && zk != null) { zk.close() }
|
||||
closed = true
|
||||
}
|
||||
|
||||
private def connectToZooKeeper() {
|
||||
if (zk != null) zk.close()
|
||||
zk = new ZooKeeper(ZK_URL, ZK_TIMEOUT_MILLIS, watcher)
|
||||
}
|
||||
|
||||
/**
|
||||
* Attempts to maintain a live ZooKeeper exception despite (very) transient failures.
|
||||
* Mainly useful for handling the natural ZooKeeper session expiration.
|
||||
*/
|
||||
private class ZooKeeperWatcher extends Watcher {
|
||||
def process(event: WatchedEvent) {
|
||||
if (closed) { return }
|
||||
|
||||
event.getState match {
|
||||
case KeeperState.SyncConnected =>
|
||||
reconnectAttempts = 0
|
||||
zkWatcher.zkSessionCreated()
|
||||
case KeeperState.Expired =>
|
||||
connectToZooKeeper()
|
||||
case KeeperState.Disconnected =>
|
||||
logWarning("ZooKeeper disconnected, will retry...")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def create(path: String, bytes: Array[Byte], createMode: CreateMode): String = {
|
||||
retry {
|
||||
zk.create(path, bytes, ZK_ACL, createMode)
|
||||
}
|
||||
}
|
||||
|
||||
def exists(path: String, watcher: Watcher = null): Stat = {
|
||||
retry {
|
||||
zk.exists(path, watcher)
|
||||
}
|
||||
}
|
||||
|
||||
def getChildren(path: String, watcher: Watcher = null): List[String] = {
|
||||
retry {
|
||||
zk.getChildren(path, watcher).toList
|
||||
}
|
||||
}
|
||||
|
||||
def getData(path: String): Array[Byte] = {
|
||||
retry {
|
||||
zk.getData(path, false, null)
|
||||
}
|
||||
}
|
||||
|
||||
def delete(path: String, version: Int = -1): Unit = {
|
||||
retry {
|
||||
zk.delete(path, version)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates the given directory (non-recursively) if it doesn't exist.
|
||||
* All znodes are created in PERSISTENT mode with no data.
|
||||
*/
|
||||
def mkdir(path: String) {
|
||||
if (exists(path) == null) {
|
||||
try {
|
||||
create(path, "".getBytes, CreateMode.PERSISTENT)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
// If the exception caused the directory not to be created, bubble it up,
|
||||
// otherwise ignore it.
|
||||
if (exists(path) == null) { throw e }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursively creates all directories up to the given one.
|
||||
* All znodes are created in PERSISTENT mode with no data.
|
||||
*/
|
||||
def mkdirRecursive(path: String) {
|
||||
var fullDir = ""
|
||||
for (dentry <- path.split("/").tail) {
|
||||
fullDir += "/" + dentry
|
||||
mkdir(fullDir)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retries the given function up to 3 times. The assumption is that failure is transient,
|
||||
* UNLESS it is a semantic exception (i.e., trying to get data from a node that doesn't exist),
|
||||
* in which case the exception will be thrown without retries.
|
||||
*
|
||||
* @param fn Block to execute, possibly multiple times.
|
||||
*/
|
||||
def retry[T](fn: => T, n: Int = MAX_RECONNECT_ATTEMPTS): T = {
|
||||
try {
|
||||
fn
|
||||
} catch {
|
||||
case e: KeeperException.NoNodeException => throw e
|
||||
case e: KeeperException.NodeExistsException => throw e
|
||||
case e if n > 0 =>
|
||||
logError("ZooKeeper exception, " + n + " more retries...", e)
|
||||
Thread.sleep(RETRY_WAIT_MILLIS)
|
||||
retry(fn, n-1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait SparkZooKeeperWatcher {
|
||||
/**
|
||||
* Called whenever a ZK session is created --
|
||||
* this will occur when we create our first session as well as each time
|
||||
* the session expires or errors out.
|
||||
*/
|
||||
def zkSessionCreated()
|
||||
|
||||
/**
|
||||
* Called if ZK appears to be completely down (i.e., not just a transient error).
|
||||
* We will no longer attempt to reconnect to ZK, and the SparkZooKeeperSession is considered dead.
|
||||
*/
|
||||
def zkDown()
|
||||
}
|
|
@ -22,28 +22,44 @@ import scala.collection.mutable
|
|||
import org.apache.spark.util.Utils
|
||||
|
||||
private[spark] class WorkerInfo(
|
||||
val id: String,
|
||||
val host: String,
|
||||
val port: Int,
|
||||
val cores: Int,
|
||||
val memory: Int,
|
||||
val actor: ActorRef,
|
||||
val webUiPort: Int,
|
||||
val publicAddress: String) {
|
||||
val id: String,
|
||||
val host: String,
|
||||
val port: Int,
|
||||
val cores: Int,
|
||||
val memory: Int,
|
||||
val actor: ActorRef,
|
||||
val webUiPort: Int,
|
||||
val publicAddress: String)
|
||||
extends Serializable {
|
||||
|
||||
Utils.checkHost(host, "Expected hostname")
|
||||
assert (port > 0)
|
||||
|
||||
var executors = new mutable.HashMap[String, ExecutorInfo] // fullId => info
|
||||
var state: WorkerState.Value = WorkerState.ALIVE
|
||||
var coresUsed = 0
|
||||
var memoryUsed = 0
|
||||
@transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // fullId => info
|
||||
@transient var state: WorkerState.Value = _
|
||||
@transient var coresUsed: Int = _
|
||||
@transient var memoryUsed: Int = _
|
||||
|
||||
var lastHeartbeat = System.currentTimeMillis()
|
||||
@transient var lastHeartbeat: Long = _
|
||||
|
||||
init()
|
||||
|
||||
def coresFree: Int = cores - coresUsed
|
||||
def memoryFree: Int = memory - memoryUsed
|
||||
|
||||
private def readObject(in: java.io.ObjectInputStream) : Unit = {
|
||||
in.defaultReadObject()
|
||||
init()
|
||||
}
|
||||
|
||||
private def init() {
|
||||
executors = new mutable.HashMap
|
||||
state = WorkerState.ALIVE
|
||||
coresUsed = 0
|
||||
memoryUsed = 0
|
||||
lastHeartbeat = System.currentTimeMillis()
|
||||
}
|
||||
|
||||
def hostPort: String = {
|
||||
assert (port > 0)
|
||||
host + ":" + port
|
||||
|
|
|
@ -20,5 +20,5 @@ package org.apache.spark.deploy.master
|
|||
private[spark] object WorkerState extends Enumeration {
|
||||
type WorkerState = Value
|
||||
|
||||
val ALIVE, DEAD, DECOMMISSIONED = Value
|
||||
val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value
|
||||
}
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy.master
|
||||
|
||||
import akka.actor.ActorRef
|
||||
import org.apache.zookeeper._
|
||||
import org.apache.zookeeper.Watcher.Event.EventType
|
||||
|
||||
import org.apache.spark.deploy.master.MasterMessages._
|
||||
import org.apache.spark.Logging
|
||||
|
||||
private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String)
|
||||
extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging {
|
||||
|
||||
val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/leader_election"
|
||||
|
||||
private val watcher = new ZooKeeperWatcher()
|
||||
private val zk = new SparkZooKeeperSession(this)
|
||||
private var status = LeadershipStatus.NOT_LEADER
|
||||
private var myLeaderFile: String = _
|
||||
private var leaderUrl: String = _
|
||||
|
||||
override def preStart() {
|
||||
logInfo("Starting ZooKeeper LeaderElection agent")
|
||||
zk.connect()
|
||||
}
|
||||
|
||||
override def zkSessionCreated() {
|
||||
synchronized {
|
||||
zk.mkdirRecursive(WORKING_DIR)
|
||||
myLeaderFile =
|
||||
zk.create(WORKING_DIR + "/master_", masterUrl.getBytes, CreateMode.EPHEMERAL_SEQUENTIAL)
|
||||
self ! CheckLeader
|
||||
}
|
||||
}
|
||||
|
||||
override def preRestart(reason: scala.Throwable, message: scala.Option[scala.Any]) {
|
||||
logError("LeaderElectionAgent failed, waiting " + zk.ZK_TIMEOUT_MILLIS + "...", reason)
|
||||
Thread.sleep(zk.ZK_TIMEOUT_MILLIS)
|
||||
super.preRestart(reason, message)
|
||||
}
|
||||
|
||||
override def zkDown() {
|
||||
logError("ZooKeeper down! LeaderElectionAgent shutting down Master.")
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
override def postStop() {
|
||||
zk.close()
|
||||
}
|
||||
|
||||
override def receive = {
|
||||
case CheckLeader => checkLeader()
|
||||
}
|
||||
|
||||
private class ZooKeeperWatcher extends Watcher {
|
||||
def process(event: WatchedEvent) {
|
||||
if (event.getType == EventType.NodeDeleted) {
|
||||
logInfo("Leader file disappeared, a master is down!")
|
||||
self ! CheckLeader
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Uses ZK leader election. Navigates several ZK potholes along the way. */
|
||||
def checkLeader() {
|
||||
val masters = zk.getChildren(WORKING_DIR).toList
|
||||
val leader = masters.sorted.head
|
||||
val leaderFile = WORKING_DIR + "/" + leader
|
||||
|
||||
// Setup a watch for the current leader.
|
||||
zk.exists(leaderFile, watcher)
|
||||
|
||||
try {
|
||||
leaderUrl = new String(zk.getData(leaderFile))
|
||||
} catch {
|
||||
// A NoNodeException may be thrown if old leader died since the start of this method call.
|
||||
// This is fine -- just check again, since we're guaranteed to see the new values.
|
||||
case e: KeeperException.NoNodeException =>
|
||||
logInfo("Leader disappeared while reading it -- finding next leader")
|
||||
checkLeader()
|
||||
return
|
||||
}
|
||||
|
||||
// Synchronization used to ensure no interleaving between the creation of a new session and the
|
||||
// checking of a leader, which could cause us to delete our real leader file erroneously.
|
||||
synchronized {
|
||||
val isLeader = myLeaderFile == leaderFile
|
||||
if (!isLeader && leaderUrl == masterUrl) {
|
||||
// We found a different master file pointing to this process.
|
||||
// This can happen in the following two cases:
|
||||
// (1) The master process was restarted on the same node.
|
||||
// (2) The ZK server died between creating the node and returning the name of the node.
|
||||
// For this case, we will end up creating a second file, and MUST explicitly delete the
|
||||
// first one, since our ZK session is still open.
|
||||
// Note that this deletion will cause a NodeDeleted event to be fired so we check again for
|
||||
// leader changes.
|
||||
assert(leaderFile < myLeaderFile)
|
||||
logWarning("Cleaning up old ZK master election file that points to this master.")
|
||||
zk.delete(leaderFile)
|
||||
} else {
|
||||
updateLeadershipStatus(isLeader)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def updateLeadershipStatus(isLeader: Boolean) {
|
||||
if (isLeader && status == LeadershipStatus.NOT_LEADER) {
|
||||
status = LeadershipStatus.LEADER
|
||||
masterActor ! ElectedLeader
|
||||
} else if (!isLeader && status == LeadershipStatus.LEADER) {
|
||||
status = LeadershipStatus.NOT_LEADER
|
||||
masterActor ! RevokedLeadership
|
||||
}
|
||||
}
|
||||
|
||||
private object LeadershipStatus extends Enumeration {
|
||||
type LeadershipStatus = Value
|
||||
val LEADER, NOT_LEADER = Value
|
||||
}
|
||||
}
|
|
@ -0,0 +1,85 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.deploy.master
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.zookeeper._
|
||||
|
||||
import akka.serialization.Serialization
|
||||
|
||||
class ZooKeeperPersistenceEngine(serialization: Serialization)
|
||||
extends PersistenceEngine
|
||||
with SparkZooKeeperWatcher
|
||||
with Logging
|
||||
{
|
||||
val WORKING_DIR = System.getProperty("spark.deploy.zookeeper.dir", "/spark") + "/master_status"
|
||||
|
||||
val zk = new SparkZooKeeperSession(this)
|
||||
|
||||
zk.connect()
|
||||
|
||||
override def zkSessionCreated() {
|
||||
zk.mkdirRecursive(WORKING_DIR)
|
||||
}
|
||||
|
||||
override def zkDown() {
|
||||
logError("PersistenceEngine disconnected from ZooKeeper -- ZK looks down.")
|
||||
}
|
||||
|
||||
override def addApplication(app: ApplicationInfo) {
|
||||
serializeIntoFile(WORKING_DIR + "/app_" + app.id, app)
|
||||
}
|
||||
|
||||
override def removeApplication(app: ApplicationInfo) {
|
||||
zk.delete(WORKING_DIR + "/app_" + app.id)
|
||||
}
|
||||
|
||||
override def addWorker(worker: WorkerInfo) {
|
||||
serializeIntoFile(WORKING_DIR + "/worker_" + worker.id, worker)
|
||||
}
|
||||
|
||||
override def removeWorker(worker: WorkerInfo) {
|
||||
zk.delete(WORKING_DIR + "/worker_" + worker.id)
|
||||
}
|
||||
|
||||
override def close() {
|
||||
zk.close()
|
||||
}
|
||||
|
||||
override def readPersistedData(): (Seq[ApplicationInfo], Seq[WorkerInfo]) = {
|
||||
val sortedFiles = zk.getChildren(WORKING_DIR).toList.sorted
|
||||
val appFiles = sortedFiles.filter(_.startsWith("app_"))
|
||||
val apps = appFiles.map(deserializeFromFile[ApplicationInfo])
|
||||
val workerFiles = sortedFiles.filter(_.startsWith("worker_"))
|
||||
val workers = workerFiles.map(deserializeFromFile[WorkerInfo])
|
||||
(apps, workers)
|
||||
}
|
||||
|
||||
private def serializeIntoFile(path: String, value: AnyRef) {
|
||||
val serializer = serialization.findSerializerFor(value)
|
||||
val serialized = serializer.toBinary(value)
|
||||
zk.create(path, serialized, CreateMode.PERSISTENT)
|
||||
}
|
||||
|
||||
def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = {
|
||||
val fileData = zk.getData("/spark/master_status/" + filename)
|
||||
val clazz = m.runtimeClass.asInstanceOf[Class[T]]
|
||||
val serializer = serialization.serializerFor(clazz)
|
||||
serializer.fromBinary(fileData).asInstanceOf[T]
|
||||
}
|
||||
}
|
|
@ -43,7 +43,8 @@ private[spark] class ExecutorRunner(
|
|||
val workerId: String,
|
||||
val host: String,
|
||||
val sparkHome: File,
|
||||
val workDir: File)
|
||||
val workDir: File,
|
||||
var state: ExecutorState.Value)
|
||||
extends Logging {
|
||||
|
||||
val fullId = appId + "/" + execId
|
||||
|
@ -83,7 +84,8 @@ private[spark] class ExecutorRunner(
|
|||
process.destroy()
|
||||
process.waitFor()
|
||||
}
|
||||
worker ! ExecutorStateChanged(appId, execId, ExecutorState.KILLED, None, None)
|
||||
state = ExecutorState.KILLED
|
||||
worker ! ExecutorStateChanged(appId, execId, state, None, None)
|
||||
Runtime.getRuntime.removeShutdownHook(shutdownHook)
|
||||
}
|
||||
}
|
||||
|
@ -102,7 +104,7 @@ private[spark] class ExecutorRunner(
|
|||
// SPARK-698: do not call the run.cmd script, as process.destroy()
|
||||
// fails to kill a process tree on Windows
|
||||
Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
|
||||
command.arguments.map(substituteVariables)
|
||||
(command.arguments ++ Seq(appId)).map(substituteVariables)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -180,9 +182,9 @@ private[spark] class ExecutorRunner(
|
|||
// long-lived processes only. However, in the future, we might restart the executor a few
|
||||
// times on the same machine.
|
||||
val exitCode = process.waitFor()
|
||||
state = ExecutorState.FAILED
|
||||
val message = "Command exited with code " + exitCode
|
||||
worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message),
|
||||
Some(exitCode))
|
||||
worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))
|
||||
} catch {
|
||||
case interrupted: InterruptedException =>
|
||||
logInfo("Runner thread for executor " + fullId + " interrupted")
|
||||
|
@ -192,8 +194,9 @@ private[spark] class ExecutorRunner(
|
|||
if (process != null) {
|
||||
process.destroy()
|
||||
}
|
||||
state = ExecutorState.FAILED
|
||||
val message = e.getClass + ": " + e.getMessage
|
||||
worker ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(message), None)
|
||||
worker ! ExecutorStateChanged(appId, execId, state, Some(message), None)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,10 +25,8 @@ import scala.collection.mutable.HashMap
|
|||
import scala.concurrent.duration._
|
||||
|
||||
import akka.actor._
|
||||
import akka.remote.{RemotingLifecycleEvent, AssociationErrorEvent, DisassociatedEvent}
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.deploy.ExecutorState
|
||||
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
|
||||
import org.apache.spark.deploy.DeployMessages._
|
||||
import org.apache.spark.deploy.master.Master
|
||||
import org.apache.spark.deploy.worker.ui.WorkerWebUI
|
||||
|
@ -45,16 +43,19 @@ import akka.remote.DisassociatedEvent
|
|||
import org.apache.spark.deploy.DeployMessages.LaunchExecutor
|
||||
import org.apache.spark.deploy.DeployMessages.RegisterWorker
|
||||
|
||||
|
||||
/**
|
||||
* @param masterUrls Each url should look like spark://host:port.
|
||||
*/
|
||||
private[spark] class Worker(
|
||||
host: String,
|
||||
port: Int,
|
||||
webUiPort: Int,
|
||||
cores: Int,
|
||||
memory: Int,
|
||||
masterUrl: String,
|
||||
masterUrls: Array[String],
|
||||
workDirPath: String = null)
|
||||
extends Actor with Logging {
|
||||
import context.dispatcher
|
||||
|
||||
Utils.checkHost(host, "Expected hostname")
|
||||
assert (port > 0)
|
||||
|
@ -64,8 +65,19 @@ private[spark] class Worker(
|
|||
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
|
||||
val HEARTBEAT_MILLIS = System.getProperty("spark.worker.timeout", "60").toLong * 1000 / 4
|
||||
|
||||
val REGISTRATION_TIMEOUT = 20.seconds
|
||||
val REGISTRATION_RETRIES = 3
|
||||
|
||||
// Index into masterUrls that we're currently trying to register with.
|
||||
var masterIndex = 0
|
||||
|
||||
val masterLock: Object = new Object()
|
||||
var master: ActorSelection = null
|
||||
var masterWebUiUrl : String = ""
|
||||
var prevMaster: ActorRef = null
|
||||
var activeMasterUrl: String = ""
|
||||
var activeMasterWebUiUrl : String = ""
|
||||
@volatile var registered = false
|
||||
@volatile var connected = false
|
||||
val workerId = generateWorkerId()
|
||||
var sparkHome: File = null
|
||||
var workDir: File = null
|
||||
|
@ -105,6 +117,7 @@ private[spark] class Worker(
|
|||
}
|
||||
|
||||
override def preStart() {
|
||||
assert(!registered)
|
||||
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
|
||||
host, port, cores, Utils.megabytesToString(memory)))
|
||||
sparkHome = new File(Option(System.getenv("SPARK_HOME")).getOrElse("."))
|
||||
|
@ -113,46 +126,99 @@ private[spark] class Worker(
|
|||
webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
|
||||
|
||||
webUi.start()
|
||||
connectToMaster()
|
||||
registerWithMaster()
|
||||
|
||||
metricsSystem.registerSource(workerSource)
|
||||
metricsSystem.start()
|
||||
}
|
||||
|
||||
def connectToMaster() {
|
||||
logInfo("Connecting to master " + masterUrl)
|
||||
master = context.actorSelection(Master.toAkkaUrl(masterUrl))
|
||||
master ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress)
|
||||
def changeMaster(url: String, uiUrl: String) {
|
||||
masterLock.synchronized {
|
||||
activeMasterUrl = url
|
||||
activeMasterWebUiUrl = uiUrl
|
||||
master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
|
||||
connected = true
|
||||
}
|
||||
}
|
||||
|
||||
import context.dispatcher
|
||||
def tryRegisterAllMasters() {
|
||||
for (masterUrl <- masterUrls) {
|
||||
logInfo("Connecting to master " + masterUrl + "...")
|
||||
val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
|
||||
actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get,
|
||||
publicAddress)
|
||||
}
|
||||
}
|
||||
|
||||
def registerWithMaster() {
|
||||
tryRegisterAllMasters()
|
||||
|
||||
var retries = 0
|
||||
lazy val retryTimer: Cancellable =
|
||||
context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
|
||||
retries += 1
|
||||
if (registered) {
|
||||
retryTimer.cancel()
|
||||
} else if (retries >= REGISTRATION_RETRIES) {
|
||||
logError("All masters are unresponsive! Giving up.")
|
||||
System.exit(1)
|
||||
} else {
|
||||
tryRegisterAllMasters()
|
||||
}
|
||||
}
|
||||
retryTimer // start timer
|
||||
}
|
||||
|
||||
override def receive = {
|
||||
case RegisteredWorker(url) =>
|
||||
masterWebUiUrl = url
|
||||
logInfo("Successfully registered with master")
|
||||
case RegisteredWorker(masterUrl, masterWebUiUrl) =>
|
||||
logInfo("Successfully registered with master " + masterUrl)
|
||||
registered = true
|
||||
context.watch(sender) // remote death watch for master
|
||||
//TODO: Is heartbeat really necessary akka does it anyway !
|
||||
context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis) {
|
||||
master ! Heartbeat(workerId)
|
||||
prevMaster = sender
|
||||
changeMaster(masterUrl, masterWebUiUrl)
|
||||
context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat)
|
||||
|
||||
case SendHeartbeat =>
|
||||
masterLock.synchronized {
|
||||
if (connected) { master ! Heartbeat(workerId) }
|
||||
}
|
||||
|
||||
case RegisterWorkerFailed(message) =>
|
||||
logError("Worker registration failed: " + message)
|
||||
System.exit(1)
|
||||
case MasterChanged(masterUrl, masterWebUiUrl) =>
|
||||
logInfo("Master has changed, new master is at " + masterUrl)
|
||||
context.unwatch(prevMaster)
|
||||
prevMaster = sender
|
||||
changeMaster(masterUrl, masterWebUiUrl)
|
||||
|
||||
case LaunchExecutor(appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
|
||||
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
|
||||
val manager = new ExecutorRunner(
|
||||
appId, execId, appDesc, cores_, memory_, self, workerId, host, new File(execSparkHome_), workDir)
|
||||
executors(appId + "/" + execId) = manager
|
||||
manager.start()
|
||||
coresUsed += cores_
|
||||
memoryUsed += memory_
|
||||
master ! ExecutorStateChanged(appId, execId, ExecutorState.RUNNING, None, None)
|
||||
val execs = executors.values.
|
||||
map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state))
|
||||
sender ! WorkerSchedulerStateResponse(workerId, execs.toList)
|
||||
|
||||
case RegisterWorkerFailed(message) =>
|
||||
if (!registered) {
|
||||
logError("Worker registration failed: " + message)
|
||||
System.exit(1)
|
||||
}
|
||||
|
||||
case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_, execSparkHome_) =>
|
||||
if (masterUrl != activeMasterUrl) {
|
||||
logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.")
|
||||
} else {
|
||||
logInfo("Asked to launch executor %s/%d for %s".format(appId, execId, appDesc.name))
|
||||
val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
|
||||
self, workerId, host, new File(execSparkHome_), workDir, ExecutorState.RUNNING)
|
||||
executors(appId + "/" + execId) = manager
|
||||
manager.start()
|
||||
coresUsed += cores_
|
||||
memoryUsed += memory_
|
||||
masterLock.synchronized {
|
||||
master ! ExecutorStateChanged(appId, execId, manager.state, None, None)
|
||||
}
|
||||
}
|
||||
|
||||
case ExecutorStateChanged(appId, execId, state, message, exitStatus) =>
|
||||
master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
|
||||
masterLock.synchronized {
|
||||
master ! ExecutorStateChanged(appId, execId, state, message, exitStatus)
|
||||
}
|
||||
val fullId = appId + "/" + execId
|
||||
if (ExecutorState.isFinished(state)) {
|
||||
val executor = executors(fullId)
|
||||
|
@ -165,14 +231,18 @@ private[spark] class Worker(
|
|||
memoryUsed -= executor.memory
|
||||
}
|
||||
|
||||
case KillExecutor(appId, execId) =>
|
||||
val fullId = appId + "/" + execId
|
||||
executors.get(fullId) match {
|
||||
case Some(executor) =>
|
||||
logInfo("Asked to kill executor " + fullId)
|
||||
executor.kill()
|
||||
case None =>
|
||||
logInfo("Asked to kill unknown executor " + fullId)
|
||||
case KillExecutor(masterUrl, appId, execId) =>
|
||||
if (masterUrl != activeMasterUrl) {
|
||||
logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor " + execId)
|
||||
} else {
|
||||
val fullId = appId + "/" + execId
|
||||
executors.get(fullId) match {
|
||||
case Some(executor) =>
|
||||
logInfo("Asked to kill executor " + fullId)
|
||||
executor.kill()
|
||||
case None =>
|
||||
logInfo("Asked to kill unknown executor " + fullId)
|
||||
}
|
||||
}
|
||||
|
||||
case Terminated(actor_) =>
|
||||
|
@ -181,17 +251,14 @@ private[spark] class Worker(
|
|||
|
||||
case RequestWorkerState => {
|
||||
sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
|
||||
finishedExecutors.values.toList, masterUrl, cores, memory,
|
||||
coresUsed, memoryUsed, masterWebUiUrl)
|
||||
finishedExecutors.values.toList, activeMasterUrl, cores, memory,
|
||||
coresUsed, memoryUsed, activeMasterWebUiUrl)
|
||||
}
|
||||
}
|
||||
|
||||
def masterDisconnected() {
|
||||
// TODO: It would be nice to try to reconnect to the master, but just shut down for now.
|
||||
// (Note that if reconnecting we would also need to assign IDs differently.)
|
||||
logError("Connection to master failed! Shutting down.")
|
||||
executors.values.foreach(_.kill())
|
||||
System.exit(1)
|
||||
logError("Connection to master failed! Waiting for master to reconnect...")
|
||||
connected = false
|
||||
}
|
||||
|
||||
def generateWorkerId(): String = {
|
||||
|
@ -209,17 +276,18 @@ private[spark] object Worker {
|
|||
def main(argStrings: Array[String]) {
|
||||
val args = new WorkerArguments(argStrings)
|
||||
val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores,
|
||||
args.memory, args.master, args.workDir)
|
||||
args.memory, args.masters, args.workDir)
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
||||
def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int,
|
||||
masterUrl: String, workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
|
||||
masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None)
|
||||
: (ActorSystem, Int) = {
|
||||
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
|
||||
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
|
||||
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
|
||||
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
|
||||
masterUrl, workDir), name = "Worker")
|
||||
masterUrls, workDir), name = "Worker")
|
||||
(actorSystem, boundPort)
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ private[spark] class WorkerArguments(args: Array[String]) {
|
|||
var webUiPort = 8081
|
||||
var cores = inferDefaultCores()
|
||||
var memory = inferDefaultMemory()
|
||||
var master: String = null
|
||||
var masters: Array[String] = null
|
||||
var workDir: String = null
|
||||
|
||||
// Check for settings in environment variables
|
||||
|
@ -86,14 +86,14 @@ private[spark] class WorkerArguments(args: Array[String]) {
|
|||
printUsageAndExit(0)
|
||||
|
||||
case value :: tail =>
|
||||
if (master != null) { // Two positional arguments were given
|
||||
if (masters != null) { // Two positional arguments were given
|
||||
printUsageAndExit(1)
|
||||
}
|
||||
master = value
|
||||
masters = value.stripPrefix("spark://").split(",").map("spark://" + _)
|
||||
parse(tail)
|
||||
|
||||
case Nil =>
|
||||
if (master == null) { // No positional argument was given
|
||||
if (masters == null) { // No positional argument was given
|
||||
printUsageAndExit(1)
|
||||
}
|
||||
|
||||
|
|
|
@ -108,7 +108,7 @@ class WorkerWebUI(val worker: Worker, val workDir: File, requestedPort: Option[I
|
|||
|
||||
val logText = <node>{Utils.offsetBytes(path, startByte, endByte)}</node>
|
||||
|
||||
val linkToMaster = <p><a href={worker.masterWebUiUrl}>Back to Master</a></p>
|
||||
val linkToMaster = <p><a href={worker.activeMasterWebUiUrl}>Back to Master</a></p>
|
||||
|
||||
val range = <span>Bytes {startByte.toString} - {endByte.toString} of {logLength}</span>
|
||||
|
||||
|
|
|
@ -24,23 +24,15 @@ import akka.remote._
|
|||
|
||||
import org.apache.spark.{Logging, SparkEnv}
|
||||
import org.apache.spark.TaskState.TaskState
|
||||
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages._
|
||||
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
|
||||
import org.apache.spark.util.{Utils, AkkaUtils}
|
||||
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisteredExecutor
|
||||
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.LaunchTask
|
||||
import akka.remote.DisassociatedEvent
|
||||
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutor
|
||||
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutorFailed
|
||||
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisteredExecutor
|
||||
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.LaunchTask
|
||||
import akka.remote.AssociationErrorEvent
|
||||
import akka.remote.DisassociatedEvent
|
||||
import akka.actor.Terminated
|
||||
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutor
|
||||
import org.apache.spark.scheduler.cluster.StandaloneClusterMessages.RegisterExecutorFailed
|
||||
|
||||
|
||||
private[spark] class StandaloneExecutorBackend(
|
||||
private[spark] class CoarseGrainedExecutorBackend(
|
||||
driverUrl: String,
|
||||
executorId: String,
|
||||
hostPort: String,
|
||||
|
@ -75,15 +67,28 @@ private[spark] class StandaloneExecutorBackend(
|
|||
case LaunchTask(taskDesc) =>
|
||||
logInfo("Got assigned task " + taskDesc.taskId)
|
||||
if (executor == null) {
|
||||
logError("Received launchTask but executor was null")
|
||||
logError("Received LaunchTask command but executor was null")
|
||||
System.exit(1)
|
||||
} else {
|
||||
executor.launchTask(this, taskDesc.taskId, taskDesc.serializedTask)
|
||||
}
|
||||
|
||||
case KillTask(taskId, _) =>
|
||||
if (executor == null) {
|
||||
logError("Received KillTask command but executor was null")
|
||||
System.exit(1)
|
||||
} else {
|
||||
executor.killTask(taskId)
|
||||
}
|
||||
|
||||
case Terminated(actor) =>
|
||||
logError("Driver terminated or disconnected! Shutting down.")
|
||||
logError(s"Driver $actor terminated or disconnected! Shutting down.")
|
||||
System.exit(1)
|
||||
|
||||
case StopExecutor =>
|
||||
logInfo("Driver commanded a shutdown")
|
||||
context.stop(self)
|
||||
context.system.shutdown()
|
||||
}
|
||||
|
||||
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {
|
||||
|
@ -91,7 +96,7 @@ private[spark] class StandaloneExecutorBackend(
|
|||
}
|
||||
}
|
||||
|
||||
private[spark] object StandaloneExecutorBackend {
|
||||
private[spark] object CoarseGrainedExecutorBackend {
|
||||
def run(driverUrl: String, executorId: String, hostname: String, cores: Int) {
|
||||
// Debug code
|
||||
Utils.checkHost(hostname)
|
||||
|
@ -103,15 +108,17 @@ private[spark] object StandaloneExecutorBackend {
|
|||
val sparkHostPort = hostname + ":" + boundPort
|
||||
System.setProperty("spark.hostPort", sparkHostPort)
|
||||
actorSystem.actorOf(
|
||||
Props(classOf[StandaloneExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
|
||||
Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
|
||||
name = "Executor")
|
||||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
||||
def main(args: Array[String]) {
|
||||
if (args.length < 4) {
|
||||
//the reason we allow the last frameworkId argument is to make it easy to kill rogue executors
|
||||
System.err.println("Usage: StandaloneExecutorBackend <driverUrl> <executorId> <hostname> <cores> [<appid>]")
|
||||
//the reason we allow the last appid argument is to make it easy to kill rogue executors
|
||||
System.err.println(
|
||||
"Usage: CoarseGrainedExecutorBackend <driverUrl> <executorId> <hostname> <cores> " +
|
||||
"[<appid>]")
|
||||
System.exit(1)
|
||||
}
|
||||
run(args(0), args(1), args(2), args(3).toInt)
|
|
@ -25,9 +25,10 @@ import java.util.concurrent._
|
|||
import scala.collection.JavaConversions._
|
||||
import scala.collection.mutable.HashMap
|
||||
|
||||
import org.apache.spark.scheduler._
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
import org.apache.spark.deploy.SparkHadoopUtil
|
||||
import org.apache.spark.scheduler._
|
||||
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
|
@ -36,7 +37,8 @@ import org.apache.spark.util.Utils
|
|||
private[spark] class Executor(
|
||||
executorId: String,
|
||||
slaveHostname: String,
|
||||
properties: Seq[(String, String)])
|
||||
properties: Seq[(String, String)],
|
||||
isLocal: Boolean = false)
|
||||
extends Logging
|
||||
{
|
||||
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
|
||||
|
@ -73,46 +75,79 @@ private[spark] class Executor(
|
|||
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
|
||||
Thread.currentThread.setContextClassLoader(replClassLoader)
|
||||
|
||||
// Make any thread terminations due to uncaught exceptions kill the entire
|
||||
// executor process to avoid surprising stalls.
|
||||
Thread.setDefaultUncaughtExceptionHandler(
|
||||
new Thread.UncaughtExceptionHandler {
|
||||
override def uncaughtException(thread: Thread, exception: Throwable) {
|
||||
try {
|
||||
logError("Uncaught exception in thread " + thread, exception)
|
||||
if (!isLocal) {
|
||||
// Setup an uncaught exception handler for non-local mode.
|
||||
// Make any thread terminations due to uncaught exceptions kill the entire
|
||||
// executor process to avoid surprising stalls.
|
||||
Thread.setDefaultUncaughtExceptionHandler(
|
||||
new Thread.UncaughtExceptionHandler {
|
||||
override def uncaughtException(thread: Thread, exception: Throwable) {
|
||||
try {
|
||||
logError("Uncaught exception in thread " + thread, exception)
|
||||
|
||||
// We may have been called from a shutdown hook. If so, we must not call System.exit().
|
||||
// (If we do, we will deadlock.)
|
||||
if (!Utils.inShutdown()) {
|
||||
if (exception.isInstanceOf[OutOfMemoryError]) {
|
||||
System.exit(ExecutorExitCode.OOM)
|
||||
} else {
|
||||
System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
|
||||
// We may have been called from a shutdown hook. If so, we must not call System.exit().
|
||||
// (If we do, we will deadlock.)
|
||||
if (!Utils.inShutdown()) {
|
||||
if (exception.isInstanceOf[OutOfMemoryError]) {
|
||||
System.exit(ExecutorExitCode.OOM)
|
||||
} else {
|
||||
System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
|
||||
case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
|
||||
}
|
||||
} catch {
|
||||
case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
|
||||
case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
val executorSource = new ExecutorSource(this, executorId)
|
||||
|
||||
// Initialize Spark environment (using system properties read above)
|
||||
val env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0, false, false)
|
||||
SparkEnv.set(env)
|
||||
env.metricsSystem.registerSource(executorSource)
|
||||
private val env = {
|
||||
if (!isLocal) {
|
||||
val _env = SparkEnv.createFromSystemProperties(executorId, slaveHostname, 0,
|
||||
isDriver = false, isLocal = false)
|
||||
SparkEnv.set(_env)
|
||||
_env.metricsSystem.registerSource(executorSource)
|
||||
_env
|
||||
} else {
|
||||
SparkEnv.get
|
||||
}
|
||||
}
|
||||
|
||||
private val akkaFrameSize = env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
|
||||
// Akka's message frame size. If task result is bigger than this, we use the block manager
|
||||
// to send the result back.
|
||||
private val akkaFrameSize = {
|
||||
env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
|
||||
}
|
||||
|
||||
// Start worker thread pool
|
||||
val threadPool = new ThreadPoolExecutor(
|
||||
1, 128, 600, TimeUnit.SECONDS, new SynchronousQueue[Runnable])
|
||||
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
|
||||
|
||||
// Maintains the list of running tasks.
|
||||
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
|
||||
|
||||
val sparkUser = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
|
||||
|
||||
def launchTask(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer) {
|
||||
threadPool.execute(new TaskRunner(context, taskId, serializedTask))
|
||||
val tr = new TaskRunner(context, taskId, serializedTask)
|
||||
runningTasks.put(taskId, tr)
|
||||
threadPool.execute(tr)
|
||||
}
|
||||
|
||||
def killTask(taskId: Long) {
|
||||
val tr = runningTasks.get(taskId)
|
||||
if (tr != null) {
|
||||
tr.kill()
|
||||
// We remove the task also in the finally block in TaskRunner.run.
|
||||
// The reason we need to remove it here is because killTask might be called before the task
|
||||
// is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
|
||||
// idempotent.
|
||||
runningTasks.remove(taskId)
|
||||
}
|
||||
}
|
||||
|
||||
/** Get the Yarn approved local directories. */
|
||||
|
@ -124,56 +159,87 @@ private[spark] class Executor(
|
|||
.getOrElse(Option(System.getenv("LOCAL_DIRS"))
|
||||
.getOrElse(""))
|
||||
|
||||
if (localDirs.isEmpty()) {
|
||||
if (localDirs.isEmpty) {
|
||||
throw new Exception("Yarn Local dirs can't be empty")
|
||||
}
|
||||
return localDirs
|
||||
localDirs
|
||||
}
|
||||
|
||||
class TaskRunner(context: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
|
||||
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
|
||||
extends Runnable {
|
||||
|
||||
override def run() {
|
||||
@volatile private var killed = false
|
||||
@volatile private var task: Task[Any] = _
|
||||
|
||||
def kill() {
|
||||
logInfo("Executor is trying to kill task " + taskId)
|
||||
killed = true
|
||||
if (task != null) {
|
||||
task.kill()
|
||||
}
|
||||
}
|
||||
|
||||
override def run(): Unit = SparkHadoopUtil.get.runAsUser(sparkUser) { () =>
|
||||
val startTime = System.currentTimeMillis()
|
||||
SparkEnv.set(env)
|
||||
Thread.currentThread.setContextClassLoader(replClassLoader)
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
logInfo("Running task ID " + taskId)
|
||||
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
|
||||
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
|
||||
var attemptedTask: Option[Task[Any]] = None
|
||||
var taskStart: Long = 0
|
||||
def getTotalGCTime = ManagementFactory.getGarbageCollectorMXBeans.map(g => g.getCollectionTime).sum
|
||||
val startGCTime = getTotalGCTime
|
||||
def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
|
||||
val startGCTime = gcTime
|
||||
|
||||
try {
|
||||
SparkEnv.set(env)
|
||||
Accumulators.clear()
|
||||
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
|
||||
updateDependencies(taskFiles, taskJars)
|
||||
val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
|
||||
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
|
||||
|
||||
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
|
||||
// continue executing the task.
|
||||
if (killed) {
|
||||
logInfo("Executor killed task " + taskId)
|
||||
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
|
||||
return
|
||||
}
|
||||
|
||||
attemptedTask = Some(task)
|
||||
logInfo("Its epoch is " + task.epoch)
|
||||
logDebug("Task " + taskId +"'s epoch is " + task.epoch)
|
||||
env.mapOutputTracker.updateEpoch(task.epoch)
|
||||
|
||||
// Run the actual task and measure its runtime.
|
||||
taskStart = System.currentTimeMillis()
|
||||
val value = task.run(taskId.toInt)
|
||||
val taskFinish = System.currentTimeMillis()
|
||||
|
||||
// If the task has been killed, let's fail it.
|
||||
if (task.killed) {
|
||||
logInfo("Executor killed task " + taskId)
|
||||
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
|
||||
return
|
||||
}
|
||||
|
||||
for (m <- task.metrics) {
|
||||
m.hostname = Utils.localHostName
|
||||
m.hostname = Utils.localHostName()
|
||||
m.executorDeserializeTime = (taskStart - startTime).toInt
|
||||
m.executorRunTime = (taskFinish - taskStart).toInt
|
||||
m.jvmGCTime = getTotalGCTime - startGCTime
|
||||
m.jvmGCTime = gcTime - startGCTime
|
||||
}
|
||||
//TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
|
||||
// we need to serialize the task metrics first. If TaskMetrics had a custom serialized format, we could
|
||||
// just change the relevants bytes in the byte buffer
|
||||
// TODO I'd also like to track the time it takes to serialize the task results, but that is
|
||||
// huge headache, b/c we need to serialize the task metrics first. If TaskMetrics had a
|
||||
// custom serialized format, we could just change the relevants bytes in the byte buffer
|
||||
val accumUpdates = Accumulators.values
|
||||
|
||||
val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
|
||||
val serializedDirectResult = ser.serialize(directResult)
|
||||
logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
|
||||
val serializedResult = {
|
||||
if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
|
||||
logInfo("Storing result for " + taskId + " in local BlockManager")
|
||||
val blockId = "taskresult_" + taskId
|
||||
val blockId = TaskResultBlockId(taskId)
|
||||
env.blockManager.putBytes(
|
||||
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
|
||||
ser.serialize(new IndirectTaskResult[Any](blockId))
|
||||
|
@ -182,12 +248,13 @@ private[spark] class Executor(
|
|||
serializedDirectResult
|
||||
}
|
||||
}
|
||||
context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
|
||||
|
||||
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
|
||||
logInfo("Finished task ID " + taskId)
|
||||
} catch {
|
||||
case ffe: FetchFailedException => {
|
||||
val reason = ffe.toTaskEndReason
|
||||
context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
|
||||
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
|
||||
}
|
||||
|
||||
case t: Throwable => {
|
||||
|
@ -195,10 +262,10 @@ private[spark] class Executor(
|
|||
val metrics = attemptedTask.flatMap(t => t.metrics)
|
||||
for (m <- metrics) {
|
||||
m.executorRunTime = serviceTime
|
||||
m.jvmGCTime = getTotalGCTime - startGCTime
|
||||
m.jvmGCTime = gcTime - startGCTime
|
||||
}
|
||||
val reason = ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace, metrics)
|
||||
context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
|
||||
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
|
||||
|
||||
// TODO: Should we exit the whole executor here? On the one hand, the failed task may
|
||||
// have left some weird state around depending on when the exception was thrown, but on
|
||||
|
@ -206,6 +273,8 @@ private[spark] class Executor(
|
|||
logError("Exception in task ID " + taskId, t)
|
||||
//System.exit(1)
|
||||
}
|
||||
} finally {
|
||||
runningTasks.remove(taskId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -215,7 +284,7 @@ private[spark] class Executor(
|
|||
* created by the interpreter to the search path
|
||||
*/
|
||||
private def createClassLoader(): ExecutorURLClassLoader = {
|
||||
var loader = this.getClass.getClassLoader
|
||||
val loader = this.getClass.getClassLoader
|
||||
|
||||
// For each of the jars in the jarSet, add them to the class loader.
|
||||
// We assume each of the files has already been fetched.
|
||||
|
@ -237,7 +306,7 @@ private[spark] class Executor(
|
|||
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
|
||||
.asInstanceOf[Class[_ <: ClassLoader]]
|
||||
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
|
||||
return constructor.newInstance(classUri, parent)
|
||||
constructor.newInstance(classUri, parent)
|
||||
} catch {
|
||||
case _: ClassNotFoundException =>
|
||||
logError("Could not find org.apache.spark.repl.ExecutorClassLoader on classpath!")
|
||||
|
@ -245,7 +314,7 @@ private[spark] class Executor(
|
|||
null
|
||||
}
|
||||
} else {
|
||||
return parent
|
||||
parent
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,14 +18,18 @@
|
|||
package org.apache.spark.executor
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
|
||||
import org.apache.mesos.Protos.{TaskState => MesosTaskState, TaskStatus => MesosTaskStatus, _}
|
||||
import org.apache.spark.TaskState.TaskState
|
||||
|
||||
import com.google.protobuf.ByteString
|
||||
import org.apache.spark.{Logging}
|
||||
|
||||
import org.apache.mesos.{Executor => MesosExecutor, MesosExecutorDriver, MesosNativeLibrary, ExecutorDriver}
|
||||
import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.TaskState
|
||||
import org.apache.spark.TaskState.TaskState
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
|
||||
private[spark] class MesosExecutorBackend
|
||||
extends MesosExecutor
|
||||
with ExecutorBackend
|
||||
|
@ -71,7 +75,11 @@ private[spark] class MesosExecutorBackend
|
|||
}
|
||||
|
||||
override def killTask(d: ExecutorDriver, t: TaskID) {
|
||||
logWarning("Mesos asked us to kill task " + t.getValue + "; ignoring (not yet implemented)")
|
||||
if (executor == null) {
|
||||
logError("Received KillTask but executor was null")
|
||||
} else {
|
||||
executor.killTask(t.getValue.toLong)
|
||||
}
|
||||
}
|
||||
|
||||
override def reregistered(d: ExecutorDriver, p2: SlaveInfo) {}
|
||||
|
|
|
@ -102,4 +102,9 @@ class ShuffleWriteMetrics extends Serializable {
|
|||
* Number of bytes written for a shuffle
|
||||
*/
|
||||
var shuffleBytesWritten: Long = _
|
||||
|
||||
/**
|
||||
* Time spent blocking on writes to disk or buffer cache, in nanoseconds.
|
||||
*/
|
||||
var shuffleWriteTime: Long = _
|
||||
}
|
||||
|
|
|
@ -79,7 +79,8 @@ private[spark] class ConnectionManager(port: Int) extends Logging {
|
|||
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
|
||||
private val registerRequests = new SynchronizedQueue[SendingConnection]
|
||||
|
||||
implicit val futureExecContext = ExecutionContext.fromExecutor(Utils.newDaemonCachedThreadPool())
|
||||
implicit val futureExecContext = ExecutionContext.fromExecutor(
|
||||
Utils.newDaemonCachedThreadPool("Connection manager future execution context"))
|
||||
|
||||
private var onReceiveCallback: (BufferMessage, ConnectionManagerId) => Option[Message]= null
|
||||
|
||||
|
|
|
@ -20,17 +20,18 @@ package org.apache.spark.network.netty
|
|||
import io.netty.buffer._
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.storage.{TestBlockId, BlockId}
|
||||
|
||||
private[spark] class FileHeader (
|
||||
val fileLen: Int,
|
||||
val blockId: String) extends Logging {
|
||||
val blockId: BlockId) extends Logging {
|
||||
|
||||
lazy val buffer = {
|
||||
val buf = Unpooled.buffer()
|
||||
buf.capacity(FileHeader.HEADER_SIZE)
|
||||
buf.writeInt(fileLen)
|
||||
buf.writeInt(blockId.length)
|
||||
blockId.foreach((x: Char) => buf.writeByte(x))
|
||||
buf.writeInt(blockId.name.length)
|
||||
blockId.name.foreach((x: Char) => buf.writeByte(x))
|
||||
//padding the rest of header
|
||||
if (FileHeader.HEADER_SIZE - buf.readableBytes > 0 ) {
|
||||
buf.writeZero(FileHeader.HEADER_SIZE - buf.readableBytes)
|
||||
|
@ -57,18 +58,15 @@ private[spark] object FileHeader {
|
|||
for (i <- 1 to idLength) {
|
||||
idBuilder += buf.readByte().asInstanceOf[Char]
|
||||
}
|
||||
val blockId = idBuilder.toString()
|
||||
val blockId = BlockId(idBuilder.toString())
|
||||
new FileHeader(length, blockId)
|
||||
}
|
||||
|
||||
|
||||
def main (args:Array[String]){
|
||||
|
||||
val header = new FileHeader(25,"block_0");
|
||||
val buf = header.buffer;
|
||||
val newheader = FileHeader.create(buf);
|
||||
System.out.println("id="+newheader.blockId+",size="+newheader.fileLen)
|
||||
|
||||
def main (args:Array[String]) {
|
||||
val header = new FileHeader(25, TestBlockId("my_block"))
|
||||
val buf = header.buffer
|
||||
val newHeader = FileHeader.create(buf)
|
||||
System.out.println("id=" + newHeader.blockId + ",size=" + newHeader.fileLen)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,12 +27,13 @@ import org.apache.spark.Logging
|
|||
import org.apache.spark.network.ConnectionManagerId
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
import org.apache.spark.storage.BlockId
|
||||
|
||||
|
||||
private[spark] class ShuffleCopier extends Logging {
|
||||
|
||||
def getBlock(host: String, port: Int, blockId: String,
|
||||
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
|
||||
def getBlock(host: String, port: Int, blockId: BlockId,
|
||||
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
|
||||
|
||||
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
|
||||
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
|
||||
|
@ -41,7 +42,7 @@ private[spark] class ShuffleCopier extends Logging {
|
|||
try {
|
||||
fc.init()
|
||||
fc.connect(host, port)
|
||||
fc.sendRequest(blockId)
|
||||
fc.sendRequest(blockId.name)
|
||||
fc.waitForClose()
|
||||
fc.close()
|
||||
} catch {
|
||||
|
@ -53,14 +54,14 @@ private[spark] class ShuffleCopier extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
def getBlock(cmId: ConnectionManagerId, blockId: String,
|
||||
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
|
||||
def getBlock(cmId: ConnectionManagerId, blockId: BlockId,
|
||||
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
|
||||
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
|
||||
}
|
||||
|
||||
def getBlocks(cmId: ConnectionManagerId,
|
||||
blocks: Seq[(String, Long)],
|
||||
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
|
||||
blocks: Seq[(BlockId, Long)],
|
||||
resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) {
|
||||
|
||||
for ((blockId, size) <- blocks) {
|
||||
getBlock(cmId, blockId, resultCollectCallback)
|
||||
|
@ -71,7 +72,7 @@ private[spark] class ShuffleCopier extends Logging {
|
|||
|
||||
private[spark] object ShuffleCopier extends Logging {
|
||||
|
||||
private class ShuffleClientHandler(resultCollectCallBack: (String, Long, ByteBuf) => Unit)
|
||||
private class ShuffleClientHandler(resultCollectCallBack: (BlockId, Long, ByteBuf) => Unit)
|
||||
extends FileClientHandler with Logging {
|
||||
|
||||
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
|
||||
|
@ -79,14 +80,14 @@ private[spark] object ShuffleCopier extends Logging {
|
|||
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
|
||||
}
|
||||
|
||||
override def handleError(blockId: String) {
|
||||
override def handleError(blockId: BlockId) {
|
||||
if (!isComplete) {
|
||||
resultCollectCallBack(blockId, -1, null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
|
||||
def echoResultCollectCallBack(blockId: BlockId, size: Long, content: ByteBuf) {
|
||||
if (size != -1) {
|
||||
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
|
||||
}
|
||||
|
@ -99,7 +100,7 @@ private[spark] object ShuffleCopier extends Logging {
|
|||
}
|
||||
val host = args(0)
|
||||
val port = args(1).toInt
|
||||
val file = args(2)
|
||||
val blockId = BlockId(args(2))
|
||||
val threads = if (args.length > 3) args(3).toInt else 10
|
||||
|
||||
val copiers = Executors.newFixedThreadPool(80)
|
||||
|
@ -107,12 +108,12 @@ private[spark] object ShuffleCopier extends Logging {
|
|||
Executors.callable(new Runnable() {
|
||||
def run() {
|
||||
val copier = new ShuffleCopier()
|
||||
copier.getBlock(host, port, file, echoResultCollectCallBack)
|
||||
copier.getBlock(host, port, blockId, echoResultCollectCallBack)
|
||||
}
|
||||
})
|
||||
}).asJava
|
||||
copiers.invokeAll(tasks)
|
||||
copiers.shutdown
|
||||
copiers.shutdown()
|
||||
System.exit(0)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ import java.io.File
|
|||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.util.Utils
|
||||
import org.apache.spark.storage.{BlockId, FileSegment}
|
||||
|
||||
|
||||
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
|
||||
|
@ -53,8 +54,8 @@ private[spark] object ShuffleSender {
|
|||
val localDirs = args.drop(2).map(new File(_))
|
||||
|
||||
val pResovler = new PathResolver {
|
||||
override def getAbsolutePath(blockId: String): String = {
|
||||
if (!blockId.startsWith("shuffle_")) {
|
||||
override def getBlockLocation(blockId: BlockId): FileSegment = {
|
||||
if (!blockId.isShuffle) {
|
||||
throw new Exception("Block " + blockId + " is not a shuffle block")
|
||||
}
|
||||
// Figure out which local directory it hashes to, and which subdirectory in that
|
||||
|
@ -62,8 +63,8 @@ private[spark] object ShuffleSender {
|
|||
val dirId = hash % localDirs.length
|
||||
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
|
||||
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
|
||||
val file = new File(subDir, blockId)
|
||||
return file.getAbsolutePath
|
||||
val file = new File(subDir, blockId.name)
|
||||
return new FileSegment(file, 0, file.length())
|
||||
}
|
||||
}
|
||||
val sender = new ShuffleSender(port, pResovler)
|
||||
|
|
|
@ -15,6 +15,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache
|
||||
|
||||
/**
|
||||
* Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to
|
||||
* Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection,
|
||||
|
|
123
core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Normal file
123
core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
Normal file
|
@ -0,0 +1,123 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.rdd
|
||||
|
||||
import java.util.concurrent.atomic.AtomicLong
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
|
||||
|
||||
/**
|
||||
* A set of asynchronous RDD actions available through an implicit conversion.
|
||||
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
|
||||
*/
|
||||
class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
|
||||
|
||||
/**
|
||||
* Returns a future for counting the number of elements in the RDD.
|
||||
*/
|
||||
def countAsync(): FutureAction[Long] = {
|
||||
val totalCount = new AtomicLong
|
||||
self.context.submitJob(
|
||||
self,
|
||||
(iter: Iterator[T]) => {
|
||||
var result = 0L
|
||||
while (iter.hasNext) {
|
||||
result += 1L
|
||||
iter.next()
|
||||
}
|
||||
result
|
||||
},
|
||||
Range(0, self.partitions.size),
|
||||
(index: Int, data: Long) => totalCount.addAndGet(data),
|
||||
totalCount.get())
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a future for retrieving all elements of this RDD.
|
||||
*/
|
||||
def collectAsync(): FutureAction[Seq[T]] = {
|
||||
val results = new Array[Array[T]](self.partitions.size)
|
||||
self.context.submitJob[T, Array[T], Seq[T]](self, _.toArray, Range(0, self.partitions.size),
|
||||
(index, data) => results(index) = data, results.flatten.toSeq)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a future for retrieving the first num elements of the RDD.
|
||||
*/
|
||||
def takeAsync(num: Int): FutureAction[Seq[T]] = {
|
||||
val f = new ComplexFutureAction[Seq[T]]
|
||||
|
||||
f.run {
|
||||
val results = new ArrayBuffer[T](num)
|
||||
val totalParts = self.partitions.length
|
||||
var partsScanned = 0
|
||||
while (results.size < num && partsScanned < totalParts) {
|
||||
// The number of partitions to try in this iteration. It is ok for this number to be
|
||||
// greater than totalParts because we actually cap it at totalParts in runJob.
|
||||
var numPartsToTry = 1
|
||||
if (partsScanned > 0) {
|
||||
// If we didn't find any rows after the first iteration, just try all partitions next.
|
||||
// Otherwise, interpolate the number of partitions we need to try, but overestimate it
|
||||
// by 50%.
|
||||
if (results.size == 0) {
|
||||
numPartsToTry = totalParts - 1
|
||||
} else {
|
||||
numPartsToTry = (1.5 * num * partsScanned / results.size).toInt
|
||||
}
|
||||
}
|
||||
numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions
|
||||
|
||||
val left = num - results.size
|
||||
val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
|
||||
|
||||
val buf = new Array[Array[T]](p.size)
|
||||
f.runJob(self,
|
||||
(it: Iterator[T]) => it.take(left).toArray,
|
||||
p,
|
||||
(index: Int, data: Array[T]) => buf(index) = data,
|
||||
Unit)
|
||||
|
||||
buf.foreach(results ++= _.take(num - results.size))
|
||||
partsScanned += numPartsToTry
|
||||
}
|
||||
results.toSeq
|
||||
}
|
||||
|
||||
f
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a function f to all elements of this RDD.
|
||||
*/
|
||||
def foreachAsync(f: T => Unit): FutureAction[Unit] = {
|
||||
self.context.submitJob[T, Unit, Unit](self, _.foreach(f), Range(0, self.partitions.size),
|
||||
(index, data) => Unit, Unit)
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a function f to each partition of this RDD.
|
||||
*/
|
||||
def foreachPartitionAsync(f: Iterator[T] => Unit): FutureAction[Unit] = {
|
||||
self.context.submitJob[T, Unit, Unit](self, f, Range(0, self.partitions.size),
|
||||
(index, data) => Unit, Unit)
|
||||
}
|
||||
}
|
|
@ -17,16 +17,17 @@
|
|||
|
||||
package org.apache.spark.rdd
|
||||
|
||||
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
|
||||
import org.apache.spark.storage.BlockManager
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
private[spark] class BlockRDDPartition(val blockId: String, idx: Int) extends Partition {
|
||||
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
|
||||
import org.apache.spark.storage.{BlockId, BlockManager}
|
||||
|
||||
private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition {
|
||||
val index = idx
|
||||
}
|
||||
|
||||
private[spark]
|
||||
class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[String])
|
||||
class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId])
|
||||
extends RDD[T](sc, Nil) {
|
||||
|
||||
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
|
||||
|
|
|
@ -19,6 +19,7 @@ package org.apache.spark.rdd
|
|||
|
||||
import scala.reflect.ClassTag
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.deploy.SparkHadoopUtil
|
||||
import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.io.{NullWritable, BytesWritable}
|
||||
|
@ -84,9 +85,9 @@ private[spark] object CheckpointRDD extends Logging {
|
|||
def writeToFile[T](path: String, blockSize: Int = -1)(ctx: TaskContext, iterator: Iterator[T]) {
|
||||
val env = SparkEnv.get
|
||||
val outputDir = new Path(path)
|
||||
val fs = outputDir.getFileSystem(env.hadoop.newConfiguration())
|
||||
val fs = outputDir.getFileSystem(SparkHadoopUtil.get.newConfiguration())
|
||||
|
||||
val finalOutputName = splitIdToFile(ctx.splitId)
|
||||
val finalOutputName = splitIdToFile(ctx.partitionId)
|
||||
val finalOutputPath = new Path(outputDir, finalOutputName)
|
||||
val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
|
||||
|
||||
|
@ -123,7 +124,7 @@ private[spark] object CheckpointRDD extends Logging {
|
|||
|
||||
def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = {
|
||||
val env = SparkEnv.get
|
||||
val fs = path.getFileSystem(env.hadoop.newConfiguration())
|
||||
val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
|
||||
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
|
||||
val fileInputStream = fs.open(path, bufferSize)
|
||||
val serializer = env.serializer.newInstance()
|
||||
|
@ -146,7 +147,7 @@ private[spark] object CheckpointRDD extends Logging {
|
|||
val sc = new SparkContext(cluster, "CheckpointRDD Test")
|
||||
val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
|
||||
val path = new Path(hdfsPath, "temp")
|
||||
val fs = path.getFileSystem(env.hadoop.newConfiguration())
|
||||
val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration())
|
||||
sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 1024) _)
|
||||
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
|
||||
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
|
||||
|
|
|
@ -18,13 +18,12 @@
|
|||
package org.apache.spark.rdd
|
||||
|
||||
import java.io.{ObjectOutputStream, IOException}
|
||||
import java.util.{HashMap => JHashMap}
|
||||
|
||||
import scala.collection.JavaConversions
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
|
||||
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
|
||||
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
|
||||
import org.apache.spark.util.AppendOnlyMap
|
||||
|
||||
|
||||
private[spark] sealed trait CoGroupSplitDep extends Serializable
|
||||
|
@ -105,17 +104,14 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
|
|||
val split = s.asInstanceOf[CoGroupPartition]
|
||||
val numRdds = split.deps.size
|
||||
// e.g. for `(k, a) cogroup (k, b)`, K -> Seq(ArrayBuffer as, ArrayBuffer bs)
|
||||
val map = new JHashMap[K, Seq[ArrayBuffer[Any]]]
|
||||
val map = new AppendOnlyMap[K, Seq[ArrayBuffer[Any]]]
|
||||
|
||||
def getSeq(k: K): Seq[ArrayBuffer[Any]] = {
|
||||
val seq = map.get(k)
|
||||
if (seq != null) {
|
||||
seq
|
||||
} else {
|
||||
val seq = Array.fill(numRdds)(new ArrayBuffer[Any])
|
||||
map.put(k, seq)
|
||||
seq
|
||||
}
|
||||
val update: (Boolean, Seq[ArrayBuffer[Any]]) => Seq[ArrayBuffer[Any]] = (hadVal, oldVal) => {
|
||||
if (hadVal) oldVal else Array.fill(numRdds)(new ArrayBuffer[Any])
|
||||
}
|
||||
|
||||
val getSeq = (k: K) => {
|
||||
map.changeValue(k, update)
|
||||
}
|
||||
|
||||
val ser = SparkEnv.get.serializerManager.get(serializerClass)
|
||||
|
@ -129,12 +125,12 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
|
|||
case ShuffleCoGroupSplitDep(shuffleId) => {
|
||||
// Read map outputs of shuffle
|
||||
val fetcher = SparkEnv.get.shuffleFetcher
|
||||
fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
|
||||
fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
|
||||
kv => getSeq(kv._1)(depNum) += kv._2
|
||||
}
|
||||
}
|
||||
}
|
||||
JavaConversions.mapAsScalaMap(map).iterator
|
||||
new InterruptibleIterator(context, map.iterator)
|
||||
}
|
||||
|
||||
override def clearDependencies() {
|
||||
|
|
|
@ -27,47 +27,12 @@ import org.apache.hadoop.mapred.RecordReader
|
|||
import org.apache.hadoop.mapred.Reporter
|
||||
import org.apache.hadoop.util.ReflectionUtils
|
||||
|
||||
import org.apache.spark.{Logging, Partition, SerializableWritable, SparkContext, SparkEnv,
|
||||
TaskContext}
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.broadcast.Broadcast
|
||||
import org.apache.spark.deploy.SparkHadoopUtil
|
||||
import org.apache.spark.util.NextIterator
|
||||
import org.apache.hadoop.conf.{Configuration, Configurable}
|
||||
|
||||
/**
|
||||
* An RDD that reads a file (or multiple files) from Hadoop (e.g. files in HDFS, the local file
|
||||
* system, or S3).
|
||||
* This accepts a general, broadcasted Hadoop Configuration because those tend to remain the same
|
||||
* across multiple reads; the 'path' is the only variable that is different across new JobConfs
|
||||
* created from the Configuration.
|
||||
*/
|
||||
class HadoopFileRDD[K, V](
|
||||
sc: SparkContext,
|
||||
path: String,
|
||||
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
|
||||
inputFormatClass: Class[_ <: InputFormat[K, V]],
|
||||
keyClass: Class[K],
|
||||
valueClass: Class[V],
|
||||
minSplits: Int)
|
||||
extends HadoopRDD[K, V](sc, broadcastedConf, inputFormatClass, keyClass, valueClass, minSplits) {
|
||||
|
||||
override def getJobConf(): JobConf = {
|
||||
if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) {
|
||||
// getJobConf() has been called previously, so there is already a local cache of the JobConf
|
||||
// needed by this RDD.
|
||||
return HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf]
|
||||
} else {
|
||||
// Create a new JobConf, set the input file/directory paths to read from, and cache the
|
||||
// JobConf (i.e., in a shared hash map in the slave's JVM process that's accessible through
|
||||
// HadoopRDD.putCachedMetadata()), so that we only create one copy across multiple
|
||||
// getJobConf() calls for this RDD in the local process.
|
||||
// The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
|
||||
val newJobConf = new JobConf(broadcastedConf.value.value)
|
||||
FileInputFormat.setInputPaths(newJobConf, path)
|
||||
HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
|
||||
return newJobConf
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A Spark split class that wraps around a Hadoop InputSplit.
|
||||
|
@ -83,11 +48,24 @@ private[spark] class HadoopPartition(rddId: Int, idx: Int, @transient s: InputSp
|
|||
}
|
||||
|
||||
/**
|
||||
* A base class that provides core functionality for reading data partitions stored in Hadoop.
|
||||
* An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS,
|
||||
* sources in HBase, or S3).
|
||||
*
|
||||
* @param sc The SparkContext to associate the RDD with.
|
||||
* @param broadcastedConf A general Hadoop Configuration, or a subclass of it. If the enclosed
|
||||
* variabe references an instance of JobConf, then that JobConf will be used for the Hadoop job.
|
||||
* Otherwise, a new JobConf will be created on each slave using the enclosed Configuration.
|
||||
* @param initLocalJobConfFuncOpt Optional closure used to initialize any JobConf that HadoopRDD
|
||||
* creates.
|
||||
* @param inputFormatClass Storage format of the data to be read.
|
||||
* @param keyClass Class of the key associated with the inputFormatClass.
|
||||
* @param valueClass Class of the value associated with the inputFormatClass.
|
||||
* @param minSplits Minimum number of Hadoop Splits (HadoopRDD partitions) to generate.
|
||||
*/
|
||||
class HadoopRDD[K, V](
|
||||
sc: SparkContext,
|
||||
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
|
||||
initLocalJobConfFuncOpt: Option[JobConf => Unit],
|
||||
inputFormatClass: Class[_ <: InputFormat[K, V]],
|
||||
keyClass: Class[K],
|
||||
valueClass: Class[V],
|
||||
|
@ -105,6 +83,7 @@ class HadoopRDD[K, V](
|
|||
sc,
|
||||
sc.broadcast(new SerializableWritable(conf))
|
||||
.asInstanceOf[Broadcast[SerializableWritable[Configuration]]],
|
||||
None /* initLocalJobConfFuncOpt */,
|
||||
inputFormatClass,
|
||||
keyClass,
|
||||
valueClass,
|
||||
|
@ -130,6 +109,7 @@ class HadoopRDD[K, V](
|
|||
// local process. The local cache is accessed through HadoopRDD.putCachedMetadata().
|
||||
// The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects.
|
||||
val newJobConf = new JobConf(broadcastedConf.value.value)
|
||||
initLocalJobConfFuncOpt.map(f => f(newJobConf))
|
||||
HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf)
|
||||
return newJobConf
|
||||
}
|
||||
|
@ -152,6 +132,8 @@ class HadoopRDD[K, V](
|
|||
|
||||
override def getPartitions: Array[Partition] = {
|
||||
val jobConf = getJobConf()
|
||||
// add the credentials here as this can be called before SparkContext initialized
|
||||
SparkHadoopUtil.get.addCredentials(jobConf)
|
||||
val inputFormat = getInputFormat(jobConf)
|
||||
if (inputFormat.isInstanceOf[Configurable]) {
|
||||
inputFormat.asInstanceOf[Configurable].setConf(jobConf)
|
||||
|
@ -164,38 +146,41 @@ class HadoopRDD[K, V](
|
|||
array
|
||||
}
|
||||
|
||||
override def compute(theSplit: Partition, context: TaskContext) = new NextIterator[(K, V)] {
|
||||
val split = theSplit.asInstanceOf[HadoopPartition]
|
||||
logInfo("Input split: " + split.inputSplit)
|
||||
var reader: RecordReader[K, V] = null
|
||||
override def compute(theSplit: Partition, context: TaskContext) = {
|
||||
val iter = new NextIterator[(K, V)] {
|
||||
val split = theSplit.asInstanceOf[HadoopPartition]
|
||||
logInfo("Input split: " + split.inputSplit)
|
||||
var reader: RecordReader[K, V] = null
|
||||
|
||||
val jobConf = getJobConf()
|
||||
val inputFormat = getInputFormat(jobConf)
|
||||
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
|
||||
val jobConf = getJobConf()
|
||||
val inputFormat = getInputFormat(jobConf)
|
||||
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
|
||||
|
||||
// Register an on-task-completion callback to close the input stream.
|
||||
context.addOnCompleteCallback{ () => closeIfNeeded() }
|
||||
// Register an on-task-completion callback to close the input stream.
|
||||
context.addOnCompleteCallback{ () => closeIfNeeded() }
|
||||
|
||||
val key: K = reader.createKey()
|
||||
val value: V = reader.createValue()
|
||||
val key: K = reader.createKey()
|
||||
val value: V = reader.createValue()
|
||||
|
||||
override def getNext() = {
|
||||
try {
|
||||
finished = !reader.next(key, value)
|
||||
} catch {
|
||||
case eof: EOFException =>
|
||||
finished = true
|
||||
override def getNext() = {
|
||||
try {
|
||||
finished = !reader.next(key, value)
|
||||
} catch {
|
||||
case eof: EOFException =>
|
||||
finished = true
|
||||
}
|
||||
(key, value)
|
||||
}
|
||||
(key, value)
|
||||
}
|
||||
|
||||
override def close() {
|
||||
try {
|
||||
reader.close()
|
||||
} catch {
|
||||
case e: Exception => logWarning("Exception in RecordReader.close()", e)
|
||||
override def close() {
|
||||
try {
|
||||
reader.close()
|
||||
} catch {
|
||||
case e: Exception => logWarning("Exception in RecordReader.close()", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
new InterruptibleIterator[(K, V)](context, iter)
|
||||
}
|
||||
|
||||
override def getPreferredLocations(split: Partition): Seq[String] = {
|
||||
|
@ -216,10 +201,10 @@ private[spark] object HadoopRDD {
|
|||
* The three methods below are helpers for accessing the local map, a property of the SparkEnv of
|
||||
* the local process.
|
||||
*/
|
||||
def getCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.get(key)
|
||||
def getCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.get(key)
|
||||
|
||||
def containsCachedMetadata(key: String) = SparkEnv.get.hadoop.hadoopJobMetadata.containsKey(key)
|
||||
def containsCachedMetadata(key: String) = SparkEnv.get.hadoopJobMetadata.containsKey(key)
|
||||
|
||||
def putCachedMetadata(key: String, value: Any) =
|
||||
SparkEnv.get.hadoop.hadoopJobMetadata.put(key, value)
|
||||
SparkEnv.get.hadoopJobMetadata.put(key, value)
|
||||
}
|
||||
|
|
|
@ -22,14 +22,14 @@ import scala.reflect.ClassTag
|
|||
|
||||
|
||||
/**
|
||||
* A variant of the MapPartitionsRDD that passes the partition index into the
|
||||
* closure. This can be used to generate or collect partition specific
|
||||
* information such as the number of tuples in a partition.
|
||||
* A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
|
||||
* TaskContext, the closure can either get access to the interruptible flag or get the index
|
||||
* of the partition in the RDD.
|
||||
*/
|
||||
private[spark]
|
||||
class MapPartitionsWithIndexRDD[U: ClassTag, T: ClassTag](
|
||||
class MapPartitionsWithContextRDD[U: ClassTag, T: ClassTag](
|
||||
prev: RDD[T],
|
||||
f: (Int, Iterator[T]) => Iterator[U],
|
||||
f: (TaskContext, Iterator[T]) => Iterator[U],
|
||||
preservesPartitioning: Boolean
|
||||
) extends RDD[U](prev) {
|
||||
|
||||
|
@ -38,5 +38,5 @@ class MapPartitionsWithIndexRDD[U: ClassTag, T: ClassTag](
|
|||
override val partitioner = if (preservesPartitioning) prev.partitioner else None
|
||||
|
||||
override def compute(split: Partition, context: TaskContext) =
|
||||
f(split.index, firstParent[T].iterator(split, context))
|
||||
f(context, firstParent[T].iterator(split, context))
|
||||
}
|
|
@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
|
|||
import org.apache.hadoop.io.Writable
|
||||
import org.apache.hadoop.mapreduce._
|
||||
|
||||
import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
|
||||
import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
|
||||
|
||||
|
||||
private[spark]
|
||||
|
@ -71,49 +71,52 @@ class NewHadoopRDD[K, V](
|
|||
result
|
||||
}
|
||||
|
||||
override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
|
||||
val split = theSplit.asInstanceOf[NewHadoopPartition]
|
||||
logInfo("Input split: " + split.serializableHadoopSplit)
|
||||
val conf = confBroadcast.value.value
|
||||
val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
|
||||
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
|
||||
val format = inputFormatClass.newInstance
|
||||
if (format.isInstanceOf[Configurable]) {
|
||||
format.asInstanceOf[Configurable].setConf(conf)
|
||||
}
|
||||
val reader = format.createRecordReader(
|
||||
split.serializableHadoopSplit.value, hadoopAttemptContext)
|
||||
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
|
||||
|
||||
// Register an on-task-completion callback to close the input stream.
|
||||
context.addOnCompleteCallback(() => close())
|
||||
|
||||
var havePair = false
|
||||
var finished = false
|
||||
|
||||
override def hasNext: Boolean = {
|
||||
if (!finished && !havePair) {
|
||||
finished = !reader.nextKeyValue
|
||||
havePair = !finished
|
||||
override def compute(theSplit: Partition, context: TaskContext) = {
|
||||
val iter = new Iterator[(K, V)] {
|
||||
val split = theSplit.asInstanceOf[NewHadoopPartition]
|
||||
logInfo("Input split: " + split.serializableHadoopSplit)
|
||||
val conf = confBroadcast.value.value
|
||||
val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
|
||||
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
|
||||
val format = inputFormatClass.newInstance
|
||||
if (format.isInstanceOf[Configurable]) {
|
||||
format.asInstanceOf[Configurable].setConf(conf)
|
||||
}
|
||||
!finished
|
||||
}
|
||||
val reader = format.createRecordReader(
|
||||
split.serializableHadoopSplit.value, hadoopAttemptContext)
|
||||
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
|
||||
|
||||
override def next: (K, V) = {
|
||||
if (!hasNext) {
|
||||
throw new java.util.NoSuchElementException("End of stream")
|
||||
// Register an on-task-completion callback to close the input stream.
|
||||
context.addOnCompleteCallback(() => close())
|
||||
|
||||
var havePair = false
|
||||
var finished = false
|
||||
|
||||
override def hasNext: Boolean = {
|
||||
if (!finished && !havePair) {
|
||||
finished = !reader.nextKeyValue
|
||||
havePair = !finished
|
||||
}
|
||||
!finished
|
||||
}
|
||||
havePair = false
|
||||
return (reader.getCurrentKey, reader.getCurrentValue)
|
||||
}
|
||||
|
||||
private def close() {
|
||||
try {
|
||||
reader.close()
|
||||
} catch {
|
||||
case e: Exception => logWarning("Exception in RecordReader.close()", e)
|
||||
override def next(): (K, V) = {
|
||||
if (!hasNext) {
|
||||
throw new java.util.NoSuchElementException("End of stream")
|
||||
}
|
||||
havePair = false
|
||||
(reader.getCurrentKey, reader.getCurrentValue)
|
||||
}
|
||||
|
||||
private def close() {
|
||||
try {
|
||||
reader.close()
|
||||
} catch {
|
||||
case e: Exception => logWarning("Exception in RecordReader.close()", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
new InterruptibleIterator(context, iter)
|
||||
}
|
||||
|
||||
override def getPreferredLocations(split: Partition): Seq[String] = {
|
||||
|
|
|
@ -85,18 +85,24 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
|
|||
}
|
||||
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
|
||||
if (self.partitioner == Some(partitioner)) {
|
||||
self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
|
||||
self.mapPartitionsWithContext((context, iter) => {
|
||||
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
|
||||
}, preservesPartitioning = true)
|
||||
} else if (mapSideCombine) {
|
||||
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
|
||||
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
|
||||
.setSerializer(serializerClass)
|
||||
partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
|
||||
partitioned.mapPartitionsWithContext((context, iter) => {
|
||||
new InterruptibleIterator(context, aggregator.combineCombinersByKey(iter))
|
||||
}, preservesPartitioning = true)
|
||||
} else {
|
||||
// Don't apply map-side combiner.
|
||||
// A sanity check to make sure mergeCombiners is not defined.
|
||||
assert(mergeCombiners == null)
|
||||
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
|
||||
values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
|
||||
values.mapPartitionsWithContext((context, iter) => {
|
||||
new InterruptibleIterator(context, aggregator.combineValuesByKey(iter))
|
||||
}, preservesPartitioning = true)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -565,7 +571,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
|
|||
// around by taking a mod. We expect that no task will be attempted 2 billion times.
|
||||
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
|
||||
/* "reduce task" <split #> <attempt # = spark task #> */
|
||||
val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.splitId, attemptNumber)
|
||||
val attemptId = newTaskAttemptID(jobtrackerID, stageId, false, context.partitionId, attemptNumber)
|
||||
val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
|
||||
val format = outputFormatClass.newInstance
|
||||
val committer = format.getOutputCommitter(hadoopContext)
|
||||
|
@ -664,7 +670,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
|
|||
// around by taking a mod. We expect that no task will be attempted 2 billion times.
|
||||
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
|
||||
|
||||
writer.setup(context.stageId, context.splitId, attemptNumber)
|
||||
writer.setup(context.stageId, context.partitionId, attemptNumber)
|
||||
writer.open()
|
||||
|
||||
var count = 0
|
||||
|
|
|
@ -96,8 +96,9 @@ private[spark] class ParallelCollectionRDD[T: ClassTag](
|
|||
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
|
||||
}
|
||||
|
||||
override def compute(s: Partition, context: TaskContext) =
|
||||
s.asInstanceOf[ParallelCollectionPartition[T]].iterator
|
||||
override def compute(s: Partition, context: TaskContext) = {
|
||||
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
|
||||
}
|
||||
|
||||
override def getPreferredLocations(s: Partition): Seq[String] = {
|
||||
locationPrefs.getOrElse(s.index, Nil)
|
||||
|
|
|
@ -268,6 +268,19 @@ abstract class RDD[T: ClassTag](
|
|||
|
||||
def distinct(): RDD[T] = distinct(partitions.size)
|
||||
|
||||
/**
|
||||
* Return a new RDD that has exactly numPartitions partitions.
|
||||
*
|
||||
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
|
||||
* a shuffle to redistribute data.
|
||||
*
|
||||
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
|
||||
* which can avoid performing a shuffle.
|
||||
*/
|
||||
def repartition(numPartitions: Int): RDD[T] = {
|
||||
coalesce(numPartitions, true)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new RDD that is reduced into `numPartitions` partitions.
|
||||
*
|
||||
|
@ -421,26 +434,39 @@ abstract class RDD[T: ClassTag](
|
|||
command: Seq[String],
|
||||
env: Map[String, String] = Map(),
|
||||
printPipeContext: (String => Unit) => Unit = null,
|
||||
printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
|
||||
printRDDElement: (T, String => Unit) => Unit = null): RDD[String] = {
|
||||
new PipedRDD(this, command, env,
|
||||
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
|
||||
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new RDD by applying a function to each partition of this RDD.
|
||||
*/
|
||||
def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U],
|
||||
preservesPartitioning: Boolean = false): RDD[U] =
|
||||
preservesPartitioning: Boolean = false): RDD[U] = {
|
||||
new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
|
||||
* of the original partition.
|
||||
*/
|
||||
def mapPartitionsWithIndex[U: ClassTag](
|
||||
f: (Int, Iterator[T]) => Iterator[U],
|
||||
preservesPartitioning: Boolean = false): RDD[U] =
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
|
||||
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
|
||||
val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
|
||||
new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new RDD by applying a function to each partition of this RDD. This is a variant of
|
||||
* mapPartitions that also passes the TaskContext into the closure.
|
||||
*/
|
||||
def mapPartitionsWithContext[U: ClassTag](
|
||||
f: (TaskContext, Iterator[T]) => Iterator[U],
|
||||
preservesPartitioning: Boolean = false): RDD[U] = {
|
||||
new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
|
||||
|
@ -448,22 +474,23 @@ abstract class RDD[T: ClassTag](
|
|||
*/
|
||||
@deprecated("use mapPartitionsWithIndex", "0.7.0")
|
||||
def mapPartitionsWithSplit[U: ClassTag](
|
||||
f: (Int, Iterator[T]) => Iterator[U],
|
||||
preservesPartitioning: Boolean = false): RDD[U] =
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning)
|
||||
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
|
||||
mapPartitionsWithIndex(f, preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
* Maps f over this RDD, where f takes an additional parameter of type A. This
|
||||
* additional parameter is produced by constructA, which is called in each
|
||||
* partition with the index of that partition.
|
||||
*/
|
||||
def mapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false)
|
||||
(f:(T, A) => U): RDD[U] = {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
|
||||
val a = constructA(index)
|
||||
iter.map(t => f(t, a))
|
||||
}
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
|
||||
def mapWith[A: ClassTag, U: ClassTag]
|
||||
(constructA: Int => A, preservesPartitioning: Boolean = false)
|
||||
(f: (T, A) => U): RDD[U] = {
|
||||
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
|
||||
val a = constructA(context.partitionId)
|
||||
iter.map(t => f(t, a))
|
||||
}
|
||||
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -471,13 +498,14 @@ abstract class RDD[T: ClassTag](
|
|||
* additional parameter is produced by constructA, which is called in each
|
||||
* partition with the index of that partition.
|
||||
*/
|
||||
def flatMapWith[A: ClassTag, U: ClassTag](constructA: Int => A, preservesPartitioning: Boolean = false)
|
||||
(f:(T, A) => Seq[U]): RDD[U] = {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[U] = {
|
||||
val a = constructA(index)
|
||||
iter.flatMap(t => f(t, a))
|
||||
}
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning)
|
||||
def flatMapWith[A: ClassTag, U: ClassTag]
|
||||
(constructA: Int => A, preservesPartitioning: Boolean = false)
|
||||
(f: (T, A) => Seq[U]): RDD[U] = {
|
||||
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
|
||||
val a = constructA(context.partitionId)
|
||||
iter.flatMap(t => f(t, a))
|
||||
}
|
||||
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -485,13 +513,12 @@ abstract class RDD[T: ClassTag](
|
|||
* This additional parameter is produced by constructA, which is called in each
|
||||
* partition with the index of that partition.
|
||||
*/
|
||||
def foreachWith[A: ClassTag](constructA: Int => A)
|
||||
(f:(T, A) => Unit) {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
|
||||
val a = constructA(index)
|
||||
iter.map(t => {f(t, a); t})
|
||||
}
|
||||
(new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {})
|
||||
def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
|
||||
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
|
||||
val a = constructA(context.partitionId)
|
||||
iter.map(t => {f(t, a); t})
|
||||
}
|
||||
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -499,13 +526,12 @@ abstract class RDD[T: ClassTag](
|
|||
* additional parameter is produced by constructA, which is called in each
|
||||
* partition with the index of that partition.
|
||||
*/
|
||||
def filterWith[A: ClassTag](constructA: Int => A)
|
||||
(p:(T, A) => Boolean): RDD[T] = {
|
||||
def iterF(index: Int, iter: Iterator[T]): Iterator[T] = {
|
||||
val a = constructA(index)
|
||||
iter.filter(t => p(t, a))
|
||||
}
|
||||
new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)
|
||||
def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
|
||||
def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
|
||||
val a = constructA(context.partitionId)
|
||||
iter.filter(t => p(t, a))
|
||||
}
|
||||
new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -544,16 +570,14 @@ abstract class RDD[T: ClassTag](
|
|||
* Applies a function f to all elements of this RDD.
|
||||
*/
|
||||
def foreach(f: T => Unit) {
|
||||
val cleanF = sc.clean(f)
|
||||
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF))
|
||||
sc.runJob(this, (iter: Iterator[T]) => iter.foreach(f))
|
||||
}
|
||||
|
||||
/**
|
||||
* Applies a function f to each partition of this RDD.
|
||||
*/
|
||||
def foreachPartition(f: Iterator[T] => Unit) {
|
||||
val cleanF = sc.clean(f)
|
||||
sc.runJob(this, (iter: Iterator[T]) => cleanF(iter))
|
||||
sc.runJob(this, (iter: Iterator[T]) => f(iter))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -678,6 +702,8 @@ abstract class RDD[T: ClassTag](
|
|||
*/
|
||||
def count(): Long = {
|
||||
sc.runJob(this, (iter: Iterator[T]) => {
|
||||
// Use a while loop to count the number of elements rather than iter.size because
|
||||
// iter.size uses a for loop, which is slightly slower in current version of Scala.
|
||||
var result = 0L
|
||||
while (iter.hasNext) {
|
||||
result += 1L
|
||||
|
|
|
@ -57,7 +57,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
|
|||
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
|
||||
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
|
||||
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics,
|
||||
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
|
||||
SparkEnv.get.serializerManager.get(serializerClass))
|
||||
}
|
||||
|
||||
|
|
|
@ -111,7 +111,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
|
|||
}
|
||||
case ShuffleCoGroupSplitDep(shuffleId) => {
|
||||
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
|
||||
context.taskMetrics, serializer)
|
||||
context, serializer)
|
||||
iter.foreach(op)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -19,18 +19,21 @@ package org.apache.spark.scheduler
|
|||
|
||||
import java.io.NotSerializableException
|
||||
import java.util.Properties
|
||||
import java.util.concurrent.{LinkedBlockingQueue, TimeUnit}
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
|
||||
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
import scala.concurrent.duration._
|
||||
import scala.reflect.ClassTag
|
||||
|
||||
import akka.actor._
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.executor.TaskMetrics
|
||||
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
|
||||
import org.apache.spark.storage.{BlockManager, BlockManagerMaster}
|
||||
import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
|
||||
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerMaster, RDDBlockId}
|
||||
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
|
||||
|
||||
/**
|
||||
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
|
||||
|
@ -42,76 +45,108 @@ import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
|
|||
* locations to run each task on, based on the current cache status, and passes these to the
|
||||
* low-level TaskScheduler. Furthermore, it handles failures due to shuffle output files being
|
||||
* lost, in which case old stages may need to be resubmitted. Failures *within* a stage that are
|
||||
* not caused by shuffie file loss are handled by the TaskScheduler, which will retry each task
|
||||
* not caused by shuffle file loss are handled by the TaskScheduler, which will retry each task
|
||||
* a small number of times before cancelling the whole stage.
|
||||
*
|
||||
* THREADING: This class runs all its logic in a single thread executing the run() method, to which
|
||||
* events are submitted using a synchonized queue (eventQueue). The public API methods, such as
|
||||
* events are submitted using a synchronized queue (eventQueue). The public API methods, such as
|
||||
* runJob, taskEnded and executorLost, post events asynchronously to this queue. All other methods
|
||||
* should be private.
|
||||
*/
|
||||
private[spark]
|
||||
class DAGScheduler(
|
||||
taskSched: TaskScheduler,
|
||||
mapOutputTracker: MapOutputTracker,
|
||||
mapOutputTracker: MapOutputTrackerMaster,
|
||||
blockManagerMaster: BlockManagerMaster,
|
||||
env: SparkEnv)
|
||||
extends TaskSchedulerListener with Logging {
|
||||
extends Logging {
|
||||
|
||||
def this(taskSched: TaskScheduler) {
|
||||
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
|
||||
this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
|
||||
SparkEnv.get.blockManager.master, SparkEnv.get)
|
||||
}
|
||||
taskSched.setListener(this)
|
||||
taskSched.setDAGScheduler(this)
|
||||
|
||||
// Called by TaskScheduler to report task's starting.
|
||||
override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
|
||||
eventQueue.put(BeginEvent(task, taskInfo))
|
||||
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
|
||||
eventProcessActor ! BeginEvent(task, taskInfo)
|
||||
}
|
||||
|
||||
// Called to report that a task has completed and results are being fetched remotely.
|
||||
def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
|
||||
eventProcessActor ! GettingResultEvent(task, taskInfo)
|
||||
}
|
||||
|
||||
// Called by TaskScheduler to report task completions or failures.
|
||||
override def taskEnded(
|
||||
def taskEnded(
|
||||
task: Task[_],
|
||||
reason: TaskEndReason,
|
||||
result: Any,
|
||||
accumUpdates: Map[Long, Any],
|
||||
taskInfo: TaskInfo,
|
||||
taskMetrics: TaskMetrics) {
|
||||
eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
|
||||
eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
|
||||
}
|
||||
|
||||
// Called by TaskScheduler when an executor fails.
|
||||
override def executorLost(execId: String) {
|
||||
eventQueue.put(ExecutorLost(execId))
|
||||
def executorLost(execId: String) {
|
||||
eventProcessActor ! ExecutorLost(execId)
|
||||
}
|
||||
|
||||
// Called by TaskScheduler when a host is added
|
||||
override def executorGained(execId: String, host: String) {
|
||||
eventQueue.put(ExecutorGained(execId, host))
|
||||
def executorGained(execId: String, host: String) {
|
||||
eventProcessActor ! ExecutorGained(execId, host)
|
||||
}
|
||||
|
||||
// Called by TaskScheduler to cancel an entire TaskSet due to repeated failures.
|
||||
override def taskSetFailed(taskSet: TaskSet, reason: String) {
|
||||
eventQueue.put(TaskSetFailed(taskSet, reason))
|
||||
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
|
||||
// cancellation of the job itself.
|
||||
def taskSetFailed(taskSet: TaskSet, reason: String) {
|
||||
eventProcessActor ! TaskSetFailed(taskSet, reason)
|
||||
}
|
||||
|
||||
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
|
||||
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
|
||||
// as more failure events come in
|
||||
val RESUBMIT_TIMEOUT = 50L
|
||||
val RESUBMIT_TIMEOUT = 50.milliseconds
|
||||
|
||||
// The time, in millis, to wake up between polls of the completion queue in order to potentially
|
||||
// resubmit failed stages
|
||||
val POLL_TIMEOUT = 10L
|
||||
|
||||
private val eventQueue = new LinkedBlockingQueue[DAGSchedulerEvent]
|
||||
private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor {
|
||||
override def preStart() {
|
||||
context.system.scheduler.schedule(RESUBMIT_TIMEOUT, RESUBMIT_TIMEOUT) {
|
||||
if (failed.size > 0) {
|
||||
resubmitFailedStages()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
val nextJobId = new AtomicInteger(0)
|
||||
/**
|
||||
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
|
||||
* events and responds by launching tasks. This runs in a dedicated thread and receives events
|
||||
* via the eventQueue.
|
||||
*/
|
||||
def receive = {
|
||||
case event: DAGSchedulerEvent =>
|
||||
logDebug("Got event of type " + event.getClass.getName)
|
||||
|
||||
val nextStageId = new AtomicInteger(0)
|
||||
if (!processEvent(event))
|
||||
submitWaitingStages()
|
||||
else
|
||||
context.stop(self)
|
||||
}
|
||||
}))
|
||||
|
||||
val stageIdToStage = new TimeStampedHashMap[Int, Stage]
|
||||
private[scheduler] val nextJobId = new AtomicInteger(0)
|
||||
|
||||
val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
|
||||
def numTotalJobs: Int = nextJobId.get()
|
||||
|
||||
private val nextStageId = new AtomicInteger(0)
|
||||
|
||||
private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
|
||||
|
||||
private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
|
||||
|
||||
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
|
||||
|
||||
|
@ -128,6 +163,7 @@ class DAGScheduler(
|
|||
// stray messages to detect.
|
||||
val failedEpoch = new HashMap[String, Long]
|
||||
|
||||
// stage id to the active job
|
||||
val idToActiveJob = new HashMap[Int, ActiveJob]
|
||||
|
||||
val waiting = new HashSet[Stage] // Stages we need to run whose parents aren't done
|
||||
|
@ -139,17 +175,7 @@ class DAGScheduler(
|
|||
val activeJobs = new HashSet[ActiveJob]
|
||||
val resultStageToJob = new HashMap[Stage, ActiveJob]
|
||||
|
||||
val metadataCleaner = new MetadataCleaner("DAGScheduler", this.cleanup)
|
||||
|
||||
// Start a thread to run the DAGScheduler event loop
|
||||
def start() {
|
||||
new Thread("DAGScheduler") {
|
||||
setDaemon(true)
|
||||
override def run() {
|
||||
DAGScheduler.this.run()
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup)
|
||||
|
||||
def addSparkListener(listener: SparkListener) {
|
||||
listenerBus.addListener(listener)
|
||||
|
@ -157,7 +183,7 @@ class DAGScheduler(
|
|||
|
||||
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
|
||||
if (!cacheLocs.contains(rdd.id)) {
|
||||
val blockIds = rdd.partitions.indices.map(index=> "rdd_%d_%d".format(rdd.id, index)).toArray
|
||||
val blockIds = rdd.partitions.indices.map(index=> RDDBlockId(rdd.id, index)).toArray[BlockId]
|
||||
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
|
||||
cacheLocs(rdd.id) = blockIds.map { id =>
|
||||
locs.getOrElse(id, Nil).map(bm => TaskLocation(bm.host, bm.executorId))
|
||||
|
@ -179,7 +205,7 @@ class DAGScheduler(
|
|||
shuffleToMapStage.get(shuffleDep.shuffleId) match {
|
||||
case Some(stage) => stage
|
||||
case None =>
|
||||
val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
|
||||
val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
|
||||
shuffleToMapStage(shuffleDep.shuffleId) = stage
|
||||
stage
|
||||
}
|
||||
|
@ -192,6 +218,7 @@ class DAGScheduler(
|
|||
*/
|
||||
private def newStage(
|
||||
rdd: RDD[_],
|
||||
numTasks: Int,
|
||||
shuffleDep: Option[ShuffleDependency[_,_]],
|
||||
jobId: Int,
|
||||
callSite: Option[String] = None)
|
||||
|
@ -204,9 +231,10 @@ class DAGScheduler(
|
|||
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
|
||||
}
|
||||
val id = nextStageId.getAndIncrement()
|
||||
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
|
||||
val stage =
|
||||
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
|
||||
stageIdToStage(id) = stage
|
||||
stageToInfos(stage) = StageInfo(stage)
|
||||
stageToInfos(stage) = new StageInfo(stage)
|
||||
stage
|
||||
}
|
||||
|
||||
|
@ -262,54 +290,48 @@ class DAGScheduler(
|
|||
}
|
||||
|
||||
/**
|
||||
* Returns (and does not submit) a JobSubmitted event suitable to run a given job, and a
|
||||
* JobWaiter whose getResult() method will return the result of the job when it is complete.
|
||||
*
|
||||
* The job is assumed to have at least one partition; zero partition jobs should be handled
|
||||
* without a JobSubmitted event.
|
||||
* Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
|
||||
* can be used to block until the the job finishes executing or can be used to cancel the job.
|
||||
*/
|
||||
private[scheduler] def prepareJob[T, U: ClassTag](
|
||||
finalRdd: RDD[T],
|
||||
def submitJob[T, U](
|
||||
rdd: RDD[T],
|
||||
func: (TaskContext, Iterator[T]) => U,
|
||||
partitions: Seq[Int],
|
||||
callSite: String,
|
||||
allowLocal: Boolean,
|
||||
resultHandler: (Int, U) => Unit,
|
||||
properties: Properties = null)
|
||||
: (JobSubmitted, JobWaiter[U]) =
|
||||
properties: Properties = null): JobWaiter[U] =
|
||||
{
|
||||
assert(partitions.size > 0)
|
||||
val waiter = new JobWaiter(partitions.size, resultHandler)
|
||||
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||
val toSubmit = JobSubmitted(finalRdd, func2, partitions.toArray, allowLocal, callSite, waiter,
|
||||
properties)
|
||||
(toSubmit, waiter)
|
||||
}
|
||||
|
||||
def runJob[T, U: ClassTag](
|
||||
finalRdd: RDD[T],
|
||||
func: (TaskContext, Iterator[T]) => U,
|
||||
partitions: Seq[Int],
|
||||
callSite: String,
|
||||
allowLocal: Boolean,
|
||||
resultHandler: (Int, U) => Unit,
|
||||
properties: Properties = null)
|
||||
{
|
||||
if (partitions.size == 0) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check to make sure we are not launching a task on a partition that does not exist.
|
||||
val maxPartitions = finalRdd.partitions.length
|
||||
val maxPartitions = rdd.partitions.length
|
||||
partitions.find(p => p >= maxPartitions).foreach { p =>
|
||||
throw new IllegalArgumentException(
|
||||
"Attempting to access a non-existent partition: " + p + ". " +
|
||||
"Total number of partitions: " + maxPartitions)
|
||||
"Total number of partitions: " + maxPartitions)
|
||||
}
|
||||
|
||||
val (toSubmit: JobSubmitted, waiter: JobWaiter[_]) = prepareJob(
|
||||
finalRdd, func, partitions, callSite, allowLocal, resultHandler, properties)
|
||||
eventQueue.put(toSubmit)
|
||||
val jobId = nextJobId.getAndIncrement()
|
||||
if (partitions.size == 0) {
|
||||
return new JobWaiter[U](this, jobId, 0, resultHandler)
|
||||
}
|
||||
|
||||
assert(partitions.size > 0)
|
||||
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
|
||||
eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
|
||||
waiter
|
||||
}
|
||||
|
||||
def runJob[T, U: ClassTag](
|
||||
rdd: RDD[T],
|
||||
func: (TaskContext, Iterator[T]) => U,
|
||||
partitions: Seq[Int],
|
||||
callSite: String,
|
||||
allowLocal: Boolean,
|
||||
resultHandler: (Int, U) => Unit,
|
||||
properties: Properties = null)
|
||||
{
|
||||
val waiter = submitJob(rdd, func, partitions, callSite, allowLocal, resultHandler, properties)
|
||||
waiter.awaitResult() match {
|
||||
case JobSucceeded => {}
|
||||
case JobFailed(exception: Exception, _) =>
|
||||
|
@ -330,19 +352,39 @@ class DAGScheduler(
|
|||
val listener = new ApproximateActionListener(rdd, func, evaluator, timeout)
|
||||
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||
val partitions = (0 until rdd.partitions.size).toArray
|
||||
eventQueue.put(JobSubmitted(rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
|
||||
val jobId = nextJobId.getAndIncrement()
|
||||
eventProcessActor ! JobSubmitted(jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
|
||||
listener.awaitResult() // Will throw an exception if the job fails
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel a job that is running or waiting in the queue.
|
||||
*/
|
||||
def cancelJob(jobId: Int) {
|
||||
logInfo("Asked to cancel job " + jobId)
|
||||
eventProcessActor ! JobCancelled(jobId)
|
||||
}
|
||||
|
||||
def cancelJobGroup(groupId: String) {
|
||||
logInfo("Asked to cancel job group " + groupId)
|
||||
eventProcessActor ! JobGroupCancelled(groupId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Cancel all jobs that are running or waiting in the queue.
|
||||
*/
|
||||
def cancelAllJobs() {
|
||||
eventProcessActor ! AllJobsCancelled
|
||||
}
|
||||
|
||||
/**
|
||||
* Process one event retrieved from the event queue.
|
||||
* Returns true if we should stop the event loop.
|
||||
*/
|
||||
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
|
||||
event match {
|
||||
case JobSubmitted(finalRDD, func, partitions, allowLocal, callSite, listener, properties) =>
|
||||
val jobId = nextJobId.getAndIncrement()
|
||||
val finalStage = newStage(finalRDD, None, jobId, Some(callSite))
|
||||
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
|
||||
val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
|
||||
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
|
||||
clearCacheLocs()
|
||||
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
|
||||
|
@ -361,18 +403,43 @@ class DAGScheduler(
|
|||
submitStage(finalStage)
|
||||
}
|
||||
|
||||
case JobCancelled(jobId) =>
|
||||
// Cancel a job: find all the running stages that are linked to this job, and cancel them.
|
||||
running.filter(_.jobId == jobId).foreach { stage =>
|
||||
taskSched.cancelTasks(stage.id)
|
||||
}
|
||||
|
||||
case JobGroupCancelled(groupId) =>
|
||||
// Cancel all jobs belonging to this job group.
|
||||
// First finds all active jobs with this group id, and then kill stages for them.
|
||||
val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
|
||||
.map(_.jobId)
|
||||
if (!jobIds.isEmpty) {
|
||||
running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
|
||||
taskSched.cancelTasks(stage.id)
|
||||
}
|
||||
}
|
||||
|
||||
case AllJobsCancelled =>
|
||||
// Cancel all running jobs.
|
||||
running.foreach { stage =>
|
||||
taskSched.cancelTasks(stage.id)
|
||||
}
|
||||
|
||||
case ExecutorGained(execId, host) =>
|
||||
handleExecutorGained(execId, host)
|
||||
|
||||
case ExecutorLost(execId) =>
|
||||
handleExecutorLost(execId)
|
||||
|
||||
case begin: BeginEvent =>
|
||||
listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
|
||||
case BeginEvent(task, taskInfo) =>
|
||||
listenerBus.post(SparkListenerTaskStart(task, taskInfo))
|
||||
|
||||
case completion: CompletionEvent =>
|
||||
listenerBus.post(SparkListenerTaskEnd(
|
||||
completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
|
||||
case GettingResultEvent(task, taskInfo) =>
|
||||
listenerBus.post(SparkListenerTaskGettingResult(task, taskInfo))
|
||||
|
||||
case completion @ CompletionEvent(task, reason, _, _, taskInfo, taskMetrics) =>
|
||||
listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics))
|
||||
handleTaskCompletion(completion)
|
||||
|
||||
case TaskSetFailed(taskSet, reason) =>
|
||||
|
@ -422,42 +489,6 @@ class DAGScheduler(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
|
||||
* events and responds by launching tasks. This runs in a dedicated thread and receives events
|
||||
* via the eventQueue.
|
||||
*/
|
||||
private def run() {
|
||||
SparkEnv.set(env)
|
||||
|
||||
while (true) {
|
||||
val event = eventQueue.poll(POLL_TIMEOUT, TimeUnit.MILLISECONDS)
|
||||
if (event != null) {
|
||||
logDebug("Got event of type " + event.getClass.getName)
|
||||
}
|
||||
this.synchronized { // needed in case other threads makes calls into methods of this class
|
||||
if (event != null) {
|
||||
if (processEvent(event)) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
val time = System.currentTimeMillis() // TODO: use a pluggable clock for testability
|
||||
// Periodically resubmit failed stages if some map output fetches have failed and we have
|
||||
// waited at least RESUBMIT_TIMEOUT. We wait for this short time because when a node fails,
|
||||
// tasks on many other nodes are bound to get a fetch failure, and they won't all get it at
|
||||
// the same time, so we want to make sure we've identified all the reduce tasks that depend
|
||||
// on the failed node.
|
||||
if (failed.size > 0 && time > lastFetchFailureTime + RESUBMIT_TIMEOUT) {
|
||||
resubmitFailedStages()
|
||||
} else {
|
||||
submitWaitingStages()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a job on an RDD locally, assuming it has only a single partition and no dependencies.
|
||||
* We run the operation in a separate thread just in case it takes a bunch of time, so that we
|
||||
|
@ -542,7 +573,7 @@ class DAGScheduler(
|
|||
|
||||
// must be run listener before possible NotSerializableException
|
||||
// should be "StageSubmitted" first and then "JobEnded"
|
||||
listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
|
||||
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
|
||||
|
||||
if (tasks.size > 0) {
|
||||
// Preemptively serialize a task to make sure it can be serialized. We are catching this
|
||||
|
@ -563,9 +594,7 @@ class DAGScheduler(
|
|||
logDebug("New pending tasks: " + myPending)
|
||||
taskSched.submitTasks(
|
||||
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
|
||||
if (!stage.submissionTime.isDefined) {
|
||||
stage.submissionTime = Some(System.currentTimeMillis())
|
||||
}
|
||||
stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
|
||||
} else {
|
||||
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
|
||||
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
|
||||
|
@ -579,15 +608,20 @@ class DAGScheduler(
|
|||
*/
|
||||
private def handleTaskCompletion(event: CompletionEvent) {
|
||||
val task = event.task
|
||||
|
||||
if (!stageIdToStage.contains(task.stageId)) {
|
||||
// Skip all the actions if the stage has been cancelled.
|
||||
return
|
||||
}
|
||||
val stage = stageIdToStage(task.stageId)
|
||||
|
||||
def markStageAsFinished(stage: Stage) = {
|
||||
val serviceTime = stage.submissionTime match {
|
||||
val serviceTime = stageToInfos(stage).submissionTime match {
|
||||
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
|
||||
case _ => "Unkown"
|
||||
case _ => "Unknown"
|
||||
}
|
||||
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
|
||||
stage.completionTime = Some(System.currentTimeMillis)
|
||||
stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
|
||||
listenerBus.post(StageCompleted(stageToInfos(stage)))
|
||||
running -= stage
|
||||
}
|
||||
|
@ -627,7 +661,7 @@ class DAGScheduler(
|
|||
if (failedEpoch.contains(execId) && smt.epoch <= failedEpoch(execId)) {
|
||||
logInfo("Ignoring possibly bogus ShuffleMapTask completion from " + execId)
|
||||
} else {
|
||||
stage.addOutputLoc(smt.partition, status)
|
||||
stage.addOutputLoc(smt.partitionId, status)
|
||||
}
|
||||
if (running.contains(stage) && pendingTasks(stage).isEmpty) {
|
||||
markStageAsFinished(stage)
|
||||
|
@ -753,14 +787,14 @@ class DAGScheduler(
|
|||
|
||||
/**
|
||||
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
|
||||
* being cancelled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
|
||||
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
|
||||
*/
|
||||
private def abortStage(failedStage: Stage, reason: String) {
|
||||
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
|
||||
failedStage.completionTime = Some(System.currentTimeMillis())
|
||||
stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
|
||||
for (resultStage <- dependentStages) {
|
||||
val job = resultStageToJob(resultStage)
|
||||
val error = new SparkException("Job failed: " + reason)
|
||||
val error = new SparkException("Job aborted: " + reason)
|
||||
job.listener.jobFailed(error)
|
||||
listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
|
||||
idToActiveJob -= resultStage.jobId
|
||||
|
@ -823,7 +857,7 @@ class DAGScheduler(
|
|||
// If the RDD has narrow dependencies, pick the first partition of the first narrow dep
|
||||
// that has any placement preferences. Ideally we would choose based on transfer sizes,
|
||||
// but this will do for now.
|
||||
rdd.dependencies.foreach(_ match {
|
||||
rdd.dependencies.foreach {
|
||||
case n: NarrowDependency[_] =>
|
||||
for (inPart <- n.getParents(partition)) {
|
||||
val locs = getPreferredLocs(n.rdd, inPart)
|
||||
|
@ -831,7 +865,7 @@ class DAGScheduler(
|
|||
return locs
|
||||
}
|
||||
case _ =>
|
||||
})
|
||||
}
|
||||
Nil
|
||||
}
|
||||
|
||||
|
@ -854,7 +888,7 @@ class DAGScheduler(
|
|||
}
|
||||
|
||||
def stop() {
|
||||
eventQueue.put(StopDAGScheduler)
|
||||
eventProcessActor ! StopDAGScheduler
|
||||
metadataCleaner.cancel()
|
||||
taskSched.stop()
|
||||
}
|
||||
|
|
|
@ -31,9 +31,10 @@ import org.apache.spark.executor.TaskMetrics
|
|||
* submitted) but there is a single "logic" thread that reads these events and takes decisions.
|
||||
* This greatly simplifies synchronization.
|
||||
*/
|
||||
private[spark] sealed trait DAGSchedulerEvent
|
||||
private[scheduler] sealed trait DAGSchedulerEvent
|
||||
|
||||
private[spark] case class JobSubmitted(
|
||||
private[scheduler] case class JobSubmitted(
|
||||
jobId: Int,
|
||||
finalRDD: RDD[_],
|
||||
func: (TaskContext, Iterator[_]) => _,
|
||||
partitions: Array[Int],
|
||||
|
@ -43,9 +44,19 @@ private[spark] case class JobSubmitted(
|
|||
properties: Properties = null)
|
||||
extends DAGSchedulerEvent
|
||||
|
||||
private[spark] case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
|
||||
private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
|
||||
|
||||
private[spark] case class CompletionEvent(
|
||||
private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
|
||||
|
||||
private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
|
||||
|
||||
private[scheduler]
|
||||
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
|
||||
|
||||
private[scheduler]
|
||||
case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
|
||||
|
||||
private[scheduler] case class CompletionEvent(
|
||||
task: Task[_],
|
||||
reason: TaskEndReason,
|
||||
result: Any,
|
||||
|
@ -54,10 +65,12 @@ private[spark] case class CompletionEvent(
|
|||
taskMetrics: TaskMetrics)
|
||||
extends DAGSchedulerEvent
|
||||
|
||||
private[spark] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
|
||||
private[scheduler]
|
||||
case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
|
||||
|
||||
private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
|
||||
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
|
||||
|
||||
private[spark] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
|
||||
private[scheduler]
|
||||
case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
|
||||
|
||||
private[spark] case object StopDAGScheduler extends DAGSchedulerEvent
|
||||
private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent
|
||||
|
|
|
@ -40,7 +40,7 @@ private[spark] class DAGSchedulerSource(val dagScheduler: DAGScheduler, sc: Spar
|
|||
})
|
||||
|
||||
metricRegistry.register(MetricRegistry.name("job", "allJobs"), new Gauge[Int] {
|
||||
override def getValue: Int = dagScheduler.nextJobId.get()
|
||||
override def getValue: Int = dagScheduler.numTotalJobs
|
||||
})
|
||||
|
||||
metricRegistry.register(MetricRegistry.name("job", "activeJobs"), new Gauge[Int] {
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.scheduler
|
||||
|
||||
import org.apache.spark.{Logging, SparkEnv}
|
||||
import org.apache.spark.deploy.SparkHadoopUtil
|
||||
import scala.collection.immutable.Set
|
||||
import org.apache.hadoop.mapred.{FileInputFormat, JobConf}
|
||||
import org.apache.hadoop.security.UserGroupInformation
|
||||
|
@ -87,9 +88,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
|
|||
|
||||
// This method does not expect failures, since validate has already passed ...
|
||||
private def prefLocsFromMapreduceInputFormat(): Set[SplitInfo] = {
|
||||
val env = SparkEnv.get
|
||||
val conf = new JobConf(configuration)
|
||||
env.hadoop.addCredentials(conf)
|
||||
SparkHadoopUtil.get.addCredentials(conf)
|
||||
FileInputFormat.setInputPaths(conf, path)
|
||||
|
||||
val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] =
|
||||
|
@ -108,9 +108,8 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl
|
|||
|
||||
// This method does not expect failures, since validate has already passed ...
|
||||
private def prefLocsFromMapredInputFormat(): Set[SplitInfo] = {
|
||||
val env = SparkEnv.get
|
||||
val jobConf = new JobConf(configuration)
|
||||
env.hadoop.addCredentials(jobConf)
|
||||
SparkHadoopUtil.get.addCredentials(jobConf)
|
||||
FileInputFormat.setInputPaths(jobConf, path)
|
||||
|
||||
val instance: org.apache.hadoop.mapred.InputFormat[_, _] =
|
||||
|
|
|
@ -17,29 +17,39 @@
|
|||
|
||||
package org.apache.spark.scheduler
|
||||
|
||||
import java.io.PrintWriter
|
||||
import java.io.File
|
||||
import java.io.FileNotFoundException
|
||||
import java.io.{IOException, File, FileNotFoundException, PrintWriter}
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.{Date, Properties}
|
||||
import java.util.concurrent.LinkedBlockingQueue
|
||||
|
||||
import scala.collection.mutable.{Map, HashMap, ListBuffer}
|
||||
import scala.io.Source
|
||||
import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
|
||||
|
||||
import org.apache.spark._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.executor.TaskMetrics
|
||||
import org.apache.spark.storage.StorageLevel
|
||||
|
||||
// Used to record runtime information for each job, including RDD graph
|
||||
// tasks' start/stop shuffle information and information from outside
|
||||
/**
|
||||
* A logger class to record runtime information for jobs in Spark. This class outputs one log file
|
||||
* for each Spark job, containing RDD graph, tasks start/stop, shuffle information.
|
||||
* JobLogger is a subclass of SparkListener, use addSparkListener to add JobLogger to a SparkContext
|
||||
* after the SparkContext is created.
|
||||
* Note that each JobLogger only works for one SparkContext
|
||||
* @param logDirName The base directory for the log files.
|
||||
*/
|
||||
|
||||
class JobLogger(val user: String, val logDirName: String)
|
||||
extends SparkListener with Logging {
|
||||
|
||||
def this() = this(System.getProperty("user.name", "<unknown>"),
|
||||
String.valueOf(System.currentTimeMillis()))
|
||||
|
||||
class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
||||
private val logDir =
|
||||
if (System.getenv("SPARK_LOG_DIR") != null)
|
||||
System.getenv("SPARK_LOG_DIR")
|
||||
else
|
||||
"/tmp/spark"
|
||||
"/tmp/spark-%s".format(user)
|
||||
|
||||
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
|
||||
private val stageIDToJobID = new HashMap[Int, Int]
|
||||
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
|
||||
|
@ -47,37 +57,45 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
|
||||
|
||||
createLogDir()
|
||||
def this() = this(String.valueOf(System.currentTimeMillis()))
|
||||
|
||||
def getLogDir = logDir
|
||||
def getJobIDtoPrintWriter = jobIDToPrintWriter
|
||||
def getStageIDToJobID = stageIDToJobID
|
||||
def getJobIDToStages = jobIDToStages
|
||||
def getEventQueue = eventQueue
|
||||
// The following 5 functions are used only in testing.
|
||||
private[scheduler] def getLogDir = logDir
|
||||
private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
|
||||
private[scheduler] def getStageIDToJobID = stageIDToJobID
|
||||
private[scheduler] def getJobIDToStages = jobIDToStages
|
||||
private[scheduler] def getEventQueue = eventQueue
|
||||
|
||||
// Create a folder for log files, the folder's name is the creation time of the jobLogger
|
||||
/** Create a folder for log files, the folder's name is the creation time of jobLogger */
|
||||
protected def createLogDir() {
|
||||
val dir = new File(logDir + "/" + logDirName + "/")
|
||||
if (dir.exists()) {
|
||||
return
|
||||
}
|
||||
if (dir.mkdirs() == false) {
|
||||
logError("create log directory error:" + logDir + "/" + logDirName + "/")
|
||||
// JobLogger should throw a exception rather than continue to construct this object.
|
||||
throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/")
|
||||
}
|
||||
}
|
||||
|
||||
// Create a log file for one job, the file name is the jobID
|
||||
/**
|
||||
* Create a log file for one job
|
||||
* @param jobID ID of the job
|
||||
* @exception FileNotFoundException Fail to create log file
|
||||
*/
|
||||
protected def createLogWriter(jobID: Int) {
|
||||
try{
|
||||
try {
|
||||
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
|
||||
jobIDToPrintWriter += (jobID -> fileWriter)
|
||||
} catch {
|
||||
case e: FileNotFoundException => e.printStackTrace()
|
||||
}
|
||||
} catch {
|
||||
case e: FileNotFoundException => e.printStackTrace()
|
||||
}
|
||||
}
|
||||
|
||||
// Close log file, and clean the stage relationship in stageIDToJobID
|
||||
protected def closeLogWriter(jobID: Int) =
|
||||
/**
|
||||
* Close log file, and clean the stage relationship in stageIDToJobID
|
||||
* @param jobID ID of the job
|
||||
*/
|
||||
protected def closeLogWriter(jobID: Int) {
|
||||
jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
|
||||
fileWriter.close()
|
||||
jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
|
||||
|
@ -86,9 +104,14 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
jobIDToPrintWriter -= jobID
|
||||
jobIDToStages -= jobID
|
||||
}
|
||||
}
|
||||
|
||||
// Write log information to log file, withTime parameter controls whether to recored
|
||||
// time stamp for the information
|
||||
/**
|
||||
* Write info into log file
|
||||
* @param jobID ID of the job
|
||||
* @param info Info to be recorded
|
||||
* @param withTime Controls whether to record time stamp before the info, default is true
|
||||
*/
|
||||
protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
|
||||
var writeInfo = info
|
||||
if (withTime) {
|
||||
|
@ -98,9 +121,21 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
|
||||
}
|
||||
|
||||
protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
|
||||
/**
|
||||
* Write info into log file
|
||||
* @param stageID ID of the stage
|
||||
* @param info Info to be recorded
|
||||
* @param withTime Controls whether to record time stamp before the info, default is true
|
||||
*/
|
||||
protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) {
|
||||
stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
|
||||
}
|
||||
|
||||
/**
|
||||
* Build stage dependency for a job
|
||||
* @param jobID ID of the job
|
||||
* @param stage Root stage of the job
|
||||
*/
|
||||
protected def buildJobDep(jobID: Int, stage: Stage) {
|
||||
if (stage.jobId == jobID) {
|
||||
jobIDToStages.get(jobID) match {
|
||||
|
@ -114,14 +149,17 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Record stage dependency and RDD dependency for a stage
|
||||
* @param jobID Job ID of the stage
|
||||
*/
|
||||
protected def recordStageDep(jobID: Int) {
|
||||
def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
|
||||
var rddList = new ListBuffer[RDD[_]]
|
||||
rddList += rdd
|
||||
rdd.dependencies.foreach{ dep => dep match {
|
||||
case shufDep: ShuffleDependency[_,_] =>
|
||||
case _ => rddList ++= getRddsInStage(dep.rdd)
|
||||
}
|
||||
rdd.dependencies.foreach {
|
||||
case shufDep: ShuffleDependency[_, _] =>
|
||||
case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
|
||||
}
|
||||
rddList
|
||||
}
|
||||
|
@ -141,8 +179,12 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
}
|
||||
}
|
||||
|
||||
// Generate indents and convert to String
|
||||
protected def indentString(indent: Int) = {
|
||||
/**
|
||||
* Generate indents and convert to String
|
||||
* @param indent Number of indents
|
||||
* @return string of indents
|
||||
*/
|
||||
protected def indentString(indent: Int): String = {
|
||||
val sb = new StringBuilder()
|
||||
for (i <- 1 to indent) {
|
||||
sb.append(" ")
|
||||
|
@ -150,86 +192,122 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
sb.toString()
|
||||
}
|
||||
|
||||
protected def getRddName(rdd: RDD[_]) = {
|
||||
var rddName = rdd.getClass.getName
|
||||
/**
|
||||
* Get RDD's name
|
||||
* @param rdd Input RDD
|
||||
* @return String of RDD's name
|
||||
*/
|
||||
protected def getRddName(rdd: RDD[_]): String = {
|
||||
var rddName = rdd.getClass.getSimpleName
|
||||
if (rdd.name != null) {
|
||||
rddName = rdd.name
|
||||
}
|
||||
rddName
|
||||
}
|
||||
|
||||
/**
|
||||
* Record RDD dependency graph in a stage
|
||||
* @param jobID Job ID of the stage
|
||||
* @param rdd Root RDD of the stage
|
||||
* @param indent Indent number before info
|
||||
*/
|
||||
protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
|
||||
val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
|
||||
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
|
||||
rdd.dependencies.foreach{ dep => dep match {
|
||||
case shufDep: ShuffleDependency[_,_] =>
|
||||
val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
|
||||
jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
|
||||
case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
|
||||
val rddInfo =
|
||||
if (rdd.getStorageLevel != StorageLevel.NONE) {
|
||||
"RDD_ID=" + rdd.id + " " + getRddName(rdd) + " CACHED" + " " +
|
||||
rdd.origin + " " + rdd.generator
|
||||
} else {
|
||||
"RDD_ID=" + rdd.id + " " + getRddName(rdd) + " NONE" + " " +
|
||||
rdd.origin + " " + rdd.generator
|
||||
}
|
||||
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
|
||||
rdd.dependencies.foreach {
|
||||
case shufDep: ShuffleDependency[_, _] =>
|
||||
val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
|
||||
jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
|
||||
case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
|
||||
}
|
||||
}
|
||||
|
||||
protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
|
||||
var stageInfo: String = ""
|
||||
if (stage.isShuffleMap) {
|
||||
stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
|
||||
stage.shuffleDep.get.shuffleId
|
||||
}else{
|
||||
stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
|
||||
/**
|
||||
* Record stage dependency graph of a job
|
||||
* @param jobID Job ID of the stage
|
||||
* @param stage Root stage of the job
|
||||
* @param indent Indent number before info, default is 0
|
||||
*/
|
||||
protected def recordStageDepGraph(jobID: Int, stage: Stage, idSet: HashSet[Int], indent: Int = 0) {
|
||||
val stageInfo = if (stage.isShuffleMap) {
|
||||
"STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
|
||||
} else {
|
||||
"STAGE_ID=" + stage.id + " RESULT_STAGE"
|
||||
}
|
||||
if (stage.jobId == jobID) {
|
||||
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
|
||||
recordRddInStageGraph(jobID, stage.rdd, indent)
|
||||
stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
|
||||
} else
|
||||
if (!idSet.contains(stage.id)) {
|
||||
idSet += stage.id
|
||||
recordRddInStageGraph(jobID, stage.rdd, indent)
|
||||
stage.parents.foreach(recordStageDepGraph(jobID, _, idSet, indent + 2))
|
||||
}
|
||||
} else {
|
||||
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
|
||||
}
|
||||
}
|
||||
|
||||
// Record task metrics into job log files
|
||||
/**
|
||||
* Record task metrics into job log files, including execution info and shuffle metrics
|
||||
* @param stageID Stage ID of the task
|
||||
* @param status Status info of the task
|
||||
* @param taskInfo Task description info
|
||||
* @param taskMetrics Task running metrics
|
||||
*/
|
||||
protected def recordTaskMetrics(stageID: Int, status: String,
|
||||
taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
|
||||
val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
|
||||
" START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
|
||||
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
|
||||
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
|
||||
val readMetrics =
|
||||
taskMetrics.shuffleReadMetrics match {
|
||||
case Some(metrics) =>
|
||||
" SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
|
||||
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
|
||||
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
|
||||
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
|
||||
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
|
||||
" REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
|
||||
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
|
||||
case None => ""
|
||||
}
|
||||
val writeMetrics =
|
||||
taskMetrics.shuffleWriteMetrics match {
|
||||
case Some(metrics) =>
|
||||
" SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
|
||||
case None => ""
|
||||
}
|
||||
val readMetrics = taskMetrics.shuffleReadMetrics match {
|
||||
case Some(metrics) =>
|
||||
" SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
|
||||
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
|
||||
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
|
||||
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
|
||||
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
|
||||
" REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
|
||||
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
|
||||
case None => ""
|
||||
}
|
||||
val writeMetrics = taskMetrics.shuffleWriteMetrics match {
|
||||
case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
|
||||
case None => ""
|
||||
}
|
||||
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
|
||||
}
|
||||
|
||||
/**
|
||||
* When stage is submitted, record stage submit info
|
||||
* @param stageSubmitted Stage submitted event
|
||||
*/
|
||||
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
|
||||
stageLogInfo(
|
||||
stageSubmitted.stage.id,
|
||||
"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
|
||||
stageSubmitted.stage.id, stageSubmitted.taskSize))
|
||||
stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
|
||||
stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
|
||||
}
|
||||
|
||||
/**
|
||||
* When stage is completed, record stage completion status
|
||||
* @param stageCompleted Stage completed event
|
||||
*/
|
||||
override def onStageCompleted(stageCompleted: StageCompleted) {
|
||||
stageLogInfo(
|
||||
stageCompleted.stageInfo.stage.id,
|
||||
"STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
|
||||
|
||||
stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
|
||||
stageCompleted.stage.stageId))
|
||||
}
|
||||
|
||||
override def onTaskStart(taskStart: SparkListenerTaskStart) { }
|
||||
|
||||
/**
|
||||
* When task ends, record task completion status and metrics
|
||||
* @param taskEnd Task end event
|
||||
*/
|
||||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
|
||||
val task = taskEnd.task
|
||||
val taskInfo = taskEnd.taskInfo
|
||||
|
@ -258,6 +336,10 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* When job ends, recording job completion status and close log file
|
||||
* @param jobEnd Job end event
|
||||
*/
|
||||
override def onJobEnd(jobEnd: SparkListenerJobEnd) {
|
||||
val job = jobEnd.job
|
||||
var info = "JOB_ID=" + job.jobId
|
||||
|
@ -272,6 +354,11 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
closeLogWriter(job.jobId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Record job properties into job log file
|
||||
* @param jobID ID of the job
|
||||
* @param properties Properties of the job
|
||||
*/
|
||||
protected def recordJobProperties(jobID: Int, properties: Properties) {
|
||||
if(properties != null) {
|
||||
val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "")
|
||||
|
@ -279,6 +366,10 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* When job starts, record job property and stage graph
|
||||
* @param jobStart Job start event
|
||||
*/
|
||||
override def onJobStart(jobStart: SparkListenerJobStart) {
|
||||
val job = jobStart.job
|
||||
val properties = jobStart.properties
|
||||
|
@ -286,7 +377,8 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
|||
recordJobProperties(job.jobId, properties)
|
||||
buildJobDep(job.jobId, job.finalStage)
|
||||
recordStageDep(job.jobId)
|
||||
recordStageDepGraph(job.jobId, job.finalStage)
|
||||
recordStageDepGraph(job.jobId, job.finalStage, new HashSet[Int])
|
||||
jobLogInfo(job.jobId, "JOB_ID=" + job.jobId + " STATUS=STARTED")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,48 +17,58 @@
|
|||
|
||||
package org.apache.spark.scheduler
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
/**
|
||||
* An object that waits for a DAGScheduler job to complete. As tasks finish, it passes their
|
||||
* results to the given handler function.
|
||||
*/
|
||||
private[spark] class JobWaiter[T](totalTasks: Int, resultHandler: (Int, T) => Unit)
|
||||
private[spark] class JobWaiter[T](
|
||||
dagScheduler: DAGScheduler,
|
||||
jobId: Int,
|
||||
totalTasks: Int,
|
||||
resultHandler: (Int, T) => Unit)
|
||||
extends JobListener {
|
||||
|
||||
private var finishedTasks = 0
|
||||
|
||||
private var jobFinished = false // Is the job as a whole finished (succeeded or failed)?
|
||||
private var jobResult: JobResult = null // If the job is finished, this will be its result
|
||||
// Is the job as a whole finished (succeeded or failed)?
|
||||
private var _jobFinished = totalTasks == 0
|
||||
|
||||
override def taskSucceeded(index: Int, result: Any) {
|
||||
synchronized {
|
||||
if (jobFinished) {
|
||||
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
|
||||
}
|
||||
resultHandler(index, result.asInstanceOf[T])
|
||||
finishedTasks += 1
|
||||
if (finishedTasks == totalTasks) {
|
||||
jobFinished = true
|
||||
jobResult = JobSucceeded
|
||||
this.notifyAll()
|
||||
}
|
||||
}
|
||||
def jobFinished = _jobFinished
|
||||
|
||||
// If the job is finished, this will be its result. In the case of 0 task jobs (e.g. zero
|
||||
// partition RDDs), we set the jobResult directly to JobSucceeded.
|
||||
private var jobResult: JobResult = if (jobFinished) JobSucceeded else null
|
||||
|
||||
/**
|
||||
* Sends a signal to the DAGScheduler to cancel the job. The cancellation itself is handled
|
||||
* asynchronously. After the low level scheduler cancels all the tasks belonging to this job, it
|
||||
* will fail this job with a SparkException.
|
||||
*/
|
||||
def cancel() {
|
||||
dagScheduler.cancelJob(jobId)
|
||||
}
|
||||
|
||||
override def jobFailed(exception: Exception) {
|
||||
synchronized {
|
||||
if (jobFinished) {
|
||||
throw new UnsupportedOperationException("jobFailed() called on a finished JobWaiter")
|
||||
}
|
||||
jobFinished = true
|
||||
jobResult = JobFailed(exception, None)
|
||||
override def taskSucceeded(index: Int, result: Any): Unit = synchronized {
|
||||
if (_jobFinished) {
|
||||
throw new UnsupportedOperationException("taskSucceeded() called on a finished JobWaiter")
|
||||
}
|
||||
resultHandler(index, result.asInstanceOf[T])
|
||||
finishedTasks += 1
|
||||
if (finishedTasks == totalTasks) {
|
||||
_jobFinished = true
|
||||
jobResult = JobSucceeded
|
||||
this.notifyAll()
|
||||
}
|
||||
}
|
||||
|
||||
override def jobFailed(exception: Exception): Unit = synchronized {
|
||||
_jobFinished = true
|
||||
jobResult = JobFailed(exception, None)
|
||||
this.notifyAll()
|
||||
}
|
||||
|
||||
def awaitResult(): JobResult = synchronized {
|
||||
while (!jobFinished) {
|
||||
while (!_jobFinished) {
|
||||
this.wait()
|
||||
}
|
||||
return jobResult
|
||||
|
|
|
@ -43,7 +43,10 @@ private[spark] class Pool(
|
|||
var runningTasks = 0
|
||||
|
||||
var priority = 0
|
||||
var stageId = 0
|
||||
|
||||
// A pool's stage id is used to break the tie in scheduling.
|
||||
var stageId = -1
|
||||
|
||||
var name = poolName
|
||||
var parent: Pool = null
|
||||
|
||||
|
|
|
@ -23,7 +23,7 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream}
|
|||
import org.apache.spark._
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.rdd.RDDCheckpointData
|
||||
import org.apache.spark.util.{MetadataCleaner, TimeStampedHashMap}
|
||||
import org.apache.spark.util.{MetadataCleanerType, MetadataCleaner, TimeStampedHashMap}
|
||||
|
||||
private[spark] object ResultTask {
|
||||
|
||||
|
@ -32,23 +32,23 @@ private[spark] object ResultTask {
|
|||
// expensive on the master node if it needs to launch thousands of tasks.
|
||||
val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]]
|
||||
|
||||
val metadataCleaner = new MetadataCleaner("ResultTask", serializedInfoCache.clearOldValues)
|
||||
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues)
|
||||
|
||||
def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = {
|
||||
synchronized {
|
||||
val old = serializedInfoCache.get(stageId).orNull
|
||||
if (old != null) {
|
||||
return old
|
||||
old
|
||||
} else {
|
||||
val out = new ByteArrayOutputStream
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val objOut = ser.serializeStream(new GZIPOutputStream(out))
|
||||
objOut.writeObject(rdd)
|
||||
objOut.writeObject(func)
|
||||
objOut.close()
|
||||
val bytes = out.toByteArray
|
||||
serializedInfoCache.put(stageId, bytes)
|
||||
return bytes
|
||||
bytes
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -56,11 +56,11 @@ private[spark] object ResultTask {
|
|||
def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = {
|
||||
val loader = Thread.currentThread.getContextClassLoader
|
||||
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val objIn = ser.deserializeStream(in)
|
||||
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
|
||||
val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
|
||||
return (rdd, func)
|
||||
(rdd, func)
|
||||
}
|
||||
|
||||
def clearCache() {
|
||||
|
@ -71,29 +71,37 @@ private[spark] object ResultTask {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* A task that sends back the output to the driver application.
|
||||
*
|
||||
* See [[org.apache.spark.scheduler.Task]] for more information.
|
||||
*
|
||||
* @param stageId id of the stage this task belongs to
|
||||
* @param rdd input to func
|
||||
* @param func a function to apply on a partition of the RDD
|
||||
* @param _partitionId index of the number in the RDD
|
||||
* @param locs preferred task execution locations for locality scheduling
|
||||
* @param outputId index of the task in this job (a job can launch tasks on only a subset of the
|
||||
* input RDD's partitions).
|
||||
*/
|
||||
private[spark] class ResultTask[T, U](
|
||||
stageId: Int,
|
||||
var rdd: RDD[T],
|
||||
var func: (TaskContext, Iterator[T]) => U,
|
||||
var partition: Int,
|
||||
_partitionId: Int,
|
||||
@transient locs: Seq[TaskLocation],
|
||||
var outputId: Int)
|
||||
extends Task[U](stageId) with Externalizable {
|
||||
extends Task[U](stageId, _partitionId) with Externalizable {
|
||||
|
||||
def this() = this(0, null, null, 0, null, 0)
|
||||
|
||||
var split = if (rdd == null) {
|
||||
null
|
||||
} else {
|
||||
rdd.partitions(partition)
|
||||
}
|
||||
var split = if (rdd == null) null else rdd.partitions(partitionId)
|
||||
|
||||
@transient private val preferredLocs: Seq[TaskLocation] = {
|
||||
if (locs == null) Nil else locs.toSet.toSeq
|
||||
}
|
||||
|
||||
override def run(attemptId: Long): U = {
|
||||
val context = new TaskContext(stageId, partition, attemptId, runningLocally = false)
|
||||
override def runTask(context: TaskContext): U = {
|
||||
metrics = Some(context.taskMetrics)
|
||||
try {
|
||||
func(context, rdd.iterator(split, context))
|
||||
|
@ -104,17 +112,17 @@ private[spark] class ResultTask[T, U](
|
|||
|
||||
override def preferredLocations: Seq[TaskLocation] = preferredLocs
|
||||
|
||||
override def toString = "ResultTask(" + stageId + ", " + partition + ")"
|
||||
override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
|
||||
|
||||
override def writeExternal(out: ObjectOutput) {
|
||||
RDDCheckpointData.synchronized {
|
||||
split = rdd.partitions(partition)
|
||||
split = rdd.partitions(partitionId)
|
||||
out.writeInt(stageId)
|
||||
val bytes = ResultTask.serializeInfo(
|
||||
stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
|
||||
out.writeInt(bytes.length)
|
||||
out.write(bytes)
|
||||
out.writeInt(partition)
|
||||
out.writeInt(partitionId)
|
||||
out.writeInt(outputId)
|
||||
out.writeLong(epoch)
|
||||
out.writeObject(split)
|
||||
|
@ -129,7 +137,7 @@ private[spark] class ResultTask[T, U](
|
|||
val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
|
||||
rdd = rdd_.asInstanceOf[RDD[T]]
|
||||
func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
|
||||
partition = in.readInt()
|
||||
partitionId = in.readInt()
|
||||
outputId = in.readInt()
|
||||
epoch = in.readLong()
|
||||
split = in.readObject().asInstanceOf[Partition]
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Reference in a new issue