Merge branch 'master' of github.com:mesos/spark into graph
Conflicts: run run2.cmd
This commit is contained in:
commit
564d902d79
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -36,4 +36,5 @@ streaming-tests.log
|
|||
dependency-reduced-pom.xml
|
||||
.ensime
|
||||
.ensime_lucene
|
||||
checkpoint
|
||||
derby.log
|
||||
|
|
52
bin/compute-classpath.cmd
Normal file
52
bin/compute-classpath.cmd
Normal file
|
@ -0,0 +1,52 @@
|
|||
@echo off
|
||||
|
||||
rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
|
||||
rem script and the ExecutorRunner in standalone cluster mode.
|
||||
|
||||
set SCALA_VERSION=2.9.3
|
||||
|
||||
rem Figure out where the Spark framework is installed
|
||||
set FWDIR=%~dp0..\
|
||||
|
||||
rem Load environment variables from conf\spark-env.cmd, if it exists
|
||||
if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
|
||||
|
||||
set CORE_DIR=%FWDIR%core
|
||||
set REPL_DIR=%FWDIR%repl
|
||||
set EXAMPLES_DIR=%FWDIR%examples
|
||||
set BAGEL_DIR=%FWDIR%bagel
|
||||
set STREAMING_DIR=%FWDIR%streaming
|
||||
set PYSPARK_DIR=%FWDIR%python
|
||||
|
||||
rem Build up classpath
|
||||
set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
|
||||
set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources
|
||||
set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes
|
||||
set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\*
|
||||
set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes
|
||||
set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\*
|
||||
set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\*
|
||||
set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\*
|
||||
set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\*
|
||||
set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
|
||||
|
||||
rem Add hadoop conf dir - else FileSystem.*, etc fail
|
||||
rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
|
||||
rem the configurtion files.
|
||||
if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
|
||||
set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
|
||||
:no_hadoop_conf_dir
|
||||
|
||||
if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
|
||||
set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
|
||||
:no_yarn_conf_dir
|
||||
|
||||
rem Add Scala standard library
|
||||
set CLASSPATH=%CLASSPATH%;%SCALA_HOME%\lib\scala-library.jar;%SCALA_HOME%\lib\scala-compiler.jar;%SCALA_HOME%\lib\jline.jar
|
||||
|
||||
rem A bit of a hack to allow calling this script within run2.cmd without seeing output
|
||||
if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
|
||||
|
||||
echo %CLASSPATH%
|
||||
|
||||
:exit
|
89
bin/compute-classpath.sh
Executable file
89
bin/compute-classpath.sh
Executable file
|
@ -0,0 +1,89 @@
|
|||
#!/bin/bash
|
||||
|
||||
# This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
|
||||
# script and the ExecutorRunner in standalone cluster mode.
|
||||
|
||||
SCALA_VERSION=2.9.3
|
||||
|
||||
# Figure out where Spark is installed
|
||||
FWDIR="$(cd `dirname $0`/..; pwd)"
|
||||
|
||||
# Load environment variables from conf/spark-env.sh, if it exists
|
||||
if [ -e $FWDIR/conf/spark-env.sh ] ; then
|
||||
. $FWDIR/conf/spark-env.sh
|
||||
fi
|
||||
|
||||
CORE_DIR="$FWDIR/core"
|
||||
REPL_DIR="$FWDIR/repl"
|
||||
REPL_BIN_DIR="$FWDIR/repl-bin"
|
||||
EXAMPLES_DIR="$FWDIR/examples"
|
||||
BAGEL_DIR="$FWDIR/bagel"
|
||||
STREAMING_DIR="$FWDIR/streaming"
|
||||
PYSPARK_DIR="$FWDIR/python"
|
||||
|
||||
# Build up classpath
|
||||
CLASSPATH="$SPARK_CLASSPATH"
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/conf"
|
||||
CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
if [ -n "$SPARK_TESTING" ] ; then
|
||||
CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
|
||||
CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes"
|
||||
fi
|
||||
CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources"
|
||||
CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar
|
||||
if [ -e "$FWDIR/lib_managed" ]; then
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*"
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*"
|
||||
fi
|
||||
CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*"
|
||||
# Add the shaded JAR for Maven builds
|
||||
if [ -e $REPL_BIN_DIR/target ]; then
|
||||
for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
|
||||
CLASSPATH="$CLASSPATH:$jar"
|
||||
done
|
||||
# The shaded JAR doesn't contain examples, so include those separately
|
||||
EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
|
||||
CLASSPATH+=":$EXAMPLES_JAR"
|
||||
fi
|
||||
CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
|
||||
CLASSPATH="$CLASSPATH:$jar"
|
||||
done
|
||||
|
||||
# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack
|
||||
# to avoid the -sources and -doc packages that are built by publish-local.
|
||||
if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ]; then
|
||||
# Use the JAR from the SBT build
|
||||
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar`
|
||||
fi
|
||||
if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then
|
||||
# Use the JAR from the Maven build
|
||||
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
|
||||
fi
|
||||
|
||||
# Add hadoop conf dir - else FileSystem.*, etc fail !
|
||||
# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
|
||||
# the configurtion files.
|
||||
if [ "x" != "x$HADOOP_CONF_DIR" ]; then
|
||||
CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR"
|
||||
fi
|
||||
if [ "x" != "x$YARN_CONF_DIR" ]; then
|
||||
CLASSPATH="$CLASSPATH:$YARN_CONF_DIR"
|
||||
fi
|
||||
|
||||
# Add Scala standard library
|
||||
if [ -z "$SCALA_LIBRARY_PATH" ]; then
|
||||
if [ -z "$SCALA_HOME" ]; then
|
||||
echo "SCALA_HOME is not set" >&2
|
||||
exit 1
|
||||
fi
|
||||
SCALA_LIBRARY_PATH="$SCALA_HOME/lib"
|
||||
fi
|
||||
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar"
|
||||
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar"
|
||||
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar"
|
||||
|
||||
echo "$CLASSPATH"
|
|
@ -32,8 +32,8 @@
|
|||
<artifactId>compress-lzf</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>asm</groupId>
|
||||
<artifactId>asm-all</artifactId>
|
||||
<groupId>org.ow2.asm</groupId>
|
||||
<artifactId>asm</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
|
|
|
@ -8,15 +8,20 @@ import io.netty.channel.ChannelOption;
|
|||
import io.netty.channel.oio.OioEventLoopGroup;
|
||||
import io.netty.channel.socket.oio.OioSocketChannel;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
class FileClient {
|
||||
|
||||
private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
|
||||
private FileClientHandler handler = null;
|
||||
private Channel channel = null;
|
||||
private Bootstrap bootstrap = null;
|
||||
private int connectTimeout = 60*1000; // 1 min
|
||||
|
||||
public FileClient(FileClientHandler handler) {
|
||||
public FileClient(FileClientHandler handler, int connectTimeout) {
|
||||
this.handler = handler;
|
||||
this.connectTimeout = connectTimeout;
|
||||
}
|
||||
|
||||
public void init() {
|
||||
|
@ -25,25 +30,10 @@ class FileClient {
|
|||
.channel(OioSocketChannel.class)
|
||||
.option(ChannelOption.SO_KEEPALIVE, true)
|
||||
.option(ChannelOption.TCP_NODELAY, true)
|
||||
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
|
||||
.handler(new FileClientChannelInitializer(handler));
|
||||
}
|
||||
|
||||
public static final class ChannelCloseListener implements ChannelFutureListener {
|
||||
private FileClient fc = null;
|
||||
|
||||
public ChannelCloseListener(FileClient fc){
|
||||
this.fc = fc;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void operationComplete(ChannelFuture future) {
|
||||
if (fc.bootstrap!=null){
|
||||
fc.bootstrap.shutdown();
|
||||
fc.bootstrap = null;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public void connect(String host, int port) {
|
||||
try {
|
||||
// Start the connection attempt.
|
||||
|
@ -58,8 +48,8 @@ class FileClient {
|
|||
public void waitForClose() {
|
||||
try {
|
||||
channel.closeFuture().sync();
|
||||
} catch (InterruptedException e){
|
||||
e.printStackTrace();
|
||||
} catch (InterruptedException e) {
|
||||
LOG.warn("FileClient interrupted", e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,7 +9,14 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
|
|||
|
||||
private FileHeader currentHeader = null;
|
||||
|
||||
private volatile boolean handlerCalled = false;
|
||||
|
||||
public boolean isComplete() {
|
||||
return handlerCalled;
|
||||
}
|
||||
|
||||
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
|
||||
public abstract void handleError(String blockId);
|
||||
|
||||
@Override
|
||||
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
|
||||
|
@ -26,6 +33,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
|
|||
// get file
|
||||
if(in.readableBytes() >= currentHeader.fileLen()) {
|
||||
handle(ctx, in, currentHeader);
|
||||
handlerCalled = true;
|
||||
currentHeader = null;
|
||||
ctx.close();
|
||||
}
|
||||
|
|
|
@ -58,6 +58,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
|
|||
|
||||
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
|
||||
val shuffleMetrics = new ShuffleReadMetrics
|
||||
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
|
||||
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
|
||||
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
|
||||
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
|
||||
|
|
|
@ -5,8 +5,7 @@ import java.lang.reflect.Field
|
|||
import scala.collection.mutable.Map
|
||||
import scala.collection.mutable.Set
|
||||
|
||||
import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
|
||||
import org.objectweb.asm.commons.EmptyVisitor
|
||||
import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
|
||||
import org.objectweb.asm.Opcodes._
|
||||
import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
|
||||
|
||||
|
@ -162,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging {
|
|||
}
|
||||
}
|
||||
|
||||
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
|
||||
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
|
||||
override def visitMethod(access: Int, name: String, desc: String,
|
||||
sig: String, exceptions: Array[String]): MethodVisitor = {
|
||||
return new EmptyVisitor {
|
||||
return new MethodVisitor(ASM4) {
|
||||
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
|
||||
if (op == GETFIELD) {
|
||||
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
|
||||
|
@ -188,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten
|
|||
}
|
||||
}
|
||||
|
||||
private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
|
||||
private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
|
||||
var myName: String = null
|
||||
|
||||
override def visit(version: Int, access: Int, name: String, sig: String,
|
||||
|
@ -198,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi
|
|||
|
||||
override def visitMethod(access: Int, name: String, desc: String,
|
||||
sig: String, exceptions: Array[String]): MethodVisitor = {
|
||||
return new EmptyVisitor {
|
||||
return new MethodVisitor(ASM4) {
|
||||
override def visitMethodInsn(op: Int, owner: String, name: String,
|
||||
desc: String) {
|
||||
val argTypes = Type.getArgumentTypes(desc)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
package spark
|
||||
|
||||
import java.nio.ByteBuffer
|
||||
import java.util.{Date, HashMap => JHashMap}
|
||||
import java.text.SimpleDateFormat
|
||||
|
||||
|
@ -10,6 +11,8 @@ import scala.collection.JavaConversions._
|
|||
|
||||
import org.apache.hadoop.conf.Configuration
|
||||
import org.apache.hadoop.fs.Path
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
import org.apache.hadoop.io.SequenceFile.CompressionType
|
||||
import org.apache.hadoop.mapred.FileOutputCommitter
|
||||
import org.apache.hadoop.mapred.FileOutputFormat
|
||||
import org.apache.hadoop.mapred.HadoopWriter
|
||||
|
@ -17,7 +20,7 @@ import org.apache.hadoop.mapred.JobConf
|
|||
import org.apache.hadoop.mapred.OutputFormat
|
||||
|
||||
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
|
||||
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext}
|
||||
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil}
|
||||
|
||||
import spark.partial.BoundedDouble
|
||||
import spark.partial.PartialResult
|
||||
|
@ -62,8 +65,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
throw new SparkException("Default partitioner cannot partition array keys.")
|
||||
}
|
||||
}
|
||||
val aggregator =
|
||||
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
|
||||
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
|
||||
if (self.partitioner == Some(partitioner)) {
|
||||
self.mapPartitions(aggregator.combineValuesByKey(_), true)
|
||||
} else if (mapSideCombine) {
|
||||
|
@ -95,7 +97,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
* list concatenation, 0 for addition, or 1 for multiplication.).
|
||||
*/
|
||||
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
|
||||
combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner)
|
||||
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
|
||||
val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
|
||||
val zeroArray = new Array[Byte](zeroBuffer.limit)
|
||||
zeroBuffer.get(zeroArray)
|
||||
|
||||
// When deserializing, use a lazy val to create just one instance of the serializer per task
|
||||
lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
|
||||
def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
|
||||
|
||||
combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -185,11 +196,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
|
||||
*/
|
||||
def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
|
||||
// groupByKey shouldn't use map side combine because map side combine does not
|
||||
// reduce the amount of data shuffled and requires all map side data be inserted
|
||||
// into a hash table, leading to more objects in the old gen.
|
||||
def createCombiner(v: V) = ArrayBuffer(v)
|
||||
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
|
||||
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
|
||||
val bufs = combineByKey[ArrayBuffer[V]](
|
||||
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
|
||||
createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
|
||||
bufs.asInstanceOf[RDD[(K, Seq[V])]]
|
||||
}
|
||||
|
||||
|
@ -515,6 +528,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
|
||||
}
|
||||
|
||||
/**
|
||||
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
|
||||
* supporting the key and value types K and V in this RDD. Compress the result with the
|
||||
* supplied codec.
|
||||
*/
|
||||
def saveAsHadoopFile[F <: OutputFormat[K, V]](
|
||||
path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
|
||||
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
|
||||
}
|
||||
|
||||
/**
|
||||
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
|
||||
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
|
||||
|
@ -574,6 +597,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
jobCommitter.cleanupJob(jobTaskContext)
|
||||
}
|
||||
|
||||
/**
|
||||
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
|
||||
* supporting the key and value types K and V in this RDD. Compress with the supplied codec.
|
||||
*/
|
||||
def saveAsHadoopFile(
|
||||
path: String,
|
||||
keyClass: Class[_],
|
||||
valueClass: Class[_],
|
||||
outputFormatClass: Class[_ <: OutputFormat[_, _]],
|
||||
codec: Class[_ <: CompressionCodec]) {
|
||||
saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass,
|
||||
new JobConf(self.context.hadoopConfiguration), Some(codec))
|
||||
}
|
||||
|
||||
/**
|
||||
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
|
||||
* supporting the key and value types K and V in this RDD.
|
||||
|
@ -583,11 +620,19 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
|
|||
keyClass: Class[_],
|
||||
valueClass: Class[_],
|
||||
outputFormatClass: Class[_ <: OutputFormat[_, _]],
|
||||
conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
|
||||
conf: JobConf = new JobConf(self.context.hadoopConfiguration),
|
||||
codec: Option[Class[_ <: CompressionCodec]] = None) {
|
||||
conf.setOutputKeyClass(keyClass)
|
||||
conf.setOutputValueClass(valueClass)
|
||||
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
|
||||
conf.set("mapred.output.format.class", outputFormatClass.getName)
|
||||
for (c <- codec) {
|
||||
conf.setCompressMapOutput(true)
|
||||
conf.set("mapred.output.compress", "true")
|
||||
conf.setMapOutputCompressorClass(c)
|
||||
conf.set("mapred.output.compression.codec", c.getCanonicalName)
|
||||
conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString)
|
||||
}
|
||||
conf.setOutputCommitter(classOf[FileOutputCommitter])
|
||||
FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf))
|
||||
saveAsHadoopDataset(conf)
|
||||
|
|
|
@ -7,12 +7,14 @@ import scala.collection.JavaConversions.mapAsScalaMap
|
|||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.hadoop.io.BytesWritable
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
import org.apache.hadoop.io.NullWritable
|
||||
import org.apache.hadoop.io.Text
|
||||
import org.apache.hadoop.mapred.TextOutputFormat
|
||||
|
||||
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
|
||||
|
||||
import spark.broadcast.Broadcast
|
||||
import spark.Partitioner._
|
||||
import spark.partial.BoundedDouble
|
||||
import spark.partial.CountEvaluator
|
||||
|
@ -35,6 +37,7 @@ import spark.rdd.ZippedPartitionsRDD2
|
|||
import spark.rdd.ZippedPartitionsRDD3
|
||||
import spark.rdd.ZippedPartitionsRDD4
|
||||
import spark.storage.StorageLevel
|
||||
import spark.util.BoundedPriorityQueue
|
||||
|
||||
import SparkContext._
|
||||
|
||||
|
@ -114,6 +117,14 @@ abstract class RDD[T: ClassManifest](
|
|||
this
|
||||
}
|
||||
|
||||
/** User-defined generator of this RDD*/
|
||||
var generator = Utils.getCallSiteInfo.firstUserClass
|
||||
|
||||
/** Reset generator*/
|
||||
def setGenerator(_generator: String) = {
|
||||
generator = _generator
|
||||
}
|
||||
|
||||
/**
|
||||
* Set this RDD's storage level to persist its values across operations after the first time
|
||||
* it is computed. This can only be used to assign a new storage level if the RDD does not
|
||||
|
@ -352,13 +363,36 @@ abstract class RDD[T: ClassManifest](
|
|||
/**
|
||||
* Return an RDD created by piping elements to a forked external process.
|
||||
*/
|
||||
def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
|
||||
def pipe(command: String, env: Map[String, String]): RDD[String] =
|
||||
new PipedRDD(this, command, env)
|
||||
|
||||
|
||||
/**
|
||||
* Return an RDD created by piping elements to a forked external process.
|
||||
* The print behavior can be customized by providing two functions.
|
||||
*
|
||||
* @param command command to run in forked process.
|
||||
* @param env environment variables to set.
|
||||
* @param printPipeContext Before piping elements, this function is called as an oppotunity
|
||||
* to pipe context data. Print line function (like out.println) will be
|
||||
* passed as printPipeContext's parameter.
|
||||
* @param printRDDElement Use this function to customize how to pipe elements. This function
|
||||
* will be called with each RDD element as the 1st parameter, and the
|
||||
* print line function (like out.println()) as the 2nd parameter.
|
||||
* An example of pipe the RDD data of groupBy() in a streaming way,
|
||||
* instead of constructing a huge String to concat all the elements:
|
||||
* def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
|
||||
* for (e <- record._2){f(e)}
|
||||
* @return the result RDD
|
||||
*/
|
||||
def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
|
||||
new PipedRDD(this, command, env)
|
||||
def pipe(
|
||||
command: Seq[String],
|
||||
env: Map[String, String] = Map(),
|
||||
printPipeContext: (String => Unit) => Unit = null,
|
||||
printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
|
||||
new PipedRDD(this, command, env,
|
||||
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
|
||||
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
|
||||
|
||||
/**
|
||||
* Return a new RDD by applying a function to each partition of this RDD.
|
||||
|
@ -722,6 +756,24 @@ abstract class RDD[T: ClassManifest](
|
|||
case _ => throw new UnsupportedOperationException("empty collection")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the top K elements from this RDD as defined by
|
||||
* the specified implicit Ordering[T].
|
||||
* @param num the number of top elements to return
|
||||
* @param ord the implicit ordering for T
|
||||
* @return an array of top elements
|
||||
*/
|
||||
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
|
||||
mapPartitions { items =>
|
||||
val queue = new BoundedPriorityQueue[T](num)
|
||||
queue ++= items
|
||||
Iterator.single(queue)
|
||||
}.reduce { (queue1, queue2) =>
|
||||
queue1 ++= queue2
|
||||
queue1
|
||||
}.toArray
|
||||
}
|
||||
|
||||
/**
|
||||
* Save this RDD as a text file, using string representations of elements.
|
||||
*/
|
||||
|
@ -730,6 +782,14 @@ abstract class RDD[T: ClassManifest](
|
|||
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save this RDD as a compressed text file, using string representations of elements.
|
||||
*/
|
||||
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
|
||||
this.map(x => (NullWritable.get(), new Text(x.toString)))
|
||||
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
|
||||
}
|
||||
|
||||
/**
|
||||
* Save this RDD as a SequenceFile of serialized objects.
|
||||
*/
|
||||
|
@ -788,7 +848,7 @@ abstract class RDD[T: ClassManifest](
|
|||
private var storageLevel: StorageLevel = StorageLevel.NONE
|
||||
|
||||
/** Record user function generating this RDD. */
|
||||
private[spark] val origin = Utils.getSparkCallSite
|
||||
private[spark] val origin = Utils.formatSparkCallSite
|
||||
|
||||
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
|
|||
import org.apache.hadoop.mapred.SequenceFileOutputFormat
|
||||
import org.apache.hadoop.mapred.OutputCommitter
|
||||
import org.apache.hadoop.mapred.FileOutputCommitter
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
import org.apache.hadoop.io.Writable
|
||||
import org.apache.hadoop.io.NullWritable
|
||||
import org.apache.hadoop.io.BytesWritable
|
||||
|
@ -62,7 +63,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
|
|||
* byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
|
||||
* file system.
|
||||
*/
|
||||
def saveAsSequenceFile(path: String) {
|
||||
def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
|
||||
def anyToWritable[U <% Writable](u: U): Writable = u
|
||||
|
||||
val keyClass = getWritableClass[K]
|
||||
|
@ -72,14 +73,18 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
|
|||
|
||||
logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" )
|
||||
val format = classOf[SequenceFileOutputFormat[Writable, Writable]]
|
||||
val jobConf = new JobConf(self.context.hadoopConfiguration)
|
||||
if (!convertKey && !convertValue) {
|
||||
self.saveAsHadoopFile(path, keyClass, valueClass, format)
|
||||
self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec)
|
||||
} else if (!convertKey && convertValue) {
|
||||
self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
|
||||
self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(
|
||||
path, keyClass, valueClass, format, jobConf, codec)
|
||||
} else if (convertKey && !convertValue) {
|
||||
self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format)
|
||||
self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(
|
||||
path, keyClass, valueClass, format, jobConf, codec)
|
||||
} else if (convertKey && convertValue) {
|
||||
self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
|
||||
self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(
|
||||
path, keyClass, valueClass, format, jobConf, codec)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,7 +49,6 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
|
|||
import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo}
|
||||
import spark.util.{MetadataCleaner, TimeStampedHashMap}
|
||||
|
||||
|
||||
/**
|
||||
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
|
||||
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
|
||||
|
@ -630,7 +629,7 @@ class SparkContext(
|
|||
partitions: Seq[Int],
|
||||
allowLocal: Boolean,
|
||||
resultHandler: (Int, U) => Unit) {
|
||||
val callSite = Utils.getSparkCallSite
|
||||
val callSite = Utils.formatSparkCallSite
|
||||
logInfo("Starting job: " + callSite)
|
||||
val start = System.nanoTime
|
||||
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value)
|
||||
|
@ -713,7 +712,7 @@ class SparkContext(
|
|||
func: (TaskContext, Iterator[T]) => U,
|
||||
evaluator: ApproximateEvaluator[U, R],
|
||||
timeout: Long): PartialResult[R] = {
|
||||
val callSite = Utils.getSparkCallSite
|
||||
val callSite = Utils.formatSparkCallSite
|
||||
logInfo("Starting job: " + callSite)
|
||||
val start = System.nanoTime
|
||||
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
package spark
|
||||
|
||||
import collection.mutable
|
||||
import serializer.Serializer
|
||||
|
||||
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
|
||||
import akka.remote.RemoteActorRefProvider
|
||||
|
||||
|
@ -9,6 +12,7 @@ import spark.storage.BlockManagerMaster
|
|||
import spark.network.ConnectionManager
|
||||
import spark.serializer.{Serializer, SerializerManager}
|
||||
import spark.util.AkkaUtils
|
||||
import spark.api.python.PythonWorkerFactory
|
||||
|
||||
|
||||
/**
|
||||
|
@ -37,7 +41,10 @@ class SparkEnv (
|
|||
// If executorId is NOT found, return defaultHostPort
|
||||
var executorIdToHostPort: Option[(String, String) => String]) {
|
||||
|
||||
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
|
||||
|
||||
def stop() {
|
||||
pythonWorkers.foreach { case(key, worker) => worker.stop() }
|
||||
httpFileServer.stop()
|
||||
mapOutputTracker.stop()
|
||||
shuffleFetcher.stop()
|
||||
|
@ -50,6 +57,11 @@ class SparkEnv (
|
|||
actorSystem.awaitTermination()
|
||||
}
|
||||
|
||||
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
|
||||
synchronized {
|
||||
pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create()
|
||||
}
|
||||
}
|
||||
|
||||
def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
|
||||
val env = SparkEnv.get
|
||||
|
|
|
@ -116,8 +116,8 @@ private object Utils extends Logging {
|
|||
while (dir == null) {
|
||||
attempts += 1
|
||||
if (attempts > maxAttempts) {
|
||||
throw new IOException("Failed to create a temp directory after " + maxAttempts +
|
||||
" attempts!")
|
||||
throw new IOException("Failed to create a temp directory (under " + root + ") after " +
|
||||
maxAttempts + " attempts!")
|
||||
}
|
||||
try {
|
||||
dir = new File(root, "spark-" + UUID.randomUUID.toString)
|
||||
|
@ -522,13 +522,45 @@ private object Utils extends Logging {
|
|||
execute(command, new File("."))
|
||||
}
|
||||
|
||||
/**
|
||||
* Execute a command and get its output, throwing an exception if it yields a code other than 0.
|
||||
*/
|
||||
def executeAndGetOutput(command: Seq[String], workingDir: File = new File(".")): String = {
|
||||
val process = new ProcessBuilder(command: _*)
|
||||
.directory(workingDir)
|
||||
.start()
|
||||
new Thread("read stderr for " + command(0)) {
|
||||
override def run() {
|
||||
for (line <- Source.fromInputStream(process.getErrorStream).getLines) {
|
||||
System.err.println(line)
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
val output = new StringBuffer
|
||||
val stdoutThread = new Thread("read stdout for " + command(0)) {
|
||||
override def run() {
|
||||
for (line <- Source.fromInputStream(process.getInputStream).getLines) {
|
||||
output.append(line)
|
||||
}
|
||||
}
|
||||
}
|
||||
stdoutThread.start()
|
||||
val exitCode = process.waitFor()
|
||||
stdoutThread.join() // Wait for it to finish reading output
|
||||
if (exitCode != 0) {
|
||||
throw new SparkException("Process " + command + " exited with code " + exitCode)
|
||||
}
|
||||
output.toString
|
||||
}
|
||||
|
||||
private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
|
||||
val firstUserLine: Int, val firstUserClass: String)
|
||||
/**
|
||||
* When called inside a class in the spark package, returns the name of the user code class
|
||||
* (outside the spark package) that called into Spark, as well as which Spark method they called.
|
||||
* This is used, for example, to tell users where in their code each RDD got created.
|
||||
*/
|
||||
def getSparkCallSite: String = {
|
||||
def getCallSiteInfo: CallSiteInfo = {
|
||||
val trace = Thread.currentThread.getStackTrace().filter( el =>
|
||||
(!el.getMethodName.contains("getStackTrace")))
|
||||
|
||||
|
@ -540,6 +572,7 @@ private object Utils extends Logging {
|
|||
var firstUserFile = "<unknown>"
|
||||
var firstUserLine = 0
|
||||
var finished = false
|
||||
var firstUserClass = "<unknown>"
|
||||
|
||||
for (el <- trace) {
|
||||
if (!finished) {
|
||||
|
@ -554,13 +587,19 @@ private object Utils extends Logging {
|
|||
else {
|
||||
firstUserLine = el.getLineNumber
|
||||
firstUserFile = el.getFileName
|
||||
firstUserClass = el.getClassName
|
||||
finished = true
|
||||
}
|
||||
}
|
||||
}
|
||||
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
|
||||
new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
|
||||
}
|
||||
|
||||
def formatSparkCallSite = {
|
||||
val callSiteInfo = getCallSiteInfo
|
||||
"%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
|
||||
callSiteInfo.firstUserLine)
|
||||
}
|
||||
/**
|
||||
* Try to find a free port to bind to on the local host. This should ideally never be needed,
|
||||
* except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
|
||||
|
@ -602,4 +641,67 @@ private object Utils extends Logging {
|
|||
}
|
||||
return false
|
||||
}
|
||||
|
||||
def isSpace(c: Char): Boolean = {
|
||||
" \t\r\n".indexOf(c) != -1
|
||||
}
|
||||
|
||||
/**
|
||||
* Split a string of potentially quoted arguments from the command line the way that a shell
|
||||
* would do it to determine arguments to a command. For example, if the string is 'a "b c" d',
|
||||
* then it would be parsed as three arguments: 'a', 'b c' and 'd'.
|
||||
*/
|
||||
def splitCommandString(s: String): Seq[String] = {
|
||||
val buf = new ArrayBuffer[String]
|
||||
var inWord = false
|
||||
var inSingleQuote = false
|
||||
var inDoubleQuote = false
|
||||
var curWord = new StringBuilder
|
||||
def endWord() {
|
||||
buf += curWord.toString
|
||||
curWord.clear()
|
||||
}
|
||||
var i = 0
|
||||
while (i < s.length) {
|
||||
var nextChar = s.charAt(i)
|
||||
if (inDoubleQuote) {
|
||||
if (nextChar == '"') {
|
||||
inDoubleQuote = false
|
||||
} else if (nextChar == '\\') {
|
||||
if (i < s.length - 1) {
|
||||
// Append the next character directly, because only " and \ may be escaped in
|
||||
// double quotes after the shell's own expansion
|
||||
curWord.append(s.charAt(i + 1))
|
||||
i += 1
|
||||
}
|
||||
} else {
|
||||
curWord.append(nextChar)
|
||||
}
|
||||
} else if (inSingleQuote) {
|
||||
if (nextChar == '\'') {
|
||||
inSingleQuote = false
|
||||
} else {
|
||||
curWord.append(nextChar)
|
||||
}
|
||||
// Backslashes are not treated specially in single quotes
|
||||
} else if (nextChar == '"') {
|
||||
inWord = true
|
||||
inDoubleQuote = true
|
||||
} else if (nextChar == '\'') {
|
||||
inWord = true
|
||||
inSingleQuote = true
|
||||
} else if (!isSpace(nextChar)) {
|
||||
curWord.append(nextChar)
|
||||
inWord = true
|
||||
} else if (inWord && isSpace(nextChar)) {
|
||||
endWord()
|
||||
inWord = false
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
if (inWord || inDoubleQuote || inSingleQuote) {
|
||||
endWord()
|
||||
}
|
||||
return buf
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import java.util.Comparator
|
|||
import scala.Tuple2
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
import org.apache.hadoop.mapred.JobConf
|
||||
import org.apache.hadoop.mapred.OutputFormat
|
||||
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
|
||||
|
@ -459,6 +460,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
|
|||
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
|
||||
}
|
||||
|
||||
/** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
|
||||
def saveAsHadoopFile[F <: OutputFormat[_, _]](
|
||||
path: String,
|
||||
keyClass: Class[_],
|
||||
valueClass: Class[_],
|
||||
outputFormatClass: Class[F],
|
||||
codec: Class[_ <: CompressionCodec]) {
|
||||
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
|
||||
}
|
||||
|
||||
/** Output the RDD to any Hadoop-supported file system. */
|
||||
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
|
||||
path: String,
|
||||
|
|
|
@ -86,7 +86,6 @@ JavaRDDLike[T, JavaRDD[T]] {
|
|||
*/
|
||||
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
|
||||
wrapRDD(rdd.subtract(other, p))
|
||||
|
||||
}
|
||||
|
||||
object JavaRDD {
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
package spark.api.java
|
||||
|
||||
import java.util.{List => JList}
|
||||
import java.util.{List => JList, Comparator}
|
||||
import scala.Tuple2
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import org.apache.hadoop.io.compress.CompressionCodec
|
||||
import spark.{SparkContext, Partition, RDD, TaskContext}
|
||||
import spark.api.java.JavaPairRDD._
|
||||
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
|
||||
|
@ -310,6 +311,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
|
|||
*/
|
||||
def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
|
||||
|
||||
|
||||
/**
|
||||
* Save this RDD as a compressed text file, using string representations of elements.
|
||||
*/
|
||||
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
|
||||
rdd.saveAsTextFile(path, codec)
|
||||
|
||||
/**
|
||||
* Save this RDD as a SequenceFile of serialized objects.
|
||||
*/
|
||||
|
@ -351,4 +359,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
|
|||
def toDebugString(): String = {
|
||||
rdd.toDebugString
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the top K elements from this RDD as defined by
|
||||
* the specified Comparator[T].
|
||||
* @param num the number of top elements to return
|
||||
* @param comp the comparator that defines the order
|
||||
* @return an array of top elements
|
||||
*/
|
||||
def top(num: Int, comp: Comparator[T]): JList[T] = {
|
||||
import scala.collection.JavaConversions._
|
||||
val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp))
|
||||
val arr: java.util.Collection[T] = topElems.toSeq
|
||||
new java.util.ArrayList(arr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the top K elements from this RDD using the
|
||||
* natural ordering for T.
|
||||
* @param num the number of top elements to return
|
||||
* @return an array of top elements
|
||||
*/
|
||||
def top(num: Int): JList[T] = {
|
||||
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
|
||||
top(num, comp)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,10 +2,9 @@ package spark.api.python
|
|||
|
||||
import java.io._
|
||||
import java.net._
|
||||
import java.util.{List => JList, ArrayList => JArrayList, Collections}
|
||||
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
import scala.io.Source
|
||||
|
||||
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
|
||||
import spark.broadcast.Broadcast
|
||||
|
@ -16,7 +15,7 @@ import spark.rdd.PipedRDD
|
|||
private[spark] class PythonRDD[T: ClassManifest](
|
||||
parent: RDD[T],
|
||||
command: Seq[String],
|
||||
envVars: java.util.Map[String, String],
|
||||
envVars: JMap[String, String],
|
||||
preservePartitoning: Boolean,
|
||||
pythonExec: String,
|
||||
broadcastVars: JList[Broadcast[Array[Byte]]],
|
||||
|
@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
|
||||
// Similar to Runtime.exec(), if we are given a single string, split it into words
|
||||
// using a standard StringTokenizer (i.e. by spaces)
|
||||
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
|
||||
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
|
||||
preservePartitoning: Boolean, pythonExec: String,
|
||||
broadcastVars: JList[Broadcast[Array[Byte]]],
|
||||
accumulator: Accumulator[JList[Array[Byte]]]) =
|
||||
|
@ -36,35 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
|
||||
override val partitioner = if (preservePartitoning) parent.partitioner else None
|
||||
|
||||
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
|
||||
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
|
||||
|
||||
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
|
||||
// Add the environmental variables to the process.
|
||||
val currentEnvVars = pb.environment()
|
||||
|
||||
for ((variable, value) <- envVars) {
|
||||
currentEnvVars.put(variable, value)
|
||||
}
|
||||
|
||||
val proc = pb.start()
|
||||
val startTime = System.currentTimeMillis
|
||||
val env = SparkEnv.get
|
||||
|
||||
// Start a thread to print the process's stderr to ours
|
||||
new Thread("stderr reader for " + pythonExec) {
|
||||
override def run() {
|
||||
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
|
||||
System.err.println(line)
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
val worker = env.createPythonWorker(pythonExec, envVars.toMap)
|
||||
|
||||
// Start a thread to feed the process input from our parent's iterator
|
||||
new Thread("stdin writer for " + pythonExec) {
|
||||
override def run() {
|
||||
SparkEnv.set(env)
|
||||
val out = new PrintWriter(proc.getOutputStream)
|
||||
val dOut = new DataOutputStream(proc.getOutputStream)
|
||||
val out = new PrintWriter(worker.getOutputStream)
|
||||
val dOut = new DataOutputStream(worker.getOutputStream)
|
||||
// Partition index
|
||||
dOut.writeInt(split.index)
|
||||
// sparkFilesDir
|
||||
|
@ -88,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
}
|
||||
dOut.flush()
|
||||
out.flush()
|
||||
proc.getOutputStream.close()
|
||||
worker.shutdownOutput()
|
||||
}
|
||||
}.start()
|
||||
|
||||
// Return an iterator that read lines from the process's stdout
|
||||
val stream = new DataInputStream(proc.getInputStream)
|
||||
val stream = new DataInputStream(worker.getInputStream)
|
||||
return new Iterator[Array[Byte]] {
|
||||
def next(): Array[Byte] = {
|
||||
val obj = _nextObj
|
||||
_nextObj = read()
|
||||
if (hasNext) {
|
||||
// FIXME: can deadlock if worker is waiting for us to
|
||||
// respond to current message (currently irrelevant because
|
||||
// output is shutdown before we read any input)
|
||||
_nextObj = read()
|
||||
}
|
||||
obj
|
||||
}
|
||||
|
||||
|
@ -108,6 +95,17 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
val obj = new Array[Byte](length)
|
||||
stream.readFully(obj)
|
||||
obj
|
||||
case -3 =>
|
||||
// Timing data from worker
|
||||
val bootTime = stream.readLong()
|
||||
val initTime = stream.readLong()
|
||||
val finishTime = stream.readLong()
|
||||
val boot = bootTime - startTime
|
||||
val init = initTime - bootTime
|
||||
val finish = finishTime - initTime
|
||||
val total = finishTime - startTime
|
||||
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
|
||||
read
|
||||
case -2 =>
|
||||
// Signals that an exception has been thrown in python
|
||||
val exLength = stream.readInt()
|
||||
|
@ -115,23 +113,21 @@ private[spark] class PythonRDD[T: ClassManifest](
|
|||
stream.readFully(obj)
|
||||
throw new PythonException(new String(obj))
|
||||
case -1 =>
|
||||
// We've finished the data section of the output, but we can still read some
|
||||
// accumulator updates; let's do that, breaking when we get EOFException
|
||||
while (true) {
|
||||
val len2 = stream.readInt()
|
||||
// We've finished the data section of the output, but we can still
|
||||
// read some accumulator updates; let's do that, breaking when we
|
||||
// get a negative length record.
|
||||
var len2 = stream.readInt()
|
||||
while (len2 >= 0) {
|
||||
val update = new Array[Byte](len2)
|
||||
stream.readFully(update)
|
||||
accumulator += Collections.singletonList(update)
|
||||
len2 = stream.readInt()
|
||||
}
|
||||
new Array[Byte](0)
|
||||
}
|
||||
} catch {
|
||||
case eof: EOFException => {
|
||||
val exitStatus = proc.waitFor()
|
||||
if (exitStatus != 0) {
|
||||
throw new Exception("Subprocess exited with status " + exitStatus)
|
||||
}
|
||||
new Array[Byte](0)
|
||||
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
|
||||
}
|
||||
case e => throw e
|
||||
}
|
||||
|
@ -159,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
|
|||
override def compute(split: Partition, context: TaskContext) =
|
||||
prev.iterator(split, context).grouped(2).map {
|
||||
case Seq(a, b) => (a, b)
|
||||
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
|
||||
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
|
||||
}
|
||||
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
|
||||
}
|
||||
|
@ -215,7 +211,7 @@ private[spark] object PythonRDD {
|
|||
dOut.write(s)
|
||||
dOut.writeByte(Pickle.STOP)
|
||||
} else {
|
||||
throw new Exception("Unexpected RDD type")
|
||||
throw new SparkException("Unexpected RDD type")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,95 @@
|
|||
package spark.api.python
|
||||
|
||||
import java.io.{DataInputStream, IOException}
|
||||
import java.net.{Socket, SocketException, InetAddress}
|
||||
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
import spark._
|
||||
|
||||
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
|
||||
extends Logging {
|
||||
var daemon: Process = null
|
||||
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
|
||||
var daemonPort: Int = 0
|
||||
|
||||
def create(): Socket = {
|
||||
synchronized {
|
||||
// Start the daemon if it hasn't been started
|
||||
startDaemon()
|
||||
|
||||
// Attempt to connect, restart and retry once if it fails
|
||||
try {
|
||||
new Socket(daemonHost, daemonPort)
|
||||
} catch {
|
||||
case exc: SocketException => {
|
||||
logWarning("Python daemon unexpectedly quit, attempting to restart")
|
||||
stopDaemon()
|
||||
startDaemon()
|
||||
new Socket(daemonHost, daemonPort)
|
||||
}
|
||||
case e => throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def stop() {
|
||||
stopDaemon()
|
||||
}
|
||||
|
||||
private def startDaemon() {
|
||||
synchronized {
|
||||
// Is it already running?
|
||||
if (daemon != null) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
// Create and start the daemon
|
||||
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
|
||||
val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
|
||||
val workerEnv = pb.environment()
|
||||
workerEnv.putAll(envVars)
|
||||
daemon = pb.start()
|
||||
daemonPort = new DataInputStream(daemon.getInputStream).readInt()
|
||||
|
||||
// Redirect the stderr to ours
|
||||
new Thread("stderr reader for " + pythonExec) {
|
||||
override def run() {
|
||||
scala.util.control.Exception.ignoring(classOf[IOException]) {
|
||||
// FIXME HACK: We copy the stream on the level of bytes to
|
||||
// attempt to dodge encoding problems.
|
||||
val in = daemon.getErrorStream
|
||||
var buf = new Array[Byte](1024)
|
||||
var len = in.read(buf)
|
||||
while (len != -1) {
|
||||
System.err.write(buf, 0, len)
|
||||
len = in.read(buf)
|
||||
}
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
} catch {
|
||||
case e => {
|
||||
stopDaemon()
|
||||
throw e
|
||||
}
|
||||
}
|
||||
|
||||
// Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
|
||||
// detect our disappearance.
|
||||
}
|
||||
}
|
||||
|
||||
private def stopDaemon() {
|
||||
synchronized {
|
||||
// Request shutdown of existing daemon by sending SIGTERM
|
||||
if (daemon != null) {
|
||||
daemon.destroy()
|
||||
}
|
||||
|
||||
daemon = null
|
||||
daemonPort = 0
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
package spark.deploy.worker
|
||||
|
||||
import java.io._
|
||||
import java.lang.System.getenv
|
||||
import spark.deploy.{ExecutorState, ExecutorStateChanged, ApplicationDescription}
|
||||
import akka.actor.ActorRef
|
||||
import spark.{Utils, Logging}
|
||||
|
@ -40,7 +41,7 @@ private[spark] class ExecutorRunner(
|
|||
workerThread.start()
|
||||
|
||||
// Shutdown hook that kills actors on shutdown.
|
||||
shutdownHook = new Thread() {
|
||||
shutdownHook = new Thread() {
|
||||
override def run() {
|
||||
if (process != null) {
|
||||
logInfo("Shutdown hook killing child process.")
|
||||
|
@ -77,9 +78,29 @@ private[spark] class ExecutorRunner(
|
|||
|
||||
def buildCommandSeq(): Seq[String] = {
|
||||
val command = appDesc.command
|
||||
val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run"
|
||||
val runScript = new File(sparkHome, script).getCanonicalPath
|
||||
Seq(runScript, command.mainClass) ++ (command.arguments ++ Seq(appId)).map(substituteVariables)
|
||||
val runner = Option(getenv("JAVA_HOME")).map(_ + "/bin/java").getOrElse("java")
|
||||
// SPARK-698: do not call the run.cmd script, as process.destroy()
|
||||
// fails to kill a process tree on Windows
|
||||
Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
|
||||
command.arguments.map(substituteVariables)
|
||||
}
|
||||
|
||||
/**
|
||||
* Attention: this must always be aligned with the environment variables in the run scripts and
|
||||
* the way the JAVA_OPTS are assembled there.
|
||||
*/
|
||||
def buildJavaOpts(): Seq[String] = {
|
||||
val libraryOpts = Option(getenv("SPARK_LIBRARY_PATH"))
|
||||
.map(p => List("-Djava.library.path=" + p))
|
||||
.getOrElse(Nil)
|
||||
val userOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil)
|
||||
val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M")
|
||||
|
||||
// Figure out our classpath with the external compute-classpath script
|
||||
val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
|
||||
val classPath = Utils.executeAndGetOutput(Seq(sparkHome + "/bin/compute-classpath" + ext))
|
||||
|
||||
Seq("-cp", classPath) ++ libraryOpts ++ userOpts ++ memoryOpts
|
||||
}
|
||||
|
||||
/** Spawn a thread that will redirect a given stream to a file */
|
||||
|
@ -115,7 +136,6 @@ private[spark] class ExecutorRunner(
|
|||
for ((key, value) <- appDesc.command.environment) {
|
||||
env.put(key, value)
|
||||
}
|
||||
env.put("SPARK_MEM", memory.toString + "m")
|
||||
// In case we are running this from within the Spark Shell, avoid creating a "scala"
|
||||
// parent process for the executor command
|
||||
env.put("SPARK_LAUNCH_WITH_SCALA", "0")
|
||||
|
|
|
@ -42,7 +42,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
|
|||
|
||||
// Create our ClassLoader and set it on this thread
|
||||
private val urlClassLoader = createClassLoader()
|
||||
Thread.currentThread.setContextClassLoader(urlClassLoader)
|
||||
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
|
||||
Thread.currentThread.setContextClassLoader(replClassLoader)
|
||||
|
||||
// Make any thread terminations due to uncaught exceptions kill the entire
|
||||
// executor process to avoid surprising stalls.
|
||||
|
@ -88,7 +89,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
|
|||
override def run() {
|
||||
val startTime = System.currentTimeMillis()
|
||||
SparkEnv.set(env)
|
||||
Thread.currentThread.setContextClassLoader(urlClassLoader)
|
||||
Thread.currentThread.setContextClassLoader(replClassLoader)
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
logInfo("Running task ID " + taskId)
|
||||
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
|
||||
|
@ -104,6 +105,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
|
|||
val value = task.run(taskId.toInt)
|
||||
val taskFinish = System.currentTimeMillis()
|
||||
task.metrics.foreach{ m =>
|
||||
m.hostname = Utils.localHostName
|
||||
m.executorDeserializeTime = (taskStart - startTime).toInt
|
||||
m.executorRunTime = (taskFinish - taskStart).toInt
|
||||
}
|
||||
|
@ -152,26 +154,31 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
|
|||
val urls = currentJars.keySet.map { uri =>
|
||||
new File(uri.split("/").last).toURI.toURL
|
||||
}.toArray
|
||||
loader = new URLClassLoader(urls, loader)
|
||||
new ExecutorURLClassLoader(urls, loader)
|
||||
}
|
||||
|
||||
// If the REPL is in use, add another ClassLoader that will read
|
||||
// new classes defined by the REPL as the user types code
|
||||
/**
|
||||
* If the REPL is in use, add another ClassLoader that will read
|
||||
* new classes defined by the REPL as the user types code
|
||||
*/
|
||||
private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
|
||||
val classUri = System.getProperty("spark.repl.class.uri")
|
||||
if (classUri != null) {
|
||||
logInfo("Using REPL class URI: " + classUri)
|
||||
loader = {
|
||||
try {
|
||||
val klass = Class.forName("spark.repl.ExecutorClassLoader")
|
||||
.asInstanceOf[Class[_ <: ClassLoader]]
|
||||
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
|
||||
constructor.newInstance(classUri, loader)
|
||||
} catch {
|
||||
case _: ClassNotFoundException => loader
|
||||
}
|
||||
try {
|
||||
val klass = Class.forName("spark.repl.ExecutorClassLoader")
|
||||
.asInstanceOf[Class[_ <: ClassLoader]]
|
||||
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
|
||||
return constructor.newInstance(classUri, parent)
|
||||
} catch {
|
||||
case _: ClassNotFoundException =>
|
||||
logError("Could not find spark.repl.ExecutorClassLoader on classpath!")
|
||||
System.exit(1)
|
||||
null
|
||||
}
|
||||
} else {
|
||||
return parent
|
||||
}
|
||||
|
||||
return new ExecutorURLClassLoader(Array(), loader)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -1,6 +1,11 @@
|
|||
package spark.executor
|
||||
|
||||
class TaskMetrics extends Serializable {
|
||||
/**
|
||||
* Host's name the task runs on
|
||||
*/
|
||||
var hostname: String = _
|
||||
|
||||
/**
|
||||
* Time taken on the executor to deserialize this task
|
||||
*/
|
||||
|
@ -33,10 +38,15 @@ object TaskMetrics {
|
|||
|
||||
|
||||
class ShuffleReadMetrics extends Serializable {
|
||||
/**
|
||||
* Time when shuffle finishs
|
||||
*/
|
||||
var shuffleFinishTime: Long = _
|
||||
|
||||
/**
|
||||
* Total number of blocks fetched in a shuffle (remote or local)
|
||||
*/
|
||||
var totalBlocksFetched : Int = _
|
||||
var totalBlocksFetched: Int = _
|
||||
|
||||
/**
|
||||
* Number of remote blocks fetched in a shuffle
|
||||
|
|
|
@ -9,19 +9,36 @@ import io.netty.util.CharsetUtil
|
|||
import spark.Logging
|
||||
import spark.network.ConnectionManagerId
|
||||
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
|
||||
private[spark] class ShuffleCopier extends Logging {
|
||||
|
||||
def getBlock(cmId: ConnectionManagerId, blockId: String,
|
||||
def getBlock(host: String, port: Int, blockId: String,
|
||||
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
|
||||
|
||||
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
|
||||
val fc = new FileClient(handler)
|
||||
fc.init()
|
||||
fc.connect(cmId.host, cmId.port)
|
||||
fc.sendRequest(blockId)
|
||||
fc.waitForClose()
|
||||
fc.close()
|
||||
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
|
||||
val fc = new FileClient(handler, connectTimeout)
|
||||
|
||||
try {
|
||||
fc.init()
|
||||
fc.connect(host, port)
|
||||
fc.sendRequest(blockId)
|
||||
fc.waitForClose()
|
||||
fc.close()
|
||||
} catch {
|
||||
// Handle any socket-related exceptions in FileClient
|
||||
case e: Exception => {
|
||||
logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
|
||||
handler.handleError(blockId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def getBlock(cmId: ConnectionManagerId, blockId: String,
|
||||
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
|
||||
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
|
||||
}
|
||||
|
||||
def getBlocks(cmId: ConnectionManagerId,
|
||||
|
@ -44,20 +61,18 @@ private[spark] object ShuffleCopier extends Logging {
|
|||
logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
|
||||
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
|
||||
}
|
||||
|
||||
override def handleError(blockId: String) {
|
||||
if (!isComplete) {
|
||||
resultCollectCallBack(blockId, -1, null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
|
||||
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
|
||||
}
|
||||
|
||||
def runGetBlock(host:String, port:Int, file:String){
|
||||
val handler = new ShuffleClientHandler(echoResultCollectCallBack)
|
||||
val fc = new FileClient(handler)
|
||||
fc.init();
|
||||
fc.connect(host, port)
|
||||
fc.sendRequest(file)
|
||||
fc.waitForClose();
|
||||
fc.close()
|
||||
if (size != -1) {
|
||||
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
|
||||
}
|
||||
}
|
||||
|
||||
def main(args: Array[String]) {
|
||||
|
@ -71,14 +86,16 @@ private[spark] object ShuffleCopier extends Logging {
|
|||
val threads = if (args.length > 3) args(3).toInt else 10
|
||||
|
||||
val copiers = Executors.newFixedThreadPool(80)
|
||||
for (i <- Range(0, threads)) {
|
||||
val runnable = new Runnable() {
|
||||
val tasks = (for (i <- Range(0, threads)) yield {
|
||||
Executors.callable(new Runnable() {
|
||||
def run() {
|
||||
runGetBlock(host, port, file)
|
||||
val copier = new ShuffleCopier()
|
||||
copier.getBlock(host, port, file, echoResultCollectCallBack)
|
||||
}
|
||||
}
|
||||
copiers.execute(runnable)
|
||||
}
|
||||
})
|
||||
}).asJava
|
||||
copiers.invokeAll(tasks)
|
||||
copiers.shutdown
|
||||
System.exit(0)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ import java.util.{HashMap => JHashMap}
|
|||
import scala.collection.JavaConversions
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
|
||||
import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext}
|
||||
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
|
||||
|
||||
|
||||
|
@ -49,12 +49,16 @@ private[spark] class CoGroupAggregator
|
|||
*
|
||||
* @param rdds parent RDDs.
|
||||
* @param part partitioner used to partition the shuffle output.
|
||||
* @param mapSideCombine flag indicating whether to merge values before shuffle step.
|
||||
* @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag
|
||||
* is on, Spark does an extra pass over the data on the map side to merge
|
||||
* all values belonging to the same key together. This can reduce the amount
|
||||
* of data shuffled if and only if the number of distinct keys is very small,
|
||||
* and the ratio of key size to value size is also very small.
|
||||
*/
|
||||
class CoGroupedRDD[K](
|
||||
@transient var rdds: Seq[RDD[(K, _)]],
|
||||
part: Partitioner,
|
||||
val mapSideCombine: Boolean = true,
|
||||
val mapSideCombine: Boolean = false,
|
||||
val serializerClass: String = null)
|
||||
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import scala.collection.mutable.ArrayBuffer
|
|||
import scala.io.Source
|
||||
|
||||
import spark.{RDD, SparkEnv, Partition, TaskContext}
|
||||
import spark.broadcast.Broadcast
|
||||
|
||||
|
||||
/**
|
||||
|
@ -18,14 +19,21 @@ import spark.{RDD, SparkEnv, Partition, TaskContext}
|
|||
class PipedRDD[T: ClassManifest](
|
||||
prev: RDD[T],
|
||||
command: Seq[String],
|
||||
envVars: Map[String, String])
|
||||
envVars: Map[String, String],
|
||||
printPipeContext: (String => Unit) => Unit,
|
||||
printRDDElement: (T, String => Unit) => Unit)
|
||||
extends RDD[String](prev) {
|
||||
|
||||
def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
|
||||
|
||||
// Similar to Runtime.exec(), if we are given a single string, split it into words
|
||||
// using a standard StringTokenizer (i.e. by spaces)
|
||||
def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
|
||||
def this(
|
||||
prev: RDD[T],
|
||||
command: String,
|
||||
envVars: Map[String, String] = Map(),
|
||||
printPipeContext: (String => Unit) => Unit = null,
|
||||
printRDDElement: (T, String => Unit) => Unit = null) =
|
||||
this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement)
|
||||
|
||||
|
||||
override def getPartitions: Array[Partition] = firstParent[T].partitions
|
||||
|
||||
|
@ -52,8 +60,17 @@ class PipedRDD[T: ClassManifest](
|
|||
override def run() {
|
||||
SparkEnv.set(env)
|
||||
val out = new PrintWriter(proc.getOutputStream)
|
||||
|
||||
// input the pipe context firstly
|
||||
if (printPipeContext != null) {
|
||||
printPipeContext(out.println(_))
|
||||
}
|
||||
for (elem <- firstParent[T].iterator(split, context)) {
|
||||
out.println(elem)
|
||||
if (printRDDElement != null) {
|
||||
printRDDElement(elem, out.println(_))
|
||||
} else {
|
||||
out.println(elem)
|
||||
}
|
||||
}
|
||||
out.close()
|
||||
}
|
||||
|
|
|
@ -53,14 +53,10 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
|
|||
val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y))
|
||||
|
||||
// Remove exact match and then do host local match.
|
||||
val otherNodePreferredLocations = rddSplitZip.map(x => {
|
||||
x._1.preferredLocations(x._2).map(hostPort => {
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
|
||||
if (exactMatchLocations.contains(host)) null else host
|
||||
}).filter(_ != null)
|
||||
})
|
||||
val otherNodeLocalLocations = otherNodePreferredLocations.reduce((x, y) => x.intersect(y))
|
||||
val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1)
|
||||
val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1))
|
||||
.reduce((x, y) => x.intersect(y))
|
||||
val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) }
|
||||
|
||||
otherNodeLocalLocations ++ exactMatchLocations
|
||||
}
|
||||
|
|
|
@ -298,6 +298,7 @@ class DAGScheduler(
|
|||
// Compute very short actions like first() or take() with no parent stages locally.
|
||||
runLocally(job)
|
||||
} else {
|
||||
sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties)))
|
||||
idToActiveJob(runId) = job
|
||||
activeJobs += job
|
||||
resultStageToJob(finalStage) = job
|
||||
|
@ -311,6 +312,8 @@ class DAGScheduler(
|
|||
handleExecutorLost(execId)
|
||||
|
||||
case completion: CompletionEvent =>
|
||||
sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion.task,
|
||||
completion.reason, completion.taskInfo, completion.taskMetrics)))
|
||||
handleTaskCompletion(completion)
|
||||
|
||||
case TaskSetFailed(taskSet, reason) =>
|
||||
|
@ -321,6 +324,7 @@ class DAGScheduler(
|
|||
for (job <- activeJobs) {
|
||||
val error = new SparkException("Job cancelled because SparkContext was shut down")
|
||||
job.listener.jobFailed(error)
|
||||
sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
@ -468,6 +472,7 @@ class DAGScheduler(
|
|||
}
|
||||
}
|
||||
if (tasks.size > 0) {
|
||||
sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size)))
|
||||
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
|
||||
myPending ++= tasks
|
||||
logDebug("New pending tasks: " + myPending)
|
||||
|
@ -522,6 +527,7 @@ class DAGScheduler(
|
|||
activeJobs -= job
|
||||
resultStageToJob -= stage
|
||||
markStageAsFinished(stage)
|
||||
sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobSucceeded)))
|
||||
}
|
||||
job.listener.taskSucceeded(rt.outputId, event.result)
|
||||
}
|
||||
|
@ -662,7 +668,9 @@ class DAGScheduler(
|
|||
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
|
||||
for (resultStage <- dependentStages) {
|
||||
val job = resultStageToJob(resultStage)
|
||||
job.listener.jobFailed(new SparkException("Job failed: " + reason))
|
||||
val error = new SparkException("Job failed: " + reason)
|
||||
job.listener.jobFailed(error)
|
||||
sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
|
||||
activeJobs -= job
|
||||
resultStageToJob -= resultStage
|
||||
}
|
||||
|
|
306
core/src/main/scala/spark/scheduler/JobLogger.scala
Normal file
306
core/src/main/scala/spark/scheduler/JobLogger.scala
Normal file
|
@ -0,0 +1,306 @@
|
|||
package spark.scheduler
|
||||
|
||||
import java.io.PrintWriter
|
||||
import java.io.File
|
||||
import java.io.FileNotFoundException
|
||||
import java.text.SimpleDateFormat
|
||||
import java.util.{Date, Properties}
|
||||
import java.util.concurrent.LinkedBlockingQueue
|
||||
import scala.collection.mutable.{Map, HashMap, ListBuffer}
|
||||
import scala.io.Source
|
||||
import spark._
|
||||
import spark.executor.TaskMetrics
|
||||
import spark.scheduler.cluster.TaskInfo
|
||||
|
||||
// Used to record runtime information for each job, including RDD graph
|
||||
// tasks' start/stop shuffle information and information from outside
|
||||
|
||||
class JobLogger(val logDirName: String) extends SparkListener with Logging {
|
||||
private val logDir =
|
||||
if (System.getenv("SPARK_LOG_DIR") != null)
|
||||
System.getenv("SPARK_LOG_DIR")
|
||||
else
|
||||
"/tmp/spark"
|
||||
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
|
||||
private val stageIDToJobID = new HashMap[Int, Int]
|
||||
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
|
||||
private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
|
||||
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
|
||||
|
||||
createLogDir()
|
||||
def this() = this(String.valueOf(System.currentTimeMillis()))
|
||||
|
||||
def getLogDir = logDir
|
||||
def getJobIDtoPrintWriter = jobIDToPrintWriter
|
||||
def getStageIDToJobID = stageIDToJobID
|
||||
def getJobIDToStages = jobIDToStages
|
||||
def getEventQueue = eventQueue
|
||||
|
||||
new Thread("JobLogger") {
|
||||
setDaemon(true)
|
||||
override def run() {
|
||||
while (true) {
|
||||
val event = eventQueue.take
|
||||
logDebug("Got event of type " + event.getClass.getName)
|
||||
event match {
|
||||
case SparkListenerJobStart(job, properties) =>
|
||||
processJobStartEvent(job, properties)
|
||||
case SparkListenerStageSubmitted(stage, taskSize) =>
|
||||
processStageSubmittedEvent(stage, taskSize)
|
||||
case StageCompleted(stageInfo) =>
|
||||
processStageCompletedEvent(stageInfo)
|
||||
case SparkListenerJobEnd(job, result) =>
|
||||
processJobEndEvent(job, result)
|
||||
case SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics) =>
|
||||
processTaskEndEvent(task, reason, taskInfo, taskMetrics)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
}
|
||||
}.start()
|
||||
|
||||
// Create a folder for log files, the folder's name is the creation time of the jobLogger
|
||||
protected def createLogDir() {
|
||||
val dir = new File(logDir + "/" + logDirName + "/")
|
||||
if (dir.exists()) {
|
||||
return
|
||||
}
|
||||
if (dir.mkdirs() == false) {
|
||||
logError("create log directory error:" + logDir + "/" + logDirName + "/")
|
||||
}
|
||||
}
|
||||
|
||||
// Create a log file for one job, the file name is the jobID
|
||||
protected def createLogWriter(jobID: Int) {
|
||||
try{
|
||||
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
|
||||
jobIDToPrintWriter += (jobID -> fileWriter)
|
||||
} catch {
|
||||
case e: FileNotFoundException => e.printStackTrace()
|
||||
}
|
||||
}
|
||||
|
||||
// Close log file, and clean the stage relationship in stageIDToJobID
|
||||
protected def closeLogWriter(jobID: Int) =
|
||||
jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
|
||||
fileWriter.close()
|
||||
jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
|
||||
stageIDToJobID -= stage.id
|
||||
})
|
||||
jobIDToPrintWriter -= jobID
|
||||
jobIDToStages -= jobID
|
||||
}
|
||||
|
||||
// Write log information to log file, withTime parameter controls whether to recored
|
||||
// time stamp for the information
|
||||
protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
|
||||
var writeInfo = info
|
||||
if (withTime) {
|
||||
val date = new Date(System.currentTimeMillis())
|
||||
writeInfo = DATE_FORMAT.format(date) + ": " +info
|
||||
}
|
||||
jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
|
||||
}
|
||||
|
||||
protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
|
||||
stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
|
||||
|
||||
protected def buildJobDep(jobID: Int, stage: Stage) {
|
||||
if (stage.priority == jobID) {
|
||||
jobIDToStages.get(jobID) match {
|
||||
case Some(stageList) => stageList += stage
|
||||
case None => val stageList = new ListBuffer[Stage]
|
||||
stageList += stage
|
||||
jobIDToStages += (jobID -> stageList)
|
||||
}
|
||||
stageIDToJobID += (stage.id -> jobID)
|
||||
stage.parents.foreach(buildJobDep(jobID, _))
|
||||
}
|
||||
}
|
||||
|
||||
protected def recordStageDep(jobID: Int) {
|
||||
def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
|
||||
var rddList = new ListBuffer[RDD[_]]
|
||||
rddList += rdd
|
||||
rdd.dependencies.foreach{ dep => dep match {
|
||||
case shufDep: ShuffleDependency[_,_] =>
|
||||
case _ => rddList ++= getRddsInStage(dep.rdd)
|
||||
}
|
||||
}
|
||||
rddList
|
||||
}
|
||||
jobIDToStages.get(jobID).foreach {_.foreach { stage =>
|
||||
var depRddDesc: String = ""
|
||||
getRddsInStage(stage.rdd).foreach { rdd =>
|
||||
depRddDesc += rdd.id + ","
|
||||
}
|
||||
var depStageDesc: String = ""
|
||||
stage.parents.foreach { stage =>
|
||||
depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
|
||||
}
|
||||
jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
|
||||
depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
|
||||
" STAGE_DEP=" + depStageDesc, false)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate indents and convert to String
|
||||
protected def indentString(indent: Int) = {
|
||||
val sb = new StringBuilder()
|
||||
for (i <- 1 to indent) {
|
||||
sb.append(" ")
|
||||
}
|
||||
sb.toString()
|
||||
}
|
||||
|
||||
protected def getRddName(rdd: RDD[_]) = {
|
||||
var rddName = rdd.getClass.getName
|
||||
if (rdd.name != null) {
|
||||
rddName = rdd.name
|
||||
}
|
||||
rddName
|
||||
}
|
||||
|
||||
protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
|
||||
val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
|
||||
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
|
||||
rdd.dependencies.foreach{ dep => dep match {
|
||||
case shufDep: ShuffleDependency[_,_] =>
|
||||
val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
|
||||
jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
|
||||
case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
|
||||
var stageInfo: String = ""
|
||||
if (stage.isShuffleMap) {
|
||||
stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
|
||||
stage.shuffleDep.get.shuffleId
|
||||
}else{
|
||||
stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
|
||||
}
|
||||
if (stage.priority == jobID) {
|
||||
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
|
||||
recordRddInStageGraph(jobID, stage.rdd, indent)
|
||||
stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
|
||||
} else
|
||||
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false)
|
||||
}
|
||||
|
||||
// Record task metrics into job log files
|
||||
protected def recordTaskMetrics(stageID: Int, status: String,
|
||||
taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
|
||||
val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
|
||||
" START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
|
||||
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
|
||||
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
|
||||
val readMetrics =
|
||||
taskMetrics.shuffleReadMetrics match {
|
||||
case Some(metrics) =>
|
||||
" SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
|
||||
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
|
||||
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
|
||||
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
|
||||
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
|
||||
" REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
|
||||
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
|
||||
case None => ""
|
||||
}
|
||||
val writeMetrics =
|
||||
taskMetrics.shuffleWriteMetrics match {
|
||||
case Some(metrics) =>
|
||||
" SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
|
||||
case None => ""
|
||||
}
|
||||
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
|
||||
}
|
||||
|
||||
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
|
||||
eventQueue.put(stageSubmitted)
|
||||
}
|
||||
|
||||
protected def processStageSubmittedEvent(stage: Stage, taskSize: Int) {
|
||||
stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED" + " TASK_SIZE=" + taskSize)
|
||||
}
|
||||
|
||||
override def onStageCompleted(stageCompleted: StageCompleted) {
|
||||
eventQueue.put(stageCompleted)
|
||||
}
|
||||
|
||||
protected def processStageCompletedEvent(stageInfo: StageInfo) {
|
||||
stageLogInfo(stageInfo.stage.id, "STAGE_ID=" +
|
||||
stageInfo.stage.id + " STATUS=COMPLETED")
|
||||
|
||||
}
|
||||
|
||||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
|
||||
eventQueue.put(taskEnd)
|
||||
}
|
||||
|
||||
protected def processTaskEndEvent(task: Task[_], reason: TaskEndReason,
|
||||
taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
|
||||
var taskStatus = ""
|
||||
task match {
|
||||
case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
|
||||
case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
|
||||
}
|
||||
reason match {
|
||||
case Success => taskStatus += " STATUS=SUCCESS"
|
||||
recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskMetrics)
|
||||
case Resubmitted =>
|
||||
taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
|
||||
" STAGE_ID=" + task.stageId
|
||||
stageLogInfo(task.stageId, taskStatus)
|
||||
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
|
||||
taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
|
||||
task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
|
||||
mapId + " REDUCE_ID=" + reduceId
|
||||
stageLogInfo(task.stageId, taskStatus)
|
||||
case OtherFailure(message) =>
|
||||
taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
|
||||
" STAGE_ID=" + task.stageId + " INFO=" + message
|
||||
stageLogInfo(task.stageId, taskStatus)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
override def onJobEnd(jobEnd: SparkListenerJobEnd) {
|
||||
eventQueue.put(jobEnd)
|
||||
}
|
||||
|
||||
protected def processJobEndEvent(job: ActiveJob, reason: JobResult) {
|
||||
var info = "JOB_ID=" + job.runId
|
||||
reason match {
|
||||
case JobSucceeded => info += " STATUS=SUCCESS"
|
||||
case JobFailed(exception) =>
|
||||
info += " STATUS=FAILED REASON="
|
||||
exception.getMessage.split("\\s+").foreach(info += _ + "_")
|
||||
case _ =>
|
||||
}
|
||||
jobLogInfo(job.runId, info.substring(0, info.length - 1).toUpperCase)
|
||||
closeLogWriter(job.runId)
|
||||
}
|
||||
|
||||
protected def recordJobProperties(jobID: Int, properties: Properties) {
|
||||
if(properties != null) {
|
||||
val annotation = properties.getProperty("spark.job.annotation", "")
|
||||
jobLogInfo(jobID, annotation, false)
|
||||
}
|
||||
}
|
||||
|
||||
override def onJobStart(jobStart: SparkListenerJobStart) {
|
||||
eventQueue.put(jobStart)
|
||||
}
|
||||
|
||||
protected def processJobStartEvent(job: ActiveJob, properties: Properties) {
|
||||
createLogWriter(job.runId)
|
||||
recordJobProperties(job.runId, properties)
|
||||
buildJobDep(job.runId, job.finalStage)
|
||||
recordStageDep(job.runId)
|
||||
recordStageDepGraph(job.runId, job.finalStage)
|
||||
jobLogInfo(job.runId, "JOB_ID=" + job.runId + " STATUS=STARTED")
|
||||
}
|
||||
}
|
|
@ -1,27 +1,59 @@
|
|||
package spark.scheduler
|
||||
|
||||
import java.util.Properties
|
||||
import spark.scheduler.cluster.TaskInfo
|
||||
import spark.util.Distribution
|
||||
import spark.{Utils, Logging}
|
||||
import spark.{Logging, SparkContext, TaskEndReason, Utils}
|
||||
import spark.executor.TaskMetrics
|
||||
|
||||
trait SparkListener {
|
||||
/**
|
||||
* called when a stage is completed, with information on the completed stage
|
||||
*/
|
||||
def onStageCompleted(stageCompleted: StageCompleted)
|
||||
}
|
||||
|
||||
sealed trait SparkListenerEvents
|
||||
|
||||
case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents
|
||||
|
||||
case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
|
||||
|
||||
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
|
||||
taskMetrics: TaskMetrics) extends SparkListenerEvents
|
||||
|
||||
case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null)
|
||||
extends SparkListenerEvents
|
||||
|
||||
case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
|
||||
extends SparkListenerEvents
|
||||
|
||||
trait SparkListener {
|
||||
/**
|
||||
* Called when a stage is completed, with information on the completed stage
|
||||
*/
|
||||
def onStageCompleted(stageCompleted: StageCompleted) { }
|
||||
|
||||
/**
|
||||
* Called when a stage is submitted
|
||||
*/
|
||||
def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { }
|
||||
|
||||
/**
|
||||
* Called when a task ends
|
||||
*/
|
||||
def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
|
||||
|
||||
/**
|
||||
* Called when a job starts
|
||||
*/
|
||||
def onJobStart(jobStart: SparkListenerJobStart) { }
|
||||
|
||||
/**
|
||||
* Called when a job ends
|
||||
*/
|
||||
def onJobEnd(jobEnd: SparkListenerJobEnd) { }
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Simple SparkListener that logs a few summary statistics when each stage completes
|
||||
*/
|
||||
class StatsReportListener extends SparkListener with Logging {
|
||||
def onStageCompleted(stageCompleted: StageCompleted) {
|
||||
override def onStageCompleted(stageCompleted: StageCompleted) {
|
||||
import spark.scheduler.StatsReportListener._
|
||||
implicit val sc = stageCompleted
|
||||
this.logInfo("Finished stage: " + stageCompleted.stageInfo)
|
||||
|
|
|
@ -177,7 +177,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
|
|||
val tasks = taskSet.tasks
|
||||
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
|
||||
this.synchronized {
|
||||
val manager = new TaskSetManager(this, taskSet)
|
||||
val manager = new ClusterTaskSetManager(this, taskSet)
|
||||
activeTaskSets(taskSet.id) = manager
|
||||
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
|
||||
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
|
||||
|
|
|
@ -0,0 +1,747 @@
|
|||
package spark.scheduler.cluster
|
||||
|
||||
import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.mutable.HashSet
|
||||
import scala.math.max
|
||||
import scala.math.min
|
||||
|
||||
import spark._
|
||||
import spark.scheduler._
|
||||
import spark.TaskState.TaskState
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
|
||||
|
||||
// process local is expected to be used ONLY within tasksetmanager for now.
|
||||
val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
|
||||
|
||||
type TaskLocality = Value
|
||||
|
||||
def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
|
||||
|
||||
// Must not be the constraint.
|
||||
assert (constraint != TaskLocality.PROCESS_LOCAL)
|
||||
|
||||
constraint match {
|
||||
case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL
|
||||
case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
|
||||
// For anything else, allow
|
||||
case _ => true
|
||||
}
|
||||
}
|
||||
|
||||
def parse(str: String): TaskLocality = {
|
||||
// better way to do this ?
|
||||
try {
|
||||
val retval = TaskLocality.withName(str)
|
||||
// Must not specify PROCESS_LOCAL !
|
||||
assert (retval != TaskLocality.PROCESS_LOCAL)
|
||||
|
||||
retval
|
||||
} catch {
|
||||
case nEx: NoSuchElementException => {
|
||||
logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL");
|
||||
// default to preserve earlier behavior
|
||||
NODE_LOCAL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
|
||||
*/
|
||||
private[spark] class ClusterTaskSetManager(
|
||||
sched: ClusterScheduler,
|
||||
val taskSet: TaskSet)
|
||||
extends TaskSetManager
|
||||
with Logging {
|
||||
|
||||
// Maximum time to wait to run a task in a preferred location (in ms)
|
||||
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
|
||||
|
||||
// CPUs to request per task
|
||||
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
|
||||
|
||||
// Maximum times a task is allowed to fail before failing the job
|
||||
val MAX_TASK_FAILURES = 4
|
||||
|
||||
// Quantile of tasks at which to start speculation
|
||||
val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
|
||||
val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
|
||||
|
||||
// Serializer for closures and tasks.
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
|
||||
val tasks = taskSet.tasks
|
||||
val numTasks = tasks.length
|
||||
val copiesRunning = new Array[Int](numTasks)
|
||||
val finished = new Array[Boolean](numTasks)
|
||||
val numFailures = new Array[Int](numTasks)
|
||||
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
|
||||
var tasksFinished = 0
|
||||
|
||||
var weight = 1
|
||||
var minShare = 0
|
||||
var runningTasks = 0
|
||||
var priority = taskSet.priority
|
||||
var stageId = taskSet.stageId
|
||||
var name = "TaskSet_"+taskSet.stageId.toString
|
||||
var parent:Schedulable = null
|
||||
|
||||
// Last time when we launched a preferred task (for delay scheduling)
|
||||
var lastPreferredLaunchTime = System.currentTimeMillis
|
||||
|
||||
// List of pending tasks for each node (process local to container). These collections are actually
|
||||
// treated as stacks, in which new tasks are added to the end of the
|
||||
// ArrayBuffer and removed from the end. This makes it faster to detect
|
||||
// tasks that repeatedly fail because whenever a task failed, it is put
|
||||
// back at the head of the stack. They are also only cleaned up lazily;
|
||||
// when a task is launched, it remains in all the pending lists except
|
||||
// the one that it was launched from, but gets removed from them later.
|
||||
private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
|
||||
|
||||
// List of pending tasks for each node.
|
||||
// Essentially, similar to pendingTasksForHostPort, except at host level
|
||||
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
|
||||
|
||||
// List of pending tasks for each node based on rack locality.
|
||||
// Essentially, similar to pendingTasksForHost, except at rack level
|
||||
private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
|
||||
|
||||
// List containing pending tasks with no locality preferences
|
||||
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
|
||||
|
||||
// List containing all pending tasks (also used as a stack, as above)
|
||||
val allPendingTasks = new ArrayBuffer[Int]
|
||||
|
||||
// Tasks that can be speculated. Since these will be a small fraction of total
|
||||
// tasks, we'll just hold them in a HashSet.
|
||||
val speculatableTasks = new HashSet[Int]
|
||||
|
||||
// Task index, start and finish time for each task attempt (indexed by task ID)
|
||||
val taskInfos = new HashMap[Long, TaskInfo]
|
||||
|
||||
// Did the job fail?
|
||||
var failed = false
|
||||
var causeOfFailure = ""
|
||||
|
||||
// How frequently to reprint duplicate exceptions in full, in milliseconds
|
||||
val EXCEPTION_PRINT_INTERVAL =
|
||||
System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
|
||||
// Map of recent exceptions (identified by string representation and
|
||||
// top stack frame) to duplicate count (how many times the same
|
||||
// exception has appeared) and time the full exception was
|
||||
// printed. This should ideally be an LRU map that can drop old
|
||||
// exceptions automatically.
|
||||
val recentExceptions = HashMap[String, (Int, Long)]()
|
||||
|
||||
// Figure out the current map output tracker generation and set it on all tasks
|
||||
val generation = sched.mapOutputTracker.getGeneration
|
||||
logDebug("Generation for " + taskSet.id + ": " + generation)
|
||||
for (t <- tasks) {
|
||||
t.generation = generation
|
||||
}
|
||||
|
||||
// Add all our tasks to the pending lists. We do this in reverse order
|
||||
// of task index so that tasks with low indices get launched first.
|
||||
for (i <- (0 until numTasks).reverse) {
|
||||
addPendingTask(i)
|
||||
}
|
||||
|
||||
// Note that it follows the hierarchy.
|
||||
// if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
|
||||
// if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
|
||||
private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler,
|
||||
taskLocality: TaskLocality.TaskLocality): HashSet[String] = {
|
||||
|
||||
if (TaskLocality.PROCESS_LOCAL == taskLocality) {
|
||||
// straight forward comparison ! Special case it.
|
||||
val retval = new HashSet[String]()
|
||||
scheduler.synchronized {
|
||||
for (location <- _taskPreferredLocations) {
|
||||
if (scheduler.isExecutorAliveOnHostPort(location)) {
|
||||
retval += location
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return retval
|
||||
}
|
||||
|
||||
val taskPreferredLocations =
|
||||
if (TaskLocality.NODE_LOCAL == taskLocality) {
|
||||
_taskPreferredLocations
|
||||
} else {
|
||||
assert (TaskLocality.RACK_LOCAL == taskLocality)
|
||||
// Expand set to include all 'seen' rack local hosts.
|
||||
// This works since container allocation/management happens within master - so any rack locality information is updated in msater.
|
||||
// Best case effort, and maybe sort of kludge for now ... rework it later ?
|
||||
val hosts = new HashSet[String]
|
||||
_taskPreferredLocations.foreach(h => {
|
||||
val rackOpt = scheduler.getRackForHost(h)
|
||||
if (rackOpt.isDefined) {
|
||||
val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
|
||||
if (hostsOpt.isDefined) {
|
||||
hosts ++= hostsOpt.get
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that irrespective of what scheduler says, host is always added !
|
||||
hosts += h
|
||||
})
|
||||
|
||||
hosts
|
||||
}
|
||||
|
||||
val retval = new HashSet[String]
|
||||
scheduler.synchronized {
|
||||
for (prefLocation <- taskPreferredLocations) {
|
||||
val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
|
||||
if (aliveLocationsOpt.isDefined) {
|
||||
retval ++= aliveLocationsOpt.get
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
retval
|
||||
}
|
||||
|
||||
// Add a task to all the pending-task lists that it should be on.
|
||||
private def addPendingTask(index: Int) {
|
||||
// We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
|
||||
// hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
|
||||
val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL)
|
||||
val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
|
||||
val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
|
||||
|
||||
if (rackLocalLocations.size == 0) {
|
||||
// Current impl ensures this.
|
||||
assert (processLocalLocations.size == 0)
|
||||
assert (hostLocalLocations.size == 0)
|
||||
pendingTasksWithNoPrefs += index
|
||||
} else {
|
||||
|
||||
// process local locality
|
||||
for (hostPort <- processLocalLocations) {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(hostPort)
|
||||
|
||||
val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
|
||||
hostPortList += index
|
||||
}
|
||||
|
||||
// host locality (includes process local)
|
||||
for (hostPort <- hostLocalLocations) {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(hostPort)
|
||||
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
|
||||
hostList += index
|
||||
}
|
||||
|
||||
// rack locality (includes process local and host local)
|
||||
for (rackLocalHostPort <- rackLocalLocations) {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(rackLocalHostPort)
|
||||
|
||||
val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
|
||||
val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
|
||||
list += index
|
||||
}
|
||||
}
|
||||
|
||||
allPendingTasks += index
|
||||
}
|
||||
|
||||
// Return the pending tasks list for a given host port (process local), or an empty list if
|
||||
// there is no map entry for that host
|
||||
private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(hostPort)
|
||||
pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
|
||||
}
|
||||
|
||||
// Return the pending tasks list for a given host, or an empty list if
|
||||
// there is no map entry for that host
|
||||
private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
pendingTasksForHost.getOrElse(host, ArrayBuffer())
|
||||
}
|
||||
|
||||
// Return the pending tasks (rack level) list for a given host, or an empty list if
|
||||
// there is no map entry for that host
|
||||
private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
|
||||
}
|
||||
|
||||
// Number of pending tasks for a given host Port (which would be process local)
|
||||
def numPendingTasksForHostPort(hostPort: String): Int = {
|
||||
getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
|
||||
}
|
||||
|
||||
// Number of pending tasks for a given host (which would be data local)
|
||||
def numPendingTasksForHost(hostPort: String): Int = {
|
||||
getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
|
||||
}
|
||||
|
||||
// Number of pending rack local tasks for a given host
|
||||
def numRackLocalPendingTasksForHost(hostPort: String): Int = {
|
||||
getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
|
||||
}
|
||||
|
||||
|
||||
// Dequeue a pending task from the given list and return its index.
|
||||
// Return None if the list is empty.
|
||||
// This method also cleans up any tasks in the list that have already
|
||||
// been launched, since we want that to happen lazily.
|
||||
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
|
||||
while (!list.isEmpty) {
|
||||
val index = list.last
|
||||
list.trimEnd(1)
|
||||
if (copiesRunning(index) == 0 && !finished(index)) {
|
||||
return Some(index)
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
// Return a speculative task for a given host if any are available. The task should not have an
|
||||
// attempt running on this host, in case the host is slow. In addition, if locality is set, the
|
||||
// task must have a preference for this host/rack/no preferred locations at all.
|
||||
private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
|
||||
|
||||
assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
|
||||
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
|
||||
|
||||
if (speculatableTasks.size > 0) {
|
||||
val localTask = speculatableTasks.find {
|
||||
index =>
|
||||
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
|
||||
val attemptLocs = taskAttempts(index).map(_.hostPort)
|
||||
(locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
|
||||
}
|
||||
|
||||
if (localTask != None) {
|
||||
speculatableTasks -= localTask.get
|
||||
return localTask
|
||||
}
|
||||
|
||||
// check for rack locality
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
|
||||
val rackTask = speculatableTasks.find {
|
||||
index =>
|
||||
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
|
||||
val attemptLocs = taskAttempts(index).map(_.hostPort)
|
||||
locations.contains(hostPort) && !attemptLocs.contains(hostPort)
|
||||
}
|
||||
|
||||
if (rackTask != None) {
|
||||
speculatableTasks -= rackTask.get
|
||||
return rackTask
|
||||
}
|
||||
}
|
||||
|
||||
// Any task ...
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
|
||||
// Check for attemptLocs also ?
|
||||
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
|
||||
if (nonLocalTask != None) {
|
||||
speculatableTasks -= nonLocalTask.get
|
||||
return nonLocalTask
|
||||
}
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
// Dequeue a pending task for a given node and return its index.
|
||||
// If localOnly is set to false, allow non-local tasks as well.
|
||||
private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
|
||||
val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
|
||||
if (processLocalTask != None) {
|
||||
return processLocalTask
|
||||
}
|
||||
|
||||
val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
|
||||
if (localTask != None) {
|
||||
return localTask
|
||||
}
|
||||
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
|
||||
val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
|
||||
if (rackLocalTask != None) {
|
||||
return rackLocalTask
|
||||
}
|
||||
}
|
||||
|
||||
// Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
|
||||
// TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
|
||||
val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
|
||||
if (noPrefTask != None) {
|
||||
return noPrefTask
|
||||
}
|
||||
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
|
||||
val nonLocalTask = findTaskFromList(allPendingTasks)
|
||||
if (nonLocalTask != None) {
|
||||
return nonLocalTask
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, if all else has failed, find a speculative task
|
||||
return findSpeculativeTask(hostPort, locality)
|
||||
}
|
||||
|
||||
private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
|
||||
Utils.checkHostPort(hostPort)
|
||||
|
||||
val locs = task.preferredLocations
|
||||
|
||||
locs.contains(hostPort)
|
||||
}
|
||||
|
||||
private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
|
||||
val locs = task.preferredLocations
|
||||
|
||||
// If no preference, consider it as host local
|
||||
if (locs.isEmpty) return true
|
||||
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
|
||||
}
|
||||
|
||||
// Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
|
||||
// This is true if either the task has preferred locations and this host is one, or it has
|
||||
// no preferred locations (in which we still count the launch as preferred).
|
||||
private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
|
||||
|
||||
val locs = task.preferredLocations
|
||||
|
||||
val preferredRacks = new HashSet[String]()
|
||||
for (preferredHost <- locs) {
|
||||
val rack = sched.getRackForHost(preferredHost)
|
||||
if (None != rack) preferredRacks += rack.get
|
||||
}
|
||||
|
||||
if (preferredRacks.isEmpty) return false
|
||||
|
||||
val hostRack = sched.getRackForHost(hostPort)
|
||||
|
||||
return None != hostRack && preferredRacks.contains(hostRack.get)
|
||||
}
|
||||
|
||||
// Respond to an offer of a single slave from the scheduler by finding a task
|
||||
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
|
||||
|
||||
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
|
||||
// If explicitly specified, use that
|
||||
val locality = if (overrideLocality != null) overrideLocality else {
|
||||
// expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
|
||||
val time = System.currentTimeMillis
|
||||
if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY
|
||||
}
|
||||
|
||||
findTask(hostPort, locality) match {
|
||||
case Some(index) => {
|
||||
// Found a task; do some bookkeeping and return a Mesos task for it
|
||||
val task = tasks(index)
|
||||
val taskId = sched.newTaskId()
|
||||
// Figure out whether this should count as a preferred launch
|
||||
val taskLocality =
|
||||
if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else
|
||||
if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else
|
||||
if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else
|
||||
TaskLocality.ANY
|
||||
val prefStr = taskLocality.toString
|
||||
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
|
||||
taskSet.id, index, taskId, execId, hostPort, prefStr))
|
||||
// Do various bookkeeping
|
||||
copiesRunning(index) += 1
|
||||
val time = System.currentTimeMillis
|
||||
val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
|
||||
taskInfos(taskId) = info
|
||||
taskAttempts(index) = info :: taskAttempts(index)
|
||||
if (TaskLocality.NODE_LOCAL == taskLocality) {
|
||||
lastPreferredLaunchTime = time
|
||||
}
|
||||
// Serialize and return the task
|
||||
val startTime = System.currentTimeMillis
|
||||
val serializedTask = Task.serializeWithDependencies(
|
||||
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
|
||||
val timeTaken = System.currentTimeMillis - startTime
|
||||
increaseRunningTasks(1)
|
||||
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
|
||||
taskSet.id, index, serializedTask.limit, timeTaken))
|
||||
val taskName = "task %s:%d".format(taskSet.id, index)
|
||||
return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
|
||||
}
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
state match {
|
||||
case TaskState.FINISHED =>
|
||||
taskFinished(tid, state, serializedData)
|
||||
case TaskState.LOST =>
|
||||
taskLost(tid, state, serializedData)
|
||||
case TaskState.FAILED =>
|
||||
taskLost(tid, state, serializedData)
|
||||
case TaskState.KILLED =>
|
||||
taskLost(tid, state, serializedData)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
val info = taskInfos(tid)
|
||||
if (info.failed) {
|
||||
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
|
||||
// or even from Mesos itself when acks get delayed.
|
||||
return
|
||||
}
|
||||
val index = info.index
|
||||
info.markSuccessful()
|
||||
decreaseRunningTasks(1)
|
||||
if (!finished(index)) {
|
||||
tasksFinished += 1
|
||||
logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
|
||||
tid, info.duration, tasksFinished, numTasks))
|
||||
// Deserialize task result and pass it to the scheduler
|
||||
try {
|
||||
val result = ser.deserialize[TaskResult[_]](serializedData)
|
||||
result.metrics.resultSize = serializedData.limit()
|
||||
sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
|
||||
} catch {
|
||||
case cnf: ClassNotFoundException =>
|
||||
val loader = Thread.currentThread().getContextClassLoader
|
||||
throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
|
||||
case ex => throw ex
|
||||
}
|
||||
// Mark finished and stop if we've finished all the tasks
|
||||
finished(index) = true
|
||||
if (tasksFinished == numTasks) {
|
||||
sched.taskSetFinished(this)
|
||||
}
|
||||
} else {
|
||||
logInfo("Ignoring task-finished event for TID " + tid +
|
||||
" because task " + index + " is already finished")
|
||||
}
|
||||
}
|
||||
|
||||
def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
val info = taskInfos(tid)
|
||||
if (info.failed) {
|
||||
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
|
||||
// or even from Mesos itself when acks get delayed.
|
||||
return
|
||||
}
|
||||
val index = info.index
|
||||
info.markFailed()
|
||||
decreaseRunningTasks(1)
|
||||
if (!finished(index)) {
|
||||
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
|
||||
copiesRunning(index) -= 1
|
||||
// Check if the problem is a map output fetch failure. In that case, this
|
||||
// task will never succeed on any node, so tell the scheduler about it.
|
||||
if (serializedData != null && serializedData.limit() > 0) {
|
||||
val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
|
||||
reason match {
|
||||
case fetchFailed: FetchFailed =>
|
||||
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
|
||||
sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
|
||||
finished(index) = true
|
||||
tasksFinished += 1
|
||||
sched.taskSetFinished(this)
|
||||
decreaseRunningTasks(runningTasks)
|
||||
return
|
||||
|
||||
case taskResultTooBig: TaskResultTooBigFailure =>
|
||||
logInfo("Loss was due to task %s result exceeding Akka frame size; " +
|
||||
"aborting job".format(tid))
|
||||
abort("Task %s result exceeded Akka frame size".format(tid))
|
||||
return
|
||||
|
||||
case ef: ExceptionFailure =>
|
||||
val key = ef.description
|
||||
val now = System.currentTimeMillis
|
||||
val (printFull, dupCount) = {
|
||||
if (recentExceptions.contains(key)) {
|
||||
val (dupCount, printTime) = recentExceptions(key)
|
||||
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
|
||||
recentExceptions(key) = (0, now)
|
||||
(true, 0)
|
||||
} else {
|
||||
recentExceptions(key) = (dupCount + 1, printTime)
|
||||
(false, dupCount + 1)
|
||||
}
|
||||
} else {
|
||||
recentExceptions(key) = (0, now)
|
||||
(true, 0)
|
||||
}
|
||||
}
|
||||
if (printFull) {
|
||||
val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
|
||||
logInfo("Loss was due to %s\n%s\n%s".format(
|
||||
ef.className, ef.description, locs.mkString("\n")))
|
||||
} else {
|
||||
logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
|
||||
}
|
||||
|
||||
case _ => {}
|
||||
}
|
||||
}
|
||||
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
|
||||
addPendingTask(index)
|
||||
// Count failed attempts only on FAILED and LOST state (not on KILLED)
|
||||
if (state == TaskState.FAILED || state == TaskState.LOST) {
|
||||
numFailures(index) += 1
|
||||
if (numFailures(index) > MAX_TASK_FAILURES) {
|
||||
logError("Task %s:%d failed more than %d times; aborting job".format(
|
||||
taskSet.id, index, MAX_TASK_FAILURES))
|
||||
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logInfo("Ignoring task-lost event for TID " + tid +
|
||||
" because task " + index + " is already finished")
|
||||
}
|
||||
}
|
||||
|
||||
def error(message: String) {
|
||||
// Save the error message
|
||||
abort("Error: " + message)
|
||||
}
|
||||
|
||||
def abort(message: String) {
|
||||
failed = true
|
||||
causeOfFailure = message
|
||||
// TODO: Kill running tasks if we were not terminated due to a Mesos error
|
||||
sched.listener.taskSetFailed(taskSet, message)
|
||||
decreaseRunningTasks(runningTasks)
|
||||
sched.taskSetFinished(this)
|
||||
}
|
||||
|
||||
override def increaseRunningTasks(taskNum: Int) {
|
||||
runningTasks += taskNum
|
||||
if (parent != null) {
|
||||
parent.increaseRunningTasks(taskNum)
|
||||
}
|
||||
}
|
||||
|
||||
override def decreaseRunningTasks(taskNum: Int) {
|
||||
runningTasks -= taskNum
|
||||
if (parent != null) {
|
||||
parent.decreaseRunningTasks(taskNum)
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed
|
||||
override def getSchedulableByName(name: String): Schedulable = {
|
||||
return null
|
||||
}
|
||||
|
||||
override def addSchedulable(schedulable:Schedulable) {
|
||||
//nothing
|
||||
}
|
||||
|
||||
override def removeSchedulable(schedulable:Schedulable) {
|
||||
//nothing
|
||||
}
|
||||
|
||||
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
|
||||
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
|
||||
sortedTaskSetQueue += this
|
||||
return sortedTaskSetQueue
|
||||
}
|
||||
|
||||
override def executorLost(execId: String, hostPort: String) {
|
||||
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
|
||||
|
||||
// If some task has preferred locations only on hostname, and there are no more executors there,
|
||||
// put it in the no-prefs list to avoid the wait from delay scheduling
|
||||
|
||||
// host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to
|
||||
// no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc.
|
||||
// Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if
|
||||
// there is no host local node for the task (not if there is no process local node for the task)
|
||||
for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
|
||||
// val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
|
||||
val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
|
||||
if (newLocs.isEmpty) {
|
||||
pendingTasksWithNoPrefs += index
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
|
||||
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
|
||||
for ((tid, info) <- taskInfos if info.executorId == execId) {
|
||||
val index = taskInfos(tid).index
|
||||
if (finished(index)) {
|
||||
finished(index) = false
|
||||
copiesRunning(index) -= 1
|
||||
tasksFinished -= 1
|
||||
addPendingTask(index)
|
||||
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
|
||||
// stage finishes when a total of tasks.size tasks finish.
|
||||
sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Also re-enqueue any tasks that were running on the node
|
||||
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
|
||||
taskLost(tid, TaskState.KILLED, null)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check for tasks to be speculated and return true if there are any. This is called periodically
|
||||
* by the ClusterScheduler.
|
||||
*
|
||||
* TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
|
||||
* we don't scan the whole task set. It might also help to make this sorted by launch time.
|
||||
*/
|
||||
override def checkSpeculatableTasks(): Boolean = {
|
||||
// Can't speculate if we only have one task, or if all tasks have finished.
|
||||
if (numTasks == 1 || tasksFinished == numTasks) {
|
||||
return false
|
||||
}
|
||||
var foundTasks = false
|
||||
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
|
||||
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
|
||||
if (tasksFinished >= minFinishedForSpeculation) {
|
||||
val time = System.currentTimeMillis()
|
||||
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
|
||||
Arrays.sort(durations)
|
||||
val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
|
||||
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
|
||||
// TODO: Threshold should also look at standard deviation of task durations and have a lower
|
||||
// bound based on that.
|
||||
logDebug("Task length threshold for speculation: " + threshold)
|
||||
for ((tid, info) <- taskInfos) {
|
||||
val index = info.index
|
||||
if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
|
||||
!speculatableTasks.contains(index)) {
|
||||
logInfo(
|
||||
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
|
||||
taskSet.id, index, info.hostPort, threshold))
|
||||
speculatableTasks += index
|
||||
foundTasks = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return foundTasks
|
||||
}
|
||||
|
||||
override def hasPendingTasks(): Boolean = {
|
||||
numTasks > 0 && tasksFinished < numTasks
|
||||
}
|
||||
}
|
|
@ -1,747 +1,17 @@
|
|||
package spark.scheduler.cluster
|
||||
|
||||
import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.mutable.HashSet
|
||||
import scala.math.max
|
||||
import scala.math.min
|
||||
|
||||
import spark._
|
||||
import spark.scheduler._
|
||||
import spark.TaskState.TaskState
|
||||
import java.nio.ByteBuffer
|
||||
|
||||
private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
|
||||
|
||||
// process local is expected to be used ONLY within tasksetmanager for now.
|
||||
val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
|
||||
|
||||
type TaskLocality = Value
|
||||
|
||||
def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
|
||||
|
||||
// Must not be the constraint.
|
||||
assert (constraint != TaskLocality.PROCESS_LOCAL)
|
||||
|
||||
constraint match {
|
||||
case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL
|
||||
case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
|
||||
// For anything else, allow
|
||||
case _ => true
|
||||
}
|
||||
}
|
||||
|
||||
def parse(str: String): TaskLocality = {
|
||||
// better way to do this ?
|
||||
try {
|
||||
val retval = TaskLocality.withName(str)
|
||||
// Must not specify PROCESS_LOCAL !
|
||||
assert (retval != TaskLocality.PROCESS_LOCAL)
|
||||
|
||||
retval
|
||||
} catch {
|
||||
case nEx: NoSuchElementException => {
|
||||
logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL");
|
||||
// default to preserve earlier behavior
|
||||
NODE_LOCAL
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
|
||||
*/
|
||||
private[spark] class TaskSetManager(
|
||||
sched: ClusterScheduler,
|
||||
val taskSet: TaskSet)
|
||||
extends Schedulable
|
||||
with Logging {
|
||||
|
||||
// Maximum time to wait to run a task in a preferred location (in ms)
|
||||
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
|
||||
|
||||
// CPUs to request per task
|
||||
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
|
||||
|
||||
// Maximum times a task is allowed to fail before failing the job
|
||||
val MAX_TASK_FAILURES = 4
|
||||
|
||||
// Quantile of tasks at which to start speculation
|
||||
val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
|
||||
val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
|
||||
|
||||
// Serializer for closures and tasks.
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
|
||||
val tasks = taskSet.tasks
|
||||
val numTasks = tasks.length
|
||||
val copiesRunning = new Array[Int](numTasks)
|
||||
val finished = new Array[Boolean](numTasks)
|
||||
val numFailures = new Array[Int](numTasks)
|
||||
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
|
||||
var tasksFinished = 0
|
||||
|
||||
var weight = 1
|
||||
var minShare = 0
|
||||
var runningTasks = 0
|
||||
var priority = taskSet.priority
|
||||
var stageId = taskSet.stageId
|
||||
var name = "TaskSet_"+taskSet.stageId.toString
|
||||
var parent:Schedulable = null
|
||||
|
||||
// Last time when we launched a preferred task (for delay scheduling)
|
||||
var lastPreferredLaunchTime = System.currentTimeMillis
|
||||
|
||||
// List of pending tasks for each node (process local to container). These collections are actually
|
||||
// treated as stacks, in which new tasks are added to the end of the
|
||||
// ArrayBuffer and removed from the end. This makes it faster to detect
|
||||
// tasks that repeatedly fail because whenever a task failed, it is put
|
||||
// back at the head of the stack. They are also only cleaned up lazily;
|
||||
// when a task is launched, it remains in all the pending lists except
|
||||
// the one that it was launched from, but gets removed from them later.
|
||||
private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
|
||||
|
||||
// List of pending tasks for each node.
|
||||
// Essentially, similar to pendingTasksForHostPort, except at host level
|
||||
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
|
||||
|
||||
// List of pending tasks for each node based on rack locality.
|
||||
// Essentially, similar to pendingTasksForHost, except at rack level
|
||||
private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
|
||||
|
||||
// List containing pending tasks with no locality preferences
|
||||
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
|
||||
|
||||
// List containing all pending tasks (also used as a stack, as above)
|
||||
val allPendingTasks = new ArrayBuffer[Int]
|
||||
|
||||
// Tasks that can be speculated. Since these will be a small fraction of total
|
||||
// tasks, we'll just hold them in a HashSet.
|
||||
val speculatableTasks = new HashSet[Int]
|
||||
|
||||
// Task index, start and finish time for each task attempt (indexed by task ID)
|
||||
val taskInfos = new HashMap[Long, TaskInfo]
|
||||
|
||||
// Did the job fail?
|
||||
var failed = false
|
||||
var causeOfFailure = ""
|
||||
|
||||
// How frequently to reprint duplicate exceptions in full, in milliseconds
|
||||
val EXCEPTION_PRINT_INTERVAL =
|
||||
System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
|
||||
// Map of recent exceptions (identified by string representation and
|
||||
// top stack frame) to duplicate count (how many times the same
|
||||
// exception has appeared) and time the full exception was
|
||||
// printed. This should ideally be an LRU map that can drop old
|
||||
// exceptions automatically.
|
||||
val recentExceptions = HashMap[String, (Int, Long)]()
|
||||
|
||||
// Figure out the current map output tracker generation and set it on all tasks
|
||||
val generation = sched.mapOutputTracker.getGeneration
|
||||
logDebug("Generation for " + taskSet.id + ": " + generation)
|
||||
for (t <- tasks) {
|
||||
t.generation = generation
|
||||
}
|
||||
|
||||
// Add all our tasks to the pending lists. We do this in reverse order
|
||||
// of task index so that tasks with low indices get launched first.
|
||||
for (i <- (0 until numTasks).reverse) {
|
||||
addPendingTask(i)
|
||||
}
|
||||
|
||||
// Note that it follows the hierarchy.
|
||||
// if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
|
||||
// if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
|
||||
private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler,
|
||||
taskLocality: TaskLocality.TaskLocality): HashSet[String] = {
|
||||
|
||||
if (TaskLocality.PROCESS_LOCAL == taskLocality) {
|
||||
// straight forward comparison ! Special case it.
|
||||
val retval = new HashSet[String]()
|
||||
scheduler.synchronized {
|
||||
for (location <- _taskPreferredLocations) {
|
||||
if (scheduler.isExecutorAliveOnHostPort(location)) {
|
||||
retval += location
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return retval
|
||||
}
|
||||
|
||||
val taskPreferredLocations =
|
||||
if (TaskLocality.NODE_LOCAL == taskLocality) {
|
||||
_taskPreferredLocations
|
||||
} else {
|
||||
assert (TaskLocality.RACK_LOCAL == taskLocality)
|
||||
// Expand set to include all 'seen' rack local hosts.
|
||||
// This works since container allocation/management happens within master - so any rack locality information is updated in msater.
|
||||
// Best case effort, and maybe sort of kludge for now ... rework it later ?
|
||||
val hosts = new HashSet[String]
|
||||
_taskPreferredLocations.foreach(h => {
|
||||
val rackOpt = scheduler.getRackForHost(h)
|
||||
if (rackOpt.isDefined) {
|
||||
val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
|
||||
if (hostsOpt.isDefined) {
|
||||
hosts ++= hostsOpt.get
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure that irrespective of what scheduler says, host is always added !
|
||||
hosts += h
|
||||
})
|
||||
|
||||
hosts
|
||||
}
|
||||
|
||||
val retval = new HashSet[String]
|
||||
scheduler.synchronized {
|
||||
for (prefLocation <- taskPreferredLocations) {
|
||||
val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
|
||||
if (aliveLocationsOpt.isDefined) {
|
||||
retval ++= aliveLocationsOpt.get
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
retval
|
||||
}
|
||||
|
||||
// Add a task to all the pending-task lists that it should be on.
|
||||
private def addPendingTask(index: Int) {
|
||||
// We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
|
||||
// hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
|
||||
val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL)
|
||||
val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
|
||||
val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
|
||||
|
||||
if (rackLocalLocations.size == 0) {
|
||||
// Current impl ensures this.
|
||||
assert (processLocalLocations.size == 0)
|
||||
assert (hostLocalLocations.size == 0)
|
||||
pendingTasksWithNoPrefs += index
|
||||
} else {
|
||||
|
||||
// process local locality
|
||||
for (hostPort <- processLocalLocations) {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(hostPort)
|
||||
|
||||
val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
|
||||
hostPortList += index
|
||||
}
|
||||
|
||||
// host locality (includes process local)
|
||||
for (hostPort <- hostLocalLocations) {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(hostPort)
|
||||
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
|
||||
hostList += index
|
||||
}
|
||||
|
||||
// rack locality (includes process local and host local)
|
||||
for (rackLocalHostPort <- rackLocalLocations) {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(rackLocalHostPort)
|
||||
|
||||
val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
|
||||
val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
|
||||
list += index
|
||||
}
|
||||
}
|
||||
|
||||
allPendingTasks += index
|
||||
}
|
||||
|
||||
// Return the pending tasks list for a given host port (process local), or an empty list if
|
||||
// there is no map entry for that host
|
||||
private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
|
||||
// DEBUG Code
|
||||
Utils.checkHostPort(hostPort)
|
||||
pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
|
||||
}
|
||||
|
||||
// Return the pending tasks list for a given host, or an empty list if
|
||||
// there is no map entry for that host
|
||||
private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
pendingTasksForHost.getOrElse(host, ArrayBuffer())
|
||||
}
|
||||
|
||||
// Return the pending tasks (rack level) list for a given host, or an empty list if
|
||||
// there is no map entry for that host
|
||||
private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
|
||||
}
|
||||
|
||||
// Number of pending tasks for a given host Port (which would be process local)
|
||||
def numPendingTasksForHostPort(hostPort: String): Int = {
|
||||
getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
|
||||
}
|
||||
|
||||
// Number of pending tasks for a given host (which would be data local)
|
||||
def numPendingTasksForHost(hostPort: String): Int = {
|
||||
getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
|
||||
}
|
||||
|
||||
// Number of pending rack local tasks for a given host
|
||||
def numRackLocalPendingTasksForHost(hostPort: String): Int = {
|
||||
getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
|
||||
}
|
||||
|
||||
|
||||
// Dequeue a pending task from the given list and return its index.
|
||||
// Return None if the list is empty.
|
||||
// This method also cleans up any tasks in the list that have already
|
||||
// been launched, since we want that to happen lazily.
|
||||
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
|
||||
while (!list.isEmpty) {
|
||||
val index = list.last
|
||||
list.trimEnd(1)
|
||||
if (copiesRunning(index) == 0 && !finished(index)) {
|
||||
return Some(index)
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
// Return a speculative task for a given host if any are available. The task should not have an
|
||||
// attempt running on this host, in case the host is slow. In addition, if locality is set, the
|
||||
// task must have a preference for this host/rack/no preferred locations at all.
|
||||
private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
|
||||
|
||||
assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
|
||||
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
|
||||
|
||||
if (speculatableTasks.size > 0) {
|
||||
val localTask = speculatableTasks.find {
|
||||
index =>
|
||||
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
|
||||
val attemptLocs = taskAttempts(index).map(_.hostPort)
|
||||
(locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
|
||||
}
|
||||
|
||||
if (localTask != None) {
|
||||
speculatableTasks -= localTask.get
|
||||
return localTask
|
||||
}
|
||||
|
||||
// check for rack locality
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
|
||||
val rackTask = speculatableTasks.find {
|
||||
index =>
|
||||
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
|
||||
val attemptLocs = taskAttempts(index).map(_.hostPort)
|
||||
locations.contains(hostPort) && !attemptLocs.contains(hostPort)
|
||||
}
|
||||
|
||||
if (rackTask != None) {
|
||||
speculatableTasks -= rackTask.get
|
||||
return rackTask
|
||||
}
|
||||
}
|
||||
|
||||
// Any task ...
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
|
||||
// Check for attemptLocs also ?
|
||||
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
|
||||
if (nonLocalTask != None) {
|
||||
speculatableTasks -= nonLocalTask.get
|
||||
return nonLocalTask
|
||||
}
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
// Dequeue a pending task for a given node and return its index.
|
||||
// If localOnly is set to false, allow non-local tasks as well.
|
||||
private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
|
||||
val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
|
||||
if (processLocalTask != None) {
|
||||
return processLocalTask
|
||||
}
|
||||
|
||||
val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
|
||||
if (localTask != None) {
|
||||
return localTask
|
||||
}
|
||||
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
|
||||
val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
|
||||
if (rackLocalTask != None) {
|
||||
return rackLocalTask
|
||||
}
|
||||
}
|
||||
|
||||
// Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
|
||||
// TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
|
||||
val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
|
||||
if (noPrefTask != None) {
|
||||
return noPrefTask
|
||||
}
|
||||
|
||||
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
|
||||
val nonLocalTask = findTaskFromList(allPendingTasks)
|
||||
if (nonLocalTask != None) {
|
||||
return nonLocalTask
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, if all else has failed, find a speculative task
|
||||
return findSpeculativeTask(hostPort, locality)
|
||||
}
|
||||
|
||||
private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
|
||||
Utils.checkHostPort(hostPort)
|
||||
|
||||
val locs = task.preferredLocations
|
||||
|
||||
locs.contains(hostPort)
|
||||
}
|
||||
|
||||
private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
|
||||
val locs = task.preferredLocations
|
||||
|
||||
// If no preference, consider it as host local
|
||||
if (locs.isEmpty) return true
|
||||
|
||||
val host = Utils.parseHostPort(hostPort)._1
|
||||
locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
|
||||
}
|
||||
|
||||
// Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
|
||||
// This is true if either the task has preferred locations and this host is one, or it has
|
||||
// no preferred locations (in which we still count the launch as preferred).
|
||||
private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
|
||||
|
||||
val locs = task.preferredLocations
|
||||
|
||||
val preferredRacks = new HashSet[String]()
|
||||
for (preferredHost <- locs) {
|
||||
val rack = sched.getRackForHost(preferredHost)
|
||||
if (None != rack) preferredRacks += rack.get
|
||||
}
|
||||
|
||||
if (preferredRacks.isEmpty) return false
|
||||
|
||||
val hostRack = sched.getRackForHost(hostPort)
|
||||
|
||||
return None != hostRack && preferredRacks.contains(hostRack.get)
|
||||
}
|
||||
|
||||
// Respond to an offer of a single slave from the scheduler by finding a task
|
||||
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
|
||||
|
||||
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
|
||||
// If explicitly specified, use that
|
||||
val locality = if (overrideLocality != null) overrideLocality else {
|
||||
// expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
|
||||
val time = System.currentTimeMillis
|
||||
if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY
|
||||
}
|
||||
|
||||
findTask(hostPort, locality) match {
|
||||
case Some(index) => {
|
||||
// Found a task; do some bookkeeping and return a Mesos task for it
|
||||
val task = tasks(index)
|
||||
val taskId = sched.newTaskId()
|
||||
// Figure out whether this should count as a preferred launch
|
||||
val taskLocality =
|
||||
if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else
|
||||
if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else
|
||||
if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else
|
||||
TaskLocality.ANY
|
||||
val prefStr = taskLocality.toString
|
||||
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
|
||||
taskSet.id, index, taskId, execId, hostPort, prefStr))
|
||||
// Do various bookkeeping
|
||||
copiesRunning(index) += 1
|
||||
val time = System.currentTimeMillis
|
||||
val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
|
||||
taskInfos(taskId) = info
|
||||
taskAttempts(index) = info :: taskAttempts(index)
|
||||
if (TaskLocality.NODE_LOCAL == taskLocality) {
|
||||
lastPreferredLaunchTime = time
|
||||
}
|
||||
// Serialize and return the task
|
||||
val startTime = System.currentTimeMillis
|
||||
val serializedTask = Task.serializeWithDependencies(
|
||||
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
|
||||
val timeTaken = System.currentTimeMillis - startTime
|
||||
increaseRunningTasks(1)
|
||||
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
|
||||
taskSet.id, index, serializedTask.limit, timeTaken))
|
||||
val taskName = "task %s:%d".format(taskSet.id, index)
|
||||
return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
|
||||
}
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
state match {
|
||||
case TaskState.FINISHED =>
|
||||
taskFinished(tid, state, serializedData)
|
||||
case TaskState.LOST =>
|
||||
taskLost(tid, state, serializedData)
|
||||
case TaskState.FAILED =>
|
||||
taskLost(tid, state, serializedData)
|
||||
case TaskState.KILLED =>
|
||||
taskLost(tid, state, serializedData)
|
||||
case _ =>
|
||||
}
|
||||
}
|
||||
|
||||
def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
val info = taskInfos(tid)
|
||||
if (info.failed) {
|
||||
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
|
||||
// or even from Mesos itself when acks get delayed.
|
||||
return
|
||||
}
|
||||
val index = info.index
|
||||
info.markSuccessful()
|
||||
decreaseRunningTasks(1)
|
||||
if (!finished(index)) {
|
||||
tasksFinished += 1
|
||||
logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
|
||||
tid, info.duration, tasksFinished, numTasks))
|
||||
// Deserialize task result and pass it to the scheduler
|
||||
try {
|
||||
val result = ser.deserialize[TaskResult[_]](serializedData)
|
||||
result.metrics.resultSize = serializedData.limit()
|
||||
sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
|
||||
} catch {
|
||||
case cnf: ClassNotFoundException =>
|
||||
val loader = Thread.currentThread().getContextClassLoader
|
||||
throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
|
||||
case ex => throw ex
|
||||
}
|
||||
// Mark finished and stop if we've finished all the tasks
|
||||
finished(index) = true
|
||||
if (tasksFinished == numTasks) {
|
||||
sched.taskSetFinished(this)
|
||||
}
|
||||
} else {
|
||||
logInfo("Ignoring task-finished event for TID " + tid +
|
||||
" because task " + index + " is already finished")
|
||||
}
|
||||
}
|
||||
|
||||
def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
val info = taskInfos(tid)
|
||||
if (info.failed) {
|
||||
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
|
||||
// or even from Mesos itself when acks get delayed.
|
||||
return
|
||||
}
|
||||
val index = info.index
|
||||
info.markFailed()
|
||||
decreaseRunningTasks(1)
|
||||
if (!finished(index)) {
|
||||
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
|
||||
copiesRunning(index) -= 1
|
||||
// Check if the problem is a map output fetch failure. In that case, this
|
||||
// task will never succeed on any node, so tell the scheduler about it.
|
||||
if (serializedData != null && serializedData.limit() > 0) {
|
||||
val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
|
||||
reason match {
|
||||
case fetchFailed: FetchFailed =>
|
||||
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
|
||||
sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
|
||||
finished(index) = true
|
||||
tasksFinished += 1
|
||||
sched.taskSetFinished(this)
|
||||
decreaseRunningTasks(runningTasks)
|
||||
return
|
||||
|
||||
case taskResultTooBig: TaskResultTooBigFailure =>
|
||||
logInfo("Loss was due to task %s result exceeding Akka frame size;" +
|
||||
"aborting job".format(tid))
|
||||
abort("Task %s result exceeded Akka frame size".format(tid))
|
||||
return
|
||||
|
||||
case ef: ExceptionFailure =>
|
||||
val key = ef.description
|
||||
val now = System.currentTimeMillis
|
||||
val (printFull, dupCount) = {
|
||||
if (recentExceptions.contains(key)) {
|
||||
val (dupCount, printTime) = recentExceptions(key)
|
||||
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
|
||||
recentExceptions(key) = (0, now)
|
||||
(true, 0)
|
||||
} else {
|
||||
recentExceptions(key) = (dupCount + 1, printTime)
|
||||
(false, dupCount + 1)
|
||||
}
|
||||
} else {
|
||||
recentExceptions(key) = (0, now)
|
||||
(true, 0)
|
||||
}
|
||||
}
|
||||
if (printFull) {
|
||||
val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
|
||||
logInfo("Loss was due to %s\n%s\n%s".format(
|
||||
ef.className, ef.description, locs.mkString("\n")))
|
||||
} else {
|
||||
logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
|
||||
}
|
||||
|
||||
case _ => {}
|
||||
}
|
||||
}
|
||||
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
|
||||
addPendingTask(index)
|
||||
// Count failed attempts only on FAILED and LOST state (not on KILLED)
|
||||
if (state == TaskState.FAILED || state == TaskState.LOST) {
|
||||
numFailures(index) += 1
|
||||
if (numFailures(index) > MAX_TASK_FAILURES) {
|
||||
logError("Task %s:%d failed more than %d times; aborting job".format(
|
||||
taskSet.id, index, MAX_TASK_FAILURES))
|
||||
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
logInfo("Ignoring task-lost event for TID " + tid +
|
||||
" because task " + index + " is already finished")
|
||||
}
|
||||
}
|
||||
|
||||
def error(message: String) {
|
||||
// Save the error message
|
||||
abort("Error: " + message)
|
||||
}
|
||||
|
||||
def abort(message: String) {
|
||||
failed = true
|
||||
causeOfFailure = message
|
||||
// TODO: Kill running tasks if we were not terminated due to a Mesos error
|
||||
sched.listener.taskSetFailed(taskSet, message)
|
||||
decreaseRunningTasks(runningTasks)
|
||||
sched.taskSetFinished(this)
|
||||
}
|
||||
|
||||
override def increaseRunningTasks(taskNum: Int) {
|
||||
runningTasks += taskNum
|
||||
if (parent != null) {
|
||||
parent.increaseRunningTasks(taskNum)
|
||||
}
|
||||
}
|
||||
|
||||
override def decreaseRunningTasks(taskNum: Int) {
|
||||
runningTasks -= taskNum
|
||||
if (parent != null) {
|
||||
parent.decreaseRunningTasks(taskNum)
|
||||
}
|
||||
}
|
||||
|
||||
//TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed
|
||||
override def getSchedulableByName(name: String): Schedulable = {
|
||||
return null
|
||||
}
|
||||
|
||||
override def addSchedulable(schedulable:Schedulable) {
|
||||
//nothing
|
||||
}
|
||||
|
||||
override def removeSchedulable(schedulable:Schedulable) {
|
||||
//nothing
|
||||
}
|
||||
|
||||
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
|
||||
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
|
||||
sortedTaskSetQueue += this
|
||||
return sortedTaskSetQueue
|
||||
}
|
||||
|
||||
override def executorLost(execId: String, hostPort: String) {
|
||||
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
|
||||
|
||||
// If some task has preferred locations only on hostname, and there are no more executors there,
|
||||
// put it in the no-prefs list to avoid the wait from delay scheduling
|
||||
|
||||
// host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to
|
||||
// no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc.
|
||||
// Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if
|
||||
// there is no host local node for the task (not if there is no process local node for the task)
|
||||
for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
|
||||
// val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
|
||||
val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
|
||||
if (newLocs.isEmpty) {
|
||||
pendingTasksWithNoPrefs += index
|
||||
}
|
||||
}
|
||||
|
||||
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
|
||||
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
|
||||
for ((tid, info) <- taskInfos if info.executorId == execId) {
|
||||
val index = taskInfos(tid).index
|
||||
if (finished(index)) {
|
||||
finished(index) = false
|
||||
copiesRunning(index) -= 1
|
||||
tasksFinished -= 1
|
||||
addPendingTask(index)
|
||||
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
|
||||
// stage finishes when a total of tasks.size tasks finish.
|
||||
sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Also re-enqueue any tasks that were running on the node
|
||||
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
|
||||
taskLost(tid, TaskState.KILLED, null)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check for tasks to be speculated and return true if there are any. This is called periodically
|
||||
* by the ClusterScheduler.
|
||||
*
|
||||
* TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
|
||||
* we don't scan the whole task set. It might also help to make this sorted by launch time.
|
||||
*/
|
||||
override def checkSpeculatableTasks(): Boolean = {
|
||||
// Can't speculate if we only have one task, or if all tasks have finished.
|
||||
if (numTasks == 1 || tasksFinished == numTasks) {
|
||||
return false
|
||||
}
|
||||
var foundTasks = false
|
||||
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
|
||||
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
|
||||
if (tasksFinished >= minFinishedForSpeculation) {
|
||||
val time = System.currentTimeMillis()
|
||||
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
|
||||
Arrays.sort(durations)
|
||||
val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
|
||||
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
|
||||
// TODO: Threshold should also look at standard deviation of task durations and have a lower
|
||||
// bound based on that.
|
||||
logDebug("Task length threshold for speculation: " + threshold)
|
||||
for ((tid, info) <- taskInfos) {
|
||||
val index = info.index
|
||||
if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
|
||||
!speculatableTasks.contains(index)) {
|
||||
logInfo(
|
||||
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
|
||||
taskSet.id, index, info.hostPort, threshold))
|
||||
speculatableTasks += index
|
||||
foundTasks = true
|
||||
}
|
||||
}
|
||||
}
|
||||
return foundTasks
|
||||
}
|
||||
|
||||
override def hasPendingTasks(): Boolean = {
|
||||
numTasks > 0 && tasksFinished < numTasks
|
||||
}
|
||||
private[spark] trait TaskSetManager extends Schedulable {
|
||||
def taskSet: TaskSet
|
||||
def slaveOffer(execId: String, hostPort: String, availableCpus: Double,
|
||||
overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription]
|
||||
def numPendingTasksForHostPort(hostPort: String): Int
|
||||
def numRackLocalPendingTasksForHost(hostPort :String): Int
|
||||
def numPendingTasksForHost(hostPort: String): Int
|
||||
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
|
||||
def error(message: String)
|
||||
}
|
||||
|
|
|
@ -2,19 +2,50 @@ package spark.scheduler.local
|
|||
|
||||
import java.io.File
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.nio.ByteBuffer
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.mutable.HashSet
|
||||
|
||||
import spark._
|
||||
import spark.TaskState.TaskState
|
||||
import spark.executor.ExecutorURLClassLoader
|
||||
import spark.scheduler._
|
||||
import spark.scheduler.cluster.{TaskLocality, TaskInfo}
|
||||
import spark.scheduler.cluster._
|
||||
import akka.actor._
|
||||
|
||||
/**
|
||||
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
|
||||
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
|
||||
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
|
||||
* testing fault recovery.
|
||||
*/
|
||||
private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext)
|
||||
|
||||
private[spark] case class LocalReviveOffers()
|
||||
private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
|
||||
|
||||
private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
|
||||
def receive = {
|
||||
case LocalReviveOffers =>
|
||||
launchTask(localScheduler.resourceOffer(freeCores))
|
||||
case LocalStatusUpdate(taskId, state, serializeData) =>
|
||||
freeCores += 1
|
||||
localScheduler.statusUpdate(taskId, state, serializeData)
|
||||
launchTask(localScheduler.resourceOffer(freeCores))
|
||||
}
|
||||
|
||||
def launchTask(tasks : Seq[TaskDescription]) {
|
||||
for (task <- tasks) {
|
||||
freeCores -= 1
|
||||
localScheduler.threadPool.submit(new Runnable {
|
||||
def run() {
|
||||
localScheduler.runTask(task.taskId,task.serializedTask)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
|
||||
extends TaskScheduler
|
||||
with Logging {
|
||||
|
||||
|
@ -30,89 +61,127 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
|
|||
|
||||
val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
|
||||
|
||||
// TODO: Need to take into account stage priority in scheduling
|
||||
var schedulableBuilder: SchedulableBuilder = null
|
||||
var rootPool: Pool = null
|
||||
val activeTaskSets = new HashMap[String, TaskSetManager]
|
||||
val taskIdToTaskSetId = new HashMap[Long, String]
|
||||
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
|
||||
|
||||
override def start() { }
|
||||
var localActor: ActorRef = null
|
||||
|
||||
override def start() {
|
||||
//default scheduler is FIFO
|
||||
val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO")
|
||||
//temporarily set rootPool name to empty
|
||||
rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0)
|
||||
schedulableBuilder = {
|
||||
schedulingMode match {
|
||||
case "FIFO" =>
|
||||
new FIFOSchedulableBuilder(rootPool)
|
||||
case "FAIR" =>
|
||||
new FairSchedulableBuilder(rootPool)
|
||||
}
|
||||
}
|
||||
schedulableBuilder.buildPools()
|
||||
|
||||
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
|
||||
}
|
||||
|
||||
override def setListener(listener: TaskSchedulerListener) {
|
||||
this.listener = listener
|
||||
}
|
||||
|
||||
override def submitTasks(taskSet: TaskSet) {
|
||||
val tasks = taskSet.tasks
|
||||
val failCount = new Array[Int](tasks.size)
|
||||
|
||||
def submitTask(task: Task[_], idInJob: Int) {
|
||||
val myAttemptId = attemptId.getAndIncrement()
|
||||
threadPool.submit(new Runnable {
|
||||
def run() {
|
||||
runTask(task, idInJob, myAttemptId)
|
||||
}
|
||||
})
|
||||
synchronized {
|
||||
var manager = new LocalTaskSetManager(this, taskSet)
|
||||
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
|
||||
activeTaskSets(taskSet.id) = manager
|
||||
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
|
||||
localActor ! LocalReviveOffers
|
||||
}
|
||||
}
|
||||
|
||||
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
|
||||
logInfo("Running " + task)
|
||||
val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
|
||||
// Set the Spark execution environment for the worker thread
|
||||
SparkEnv.set(env)
|
||||
try {
|
||||
Accumulators.clear()
|
||||
Thread.currentThread().setContextClassLoader(classLoader)
|
||||
|
||||
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
|
||||
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser)
|
||||
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
|
||||
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
|
||||
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
|
||||
val deserStart = System.currentTimeMillis()
|
||||
val deserializedTask = ser.deserialize[Task[_]](
|
||||
taskBytes, Thread.currentThread.getContextClassLoader)
|
||||
val deserTime = System.currentTimeMillis() - deserStart
|
||||
|
||||
// Run it
|
||||
val result: Any = deserializedTask.run(attemptId)
|
||||
|
||||
// Serialize and deserialize the result to emulate what the Mesos
|
||||
// executor does. This is useful to catch serialization errors early
|
||||
// on in development (so when users move their local Spark programs
|
||||
// to the cluster, they don't get surprised by serialization errors).
|
||||
val serResult = ser.serialize(result)
|
||||
deserializedTask.metrics.get.resultSize = serResult.limit()
|
||||
val resultToReturn = ser.deserialize[Any](serResult)
|
||||
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
|
||||
ser.serialize(Accumulators.values))
|
||||
logInfo("Finished " + task)
|
||||
info.markSuccessful()
|
||||
deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough
|
||||
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
|
||||
|
||||
// If the threadpool has not already been shutdown, notify DAGScheduler
|
||||
if (!Thread.currentThread().isInterrupted)
|
||||
listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
|
||||
} catch {
|
||||
case t: Throwable => {
|
||||
logError("Exception in task " + idInJob, t)
|
||||
failCount.synchronized {
|
||||
failCount(idInJob) += 1
|
||||
if (failCount(idInJob) <= maxFailures) {
|
||||
submitTask(task, idInJob)
|
||||
} else {
|
||||
// TODO: Do something nicer here to return all the way to the user
|
||||
if (!Thread.currentThread().isInterrupted) {
|
||||
val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
|
||||
listener.taskEnded(task, failure, null, null, info, null)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
|
||||
synchronized {
|
||||
var freeCpuCores = freeCores
|
||||
val tasks = new ArrayBuffer[TaskDescription](freeCores)
|
||||
val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
|
||||
for (manager <- sortedTaskSetQueue) {
|
||||
logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
|
||||
}
|
||||
}
|
||||
|
||||
for ((task, i) <- tasks.zipWithIndex) {
|
||||
submitTask(task, i)
|
||||
var launchTask = false
|
||||
for (manager <- sortedTaskSetQueue) {
|
||||
do {
|
||||
launchTask = false
|
||||
manager.slaveOffer(null,null,freeCpuCores) match {
|
||||
case Some(task) =>
|
||||
tasks += task
|
||||
taskIdToTaskSetId(task.taskId) = manager.taskSet.id
|
||||
taskSetTaskIds(manager.taskSet.id) += task.taskId
|
||||
freeCpuCores -= 1
|
||||
launchTask = true
|
||||
case None => {}
|
||||
}
|
||||
} while(launchTask)
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
}
|
||||
|
||||
def taskSetFinished(manager: TaskSetManager) {
|
||||
synchronized {
|
||||
activeTaskSets -= manager.taskSet.id
|
||||
manager.parent.removeSchedulable(manager)
|
||||
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
|
||||
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
|
||||
taskSetTaskIds -= manager.taskSet.id
|
||||
}
|
||||
}
|
||||
|
||||
def runTask(taskId: Long, bytes: ByteBuffer) {
|
||||
logInfo("Running " + taskId)
|
||||
val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
|
||||
// Set the Spark execution environment for the worker thread
|
||||
SparkEnv.set(env)
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
try {
|
||||
Accumulators.clear()
|
||||
Thread.currentThread().setContextClassLoader(classLoader)
|
||||
|
||||
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
|
||||
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
|
||||
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
|
||||
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
|
||||
val deserStart = System.currentTimeMillis()
|
||||
val deserializedTask = ser.deserialize[Task[_]](
|
||||
taskBytes, Thread.currentThread.getContextClassLoader)
|
||||
val deserTime = System.currentTimeMillis() - deserStart
|
||||
|
||||
// Run it
|
||||
val result: Any = deserializedTask.run(taskId)
|
||||
|
||||
// Serialize and deserialize the result to emulate what the Mesos
|
||||
// executor does. This is useful to catch serialization errors early
|
||||
// on in development (so when users move their local Spark programs
|
||||
// to the cluster, they don't get surprised by serialization errors).
|
||||
val serResult = ser.serialize(result)
|
||||
deserializedTask.metrics.get.resultSize = serResult.limit()
|
||||
val resultToReturn = ser.deserialize[Any](serResult)
|
||||
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
|
||||
ser.serialize(Accumulators.values))
|
||||
logInfo("Finished " + taskId)
|
||||
deserializedTask.metrics.get.executorRunTime = deserTime.toInt//info.duration.toInt //close enough
|
||||
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
|
||||
|
||||
val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
|
||||
val serializedResult = ser.serialize(taskResult)
|
||||
localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
|
||||
} catch {
|
||||
case t: Throwable => {
|
||||
val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
|
||||
localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -128,6 +197,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
|
|||
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
|
||||
currentFiles(name) = timestamp
|
||||
}
|
||||
|
||||
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
|
||||
logInfo("Fetching " + name + " with timestamp " + timestamp)
|
||||
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
|
||||
|
@ -143,7 +213,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
|
|||
}
|
||||
}
|
||||
|
||||
override def stop() {
|
||||
def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
synchronized {
|
||||
val taskSetId = taskIdToTaskSetId(taskId)
|
||||
val taskSetManager = activeTaskSets(taskSetId)
|
||||
taskSetTaskIds(taskSetId) -= taskId
|
||||
taskSetManager.statusUpdate(taskId, state, serializedData)
|
||||
}
|
||||
}
|
||||
|
||||
override def stop() {
|
||||
threadPool.shutdownNow()
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,172 @@
|
|||
package spark.scheduler.local
|
||||
|
||||
import java.io.File
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
import java.nio.ByteBuffer
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.mutable.HashSet
|
||||
|
||||
import spark._
|
||||
import spark.TaskState.TaskState
|
||||
import spark.scheduler._
|
||||
import spark.scheduler.cluster._
|
||||
|
||||
private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
|
||||
var parent: Schedulable = null
|
||||
var weight: Int = 1
|
||||
var minShare: Int = 0
|
||||
var runningTasks: Int = 0
|
||||
var priority: Int = taskSet.priority
|
||||
var stageId: Int = taskSet.stageId
|
||||
var name: String = "TaskSet_"+taskSet.stageId.toString
|
||||
|
||||
|
||||
var failCount = new Array[Int](taskSet.tasks.size)
|
||||
val taskInfos = new HashMap[Long, TaskInfo]
|
||||
val numTasks = taskSet.tasks.size
|
||||
var numFinished = 0
|
||||
val ser = SparkEnv.get.closureSerializer.newInstance()
|
||||
val copiesRunning = new Array[Int](numTasks)
|
||||
val finished = new Array[Boolean](numTasks)
|
||||
val numFailures = new Array[Int](numTasks)
|
||||
val MAX_TASK_FAILURES = sched.maxFailures
|
||||
|
||||
def increaseRunningTasks(taskNum: Int): Unit = {
|
||||
runningTasks += taskNum
|
||||
if (parent != null) {
|
||||
parent.increaseRunningTasks(taskNum)
|
||||
}
|
||||
}
|
||||
|
||||
def decreaseRunningTasks(taskNum: Int): Unit = {
|
||||
runningTasks -= taskNum
|
||||
if (parent != null) {
|
||||
parent.decreaseRunningTasks(taskNum)
|
||||
}
|
||||
}
|
||||
|
||||
def addSchedulable(schedulable: Schedulable): Unit = {
|
||||
//nothing
|
||||
}
|
||||
|
||||
def removeSchedulable(schedulable: Schedulable): Unit = {
|
||||
//nothing
|
||||
}
|
||||
|
||||
def getSchedulableByName(name: String): Schedulable = {
|
||||
return null
|
||||
}
|
||||
|
||||
def executorLost(executorId: String, host: String): Unit = {
|
||||
//nothing
|
||||
}
|
||||
|
||||
def checkSpeculatableTasks(): Boolean = {
|
||||
return true
|
||||
}
|
||||
|
||||
def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
|
||||
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
|
||||
sortedTaskSetQueue += this
|
||||
return sortedTaskSetQueue
|
||||
}
|
||||
|
||||
def hasPendingTasks(): Boolean = {
|
||||
return true
|
||||
}
|
||||
|
||||
def findTask(): Option[Int] = {
|
||||
for (i <- 0 to numTasks-1) {
|
||||
if (copiesRunning(i) == 0 && !finished(i)) {
|
||||
return Some(i)
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
|
||||
SparkEnv.set(sched.env)
|
||||
logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
|
||||
if (availableCpus > 0 && numFinished < numTasks) {
|
||||
findTask() match {
|
||||
case Some(index) =>
|
||||
val taskId = sched.attemptId.getAndIncrement()
|
||||
val task = taskSet.tasks(index)
|
||||
val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
|
||||
taskInfos(taskId) = info
|
||||
val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
|
||||
logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
|
||||
val taskName = "task %s:%d".format(taskSet.id, index)
|
||||
copiesRunning(index) += 1
|
||||
increaseRunningTasks(1)
|
||||
return Some(new TaskDescription(taskId, null, taskName, bytes))
|
||||
case None => {}
|
||||
}
|
||||
}
|
||||
return None
|
||||
}
|
||||
|
||||
def numPendingTasksForHostPort(hostPort: String): Int = {
|
||||
return 0
|
||||
}
|
||||
|
||||
def numRackLocalPendingTasksForHost(hostPort :String): Int = {
|
||||
return 0
|
||||
}
|
||||
|
||||
def numPendingTasksForHost(hostPort: String): Int = {
|
||||
return 0
|
||||
}
|
||||
|
||||
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
state match {
|
||||
case TaskState.FINISHED =>
|
||||
taskEnded(tid, state, serializedData)
|
||||
case TaskState.FAILED =>
|
||||
taskFailed(tid, state, serializedData)
|
||||
case _ => {}
|
||||
}
|
||||
}
|
||||
|
||||
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
val info = taskInfos(tid)
|
||||
val index = info.index
|
||||
val task = taskSet.tasks(index)
|
||||
info.markSuccessful()
|
||||
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
|
||||
result.metrics.resultSize = serializedData.limit()
|
||||
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
|
||||
numFinished += 1
|
||||
decreaseRunningTasks(1)
|
||||
finished(index) = true
|
||||
if (numFinished == numTasks) {
|
||||
sched.taskSetFinished(this)
|
||||
}
|
||||
}
|
||||
|
||||
def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
|
||||
val info = taskInfos(tid)
|
||||
val index = info.index
|
||||
val task = taskSet.tasks(index)
|
||||
info.markFailed()
|
||||
decreaseRunningTasks(1)
|
||||
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
|
||||
if (!finished(index)) {
|
||||
copiesRunning(index) -= 1
|
||||
numFailures(index) += 1
|
||||
val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
|
||||
logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
|
||||
if (numFailures(index) > MAX_TASK_FAILURES) {
|
||||
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
|
||||
decreaseRunningTasks(runningTasks)
|
||||
sched.listener.taskSetFailed(taskSet, errorMessage)
|
||||
// need to delete failed Taskset from schedule queue
|
||||
sched.taskSetFinished(this)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def error(message: String) {
|
||||
}
|
||||
}
|
|
@ -67,11 +67,20 @@ object BlockFetcherIterator {
|
|||
throw new IllegalArgumentException("BlocksByAddress is null")
|
||||
}
|
||||
|
||||
protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
|
||||
logDebug("Getting " + _totalBlocks + " blocks")
|
||||
// Total number blocks fetched (local + remote). Also number of FetchResults expected
|
||||
protected var _numBlocksToFetch = 0
|
||||
|
||||
protected var startTime = System.currentTimeMillis
|
||||
protected val localBlockIds = new ArrayBuffer[String]()
|
||||
protected val remoteBlockIds = new HashSet[String]()
|
||||
|
||||
// This represents the number of local blocks, also counting zero-sized blocks
|
||||
private var numLocal = 0
|
||||
// BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
|
||||
protected val localBlocksToFetch = new ArrayBuffer[String]()
|
||||
|
||||
// This represents the number of remote blocks, also counting zero-sized blocks
|
||||
private var numRemote = 0
|
||||
// BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
|
||||
protected val remoteBlocksToFetch = new HashSet[String]()
|
||||
|
||||
// A queue to hold our results.
|
||||
protected val results = new LinkedBlockingQueue[FetchResult]
|
||||
|
@ -124,13 +133,15 @@ object BlockFetcherIterator {
|
|||
protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
|
||||
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
|
||||
// at most maxBytesInFlight in order to limit the amount of data in flight.
|
||||
val originalTotalBlocks = _totalBlocks
|
||||
val remoteRequests = new ArrayBuffer[FetchRequest]
|
||||
for ((address, blockInfos) <- blocksByAddress) {
|
||||
if (address == blockManagerId) {
|
||||
localBlockIds ++= blockInfos.map(_._1)
|
||||
numLocal = blockInfos.size
|
||||
// Filter out zero-sized blocks
|
||||
localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
|
||||
_numBlocksToFetch += localBlocksToFetch.size
|
||||
} else {
|
||||
remoteBlockIds ++= blockInfos.map(_._1)
|
||||
numRemote += blockInfos.size
|
||||
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
|
||||
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
|
||||
// nodes, rather than blocking on reading output from one node.
|
||||
|
@ -144,10 +155,10 @@ object BlockFetcherIterator {
|
|||
// Skip empty blocks
|
||||
if (size > 0) {
|
||||
curBlocks += ((blockId, size))
|
||||
remoteBlocksToFetch += blockId
|
||||
_numBlocksToFetch += 1
|
||||
curRequestSize += size
|
||||
} else if (size == 0) {
|
||||
_totalBlocks -= 1
|
||||
} else {
|
||||
} else if (size < 0) {
|
||||
throw new BlockException(blockId, "Negative block size " + size)
|
||||
}
|
||||
if (curRequestSize >= minRequestSize) {
|
||||
|
@ -163,8 +174,8 @@ object BlockFetcherIterator {
|
|||
}
|
||||
}
|
||||
}
|
||||
logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
|
||||
originalTotalBlocks + " blocks")
|
||||
logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
|
||||
totalBlocks + " blocks")
|
||||
remoteRequests
|
||||
}
|
||||
|
||||
|
@ -172,7 +183,7 @@ object BlockFetcherIterator {
|
|||
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
|
||||
// these all at once because they will just memory-map some files, so they won't consume
|
||||
// any memory that might exceed our maxBytesInFlight
|
||||
for (id <- localBlockIds) {
|
||||
for (id <- localBlocksToFetch) {
|
||||
getLocalFromDisk(id, serializer) match {
|
||||
case Some(iter) => {
|
||||
// Pass 0 as size since it's not in flight
|
||||
|
@ -198,7 +209,7 @@ object BlockFetcherIterator {
|
|||
sendRequest(fetchRequests.dequeue())
|
||||
}
|
||||
|
||||
val numGets = remoteBlockIds.size - fetchRequests.size
|
||||
val numGets = remoteRequests.size - fetchRequests.size
|
||||
logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
|
||||
|
||||
// Get Local Blocks
|
||||
|
@ -210,7 +221,7 @@ object BlockFetcherIterator {
|
|||
//an iterator that will read fetched blocks off the queue as they arrive.
|
||||
@volatile protected var resultsGotten = 0
|
||||
|
||||
override def hasNext: Boolean = resultsGotten < _totalBlocks
|
||||
override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
|
||||
|
||||
override def next(): (String, Option[Iterator[Any]]) = {
|
||||
resultsGotten += 1
|
||||
|
@ -227,9 +238,9 @@ object BlockFetcherIterator {
|
|||
}
|
||||
|
||||
// Implementing BlockFetchTracker trait.
|
||||
override def totalBlocks: Int = _totalBlocks
|
||||
override def numLocalBlocks: Int = localBlockIds.size
|
||||
override def numRemoteBlocks: Int = remoteBlockIds.size
|
||||
override def totalBlocks: Int = numLocal + numRemote
|
||||
override def numLocalBlocks: Int = numLocal
|
||||
override def numRemoteBlocks: Int = numRemote
|
||||
override def remoteFetchTime: Long = _remoteFetchTime
|
||||
override def fetchWaitTime: Long = _fetchWaitTime
|
||||
override def remoteBytesRead: Long = _remoteBytesRead
|
||||
|
@ -265,7 +276,7 @@ object BlockFetcherIterator {
|
|||
}).toList
|
||||
}
|
||||
|
||||
//keep this to interrupt the threads when necessary
|
||||
// keep this to interrupt the threads when necessary
|
||||
private def stopCopiers() {
|
||||
for (copier <- copiers) {
|
||||
copier.interrupt()
|
||||
|
@ -291,7 +302,7 @@ object BlockFetcherIterator {
|
|||
private var copiers: List[_ <: Thread] = null
|
||||
|
||||
override def initialize() {
|
||||
// Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
|
||||
// Split Local Remote Blocks and set numBlocksToFetch
|
||||
val remoteRequests = splitLocalRemoteBlocks()
|
||||
// Add the remote requests into our queue in a random order
|
||||
for (request <- Utils.randomize(remoteRequests)) {
|
||||
|
@ -311,10 +322,7 @@ object BlockFetcherIterator {
|
|||
override def next(): (String, Option[Iterator[Any]]) = {
|
||||
resultsGotten += 1
|
||||
val result = results.take()
|
||||
// if all the results has been retrieved, shutdown the copiers
|
||||
if (resultsGotten == _totalBlocks && copiers != null) {
|
||||
stopCopiers()
|
||||
}
|
||||
// If all the results has been retrieved, copiers will exit automatically
|
||||
(result.blockId, if (result.failed) None else Some(result.deserialize()))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
private var bs: OutputStream = null
|
||||
private var objOut: SerializationStream = null
|
||||
private var lastValidPosition = 0L
|
||||
private var initialized = false
|
||||
|
||||
override def open(): DiskBlockObjectWriter = {
|
||||
val fos = new FileOutputStream(f, true)
|
||||
channel = fos.getChannel()
|
||||
bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
|
||||
bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
|
||||
objOut = serializer.newInstance().serializeStream(bs)
|
||||
initialized = true
|
||||
this
|
||||
}
|
||||
|
||||
override def close() {
|
||||
objOut.close()
|
||||
bs.close()
|
||||
channel = null
|
||||
bs = null
|
||||
objOut = null
|
||||
if (initialized) {
|
||||
objOut.close()
|
||||
bs.close()
|
||||
channel = null
|
||||
bs = null
|
||||
objOut = null
|
||||
}
|
||||
// Invoke the close callback handler.
|
||||
super.close()
|
||||
}
|
||||
|
@ -59,38 +63,48 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
// Flush the partial writes, and set valid length to be the length of the entire file.
|
||||
// Return the number of bytes written for this commit.
|
||||
override def commit(): Long = {
|
||||
// NOTE: Flush the serializer first and then the compressed/buffered output stream
|
||||
objOut.flush()
|
||||
bs.flush()
|
||||
val prevPos = lastValidPosition
|
||||
lastValidPosition = channel.position()
|
||||
lastValidPosition - prevPos
|
||||
if (initialized) {
|
||||
// NOTE: Flush the serializer first and then the compressed/buffered output stream
|
||||
objOut.flush()
|
||||
bs.flush()
|
||||
val prevPos = lastValidPosition
|
||||
lastValidPosition = channel.position()
|
||||
lastValidPosition - prevPos
|
||||
} else {
|
||||
// lastValidPosition is zero if stream is uninitialized
|
||||
lastValidPosition
|
||||
}
|
||||
}
|
||||
|
||||
override def revertPartialWrites() {
|
||||
// Discard current writes. We do this by flushing the outstanding writes and
|
||||
// truncate the file to the last valid position.
|
||||
objOut.flush()
|
||||
bs.flush()
|
||||
channel.truncate(lastValidPosition)
|
||||
if (initialized) {
|
||||
// Discard current writes. We do this by flushing the outstanding writes and
|
||||
// truncate the file to the last valid position.
|
||||
objOut.flush()
|
||||
bs.flush()
|
||||
channel.truncate(lastValidPosition)
|
||||
}
|
||||
}
|
||||
|
||||
override def write(value: Any) {
|
||||
if (!initialized) {
|
||||
open()
|
||||
}
|
||||
objOut.writeObject(value)
|
||||
}
|
||||
|
||||
override def size(): Long = lastValidPosition
|
||||
}
|
||||
|
||||
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
|
||||
val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
|
||||
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
|
||||
private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
|
||||
|
||||
var shuffleSender : ShuffleSender = null
|
||||
private var shuffleSender : ShuffleSender = null
|
||||
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
|
||||
// directory, create multiple subdirectories that we will hash files into, in order to avoid
|
||||
// having really large inodes at the top level.
|
||||
val localDirs = createLocalDirs()
|
||||
val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
|
||||
private val localDirs: Array[File] = createLocalDirs()
|
||||
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
|
||||
|
||||
addShutdownHook()
|
||||
|
||||
|
@ -99,7 +113,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
|
||||
}
|
||||
|
||||
|
||||
override def getSize(blockId: String): Long = {
|
||||
getFile(blockId).length()
|
||||
}
|
||||
|
@ -197,7 +210,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
|
||||
val file = getFile(blockId)
|
||||
if (!allowAppendExisting && file.exists()) {
|
||||
throw new Exception("File for block " + blockId + " already exists on disk: " + file)
|
||||
// NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
|
||||
// was rescheduled on the same machine as the old task.
|
||||
logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
|
||||
file.delete()
|
||||
}
|
||||
file
|
||||
}
|
||||
|
@ -232,8 +248,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
private def createLocalDirs(): Array[File] = {
|
||||
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
|
||||
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
|
||||
rootDirs.split(",").map(rootDir => {
|
||||
var foundLocalDir: Boolean = false
|
||||
rootDirs.split(",").map { rootDir =>
|
||||
var foundLocalDir = false
|
||||
var localDir: File = null
|
||||
var localDirId: String = null
|
||||
var tries = 0
|
||||
|
@ -248,7 +264,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
}
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logWarning("Attempt " + tries + " to create local dir failed", e)
|
||||
logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
|
||||
}
|
||||
}
|
||||
if (!foundLocalDir) {
|
||||
|
@ -258,7 +274,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
}
|
||||
logInfo("Created local directory at " + localDir)
|
||||
localDir
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
private def addShutdownHook() {
|
||||
|
@ -266,15 +282,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
|
|||
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
|
||||
override def run() {
|
||||
logDebug("Shutdown hook called")
|
||||
try {
|
||||
localDirs.foreach { localDir =>
|
||||
localDirs.foreach { localDir =>
|
||||
try {
|
||||
if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
|
||||
} catch {
|
||||
case t: Throwable =>
|
||||
logError("Exception while deleting local spark dir: " + localDir, t)
|
||||
}
|
||||
if (shuffleSender != null) {
|
||||
shuffleSender.stop
|
||||
}
|
||||
} catch {
|
||||
case t: Throwable => logError("Exception while deleting local spark dirs", t)
|
||||
}
|
||||
if (shuffleSender != null) {
|
||||
shuffleSender.stop
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
|
@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
|
|||
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
|
||||
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
|
||||
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
|
||||
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
|
||||
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
|
||||
}
|
||||
new ShuffleWriterGroup(mapId, writers)
|
||||
}
|
||||
|
|
45
core/src/main/scala/spark/util/BoundedPriorityQueue.scala
Normal file
45
core/src/main/scala/spark/util/BoundedPriorityQueue.scala
Normal file
|
@ -0,0 +1,45 @@
|
|||
package spark.util
|
||||
|
||||
import java.io.Serializable
|
||||
import java.util.{PriorityQueue => JPriorityQueue}
|
||||
import scala.collection.generic.Growable
|
||||
import scala.collection.JavaConverters._
|
||||
|
||||
/**
|
||||
* Bounded priority queue. This class wraps the original PriorityQueue
|
||||
* class and modifies it such that only the top K elements are retained.
|
||||
* The top K elements are defined by an implicit Ordering[A].
|
||||
*/
|
||||
class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
|
||||
extends Iterable[A] with Growable[A] with Serializable {
|
||||
|
||||
private val underlying = new JPriorityQueue[A](maxSize, ord)
|
||||
|
||||
override def iterator: Iterator[A] = underlying.iterator.asScala
|
||||
|
||||
override def ++=(xs: TraversableOnce[A]): this.type = {
|
||||
xs.foreach { this += _ }
|
||||
this
|
||||
}
|
||||
|
||||
override def +=(elem: A): this.type = {
|
||||
if (size < maxSize) underlying.offer(elem)
|
||||
else maybeReplaceLowest(elem)
|
||||
this
|
||||
}
|
||||
|
||||
override def +=(elem1: A, elem2: A, elems: A*): this.type = {
|
||||
this += elem1 += elem2 ++= elems
|
||||
}
|
||||
|
||||
override def clear() { underlying.clear() }
|
||||
|
||||
private def maybeReplaceLowest(a: A): Boolean = {
|
||||
val head = underlying.peek()
|
||||
if (head != null && ord.gt(a, head)) {
|
||||
underlying.poll()
|
||||
underlying.offer(a)
|
||||
} else false
|
||||
}
|
||||
}
|
||||
|
|
@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
|
|||
if (other == this) {
|
||||
merge(other.copy()) // Avoid overwriting fields in a weird order
|
||||
} else {
|
||||
val delta = other.mu - mu
|
||||
if (other.n * 10 < n) {
|
||||
mu = mu + (delta * other.n) / (n + other.n)
|
||||
} else if (n * 10 < other.n) {
|
||||
mu = other.mu - (delta * n) / (n + other.n)
|
||||
} else {
|
||||
mu = (mu * n + other.mu * other.n) / (n + other.n)
|
||||
if (n == 0) {
|
||||
mu = other.mu
|
||||
m2 = other.m2
|
||||
n = other.n
|
||||
} else if (other.n != 0) {
|
||||
val delta = other.mu - mu
|
||||
if (other.n * 10 < n) {
|
||||
mu = mu + (delta * other.n) / (n + other.n)
|
||||
} else if (n * 10 < other.n) {
|
||||
mu = other.mu - (delta * n) / (n + other.n)
|
||||
} else {
|
||||
mu = (mu * n + other.mu * other.n) / (n + other.n)
|
||||
}
|
||||
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
|
||||
n += other.n
|
||||
}
|
||||
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
|
||||
n += other.n
|
||||
this
|
||||
this
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -27,6 +27,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
|
|||
}
|
||||
}
|
||||
|
||||
test("basic checkpointing") {
|
||||
val parCollection = sc.makeRDD(1 to 4)
|
||||
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
|
||||
flatMappedRDD.checkpoint()
|
||||
assert(flatMappedRDD.dependencies.head.rdd == parCollection)
|
||||
val result = flatMappedRDD.collect()
|
||||
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
|
||||
assert(flatMappedRDD.collect() === result)
|
||||
}
|
||||
|
||||
test("RDDs with one-to-one dependencies") {
|
||||
testCheckpointing(_.map(x => x.toString))
|
||||
testCheckpointing(_.flatMap(x => 1 to x))
|
||||
|
|
|
@ -7,6 +7,8 @@ import scala.io.Source
|
|||
import com.google.common.io.Files
|
||||
import org.scalatest.FunSuite
|
||||
import org.apache.hadoop.io._
|
||||
import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec}
|
||||
|
||||
|
||||
import SparkContext._
|
||||
|
||||
|
@ -26,6 +28,28 @@ class FileSuite extends FunSuite with LocalSparkContext {
|
|||
assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4"))
|
||||
}
|
||||
|
||||
test("text files (compressed)") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val tempDir = Files.createTempDir()
|
||||
val normalDir = new File(tempDir, "output_normal").getAbsolutePath
|
||||
val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
|
||||
val codec = new DefaultCodec()
|
||||
|
||||
val data = sc.parallelize("a" * 10000, 1)
|
||||
data.saveAsTextFile(normalDir)
|
||||
data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec])
|
||||
|
||||
val normalFile = new File(normalDir, "part-00000")
|
||||
val normalContent = sc.textFile(normalDir).collect
|
||||
assert(normalContent === Array.fill(10000)("a"))
|
||||
|
||||
val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
|
||||
val compressedContent = sc.textFile(compressedOutputDir).collect
|
||||
assert(compressedContent === Array.fill(10000)("a"))
|
||||
|
||||
assert(compressedFile.length < normalFile.length)
|
||||
}
|
||||
|
||||
test("SequenceFiles") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val tempDir = Files.createTempDir()
|
||||
|
@ -37,6 +61,28 @@ class FileSuite extends FunSuite with LocalSparkContext {
|
|||
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
|
||||
}
|
||||
|
||||
test("SequenceFile (compressed)") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val tempDir = Files.createTempDir()
|
||||
val normalDir = new File(tempDir, "output_normal").getAbsolutePath
|
||||
val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
|
||||
val codec = new DefaultCodec()
|
||||
|
||||
val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x))
|
||||
data.saveAsSequenceFile(normalDir)
|
||||
data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec]))
|
||||
|
||||
val normalFile = new File(normalDir, "part-00000")
|
||||
val normalContent = sc.sequenceFile[String, String](normalDir).collect
|
||||
assert(normalContent === Array.fill(100)("abc", "abc"))
|
||||
|
||||
val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
|
||||
val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect
|
||||
assert(compressedContent === Array.fill(100)("abc", "abc"))
|
||||
|
||||
assert(compressedFile.length < normalFile.length)
|
||||
}
|
||||
|
||||
test("SequenceFile with writable key") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val tempDir = Files.createTempDir()
|
||||
|
|
|
@ -8,6 +8,7 @@ import java.util.*;
|
|||
import scala.Tuple2;
|
||||
|
||||
import com.google.common.base.Charsets;
|
||||
import org.apache.hadoop.io.compress.DefaultCodec;
|
||||
import com.google.common.io.Files;
|
||||
import org.apache.hadoop.io.IntWritable;
|
||||
import org.apache.hadoop.io.Text;
|
||||
|
@ -473,6 +474,19 @@ public class JavaAPISuite implements Serializable {
|
|||
Assert.assertEquals(expected, readRDD.collect());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void textFilesCompressed() throws IOException {
|
||||
File tempDir = Files.createTempDir();
|
||||
String outputDir = new File(tempDir, "output").getAbsolutePath();
|
||||
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
|
||||
rdd.saveAsTextFile(outputDir, DefaultCodec.class);
|
||||
|
||||
// Try reading it in as a text file RDD
|
||||
List<String> expected = Arrays.asList("1", "2", "3", "4");
|
||||
JavaRDD<String> readRDD = sc.textFile(outputDir);
|
||||
Assert.assertEquals(expected, readRDD.collect());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void sequenceFile() {
|
||||
File tempDir = Files.createTempDir();
|
||||
|
@ -619,6 +633,37 @@ public class JavaAPISuite implements Serializable {
|
|||
}).collect().toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void hadoopFileCompressed() {
|
||||
File tempDir = Files.createTempDir();
|
||||
String outputDir = new File(tempDir, "output_compressed").getAbsolutePath();
|
||||
List<Tuple2<Integer, String>> pairs = Arrays.asList(
|
||||
new Tuple2<Integer, String>(1, "a"),
|
||||
new Tuple2<Integer, String>(2, "aa"),
|
||||
new Tuple2<Integer, String>(3, "aaa")
|
||||
);
|
||||
JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
|
||||
|
||||
rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
|
||||
@Override
|
||||
public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
|
||||
return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
|
||||
}
|
||||
}).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class,
|
||||
DefaultCodec.class);
|
||||
|
||||
JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir,
|
||||
SequenceFileInputFormat.class, IntWritable.class, Text.class);
|
||||
|
||||
Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
|
||||
String>() {
|
||||
@Override
|
||||
public String call(Tuple2<IntWritable, Text> x) {
|
||||
return x.toString();
|
||||
}
|
||||
}).collect().toString());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void zip() {
|
||||
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
|
||||
|
|
287
core/src/test/scala/spark/PairRDDFunctionsSuite.scala
Normal file
287
core/src/test/scala/spark/PairRDDFunctionsSuite.scala
Normal file
|
@ -0,0 +1,287 @@
|
|||
package spark
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.HashSet
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.prop.Checkers
|
||||
import org.scalacheck.Arbitrary._
|
||||
import org.scalacheck.Gen
|
||||
import org.scalacheck.Prop._
|
||||
|
||||
import com.google.common.io.Files
|
||||
|
||||
import spark.rdd.ShuffledRDD
|
||||
import spark.SparkContext._
|
||||
|
||||
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
|
||||
test("groupByKey") {
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
|
||||
val groups = pairs.groupByKey().collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesFor1 = groups.find(_._1 == 1).get._2
|
||||
assert(valuesFor1.toList.sorted === List(1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("groupByKey with duplicates") {
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val groups = pairs.groupByKey().collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesFor1 = groups.find(_._1 == 1).get._2
|
||||
assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("groupByKey with negative key hash codes") {
|
||||
val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
|
||||
val groups = pairs.groupByKey().collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesForMinus1 = groups.find(_._1 == -1).get._2
|
||||
assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("groupByKey with many output partitions") {
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
|
||||
val groups = pairs.groupByKey(10).collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesFor1 = groups.find(_._1 == 1).get._2
|
||||
assert(valuesFor1.toList.sorted === List(1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("reduceByKey") {
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val sums = pairs.reduceByKey(_+_).collect()
|
||||
assert(sums.toSet === Set((1, 7), (2, 1)))
|
||||
}
|
||||
|
||||
test("reduceByKey with collectAsMap") {
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val sums = pairs.reduceByKey(_+_).collectAsMap()
|
||||
assert(sums.size === 2)
|
||||
assert(sums(1) === 7)
|
||||
assert(sums(2) === 1)
|
||||
}
|
||||
|
||||
test("reduceByKey with many output partitons") {
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val sums = pairs.reduceByKey(_+_, 10).collect()
|
||||
assert(sums.toSet === Set((1, 7), (2, 1)))
|
||||
}
|
||||
|
||||
test("reduceByKey with partitioner") {
|
||||
val p = new Partitioner() {
|
||||
def numPartitions = 2
|
||||
def getPartition(key: Any) = key.asInstanceOf[Int]
|
||||
}
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
|
||||
val sums = pairs.reduceByKey(_+_)
|
||||
assert(sums.collect().toSet === Set((1, 4), (0, 1)))
|
||||
assert(sums.partitioner === Some(p))
|
||||
// count the dependencies to make sure there is only 1 ShuffledRDD
|
||||
val deps = new HashSet[RDD[_]]()
|
||||
def visit(r: RDD[_]) {
|
||||
for (dep <- r.dependencies) {
|
||||
deps += dep.rdd
|
||||
visit(dep.rdd)
|
||||
}
|
||||
}
|
||||
visit(sums)
|
||||
assert(deps.size === 2) // ShuffledRDD, ParallelCollection
|
||||
}
|
||||
|
||||
test("join") {
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.join(rdd2).collect()
|
||||
assert(joined.size === 4)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, 'x')),
|
||||
(1, (2, 'x')),
|
||||
(2, (1, 'y')),
|
||||
(2, (1, 'z'))
|
||||
))
|
||||
}
|
||||
|
||||
test("join all-to-all") {
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
|
||||
val joined = rdd1.join(rdd2).collect()
|
||||
assert(joined.size === 6)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, 'x')),
|
||||
(1, (1, 'y')),
|
||||
(1, (2, 'x')),
|
||||
(1, (2, 'y')),
|
||||
(1, (3, 'x')),
|
||||
(1, (3, 'y'))
|
||||
))
|
||||
}
|
||||
|
||||
test("leftOuterJoin") {
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.leftOuterJoin(rdd2).collect()
|
||||
assert(joined.size === 5)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, Some('x'))),
|
||||
(1, (2, Some('x'))),
|
||||
(2, (1, Some('y'))),
|
||||
(2, (1, Some('z'))),
|
||||
(3, (1, None))
|
||||
))
|
||||
}
|
||||
|
||||
test("rightOuterJoin") {
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.rightOuterJoin(rdd2).collect()
|
||||
assert(joined.size === 5)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (Some(1), 'x')),
|
||||
(1, (Some(2), 'x')),
|
||||
(2, (Some(1), 'y')),
|
||||
(2, (Some(1), 'z')),
|
||||
(4, (None, 'w'))
|
||||
))
|
||||
}
|
||||
|
||||
test("join with no matches") {
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
|
||||
val joined = rdd1.join(rdd2).collect()
|
||||
assert(joined.size === 0)
|
||||
}
|
||||
|
||||
test("join with many output partitions") {
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.join(rdd2, 10).collect()
|
||||
assert(joined.size === 4)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, 'x')),
|
||||
(1, (2, 'x')),
|
||||
(2, (1, 'y')),
|
||||
(2, (1, 'z'))
|
||||
))
|
||||
}
|
||||
|
||||
test("groupWith") {
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.groupWith(rdd2).collect()
|
||||
assert(joined.size === 4)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
|
||||
(2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
|
||||
(3, (ArrayBuffer(1), ArrayBuffer())),
|
||||
(4, (ArrayBuffer(), ArrayBuffer('w')))
|
||||
))
|
||||
}
|
||||
|
||||
test("zero-partition RDD") {
|
||||
val emptyDir = Files.createTempDir()
|
||||
val file = sc.textFile(emptyDir.getAbsolutePath)
|
||||
assert(file.partitions.size == 0)
|
||||
assert(file.collect().toList === Nil)
|
||||
// Test that a shuffle on the file works, because this used to be a bug
|
||||
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
|
||||
}
|
||||
|
||||
test("keys and values") {
|
||||
val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
|
||||
assert(rdd.keys.collect().toList === List(1, 2))
|
||||
assert(rdd.values.collect().toList === List("a", "b"))
|
||||
}
|
||||
|
||||
test("default partitioner uses partition size") {
|
||||
// specify 2000 partitions
|
||||
val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
|
||||
// do a map, which loses the partitioner
|
||||
val b = a.map(a => (a, (a * 2).toString))
|
||||
// then a group by, and see we didn't revert to 2 partitions
|
||||
val c = b.groupByKey()
|
||||
assert(c.partitions.size === 2000)
|
||||
}
|
||||
|
||||
test("default partitioner uses largest partitioner") {
|
||||
val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
|
||||
val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
|
||||
val c = a.join(b)
|
||||
assert(c.partitions.size === 2000)
|
||||
}
|
||||
|
||||
test("subtract") {
|
||||
val a = sc.parallelize(Array(1, 2, 3), 2)
|
||||
val b = sc.parallelize(Array(2, 3, 4), 4)
|
||||
val c = a.subtract(b)
|
||||
assert(c.collect().toSet === Set(1))
|
||||
assert(c.partitions.size === a.partitions.size)
|
||||
}
|
||||
|
||||
test("subtract with narrow dependency") {
|
||||
// use a deterministic partitioner
|
||||
val p = new Partitioner() {
|
||||
def numPartitions = 5
|
||||
def getPartition(key: Any) = key.asInstanceOf[Int]
|
||||
}
|
||||
// partitionBy so we have a narrow dependency
|
||||
val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
|
||||
// more partitions/no partitioner so a shuffle dependency
|
||||
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
|
||||
val c = a.subtract(b)
|
||||
assert(c.collect().toSet === Set((1, "a"), (3, "c")))
|
||||
// Ideally we could keep the original partitioner...
|
||||
assert(c.partitioner === None)
|
||||
}
|
||||
|
||||
test("subtractByKey") {
|
||||
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
|
||||
val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
|
||||
val c = a.subtractByKey(b)
|
||||
assert(c.collect().toSet === Set((1, "a"), (1, "a")))
|
||||
assert(c.partitions.size === a.partitions.size)
|
||||
}
|
||||
|
||||
test("subtractByKey with narrow dependency") {
|
||||
// use a deterministic partitioner
|
||||
val p = new Partitioner() {
|
||||
def numPartitions = 5
|
||||
def getPartition(key: Any) = key.asInstanceOf[Int]
|
||||
}
|
||||
// partitionBy so we have a narrow dependency
|
||||
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
|
||||
// more partitions/no partitioner so a shuffle dependency
|
||||
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
|
||||
val c = a.subtractByKey(b)
|
||||
assert(c.collect().toSet === Set((1, "a"), (1, "a")))
|
||||
assert(c.partitioner.get === p)
|
||||
}
|
||||
|
||||
test("foldByKey") {
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val sums = pairs.foldByKey(0)(_+_).collect()
|
||||
assert(sums.toSet === Set((1, 7), (2, 1)))
|
||||
}
|
||||
|
||||
test("foldByKey with mutable result type") {
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache()
|
||||
// Fold the values using in-place mutation
|
||||
val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect()
|
||||
assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1))))
|
||||
// Check that the mutable objects in the original RDD were not changed
|
||||
assert(bufs.collect().toSet === Set(
|
||||
(1, ArrayBuffer(1)),
|
||||
(1, ArrayBuffer(2)),
|
||||
(1, ArrayBuffer(3)),
|
||||
(1, ArrayBuffer(1)),
|
||||
(2, ArrayBuffer(1))))
|
||||
}
|
||||
}
|
|
@ -1,13 +1,13 @@
|
|||
package spark
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import SparkContext._
|
||||
import spark.util.StatCounter
|
||||
import scala.math.abs
|
||||
|
||||
class PartitioningSuite extends FunSuite with SharedSparkContext {
|
||||
|
||||
class PartitioningSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
test("HashPartitioner equality") {
|
||||
val p2 = new HashPartitioner(2)
|
||||
val p4 = new HashPartitioner(4)
|
||||
|
@ -21,8 +21,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("RangePartitioner equality") {
|
||||
sc = new SparkContext("local", "test")
|
||||
|
||||
// Make an RDD where all the elements are the same so that the partition range bounds
|
||||
// are deterministically all the same.
|
||||
val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x))
|
||||
|
@ -50,7 +48,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("HashPartitioner not equal to RangePartitioner") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
|
||||
val rangeP2 = new RangePartitioner(2, rdd)
|
||||
val hashP2 = new HashPartitioner(2)
|
||||
|
@ -61,8 +58,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("partitioner preservation") {
|
||||
sc = new SparkContext("local", "test")
|
||||
|
||||
val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x))
|
||||
|
||||
val grouped2 = rdd.groupByKey(2)
|
||||
|
@ -101,7 +96,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("partitioning Java arrays should fail") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
|
||||
val arrPairs: RDD[(Array[Int], Int)] =
|
||||
sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
|
||||
|
@ -120,4 +114,20 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
|
|||
assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
|
||||
assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
|
||||
}
|
||||
|
||||
test("zero-length partitions should be correctly handled") {
|
||||
// Create RDD with some consecutive empty partitions (including the "first" one)
|
||||
val rdd: RDD[Double] = sc
|
||||
.parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8)
|
||||
.filter(_ >= 0.0)
|
||||
|
||||
// Run the partitions, including the consecutive empty ones, through StatCounter
|
||||
val stats: StatCounter = rdd.stats();
|
||||
assert(abs(6.0 - stats.sum) < 0.01);
|
||||
assert(abs(6.0/2 - rdd.mean) < 0.01);
|
||||
assert(abs(1.0 - rdd.variance) < 0.01);
|
||||
assert(abs(1.0 - rdd.stdev) < 0.01);
|
||||
|
||||
// Add other tests here for classes that should be able to handle empty partitions correctly
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,10 +3,9 @@ package spark
|
|||
import org.scalatest.FunSuite
|
||||
import SparkContext._
|
||||
|
||||
class PipedRDDSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
class PipedRDDSuite extends FunSuite with SharedSparkContext {
|
||||
|
||||
test("basic pipe") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
|
||||
|
||||
val piped = nums.pipe(Seq("cat"))
|
||||
|
@ -19,8 +18,45 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
|
|||
assert(c(3) === "4")
|
||||
}
|
||||
|
||||
test("advanced pipe") {
|
||||
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
|
||||
val bl = sc.broadcast(List("0"))
|
||||
|
||||
val piped = nums.pipe(Seq("cat"),
|
||||
Map[String, String](),
|
||||
(f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
|
||||
(i:Int, f: String=> Unit) => f(i + "_"))
|
||||
|
||||
val c = piped.collect()
|
||||
|
||||
assert(c.size === 8)
|
||||
assert(c(0) === "0")
|
||||
assert(c(1) === "\u0001")
|
||||
assert(c(2) === "1_")
|
||||
assert(c(3) === "2_")
|
||||
assert(c(4) === "0")
|
||||
assert(c(5) === "\u0001")
|
||||
assert(c(6) === "3_")
|
||||
assert(c(7) === "4_")
|
||||
|
||||
val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
|
||||
val d = nums1.groupBy(str=>str.split("\t")(0)).
|
||||
pipe(Seq("cat"),
|
||||
Map[String, String](),
|
||||
(f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
|
||||
(i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
|
||||
assert(d.size === 8)
|
||||
assert(d(0) === "0")
|
||||
assert(d(1) === "\u0001")
|
||||
assert(d(2) === "b\t2_")
|
||||
assert(d(3) === "b\t4_")
|
||||
assert(d(4) === "0")
|
||||
assert(d(5) === "\u0001")
|
||||
assert(d(6) === "a\t1_")
|
||||
assert(d(7) === "a\t3_")
|
||||
}
|
||||
|
||||
test("pipe with env variable") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
|
||||
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
|
||||
val c = piped.collect()
|
||||
|
@ -30,7 +66,6 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("pipe with non-zero exit status") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
|
||||
val piped = nums.pipe("cat nonexistent_file")
|
||||
intercept[SparkException] {
|
||||
|
|
|
@ -7,10 +7,9 @@ import org.scalatest.time.{Span, Millis}
|
|||
import spark.SparkContext._
|
||||
import spark.rdd.{CoalescedRDD, CoGroupedRDD, EmptyRDD, PartitionPruningRDD, ShuffledRDD}
|
||||
|
||||
class RDDSuite extends FunSuite with LocalSparkContext {
|
||||
class RDDSuite extends FunSuite with SharedSparkContext {
|
||||
|
||||
test("basic operations") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
|
||||
assert(nums.collect().toList === List(1, 2, 3, 4))
|
||||
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
|
||||
|
@ -46,7 +45,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("SparkContext.union") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
|
||||
assert(sc.union(nums).collect().toList === List(1, 2, 3, 4))
|
||||
assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
|
||||
|
@ -55,7 +53,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("aggregate") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
|
||||
type StringMap = HashMap[String, Int]
|
||||
val emptyMap = new StringMap {
|
||||
|
@ -75,57 +72,14 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
|
||||
}
|
||||
|
||||
test("basic checkpointing") {
|
||||
import java.io.File
|
||||
val checkpointDir = File.createTempFile("temp", "")
|
||||
checkpointDir.delete()
|
||||
|
||||
sc = new SparkContext("local", "test")
|
||||
sc.setCheckpointDir(checkpointDir.toString)
|
||||
val parCollection = sc.makeRDD(1 to 4)
|
||||
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
|
||||
flatMappedRDD.checkpoint()
|
||||
assert(flatMappedRDD.dependencies.head.rdd == parCollection)
|
||||
val result = flatMappedRDD.collect()
|
||||
Thread.sleep(1000)
|
||||
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
|
||||
assert(flatMappedRDD.collect() === result)
|
||||
|
||||
checkpointDir.deleteOnExit()
|
||||
}
|
||||
|
||||
test("basic caching") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
|
||||
assert(rdd.collect().toList === List(1, 2, 3, 4))
|
||||
assert(rdd.collect().toList === List(1, 2, 3, 4))
|
||||
assert(rdd.collect().toList === List(1, 2, 3, 4))
|
||||
}
|
||||
|
||||
test("unpersist RDD") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
|
||||
rdd.count
|
||||
assert(sc.persistentRdds.isEmpty === false)
|
||||
rdd.unpersist()
|
||||
assert(sc.persistentRdds.isEmpty === true)
|
||||
|
||||
failAfter(Span(3000, Millis)) {
|
||||
try {
|
||||
while (! sc.getRDDStorageInfo.isEmpty) {
|
||||
Thread.sleep(200)
|
||||
}
|
||||
} catch {
|
||||
case _ => { Thread.sleep(10) }
|
||||
// Do nothing. We might see exceptions because block manager
|
||||
// is racing this thread to remove entries from the driver.
|
||||
}
|
||||
}
|
||||
assert(sc.getRDDStorageInfo.isEmpty === true)
|
||||
}
|
||||
|
||||
test("caching with failures") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val onlySplit = new Partition { override def index: Int = 0 }
|
||||
var shouldFail = true
|
||||
val rdd = new RDD[Int](sc, Nil) {
|
||||
|
@ -148,7 +102,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("empty RDD") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val empty = new EmptyRDD[Int](sc)
|
||||
assert(empty.count === 0)
|
||||
assert(empty.collect().size === 0)
|
||||
|
@ -168,37 +121,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("cogrouped RDDs") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2)
|
||||
val rdd2 = sc.makeRDD(Array((1, "one1"), (1, "another one1"), (2, "two1")), 2)
|
||||
|
||||
// Use cogroup function
|
||||
val cogrouped = rdd1.cogroup(rdd2).collectAsMap()
|
||||
assert(cogrouped(1) === (Seq("one", "another one"), Seq("one1", "another one1")))
|
||||
assert(cogrouped(2) === (Seq("two"), Seq("two1")))
|
||||
assert(cogrouped(3) === (Seq("three"), Seq()))
|
||||
|
||||
// Construct CoGroupedRDD directly, with map side combine enabled
|
||||
val cogrouped1 = new CoGroupedRDD[Int](
|
||||
Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
|
||||
new HashPartitioner(3),
|
||||
true).collectAsMap()
|
||||
assert(cogrouped1(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
|
||||
assert(cogrouped1(2).toSeq === Seq(Seq("two"), Seq("two1")))
|
||||
assert(cogrouped1(3).toSeq === Seq(Seq("three"), Seq()))
|
||||
|
||||
// Construct CoGroupedRDD directly, with map side combine disabled
|
||||
val cogrouped2 = new CoGroupedRDD[Int](
|
||||
Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
|
||||
new HashPartitioner(3),
|
||||
false).collectAsMap()
|
||||
assert(cogrouped2(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
|
||||
assert(cogrouped2(2).toSeq === Seq(Seq("two"), Seq("two1")))
|
||||
assert(cogrouped2(3).toSeq === Seq(Seq("three"), Seq()))
|
||||
}
|
||||
|
||||
test("coalesced RDDs") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val data = sc.parallelize(1 to 10, 10)
|
||||
|
||||
val coalesced1 = data.coalesce(2)
|
||||
|
@ -236,7 +158,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("zipped RDDs") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
|
||||
val zipped = nums.zip(nums.map(_ + 1.0))
|
||||
assert(zipped.glom().map(_.toList).collect().toList ===
|
||||
|
@ -248,7 +169,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
}
|
||||
|
||||
test("partition pruning") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val data = sc.parallelize(1 to 10, 10)
|
||||
// Note that split number starts from 0, so > 8 means only 10th partition left.
|
||||
val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
|
||||
|
@ -260,7 +180,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
|
||||
test("mapWith") {
|
||||
import java.util.Random
|
||||
sc = new SparkContext("local", "test")
|
||||
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
|
||||
val randoms = ones.mapWith(
|
||||
(index: Int) => new Random(index + 42))
|
||||
|
@ -279,7 +198,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
|
||||
test("flatMapWith") {
|
||||
import java.util.Random
|
||||
sc = new SparkContext("local", "test")
|
||||
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
|
||||
val randoms = ones.flatMapWith(
|
||||
(index: Int) => new Random(index + 42))
|
||||
|
@ -301,7 +219,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
|
||||
test("filterWith") {
|
||||
import java.util.Random
|
||||
sc = new SparkContext("local", "test")
|
||||
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
|
||||
val sample = ints.filterWith(
|
||||
(index: Int) => new Random(index + 42))
|
||||
|
@ -317,4 +234,21 @@ class RDDSuite extends FunSuite with LocalSparkContext {
|
|||
assert(sample.size === checkSample.size)
|
||||
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
|
||||
}
|
||||
|
||||
test("top with predefined ordering") {
|
||||
val nums = Array.range(1, 100000)
|
||||
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
|
||||
val topK = ints.top(5)
|
||||
assert(topK.size === 5)
|
||||
assert(topK.sorted === nums.sorted.takeRight(5))
|
||||
}
|
||||
|
||||
test("top with custom ordering") {
|
||||
val words = Vector("a", "b", "c", "d")
|
||||
implicit val ord = implicitly[Ordering[String]].reverse
|
||||
val rdd = sc.makeRDD(words, 2)
|
||||
val topK = rdd.top(2)
|
||||
assert(topK.size === 2)
|
||||
assert(topK.sorted === Array("b", "a"))
|
||||
}
|
||||
}
|
||||
|
|
25
core/src/test/scala/spark/SharedSparkContext.scala
Normal file
25
core/src/test/scala/spark/SharedSparkContext.scala
Normal file
|
@ -0,0 +1,25 @@
|
|||
package spark
|
||||
|
||||
import org.scalatest.Suite
|
||||
import org.scalatest.BeforeAndAfterAll
|
||||
|
||||
/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */
|
||||
trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
|
||||
|
||||
@transient private var _sc: SparkContext = _
|
||||
|
||||
def sc: SparkContext = _sc
|
||||
|
||||
override def beforeAll() {
|
||||
_sc = new SparkContext("local", "test")
|
||||
super.beforeAll()
|
||||
}
|
||||
|
||||
override def afterAll() {
|
||||
if (_sc != null) {
|
||||
LocalSparkContext.stop(_sc)
|
||||
_sc = null
|
||||
}
|
||||
super.afterAll()
|
||||
}
|
||||
}
|
|
@ -16,54 +16,9 @@ import spark.rdd.ShuffledRDD
|
|||
import spark.SparkContext._
|
||||
|
||||
class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
|
||||
|
||||
test("groupByKey") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
|
||||
val groups = pairs.groupByKey().collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesFor1 = groups.find(_._1 == 1).get._2
|
||||
assert(valuesFor1.toList.sorted === List(1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("groupByKey with duplicates") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val groups = pairs.groupByKey().collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesFor1 = groups.find(_._1 == 1).get._2
|
||||
assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("groupByKey with negative key hash codes") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
|
||||
val groups = pairs.groupByKey().collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesForMinus1 = groups.find(_._1 == -1).get._2
|
||||
assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("groupByKey with many output partitions") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
|
||||
val groups = pairs.groupByKey(10).collect()
|
||||
assert(groups.size === 2)
|
||||
val valuesFor1 = groups.find(_._1 == 1).get._2
|
||||
assert(valuesFor1.toList.sorted === List(1, 2, 3))
|
||||
val valuesFor2 = groups.find(_._1 == 2).get._2
|
||||
assert(valuesFor2.toList.sorted === List(1))
|
||||
}
|
||||
|
||||
test("groupByKey with compression") {
|
||||
try {
|
||||
System.setProperty("spark.blockManager.compress", "true")
|
||||
System.setProperty("spark.shuffle.compress", "true")
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4)
|
||||
val groups = pairs.groupByKey(4).collect()
|
||||
|
@ -77,234 +32,6 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
|
|||
}
|
||||
}
|
||||
|
||||
test("reduceByKey") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val sums = pairs.reduceByKey(_+_).collect()
|
||||
assert(sums.toSet === Set((1, 7), (2, 1)))
|
||||
}
|
||||
|
||||
test("reduceByKey with collectAsMap") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val sums = pairs.reduceByKey(_+_).collectAsMap()
|
||||
assert(sums.size === 2)
|
||||
assert(sums(1) === 7)
|
||||
assert(sums(2) === 1)
|
||||
}
|
||||
|
||||
test("reduceByKey with many output partitons") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
|
||||
val sums = pairs.reduceByKey(_+_, 10).collect()
|
||||
assert(sums.toSet === Set((1, 7), (2, 1)))
|
||||
}
|
||||
|
||||
test("reduceByKey with partitioner") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val p = new Partitioner() {
|
||||
def numPartitions = 2
|
||||
def getPartition(key: Any) = key.asInstanceOf[Int]
|
||||
}
|
||||
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
|
||||
val sums = pairs.reduceByKey(_+_)
|
||||
assert(sums.collect().toSet === Set((1, 4), (0, 1)))
|
||||
assert(sums.partitioner === Some(p))
|
||||
// count the dependencies to make sure there is only 1 ShuffledRDD
|
||||
val deps = new HashSet[RDD[_]]()
|
||||
def visit(r: RDD[_]) {
|
||||
for (dep <- r.dependencies) {
|
||||
deps += dep.rdd
|
||||
visit(dep.rdd)
|
||||
}
|
||||
}
|
||||
visit(sums)
|
||||
assert(deps.size === 2) // ShuffledRDD, ParallelCollection
|
||||
}
|
||||
|
||||
test("join") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.join(rdd2).collect()
|
||||
assert(joined.size === 4)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, 'x')),
|
||||
(1, (2, 'x')),
|
||||
(2, (1, 'y')),
|
||||
(2, (1, 'z'))
|
||||
))
|
||||
}
|
||||
|
||||
test("join all-to-all") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
|
||||
val joined = rdd1.join(rdd2).collect()
|
||||
assert(joined.size === 6)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, 'x')),
|
||||
(1, (1, 'y')),
|
||||
(1, (2, 'x')),
|
||||
(1, (2, 'y')),
|
||||
(1, (3, 'x')),
|
||||
(1, (3, 'y'))
|
||||
))
|
||||
}
|
||||
|
||||
test("leftOuterJoin") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.leftOuterJoin(rdd2).collect()
|
||||
assert(joined.size === 5)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, Some('x'))),
|
||||
(1, (2, Some('x'))),
|
||||
(2, (1, Some('y'))),
|
||||
(2, (1, Some('z'))),
|
||||
(3, (1, None))
|
||||
))
|
||||
}
|
||||
|
||||
test("rightOuterJoin") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.rightOuterJoin(rdd2).collect()
|
||||
assert(joined.size === 5)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (Some(1), 'x')),
|
||||
(1, (Some(2), 'x')),
|
||||
(2, (Some(1), 'y')),
|
||||
(2, (Some(1), 'z')),
|
||||
(4, (None, 'w'))
|
||||
))
|
||||
}
|
||||
|
||||
test("join with no matches") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
|
||||
val joined = rdd1.join(rdd2).collect()
|
||||
assert(joined.size === 0)
|
||||
}
|
||||
|
||||
test("join with many output partitions") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.join(rdd2, 10).collect()
|
||||
assert(joined.size === 4)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (1, 'x')),
|
||||
(1, (2, 'x')),
|
||||
(2, (1, 'y')),
|
||||
(2, (1, 'z'))
|
||||
))
|
||||
}
|
||||
|
||||
test("groupWith") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
|
||||
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
|
||||
val joined = rdd1.groupWith(rdd2).collect()
|
||||
assert(joined.size === 4)
|
||||
assert(joined.toSet === Set(
|
||||
(1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
|
||||
(2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
|
||||
(3, (ArrayBuffer(1), ArrayBuffer())),
|
||||
(4, (ArrayBuffer(), ArrayBuffer('w')))
|
||||
))
|
||||
}
|
||||
|
||||
test("zero-partition RDD") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val emptyDir = Files.createTempDir()
|
||||
val file = sc.textFile(emptyDir.getAbsolutePath)
|
||||
assert(file.partitions.size == 0)
|
||||
assert(file.collect().toList === Nil)
|
||||
// Test that a shuffle on the file works, because this used to be a bug
|
||||
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
|
||||
}
|
||||
|
||||
test("keys and values") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
|
||||
assert(rdd.keys.collect().toList === List(1, 2))
|
||||
assert(rdd.values.collect().toList === List("a", "b"))
|
||||
}
|
||||
|
||||
test("default partitioner uses partition size") {
|
||||
sc = new SparkContext("local", "test")
|
||||
// specify 2000 partitions
|
||||
val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
|
||||
// do a map, which loses the partitioner
|
||||
val b = a.map(a => (a, (a * 2).toString))
|
||||
// then a group by, and see we didn't revert to 2 partitions
|
||||
val c = b.groupByKey()
|
||||
assert(c.partitions.size === 2000)
|
||||
}
|
||||
|
||||
test("default partitioner uses largest partitioner") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
|
||||
val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
|
||||
val c = a.join(b)
|
||||
assert(c.partitions.size === 2000)
|
||||
}
|
||||
|
||||
test("subtract") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val a = sc.parallelize(Array(1, 2, 3), 2)
|
||||
val b = sc.parallelize(Array(2, 3, 4), 4)
|
||||
val c = a.subtract(b)
|
||||
assert(c.collect().toSet === Set(1))
|
||||
assert(c.partitions.size === a.partitions.size)
|
||||
}
|
||||
|
||||
test("subtract with narrow dependency") {
|
||||
sc = new SparkContext("local", "test")
|
||||
// use a deterministic partitioner
|
||||
val p = new Partitioner() {
|
||||
def numPartitions = 5
|
||||
def getPartition(key: Any) = key.asInstanceOf[Int]
|
||||
}
|
||||
// partitionBy so we have a narrow dependency
|
||||
val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
|
||||
// more partitions/no partitioner so a shuffle dependency
|
||||
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
|
||||
val c = a.subtract(b)
|
||||
assert(c.collect().toSet === Set((1, "a"), (3, "c")))
|
||||
// Ideally we could keep the original partitioner...
|
||||
assert(c.partitioner === None)
|
||||
}
|
||||
|
||||
test("subtractByKey") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
|
||||
val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
|
||||
val c = a.subtractByKey(b)
|
||||
assert(c.collect().toSet === Set((1, "a"), (1, "a")))
|
||||
assert(c.partitions.size === a.partitions.size)
|
||||
}
|
||||
|
||||
test("subtractByKey with narrow dependency") {
|
||||
sc = new SparkContext("local", "test")
|
||||
// use a deterministic partitioner
|
||||
val p = new Partitioner() {
|
||||
def numPartitions = 5
|
||||
def getPartition(key: Any) = key.asInstanceOf[Int]
|
||||
}
|
||||
// partitionBy so we have a narrow dependency
|
||||
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
|
||||
// more partitions/no partitioner so a shuffle dependency
|
||||
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
|
||||
val c = a.subtractByKey(b)
|
||||
assert(c.collect().toSet === Set((1, "a"), (1, "a")))
|
||||
assert(c.partitioner.get === p)
|
||||
}
|
||||
|
||||
test("shuffle non-zero block size") {
|
||||
sc = new SparkContext("local-cluster[2,1,512]", "test")
|
||||
val NUM_BLOCKS = 3
|
||||
|
@ -367,6 +94,30 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
|
|||
assert(nonEmptyBlocks.size <= 4)
|
||||
}
|
||||
|
||||
test("zero sized blocks without kryo") {
|
||||
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
|
||||
sc = new SparkContext("local-cluster[2,1,512]", "test")
|
||||
|
||||
// 10 partitions from 4 keys
|
||||
val NUM_BLOCKS = 10
|
||||
val a = sc.parallelize(1 to 4, NUM_BLOCKS)
|
||||
val b = a.map(x => (x, x*2))
|
||||
|
||||
// NOTE: The default Java serializer should create zero-sized blocks
|
||||
val c = new ShuffledRDD(b, new HashPartitioner(10))
|
||||
|
||||
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
|
||||
assert(c.count === 4)
|
||||
|
||||
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
|
||||
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
|
||||
statuses.map(x => x._2)
|
||||
}
|
||||
val nonEmptyBlocks = blockSizes.filter(x => x > 0)
|
||||
|
||||
// We should have at most 4 non-zero sized partitions
|
||||
assert(nonEmptyBlocks.size <= 4)
|
||||
}
|
||||
}
|
||||
|
||||
object ShuffleSuite {
|
||||
|
|
|
@ -35,7 +35,7 @@ class SizeEstimatorSuite
|
|||
var oldOops: String = _
|
||||
|
||||
override def beforeAll() {
|
||||
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
|
||||
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
|
||||
oldArch = System.setProperty("os.arch", "amd64")
|
||||
oldOops = System.setProperty("spark.test.useCompressedOops", "true")
|
||||
}
|
||||
|
@ -46,54 +46,54 @@ class SizeEstimatorSuite
|
|||
}
|
||||
|
||||
test("simple classes") {
|
||||
expect(16)(SizeEstimator.estimate(new DummyClass1))
|
||||
expect(16)(SizeEstimator.estimate(new DummyClass2))
|
||||
expect(24)(SizeEstimator.estimate(new DummyClass3))
|
||||
expect(24)(SizeEstimator.estimate(new DummyClass4(null)))
|
||||
expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
|
||||
assert(SizeEstimator.estimate(new DummyClass1) === 16)
|
||||
assert(SizeEstimator.estimate(new DummyClass2) === 16)
|
||||
assert(SizeEstimator.estimate(new DummyClass3) === 24)
|
||||
assert(SizeEstimator.estimate(new DummyClass4(null)) === 24)
|
||||
assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48)
|
||||
}
|
||||
|
||||
// NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
|
||||
// (Sun vs IBM). Use a DummyString class to make tests deterministic.
|
||||
test("strings") {
|
||||
expect(40)(SizeEstimator.estimate(DummyString("")))
|
||||
expect(48)(SizeEstimator.estimate(DummyString("a")))
|
||||
expect(48)(SizeEstimator.estimate(DummyString("ab")))
|
||||
expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
|
||||
assert(SizeEstimator.estimate(DummyString("")) === 40)
|
||||
assert(SizeEstimator.estimate(DummyString("a")) === 48)
|
||||
assert(SizeEstimator.estimate(DummyString("ab")) === 48)
|
||||
assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
|
||||
}
|
||||
|
||||
test("primitive arrays") {
|
||||
expect(32)(SizeEstimator.estimate(new Array[Byte](10)))
|
||||
expect(40)(SizeEstimator.estimate(new Array[Char](10)))
|
||||
expect(40)(SizeEstimator.estimate(new Array[Short](10)))
|
||||
expect(56)(SizeEstimator.estimate(new Array[Int](10)))
|
||||
expect(96)(SizeEstimator.estimate(new Array[Long](10)))
|
||||
expect(56)(SizeEstimator.estimate(new Array[Float](10)))
|
||||
expect(96)(SizeEstimator.estimate(new Array[Double](10)))
|
||||
expect(4016)(SizeEstimator.estimate(new Array[Int](1000)))
|
||||
expect(8016)(SizeEstimator.estimate(new Array[Long](1000)))
|
||||
assert(SizeEstimator.estimate(new Array[Byte](10)) === 32)
|
||||
assert(SizeEstimator.estimate(new Array[Char](10)) === 40)
|
||||
assert(SizeEstimator.estimate(new Array[Short](10)) === 40)
|
||||
assert(SizeEstimator.estimate(new Array[Int](10)) === 56)
|
||||
assert(SizeEstimator.estimate(new Array[Long](10)) === 96)
|
||||
assert(SizeEstimator.estimate(new Array[Float](10)) === 56)
|
||||
assert(SizeEstimator.estimate(new Array[Double](10)) === 96)
|
||||
assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016)
|
||||
assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016)
|
||||
}
|
||||
|
||||
test("object arrays") {
|
||||
// Arrays containing nulls should just have one pointer per element
|
||||
expect(56)(SizeEstimator.estimate(new Array[String](10)))
|
||||
expect(56)(SizeEstimator.estimate(new Array[AnyRef](10)))
|
||||
assert(SizeEstimator.estimate(new Array[String](10)) === 56)
|
||||
assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56)
|
||||
|
||||
// For object arrays with non-null elements, each object should take one pointer plus
|
||||
// however many bytes that class takes. (Note that Array.fill calls the code in its
|
||||
// second parameter separately for each object, so we get distinct objects.)
|
||||
expect(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)))
|
||||
expect(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)))
|
||||
expect(296)(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)))
|
||||
expect(56)(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)))
|
||||
assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216)
|
||||
assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216)
|
||||
assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296)
|
||||
assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56)
|
||||
|
||||
// Past size 100, our samples 100 elements, but we should still get the right size.
|
||||
expect(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)))
|
||||
assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016)
|
||||
|
||||
// If an array contains the *same* element many times, we should only count it once.
|
||||
val d1 = new DummyClass1
|
||||
expect(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object
|
||||
expect(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object
|
||||
assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object
|
||||
assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object
|
||||
|
||||
// Same thing with huge array containing the same element many times. Note that this won't
|
||||
// return exactly 4032 because it can't tell that *all* the elements will equal the first
|
||||
|
@ -111,10 +111,10 @@ class SizeEstimatorSuite
|
|||
val initialize = PrivateMethod[Unit]('initialize)
|
||||
SizeEstimator invokePrivate initialize()
|
||||
|
||||
expect(40)(SizeEstimator.estimate(DummyString("")))
|
||||
expect(48)(SizeEstimator.estimate(DummyString("a")))
|
||||
expect(48)(SizeEstimator.estimate(DummyString("ab")))
|
||||
expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
|
||||
assert(SizeEstimator.estimate(DummyString("")) === 40)
|
||||
assert(SizeEstimator.estimate(DummyString("a")) === 48)
|
||||
assert(SizeEstimator.estimate(DummyString("ab")) === 48)
|
||||
assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
|
||||
|
||||
resetOrClear("os.arch", arch)
|
||||
}
|
||||
|
@ -128,10 +128,10 @@ class SizeEstimatorSuite
|
|||
val initialize = PrivateMethod[Unit]('initialize)
|
||||
SizeEstimator invokePrivate initialize()
|
||||
|
||||
expect(56)(SizeEstimator.estimate(DummyString("")))
|
||||
expect(64)(SizeEstimator.estimate(DummyString("a")))
|
||||
expect(64)(SizeEstimator.estimate(DummyString("ab")))
|
||||
expect(72)(SizeEstimator.estimate(DummyString("abcdefgh")))
|
||||
assert(SizeEstimator.estimate(DummyString("")) === 56)
|
||||
assert(SizeEstimator.estimate(DummyString("a")) === 64)
|
||||
assert(SizeEstimator.estimate(DummyString("ab")) === 64)
|
||||
assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72)
|
||||
|
||||
resetOrClear("os.arch", arch)
|
||||
resetOrClear("spark.test.useCompressedOops", oops)
|
||||
|
|
|
@ -5,16 +5,14 @@ import org.scalatest.BeforeAndAfter
|
|||
import org.scalatest.matchers.ShouldMatchers
|
||||
import SparkContext._
|
||||
|
||||
class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging {
|
||||
|
||||
class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging {
|
||||
|
||||
test("sortByKey") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
|
||||
assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
|
||||
assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
|
||||
}
|
||||
|
||||
test("large array") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr, 2)
|
||||
|
@ -24,7 +22,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
|
|||
}
|
||||
|
||||
test("large array with one split") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr, 2)
|
||||
|
@ -32,9 +29,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
|
|||
assert(sorted.partitions.size === 1)
|
||||
assert(sorted.collect() === pairArr.sortBy(_._1))
|
||||
}
|
||||
|
||||
|
||||
test("large array with many partitions") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr, 2)
|
||||
|
@ -42,9 +38,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
|
|||
assert(sorted.partitions.size === 20)
|
||||
assert(sorted.collect() === pairArr.sortBy(_._1))
|
||||
}
|
||||
|
||||
|
||||
test("sort descending") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr, 2)
|
||||
|
@ -52,15 +47,13 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
|
|||
}
|
||||
|
||||
test("sort descending with one split") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr, 1)
|
||||
assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
|
||||
}
|
||||
|
||||
|
||||
test("sort descending with many partitions") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr, 2)
|
||||
|
@ -68,7 +61,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
|
|||
}
|
||||
|
||||
test("more partitions than elements") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rand = new scala.util.Random()
|
||||
val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
|
||||
val pairs = sc.parallelize(pairArr, 30)
|
||||
|
@ -76,14 +68,12 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
|
|||
}
|
||||
|
||||
test("empty RDD") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairArr = new Array[(Int, Int)](0)
|
||||
val pairs = sc.parallelize(pairArr, 2)
|
||||
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
|
||||
}
|
||||
|
||||
test("partition balancing") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairArr = (1 to 1000).map(x => (x, x)).toArray
|
||||
val sorted = sc.parallelize(pairArr, 4).sortByKey()
|
||||
assert(sorted.collect() === pairArr.sortBy(_._1))
|
||||
|
@ -99,7 +89,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
|
|||
}
|
||||
|
||||
test("partition balancing for descending sort") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val pairArr = (1 to 1000).map(x => (x, x)).toArray
|
||||
val sorted = sc.parallelize(pairArr, 4).sortByKey(false)
|
||||
assert(sorted.collect() === pairArr.sortBy(_._1).reverse)
|
||||
|
|
30
core/src/test/scala/spark/UnpersistSuite.scala
Normal file
30
core/src/test/scala/spark/UnpersistSuite.scala
Normal file
|
@ -0,0 +1,30 @@
|
|||
package spark
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.concurrent.Timeouts._
|
||||
import org.scalatest.time.{Span, Millis}
|
||||
import spark.SparkContext._
|
||||
|
||||
class UnpersistSuite extends FunSuite with LocalSparkContext {
|
||||
test("unpersist RDD") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
|
||||
rdd.count
|
||||
assert(sc.persistentRdds.isEmpty === false)
|
||||
rdd.unpersist()
|
||||
assert(sc.persistentRdds.isEmpty === true)
|
||||
|
||||
failAfter(Span(3000, Millis)) {
|
||||
try {
|
||||
while (! sc.getRDDStorageInfo.isEmpty) {
|
||||
Thread.sleep(200)
|
||||
}
|
||||
} catch {
|
||||
case _ => { Thread.sleep(10) }
|
||||
// Do nothing. We might see exceptions because block manager
|
||||
// is racing this thread to remove entries from the driver.
|
||||
}
|
||||
}
|
||||
assert(sc.getRDDStorageInfo.isEmpty === true)
|
||||
}
|
||||
}
|
|
@ -27,24 +27,49 @@ class UtilsSuite extends FunSuite {
|
|||
assert(os.toByteArray.toList.equals(bytes.toList))
|
||||
}
|
||||
|
||||
test("memoryStringToMb"){
|
||||
assert(Utils.memoryStringToMb("1") == 0)
|
||||
assert(Utils.memoryStringToMb("1048575") == 0)
|
||||
assert(Utils.memoryStringToMb("3145728") == 3)
|
||||
test("memoryStringToMb") {
|
||||
assert(Utils.memoryStringToMb("1") === 0)
|
||||
assert(Utils.memoryStringToMb("1048575") === 0)
|
||||
assert(Utils.memoryStringToMb("3145728") === 3)
|
||||
|
||||
assert(Utils.memoryStringToMb("1024k") == 1)
|
||||
assert(Utils.memoryStringToMb("5000k") == 4)
|
||||
assert(Utils.memoryStringToMb("4024k") == Utils.memoryStringToMb("4024K"))
|
||||
assert(Utils.memoryStringToMb("1024k") === 1)
|
||||
assert(Utils.memoryStringToMb("5000k") === 4)
|
||||
assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K"))
|
||||
|
||||
assert(Utils.memoryStringToMb("1024m") == 1024)
|
||||
assert(Utils.memoryStringToMb("5000m") == 5000)
|
||||
assert(Utils.memoryStringToMb("4024m") == Utils.memoryStringToMb("4024M"))
|
||||
assert(Utils.memoryStringToMb("1024m") === 1024)
|
||||
assert(Utils.memoryStringToMb("5000m") === 5000)
|
||||
assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M"))
|
||||
|
||||
assert(Utils.memoryStringToMb("2g") == 2048)
|
||||
assert(Utils.memoryStringToMb("3g") == Utils.memoryStringToMb("3G"))
|
||||
assert(Utils.memoryStringToMb("2g") === 2048)
|
||||
assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G"))
|
||||
|
||||
assert(Utils.memoryStringToMb("2t") == 2097152)
|
||||
assert(Utils.memoryStringToMb("3t") == Utils.memoryStringToMb("3T"))
|
||||
assert(Utils.memoryStringToMb("2t") === 2097152)
|
||||
assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T"))
|
||||
}
|
||||
|
||||
test("splitCommandString") {
|
||||
assert(Utils.splitCommandString("") === Seq())
|
||||
assert(Utils.splitCommandString("a") === Seq("a"))
|
||||
assert(Utils.splitCommandString("aaa") === Seq("aaa"))
|
||||
assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c"))
|
||||
assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c"))
|
||||
assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c"))
|
||||
assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d"))
|
||||
assert(Utils.splitCommandString("'b c'") === Seq("b c"))
|
||||
assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c"))
|
||||
assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d"))
|
||||
assert(Utils.splitCommandString("\"b c\"") === Seq("b c"))
|
||||
assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e"))
|
||||
assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d"))
|
||||
assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c"))
|
||||
assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c"))
|
||||
assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c"))
|
||||
assert(Utils.splitCommandString("'a'b") === Seq("ab"))
|
||||
assert(Utils.splitCommandString("'a''b'") === Seq("ab"))
|
||||
assert(Utils.splitCommandString("\"a\"b") === Seq("ab"))
|
||||
assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab"))
|
||||
assert(Utils.splitCommandString("''") === Seq(""))
|
||||
assert(Utils.splitCommandString("\"\"") === Seq(""))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -17,9 +17,8 @@ object ZippedPartitionsSuite {
|
|||
}
|
||||
}
|
||||
|
||||
class ZippedPartitionsSuite extends FunSuite with LocalSparkContext {
|
||||
class ZippedPartitionsSuite extends FunSuite with SharedSparkContext {
|
||||
test("print sizes") {
|
||||
sc = new SparkContext("local", "test")
|
||||
val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
|
||||
val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2)
|
||||
val data3 = sc.makeRDD(Array(1.0, 2.0), 2)
|
||||
|
|
|
@ -16,7 +16,7 @@ class DummyTaskSetManager(
|
|||
initNumTasks: Int,
|
||||
clusterScheduler: ClusterScheduler,
|
||||
taskSet: TaskSet)
|
||||
extends TaskSetManager(clusterScheduler,taskSet) {
|
||||
extends ClusterTaskSetManager(clusterScheduler,taskSet) {
|
||||
|
||||
parent = null
|
||||
weight = 1
|
||||
|
|
104
core/src/test/scala/spark/scheduler/JobLoggerSuite.scala
Normal file
104
core/src/test/scala/spark/scheduler/JobLoggerSuite.scala
Normal file
|
@ -0,0 +1,104 @@
|
|||
package spark.scheduler
|
||||
|
||||
import java.util.Properties
|
||||
import java.util.concurrent.LinkedBlockingQueue
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.matchers.ShouldMatchers
|
||||
import scala.collection.mutable
|
||||
import spark._
|
||||
import spark.SparkContext._
|
||||
|
||||
|
||||
class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
|
||||
|
||||
test("inner method") {
|
||||
sc = new SparkContext("local", "joblogger")
|
||||
val joblogger = new JobLogger {
|
||||
def createLogWriterTest(jobID: Int) = createLogWriter(jobID)
|
||||
def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID)
|
||||
def getRddNameTest(rdd: RDD[_]) = getRddName(rdd)
|
||||
def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage)
|
||||
}
|
||||
type MyRDD = RDD[(Int, Int)]
|
||||
def makeRdd(
|
||||
numPartitions: Int,
|
||||
dependencies: List[Dependency[_]]
|
||||
): MyRDD = {
|
||||
val maxPartition = numPartitions - 1
|
||||
return new MyRDD(sc, dependencies) {
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
|
||||
throw new RuntimeException("should not be reached")
|
||||
override def getPartitions = (0 to maxPartition).map(i => new Partition {
|
||||
override def index = i
|
||||
}).toArray
|
||||
}
|
||||
}
|
||||
val jobID = 5
|
||||
val parentRdd = makeRdd(4, Nil)
|
||||
val shuffleDep = new ShuffleDependency(parentRdd, null)
|
||||
val rootRdd = makeRdd(4, List(shuffleDep))
|
||||
val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID)
|
||||
val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID)
|
||||
|
||||
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4))
|
||||
joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
|
||||
parentRdd.setName("MyRDD")
|
||||
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
|
||||
joblogger.createLogWriterTest(jobID)
|
||||
joblogger.getJobIDtoPrintWriter.size should be (1)
|
||||
joblogger.buildJobDepTest(jobID, rootStage)
|
||||
joblogger.getJobIDToStages.get(jobID).get.size should be (2)
|
||||
joblogger.getStageIDToJobID.get(0) should be (Some(jobID))
|
||||
joblogger.getStageIDToJobID.get(1) should be (Some(jobID))
|
||||
joblogger.closeLogWriterTest(jobID)
|
||||
joblogger.getStageIDToJobID.size should be (0)
|
||||
joblogger.getJobIDToStages.size should be (0)
|
||||
joblogger.getJobIDtoPrintWriter.size should be (0)
|
||||
}
|
||||
|
||||
test("inner variables") {
|
||||
sc = new SparkContext("local[4]", "joblogger")
|
||||
val joblogger = new JobLogger {
|
||||
override protected def closeLogWriter(jobID: Int) =
|
||||
getJobIDtoPrintWriter.get(jobID).foreach { fileWriter =>
|
||||
fileWriter.close()
|
||||
}
|
||||
}
|
||||
sc.addSparkListener(joblogger)
|
||||
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
|
||||
rdd.reduceByKey(_+_).collect()
|
||||
|
||||
joblogger.getLogDir should be ("/tmp/spark")
|
||||
joblogger.getJobIDtoPrintWriter.size should be (1)
|
||||
joblogger.getStageIDToJobID.size should be (2)
|
||||
joblogger.getStageIDToJobID.get(0) should be (Some(0))
|
||||
joblogger.getStageIDToJobID.get(1) should be (Some(0))
|
||||
joblogger.getJobIDToStages.size should be (1)
|
||||
}
|
||||
|
||||
|
||||
test("interface functions") {
|
||||
sc = new SparkContext("local[4]", "joblogger")
|
||||
val joblogger = new JobLogger {
|
||||
var onTaskEndCount = 0
|
||||
var onJobEndCount = 0
|
||||
var onJobStartCount = 0
|
||||
var onStageCompletedCount = 0
|
||||
var onStageSubmittedCount = 0
|
||||
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1
|
||||
override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1
|
||||
override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1
|
||||
override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1
|
||||
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1
|
||||
}
|
||||
sc.addSparkListener(joblogger)
|
||||
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
|
||||
rdd.reduceByKey(_+_).collect()
|
||||
|
||||
joblogger.onJobStartCount should be (1)
|
||||
joblogger.onJobEndCount should be (1)
|
||||
joblogger.onTaskEndCount should be (8)
|
||||
joblogger.onStageSubmittedCount should be (2)
|
||||
joblogger.onStageCompletedCount should be (2)
|
||||
}
|
||||
}
|
206
core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
Normal file
206
core/src/test/scala/spark/scheduler/LocalSchedulerSuite.scala
Normal file
|
@ -0,0 +1,206 @@
|
|||
package spark.scheduler
|
||||
|
||||
import org.scalatest.FunSuite
|
||||
import org.scalatest.BeforeAndAfter
|
||||
|
||||
import spark._
|
||||
import spark.scheduler._
|
||||
import spark.scheduler.cluster._
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
import scala.collection.mutable.{ConcurrentMap, HashMap}
|
||||
import java.util.concurrent.Semaphore
|
||||
import java.util.concurrent.CountDownLatch
|
||||
import java.util.Properties
|
||||
|
||||
class Lock() {
|
||||
var finished = false
|
||||
def jobWait() = {
|
||||
synchronized {
|
||||
while(!finished) {
|
||||
this.wait()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def jobFinished() = {
|
||||
synchronized {
|
||||
finished = true
|
||||
this.notifyAll()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object TaskThreadInfo {
|
||||
val threadToLock = HashMap[Int, Lock]()
|
||||
val threadToRunning = HashMap[Int, Boolean]()
|
||||
val threadToStarted = HashMap[Int, CountDownLatch]()
|
||||
}
|
||||
|
||||
/*
|
||||
* 1. each thread contains one job.
|
||||
* 2. each job contains one stage.
|
||||
* 3. each stage only contains one task.
|
||||
* 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
|
||||
* it will get cpu core resource, and will wait to finished after user manually
|
||||
* release "Lock" and then cluster will contain another free cpu cores.
|
||||
* 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
|
||||
* thus it will be scheduled later when cluster has free cpu cores.
|
||||
*/
|
||||
class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
|
||||
|
||||
def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
|
||||
|
||||
TaskThreadInfo.threadToRunning(threadIndex) = false
|
||||
val nums = sc.parallelize(threadIndex to threadIndex, 1)
|
||||
TaskThreadInfo.threadToLock(threadIndex) = new Lock()
|
||||
TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
|
||||
new Thread {
|
||||
if (poolName != null) {
|
||||
sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName)
|
||||
}
|
||||
override def run() {
|
||||
val ans = nums.map(number => {
|
||||
TaskThreadInfo.threadToRunning(number) = true
|
||||
TaskThreadInfo.threadToStarted(number).countDown()
|
||||
TaskThreadInfo.threadToLock(number).jobWait()
|
||||
TaskThreadInfo.threadToRunning(number) = false
|
||||
number
|
||||
}).collect()
|
||||
assert(ans.toList === List(threadIndex))
|
||||
sem.release()
|
||||
}
|
||||
}.start()
|
||||
}
|
||||
|
||||
test("Local FIFO scheduler end-to-end test") {
|
||||
System.setProperty("spark.cluster.schedulingmode", "FIFO")
|
||||
sc = new SparkContext("local[4]", "test")
|
||||
val sem = new Semaphore(0)
|
||||
|
||||
createThread(1,null,sc,sem)
|
||||
TaskThreadInfo.threadToStarted(1).await()
|
||||
createThread(2,null,sc,sem)
|
||||
TaskThreadInfo.threadToStarted(2).await()
|
||||
createThread(3,null,sc,sem)
|
||||
TaskThreadInfo.threadToStarted(3).await()
|
||||
createThread(4,null,sc,sem)
|
||||
TaskThreadInfo.threadToStarted(4).await()
|
||||
// thread 5 and 6 (stage pending)must meet following two points
|
||||
// 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
|
||||
// queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
|
||||
// 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
|
||||
// So I just use "sleep" 1s here for each thread.
|
||||
// TODO: any better solution?
|
||||
createThread(5,null,sc,sem)
|
||||
Thread.sleep(1000)
|
||||
createThread(6,null,sc,sem)
|
||||
Thread.sleep(1000)
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(1) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(2) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(3) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(4) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(5) === false)
|
||||
assert(TaskThreadInfo.threadToRunning(6) === false)
|
||||
|
||||
TaskThreadInfo.threadToLock(1).jobFinished()
|
||||
TaskThreadInfo.threadToStarted(5).await()
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(1) === false)
|
||||
assert(TaskThreadInfo.threadToRunning(2) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(3) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(4) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(5) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(6) === false)
|
||||
|
||||
TaskThreadInfo.threadToLock(3).jobFinished()
|
||||
TaskThreadInfo.threadToStarted(6).await()
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(1) === false)
|
||||
assert(TaskThreadInfo.threadToRunning(2) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(3) === false)
|
||||
assert(TaskThreadInfo.threadToRunning(4) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(5) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(6) === true)
|
||||
|
||||
TaskThreadInfo.threadToLock(2).jobFinished()
|
||||
TaskThreadInfo.threadToLock(4).jobFinished()
|
||||
TaskThreadInfo.threadToLock(5).jobFinished()
|
||||
TaskThreadInfo.threadToLock(6).jobFinished()
|
||||
sem.acquire(6)
|
||||
}
|
||||
|
||||
test("Local fair scheduler end-to-end test") {
|
||||
sc = new SparkContext("local[8]", "LocalSchedulerSuite")
|
||||
val sem = new Semaphore(0)
|
||||
System.setProperty("spark.cluster.schedulingmode", "FAIR")
|
||||
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
|
||||
System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
|
||||
|
||||
createThread(10,"1",sc,sem)
|
||||
TaskThreadInfo.threadToStarted(10).await()
|
||||
createThread(20,"2",sc,sem)
|
||||
TaskThreadInfo.threadToStarted(20).await()
|
||||
createThread(30,"3",sc,sem)
|
||||
TaskThreadInfo.threadToStarted(30).await()
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(10) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(20) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(30) === true)
|
||||
|
||||
createThread(11,"1",sc,sem)
|
||||
TaskThreadInfo.threadToStarted(11).await()
|
||||
createThread(21,"2",sc,sem)
|
||||
TaskThreadInfo.threadToStarted(21).await()
|
||||
createThread(31,"3",sc,sem)
|
||||
TaskThreadInfo.threadToStarted(31).await()
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(11) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(21) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(31) === true)
|
||||
|
||||
createThread(12,"1",sc,sem)
|
||||
TaskThreadInfo.threadToStarted(12).await()
|
||||
createThread(22,"2",sc,sem)
|
||||
TaskThreadInfo.threadToStarted(22).await()
|
||||
createThread(32,"3",sc,sem)
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(12) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(22) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(32) === false)
|
||||
|
||||
TaskThreadInfo.threadToLock(10).jobFinished()
|
||||
TaskThreadInfo.threadToStarted(32).await()
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(32) === true)
|
||||
|
||||
//1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
|
||||
// queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
|
||||
//2. priority of 23 and 33 will be meaningless as using fair scheduler here.
|
||||
createThread(23,"2",sc,sem)
|
||||
createThread(33,"3",sc,sem)
|
||||
Thread.sleep(1000)
|
||||
|
||||
TaskThreadInfo.threadToLock(11).jobFinished()
|
||||
TaskThreadInfo.threadToStarted(23).await()
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(23) === true)
|
||||
assert(TaskThreadInfo.threadToRunning(33) === false)
|
||||
|
||||
TaskThreadInfo.threadToLock(12).jobFinished()
|
||||
TaskThreadInfo.threadToStarted(33).await()
|
||||
|
||||
assert(TaskThreadInfo.threadToRunning(33) === true)
|
||||
|
||||
TaskThreadInfo.threadToLock(20).jobFinished()
|
||||
TaskThreadInfo.threadToLock(21).jobFinished()
|
||||
TaskThreadInfo.threadToLock(22).jobFinished()
|
||||
TaskThreadInfo.threadToLock(23).jobFinished()
|
||||
TaskThreadInfo.threadToLock(30).jobFinished()
|
||||
TaskThreadInfo.threadToLock(31).jobFinished()
|
||||
TaskThreadInfo.threadToLock(32).jobFinished()
|
||||
TaskThreadInfo.threadToLock(33).jobFinished()
|
||||
|
||||
sem.acquire(11)
|
||||
}
|
||||
}
|
|
@ -77,7 +77,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
|
|||
|
||||
class SaveStageInfo extends SparkListener {
|
||||
val stageInfos = mutable.Buffer[StageInfo]()
|
||||
def onStageCompleted(stage: StageCompleted) {
|
||||
override def onStageCompleted(stage: StageCompleted) {
|
||||
stageInfos += stage.stageInfo
|
||||
}
|
||||
}
|
||||
|
|
|
@ -260,6 +260,13 @@ Apart from these, the following properties are also available, and may be useful
|
|||
applications). Note that any RDD that persists in memory for more than this duration will be cleared as well.
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>spark.streaming.blockInterval</td>
|
||||
<td>200</td>
|
||||
<td>
|
||||
Duration (milliseconds) of how long to batch new objects coming from network receivers.
|
||||
</td>
|
||||
</tr>
|
||||
|
||||
</table>
|
||||
|
||||
|
|
|
@ -27,14 +27,14 @@ Short functions can be passed to RDD methods using Python's [`lambda`](http://ww
|
|||
|
||||
{% highlight python %}
|
||||
logData = sc.textFile(logFile).cache()
|
||||
errors = logData.filter(lambda s: 'ERROR' in s.split())
|
||||
errors = logData.filter(lambda line: "ERROR" in line)
|
||||
{% endhighlight %}
|
||||
|
||||
You can also pass functions that are defined using the `def` keyword; this is useful for more complicated functions that cannot be expressed using `lambda`:
|
||||
|
||||
{% highlight python %}
|
||||
def is_error(line):
|
||||
return 'ERROR' in line.split()
|
||||
return "ERROR" in line
|
||||
errors = logData.filter(is_error)
|
||||
{% endhighlight %}
|
||||
|
||||
|
@ -43,8 +43,7 @@ Functions can access objects in enclosing scopes, although modifications to thos
|
|||
{% highlight python %}
|
||||
error_keywords = ["Exception", "Error"]
|
||||
def is_error(line):
|
||||
words = line.split()
|
||||
return any(keyword in words for keyword in error_keywords)
|
||||
return any(keyword in line for keyword in error_keywords)
|
||||
errors = logData.filter(is_error)
|
||||
{% endhighlight %}
|
||||
|
||||
|
|
|
@ -43,12 +43,18 @@ new SparkContext(master, appName, [sparkHome], [jars])
|
|||
|
||||
The `master` parameter is a string specifying a [Spark or Mesos cluster URL](#master-urls) to connect to, or a special "local" string to run in local mode, as described below. `appName` is a name for your application, which will be shown in the cluster web UI. Finally, the last two parameters are needed to deploy your code to a cluster if running in distributed mode, as described later.
|
||||
|
||||
In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable. For example, to run on four cores, use
|
||||
In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable, and you can add JARs to the classpath with the `ADD_JARS` variable. For example, to run `spark-shell` on four cores, use
|
||||
|
||||
{% highlight bash %}
|
||||
$ MASTER=local[4] ./spark-shell
|
||||
{% endhighlight %}
|
||||
|
||||
Or, to also add `code.jar` to its classpath, use:
|
||||
|
||||
{% highlight bash %}
|
||||
$ MASTER=local[4] ADD_JARS=code.jar ./spark-shell
|
||||
{% endhighlight %}
|
||||
|
||||
### Master URLs
|
||||
|
||||
The master URL passed to Spark can be in one of the following formats:
|
||||
|
@ -78,7 +84,7 @@ If you want to run your job on a cluster, you will need to specify the two optio
|
|||
* `sparkHome`: The path at which Spark is installed on your worker machines (it should be the same on all of them).
|
||||
* `jars`: A list of JAR files on the local machine containing your job's code and any dependencies, which Spark will deploy to all the worker nodes. You'll need to package your job into a set of JARs using your build system. For example, if you're using SBT, the [sbt-assembly](https://github.com/sbt/sbt-assembly) plugin is a good way to make a single JAR with your code and dependencies.
|
||||
|
||||
If you run `spark-shell` on a cluster, any classes you define in the shell will automatically be distributed.
|
||||
If you run `spark-shell` on a cluster, you can add JARs to it by specifying the `ADD_JARS` environment variable before you launch it. This variable should contain a comma-separated list of JARs. For example, `ADD_JARS=a.jar,b.jar ./spark-shell` will launch a shell with `a.jar` and `b.jar` on its classpath. In addition, any new classes you define in the shell will automatically be distributed.
|
||||
|
||||
|
||||
# Resilient Distributed Datasets (RDDs)
|
||||
|
|
|
@ -34,6 +34,41 @@
|
|||
<artifactId>scalacheck_${scala.version}</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.cassandra</groupId>
|
||||
<artifactId>cassandra-all</artifactId>
|
||||
<version>1.2.5</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>com.google.guava</groupId>
|
||||
<artifactId>guava</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>com.googlecode.concurrentlinkedhashmap</groupId>
|
||||
<artifactId>concurrentlinkedhashmap-lru</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>com.ning</groupId>
|
||||
<artifactId>compress-lzf</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>io.netty</groupId>
|
||||
<artifactId>netty</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>jline</groupId>
|
||||
<artifactId>jline</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>log4j</groupId>
|
||||
<artifactId>log4j</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>org.apache.cassandra.deps</groupId>
|
||||
<artifactId>avro</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
<outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
|
||||
|
@ -67,6 +102,11 @@
|
|||
<artifactId>hadoop-core</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hbase</groupId>
|
||||
<artifactId>hbase</artifactId>
|
||||
<version>0.94.6</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
<plugins>
|
||||
|
@ -105,6 +145,11 @@
|
|||
<artifactId>hadoop-client</artifactId>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.hbase</groupId>
|
||||
<artifactId>hbase</artifactId>
|
||||
<version>0.94.6</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
<build>
|
||||
<plugins>
|
||||
|
|
196
examples/src/main/scala/spark/examples/CassandraTest.scala
Normal file
196
examples/src/main/scala/spark/examples/CassandraTest.scala
Normal file
|
@ -0,0 +1,196 @@
|
|||
package spark.examples
|
||||
|
||||
import org.apache.hadoop.mapreduce.Job
|
||||
import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat
|
||||
import org.apache.cassandra.hadoop.ConfigHelper
|
||||
import org.apache.cassandra.hadoop.ColumnFamilyInputFormat
|
||||
import org.apache.cassandra.thrift._
|
||||
import spark.SparkContext
|
||||
import spark.SparkContext._
|
||||
import java.nio.ByteBuffer
|
||||
import java.util.SortedMap
|
||||
import org.apache.cassandra.db.IColumn
|
||||
import org.apache.cassandra.utils.ByteBufferUtil
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
|
||||
/*
|
||||
* This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra
|
||||
* support for Hadoop.
|
||||
*
|
||||
* To run this example, run this file with the following command params -
|
||||
* <spark_master> <cassandra_node> <cassandra_port>
|
||||
*
|
||||
* So if you want to run this on localhost this will be,
|
||||
* local[3] localhost 9160
|
||||
*
|
||||
* The example makes some assumptions:
|
||||
* 1. You have already created a keyspace called casDemo and it has a column family named Words
|
||||
* 2. There are column family has a column named "para" which has test content.
|
||||
*
|
||||
* You can create the content by running the following script at the bottom of this file with
|
||||
* cassandra-cli.
|
||||
*
|
||||
*/
|
||||
object CassandraTest {
|
||||
|
||||
def main(args: Array[String]) {
|
||||
|
||||
// Get a SparkContext
|
||||
val sc = new SparkContext(args(0), "casDemo")
|
||||
|
||||
// Build the job configuration with ConfigHelper provided by Cassandra
|
||||
val job = new Job()
|
||||
job.setInputFormatClass(classOf[ColumnFamilyInputFormat])
|
||||
|
||||
val host: String = args(1)
|
||||
val port: String = args(2)
|
||||
|
||||
ConfigHelper.setInputInitialAddress(job.getConfiguration(), host)
|
||||
ConfigHelper.setInputRpcPort(job.getConfiguration(), port)
|
||||
ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host)
|
||||
ConfigHelper.setOutputRpcPort(job.getConfiguration(), port)
|
||||
ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words")
|
||||
ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount")
|
||||
|
||||
val predicate = new SlicePredicate()
|
||||
val sliceRange = new SliceRange()
|
||||
sliceRange.setStart(Array.empty[Byte])
|
||||
sliceRange.setFinish(Array.empty[Byte])
|
||||
predicate.setSlice_range(sliceRange)
|
||||
ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate)
|
||||
|
||||
ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner")
|
||||
ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner")
|
||||
|
||||
// Make a new Hadoop RDD
|
||||
val casRdd = sc.newAPIHadoopRDD(
|
||||
job.getConfiguration(),
|
||||
classOf[ColumnFamilyInputFormat],
|
||||
classOf[ByteBuffer],
|
||||
classOf[SortedMap[ByteBuffer, IColumn]])
|
||||
|
||||
// Let us first get all the paragraphs from the retrieved rows
|
||||
val paraRdd = casRdd.map {
|
||||
case (key, value) => {
|
||||
ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value())
|
||||
}
|
||||
}
|
||||
|
||||
// Lets get the word count in paras
|
||||
val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _)
|
||||
|
||||
counts.collect().foreach {
|
||||
case (word, count) => println(word + ":" + count)
|
||||
}
|
||||
|
||||
counts.map {
|
||||
case (word, count) => {
|
||||
val colWord = new org.apache.cassandra.thrift.Column()
|
||||
colWord.setName(ByteBufferUtil.bytes("word"))
|
||||
colWord.setValue(ByteBufferUtil.bytes(word))
|
||||
colWord.setTimestamp(System.currentTimeMillis)
|
||||
|
||||
val colCount = new org.apache.cassandra.thrift.Column()
|
||||
colCount.setName(ByteBufferUtil.bytes("wcount"))
|
||||
colCount.setValue(ByteBufferUtil.bytes(count.toLong))
|
||||
colCount.setTimestamp(System.currentTimeMillis)
|
||||
|
||||
val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis)
|
||||
|
||||
val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil
|
||||
mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn())
|
||||
mutations.get(0).column_or_supercolumn.setColumn(colWord)
|
||||
mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn())
|
||||
mutations.get(1).column_or_supercolumn.setColumn(colCount)
|
||||
(outputkey, mutations)
|
||||
}
|
||||
}.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]],
|
||||
classOf[ColumnFamilyOutputFormat], job.getConfiguration)
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
create keyspace casDemo;
|
||||
use casDemo;
|
||||
|
||||
create column family WordCount with comparator = UTF8Type;
|
||||
update column family WordCount with column_metadata =
|
||||
[{column_name: word, validation_class: UTF8Type},
|
||||
{column_name: wcount, validation_class: LongType}];
|
||||
|
||||
create column family Words with comparator = UTF8Type;
|
||||
update column family Words with column_metadata =
|
||||
[{column_name: book, validation_class: UTF8Type},
|
||||
{column_name: para, validation_class: UTF8Type}];
|
||||
|
||||
assume Words keys as utf8;
|
||||
|
||||
set Words['3musk001']['book'] = 'The Three Musketeers';
|
||||
set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market
|
||||
town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to
|
||||
be in as perfect a state of revolution as if the Huguenots had just made
|
||||
a second La Rochelle of it. Many citizens, seeing the women flying
|
||||
toward the High Street, leaving their children crying at the open doors,
|
||||
hastened to don the cuirass, and supporting their somewhat uncertain
|
||||
courage with a musket or a partisan, directed their steps toward the
|
||||
hostelry of the Jolly Miller, before which was gathered, increasing
|
||||
every minute, a compact group, vociferous and full of curiosity.';
|
||||
|
||||
set Words['3musk002']['book'] = 'The Three Musketeers';
|
||||
set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without
|
||||
some city or other registering in its archives an event of this kind. There were
|
||||
nobles, who made war against each other; there was the king, who made
|
||||
war against the cardinal; there was Spain, which made war against the
|
||||
king. Then, in addition to these concealed or public, secret or open
|
||||
wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels,
|
||||
who made war upon everybody. The citizens always took up arms readily
|
||||
against thieves, wolves or scoundrels, often against nobles or
|
||||
Huguenots, sometimes against the king, but never against cardinal or
|
||||
Spain. It resulted, then, from this habit that on the said first Monday
|
||||
of April, 1625, the citizens, on hearing the clamor, and seeing neither
|
||||
the red-and-yellow standard nor the livery of the Duc de Richelieu,
|
||||
rushed toward the hostel of the Jolly Miller. When arrived there, the
|
||||
cause of the hubbub was apparent to all';
|
||||
|
||||
set Words['3musk003']['book'] = 'The Three Musketeers';
|
||||
set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however
|
||||
large the sum may be; but you ought also to endeavor to perfect yourself in
|
||||
the exercises becoming a gentleman. I will write a letter today to the
|
||||
Director of the Royal Academy, and tomorrow he will admit you without
|
||||
any expense to yourself. Do not refuse this little service. Our
|
||||
best-born and richest gentlemen sometimes solicit it without being able
|
||||
to obtain it. You will learn horsemanship, swordsmanship in all its
|
||||
branches, and dancing. You will make some desirable acquaintances; and
|
||||
from time to time you can call upon me, just to tell me how you are
|
||||
getting on, and to say whether I can be of further service to you.';
|
||||
|
||||
|
||||
set Words['thelostworld001']['book'] = 'The Lost World';
|
||||
set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined
|
||||
against the red curtain. How beautiful she was! And yet how aloof! We had been
|
||||
friends, quite good friends; but never could I get beyond the same
|
||||
comradeship which I might have established with one of my
|
||||
fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly,
|
||||
and perfectly unsexual. My instincts are all against a woman being too
|
||||
frank and at her ease with me. It is no compliment to a man. Where
|
||||
the real sex feeling begins, timidity and distrust are its companions,
|
||||
heritage from old wicked days when love and violence went often hand in
|
||||
hand. The bent head, the averted eye, the faltering voice, the wincing
|
||||
figure--these, and not the unshrinking gaze and frank reply, are the
|
||||
true signals of passion. Even in my short life I had learned as much
|
||||
as that--or had inherited it in that race memory which we call instinct.';
|
||||
|
||||
set Words['thelostworld002']['book'] = 'The Lost World';
|
||||
set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed,
|
||||
red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was
|
||||
the real boss; but he lived in the rarefied atmosphere of some Olympian
|
||||
height from which he could distinguish nothing smaller than an
|
||||
international crisis or a split in the Cabinet. Sometimes we saw him
|
||||
passing in lonely majesty to his inner sanctum, with his eyes staring
|
||||
vaguely and his mind hovering over the Balkans or the Persian Gulf. He
|
||||
was above and beyond us. But McArdle was his first lieutenant, and it
|
||||
was he that we knew. The old man nodded as I entered the room, and he
|
||||
pushed his spectacles far up on his bald forehead.';
|
||||
|
||||
*/
|
35
examples/src/main/scala/spark/examples/HBaseTest.scala
Normal file
35
examples/src/main/scala/spark/examples/HBaseTest.scala
Normal file
|
@ -0,0 +1,35 @@
|
|||
package spark.examples
|
||||
|
||||
import spark._
|
||||
import spark.rdd.NewHadoopRDD
|
||||
import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor}
|
||||
import org.apache.hadoop.hbase.client.HBaseAdmin
|
||||
import org.apache.hadoop.hbase.mapreduce.TableInputFormat
|
||||
|
||||
object HBaseTest {
|
||||
def main(args: Array[String]) {
|
||||
val sc = new SparkContext(args(0), "HBaseTest",
|
||||
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
|
||||
|
||||
val conf = HBaseConfiguration.create()
|
||||
|
||||
// Other options for configuring scan behavior are available. More information available at
|
||||
// http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html
|
||||
conf.set(TableInputFormat.INPUT_TABLE, args(1))
|
||||
|
||||
// Initialize hBase table if necessary
|
||||
val admin = new HBaseAdmin(conf)
|
||||
if(!admin.isTableAvailable(args(1))) {
|
||||
val tableDesc = new HTableDescriptor(args(1))
|
||||
admin.createTable(tableDesc)
|
||||
}
|
||||
|
||||
val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat],
|
||||
classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable],
|
||||
classOf[org.apache.hadoop.hbase.client.Result])
|
||||
|
||||
hBaseRDD.count()
|
||||
|
||||
System.exit(0)
|
||||
}
|
||||
}
|
|
@ -37,7 +37,7 @@ object KafkaWordCount {
|
|||
ssc.checkpoint("checkpoint")
|
||||
|
||||
val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
|
||||
val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap)
|
||||
val lines = ssc.kafkaStream(zkQuorum, group, topicpMap)
|
||||
val words = lines.flatMap(_.split(" "))
|
||||
val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
|
||||
wordCounts.print()
|
||||
|
|
14
pom.xml
14
pom.xml
|
@ -61,7 +61,7 @@
|
|||
<cdh.version>4.1.2</cdh.version>
|
||||
<log4j.version>1.2.17</log4j.version>
|
||||
|
||||
<PermGen>0m</PermGen>
|
||||
<PermGen>64m</PermGen>
|
||||
<MaxPermGen>512m</MaxPermGen>
|
||||
</properties>
|
||||
|
||||
|
@ -191,9 +191,9 @@
|
|||
<version>0.8.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>asm</groupId>
|
||||
<artifactId>asm-all</artifactId>
|
||||
<version>3.3.1</version>
|
||||
<groupId>org.ow2.asm</groupId>
|
||||
<artifactId>asm</artifactId>
|
||||
<version>4.0</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.protobuf</groupId>
|
||||
|
@ -396,10 +396,8 @@
|
|||
<jvmArgs>
|
||||
<jvmArg>-Xms64m</jvmArg>
|
||||
<jvmArg>-Xmx1024m</jvmArg>
|
||||
<jvmArg>-XX:PermSize</jvmArg>
|
||||
<jvmArg>${PermGen}</jvmArg>
|
||||
<jvmArg>-XX:MaxPermSize</jvmArg>
|
||||
<jvmArg>${MaxPermGen}</jvmArg>
|
||||
<jvmArg>-XX:PermSize=${PermGen}</jvmArg>
|
||||
<jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg>
|
||||
</jvmArgs>
|
||||
<javacArgs>
|
||||
<javacArg>-source</javacArg>
|
||||
|
|
|
@ -56,7 +56,7 @@ object SparkBuild extends Build {
|
|||
|
||||
// Fork new JVMs for tests and set Java options for those
|
||||
fork := true,
|
||||
javaOptions += "-Xmx2g",
|
||||
javaOptions += "-Xmx2500m",
|
||||
|
||||
// Only allow one test at a time, even across projects, since they run in the same JVM
|
||||
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
|
||||
|
@ -127,12 +127,13 @@ object SparkBuild extends Build {
|
|||
publishMavenStyle in MavenCompile := true,
|
||||
publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal),
|
||||
publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn
|
||||
)
|
||||
) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings
|
||||
|
||||
val slf4jVersion = "1.6.1"
|
||||
val slf4jVersion = "1.7.2"
|
||||
|
||||
val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson")
|
||||
val excludeNetty = ExclusionRule(organization = "org.jboss.netty")
|
||||
val excludeAsm = ExclusionRule(organization = "asm")
|
||||
|
||||
def coreSettings = sharedSettings ++ Seq(
|
||||
name := "spark-core",
|
||||
|
@ -150,7 +151,7 @@ object SparkBuild extends Build {
|
|||
"org.slf4j" % "slf4j-log4j12" % slf4jVersion,
|
||||
"commons-daemon" % "commons-daemon" % "1.0.10",
|
||||
"com.ning" % "compress-lzf" % "0.8.4",
|
||||
"asm" % "asm-all" % "3.3.1",
|
||||
"org.ow2.asm" % "asm" % "4.0",
|
||||
"com.google.protobuf" % "protobuf-java" % "2.4.1",
|
||||
"de.javakaffee" % "kryo-serializers" % "0.22",
|
||||
"com.typesafe.akka" % "akka-actor" % "2.0.3" excludeAll(excludeNetty),
|
||||
|
@ -203,7 +204,20 @@ object SparkBuild extends Build {
|
|||
|
||||
def examplesSettings = sharedSettings ++ Seq(
|
||||
name := "spark-examples",
|
||||
libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11")
|
||||
libraryDependencies ++= Seq(
|
||||
"com.twitter" % "algebird-core_2.9.2" % "0.1.11",
|
||||
|
||||
"org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty, excludeAsm),
|
||||
|
||||
"org.apache.cassandra" % "cassandra-all" % "1.2.5"
|
||||
exclude("com.google.guava", "guava")
|
||||
exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru")
|
||||
exclude("com.ning","compress-lzf")
|
||||
exclude("io.netty", "netty")
|
||||
exclude("jline","jline")
|
||||
exclude("log4j","log4j")
|
||||
exclude("org.apache.cassandra.deps", "avro")
|
||||
)
|
||||
)
|
||||
|
||||
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
|
||||
|
@ -212,9 +226,12 @@ object SparkBuild extends Build {
|
|||
|
||||
def streamingSettings = sharedSettings ++ Seq(
|
||||
name := "spark-streaming",
|
||||
resolvers ++= Seq(
|
||||
"Akka Repository" at "http://repo.akka.io/releases/"
|
||||
),
|
||||
libraryDependencies ++= Seq(
|
||||
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty),
|
||||
"com.github.sgroschupf" % "zkclient" % "0.1",
|
||||
"com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty),
|
||||
"org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
|
||||
"com.typesafe.akka" % "akka-zeromq" % "2.0.3" excludeAll(excludeNetty)
|
||||
)
|
||||
|
@ -223,7 +240,7 @@ object SparkBuild extends Build {
|
|||
def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq(
|
||||
mergeStrategy in assembly := {
|
||||
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
|
||||
case m if m.toLowerCase.matches("meta-inf/.*\\.sf$") => MergeStrategy.discard
|
||||
case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
|
||||
case "reference.conf" => MergeStrategy.concat
|
||||
case _ => MergeStrategy.first
|
||||
}
|
||||
|
|
|
@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1")
|
|||
//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns)
|
||||
|
||||
//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6")
|
||||
|
||||
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.3")
|
||||
|
|
158
python/pyspark/daemon.py
Normal file
158
python/pyspark/daemon.py
Normal file
|
@ -0,0 +1,158 @@
|
|||
import os
|
||||
import sys
|
||||
import multiprocessing
|
||||
from ctypes import c_bool
|
||||
from errno import EINTR, ECHILD
|
||||
from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN
|
||||
from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
|
||||
from pyspark.worker import main as worker_main
|
||||
from pyspark.serializers import write_int
|
||||
|
||||
try:
|
||||
POOLSIZE = multiprocessing.cpu_count()
|
||||
except NotImplementedError:
|
||||
POOLSIZE = 4
|
||||
|
||||
exit_flag = multiprocessing.Value(c_bool, False)
|
||||
|
||||
|
||||
def should_exit():
|
||||
global exit_flag
|
||||
return exit_flag.value
|
||||
|
||||
|
||||
def compute_real_exit_code(exit_code):
|
||||
# SystemExit's code can be integer or string, but os._exit only accepts integers
|
||||
import numbers
|
||||
if isinstance(exit_code, numbers.Integral):
|
||||
return exit_code
|
||||
else:
|
||||
return 1
|
||||
|
||||
|
||||
def worker(listen_sock):
|
||||
# Redirect stdout to stderr
|
||||
os.dup2(2, 1)
|
||||
|
||||
# Manager sends SIGHUP to request termination of workers in the pool
|
||||
def handle_sighup(*args):
|
||||
assert should_exit()
|
||||
signal(SIGHUP, handle_sighup)
|
||||
|
||||
# Cleanup zombie children
|
||||
def handle_sigchld(*args):
|
||||
pid = status = None
|
||||
try:
|
||||
while (pid, status) != (0, 0):
|
||||
pid, status = os.waitpid(0, os.WNOHANG)
|
||||
except EnvironmentError as err:
|
||||
if err.errno == EINTR:
|
||||
# retry
|
||||
handle_sigchld()
|
||||
elif err.errno != ECHILD:
|
||||
raise
|
||||
signal(SIGCHLD, handle_sigchld)
|
||||
|
||||
# Handle clients
|
||||
while not should_exit():
|
||||
# Wait until a client arrives or we have to exit
|
||||
sock = None
|
||||
while not should_exit() and sock is None:
|
||||
try:
|
||||
sock, addr = listen_sock.accept()
|
||||
except EnvironmentError as err:
|
||||
if err.errno != EINTR:
|
||||
raise
|
||||
|
||||
if sock is not None:
|
||||
# Fork a child to handle the client.
|
||||
# The client is handled in the child so that the manager
|
||||
# never receives SIGCHLD unless a worker crashes.
|
||||
if os.fork() == 0:
|
||||
# Leave the worker pool
|
||||
signal(SIGHUP, SIG_DFL)
|
||||
listen_sock.close()
|
||||
# Handle the client then exit
|
||||
sockfile = sock.makefile()
|
||||
exit_code = 0
|
||||
try:
|
||||
worker_main(sockfile, sockfile)
|
||||
except SystemExit as exc:
|
||||
exit_code = exc.code
|
||||
finally:
|
||||
sockfile.close()
|
||||
sock.close()
|
||||
os._exit(compute_real_exit_code(exit_code))
|
||||
else:
|
||||
sock.close()
|
||||
|
||||
|
||||
def launch_worker(listen_sock):
|
||||
if os.fork() == 0:
|
||||
try:
|
||||
worker(listen_sock)
|
||||
except Exception as err:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
os._exit(1)
|
||||
else:
|
||||
assert should_exit()
|
||||
os._exit(0)
|
||||
|
||||
|
||||
def manager():
|
||||
# Create a new process group to corral our children
|
||||
os.setpgid(0, 0)
|
||||
|
||||
# Create a listening socket on the AF_INET loopback interface
|
||||
listen_sock = socket(AF_INET, SOCK_STREAM)
|
||||
listen_sock.bind(('127.0.0.1', 0))
|
||||
listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
|
||||
listen_host, listen_port = listen_sock.getsockname()
|
||||
write_int(listen_port, sys.stdout)
|
||||
|
||||
# Launch initial worker pool
|
||||
for idx in range(POOLSIZE):
|
||||
launch_worker(listen_sock)
|
||||
listen_sock.close()
|
||||
|
||||
def shutdown():
|
||||
global exit_flag
|
||||
exit_flag.value = True
|
||||
|
||||
# Gracefully exit on SIGTERM, don't die on SIGHUP
|
||||
signal(SIGTERM, lambda signum, frame: shutdown())
|
||||
signal(SIGHUP, SIG_IGN)
|
||||
|
||||
# Cleanup zombie children
|
||||
def handle_sigchld(*args):
|
||||
try:
|
||||
pid, status = os.waitpid(0, os.WNOHANG)
|
||||
if status != 0 and not should_exit():
|
||||
raise RuntimeError("worker crashed: %s, %s" % (pid, status))
|
||||
except EnvironmentError as err:
|
||||
if err.errno not in (ECHILD, EINTR):
|
||||
raise
|
||||
signal(SIGCHLD, handle_sigchld)
|
||||
|
||||
# Initialization complete
|
||||
sys.stdout.close()
|
||||
try:
|
||||
while not should_exit():
|
||||
try:
|
||||
# Spark tells us to exit by closing stdin
|
||||
if os.read(0, 512) == '':
|
||||
shutdown()
|
||||
except EnvironmentError as err:
|
||||
if err.errno != EINTR:
|
||||
shutdown()
|
||||
raise
|
||||
finally:
|
||||
signal(SIGTERM, SIG_DFL)
|
||||
exit_flag.value = True
|
||||
# Send SIGHUP to notify workers of shutdown
|
||||
os.kill(0, SIGHUP)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
manager()
|
|
@ -46,6 +46,10 @@ def read_long(stream):
|
|||
return struct.unpack("!q", length)[0]
|
||||
|
||||
|
||||
def write_long(value, stream):
|
||||
stream.write(struct.pack("!q", value))
|
||||
|
||||
|
||||
def read_int(stream):
|
||||
length = stream.read(4)
|
||||
if length == "":
|
||||
|
|
|
@ -12,6 +12,7 @@ import unittest
|
|||
from pyspark.context import SparkContext
|
||||
from pyspark.files import SparkFiles
|
||||
from pyspark.java_gateway import SPARK_HOME
|
||||
from pyspark.serializers import read_int
|
||||
|
||||
|
||||
class PySparkTestCase(unittest.TestCase):
|
||||
|
@ -117,5 +118,47 @@ class TestIO(PySparkTestCase):
|
|||
self.sc.parallelize([1]).foreach(func)
|
||||
|
||||
|
||||
class TestDaemon(unittest.TestCase):
|
||||
def connect(self, port):
|
||||
from socket import socket, AF_INET, SOCK_STREAM
|
||||
sock = socket(AF_INET, SOCK_STREAM)
|
||||
sock.connect(('127.0.0.1', port))
|
||||
# send a split index of -1 to shutdown the worker
|
||||
sock.send("\xFF\xFF\xFF\xFF")
|
||||
sock.close()
|
||||
return True
|
||||
|
||||
def do_termination_test(self, terminator):
|
||||
from subprocess import Popen, PIPE
|
||||
from errno import ECONNREFUSED
|
||||
|
||||
# start daemon
|
||||
daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
|
||||
daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
|
||||
|
||||
# read the port number
|
||||
port = read_int(daemon.stdout)
|
||||
|
||||
# daemon should accept connections
|
||||
self.assertTrue(self.connect(port))
|
||||
|
||||
# request shutdown
|
||||
terminator(daemon)
|
||||
time.sleep(1)
|
||||
|
||||
# daemon should no longer accept connections
|
||||
with self.assertRaises(EnvironmentError) as trap:
|
||||
self.connect(port)
|
||||
self.assertEqual(trap.exception.errno, ECONNREFUSED)
|
||||
|
||||
def test_termination_stdin(self):
|
||||
"""Ensure that daemon and workers terminate when stdin is closed."""
|
||||
self.do_termination_test(lambda daemon: daemon.stdin.close())
|
||||
|
||||
def test_termination_sigterm(self):
|
||||
"""Ensure that daemon and workers terminate on SIGTERM."""
|
||||
from signal import SIGTERM
|
||||
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
@ -3,6 +3,7 @@ Worker that receives input from Piped RDD.
|
|||
"""
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from base64 import standard_b64decode
|
||||
# CloudPickler needs to be imported so that depicklers are registered using the
|
||||
|
@ -12,48 +13,60 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
|
|||
from pyspark.cloudpickle import CloudPickler
|
||||
from pyspark.files import SparkFiles
|
||||
from pyspark.serializers import write_with_length, read_with_length, write_int, \
|
||||
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
|
||||
read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
|
||||
|
||||
|
||||
# Redirect stdout to stderr so that users must return values from functions.
|
||||
old_stdout = os.fdopen(os.dup(1), 'w')
|
||||
os.dup2(2, 1)
|
||||
def load_obj(infile):
|
||||
return load_pickle(standard_b64decode(infile.readline().strip()))
|
||||
|
||||
|
||||
def load_obj():
|
||||
return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
|
||||
def report_times(outfile, boot, init, finish):
|
||||
write_int(-3, outfile)
|
||||
write_long(1000 * boot, outfile)
|
||||
write_long(1000 * init, outfile)
|
||||
write_long(1000 * finish, outfile)
|
||||
|
||||
|
||||
def main():
|
||||
split_index = read_int(sys.stdin)
|
||||
spark_files_dir = load_pickle(read_with_length(sys.stdin))
|
||||
def main(infile, outfile):
|
||||
boot_time = time.time()
|
||||
split_index = read_int(infile)
|
||||
if split_index == -1: # for unit tests
|
||||
return
|
||||
spark_files_dir = load_pickle(read_with_length(infile))
|
||||
SparkFiles._root_directory = spark_files_dir
|
||||
SparkFiles._is_running_on_worker = True
|
||||
sys.path.append(spark_files_dir)
|
||||
num_broadcast_variables = read_int(sys.stdin)
|
||||
num_broadcast_variables = read_int(infile)
|
||||
for _ in range(num_broadcast_variables):
|
||||
bid = read_long(sys.stdin)
|
||||
value = read_with_length(sys.stdin)
|
||||
bid = read_long(infile)
|
||||
value = read_with_length(infile)
|
||||
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
|
||||
func = load_obj()
|
||||
bypassSerializer = load_obj()
|
||||
func = load_obj(infile)
|
||||
bypassSerializer = load_obj(infile)
|
||||
if bypassSerializer:
|
||||
dumps = lambda x: x
|
||||
else:
|
||||
dumps = dump_pickle
|
||||
iterator = read_from_pickle_file(sys.stdin)
|
||||
init_time = time.time()
|
||||
iterator = read_from_pickle_file(infile)
|
||||
try:
|
||||
for obj in func(split_index, iterator):
|
||||
write_with_length(dumps(obj), old_stdout)
|
||||
write_with_length(dumps(obj), outfile)
|
||||
except Exception as e:
|
||||
write_int(-2, old_stdout)
|
||||
write_with_length(traceback.format_exc(), old_stdout)
|
||||
write_int(-2, outfile)
|
||||
write_with_length(traceback.format_exc(), outfile)
|
||||
sys.exit(-1)
|
||||
finish_time = time.time()
|
||||
report_times(outfile, boot_time, init_time, finish_time)
|
||||
# Mark the beginning of the accumulators section of the output
|
||||
write_int(-1, old_stdout)
|
||||
write_int(-1, outfile)
|
||||
for aid, accum in _accumulatorRegistry.items():
|
||||
write_with_length(dump_pickle((aid, accum._value)), old_stdout)
|
||||
write_with_length(dump_pickle((aid, accum._value)), outfile)
|
||||
write_int(-1, outfile)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
# Redirect stdout to stderr so that users must return values from functions.
|
||||
old_stdout = os.fdopen(os.dup(1), 'w')
|
||||
os.dup2(2, 1)
|
||||
main(sys.stdin, old_stdout)
|
||||
|
|
|
@ -8,7 +8,6 @@ import org.apache.hadoop.conf.Configuration
|
|||
import org.apache.hadoop.fs.{FileSystem, Path}
|
||||
|
||||
import org.objectweb.asm._
|
||||
import org.objectweb.asm.commons.EmptyVisitor
|
||||
import org.objectweb.asm.Opcodes._
|
||||
|
||||
|
||||
|
@ -83,7 +82,7 @@ extends ClassLoader(parent) {
|
|||
}
|
||||
|
||||
class ConstructorCleaner(className: String, cv: ClassVisitor)
|
||||
extends ClassAdapter(cv) {
|
||||
extends ClassVisitor(ASM4, cv) {
|
||||
override def visitMethod(access: Int, name: String, desc: String,
|
||||
sig: String, exceptions: Array[String]): MethodVisitor = {
|
||||
val mv = cv.visitMethod(access, name, desc, sig, exceptions)
|
||||
|
|
|
@ -822,7 +822,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
|
|||
spark.repl.Main.interp.out.println("Spark context available as sc.");
|
||||
spark.repl.Main.interp.out.flush();
|
||||
""")
|
||||
command("import spark.SparkContext._");
|
||||
command("import spark.SparkContext._")
|
||||
}
|
||||
echo("Type in expressions to have them evaluated.")
|
||||
echo("Type :help for more information.")
|
||||
|
@ -838,7 +838,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
|
|||
if (prop != null) prop else "local"
|
||||
}
|
||||
}
|
||||
sparkContext = new SparkContext(master, "Spark shell")
|
||||
val jars = Option(System.getenv("ADD_JARS")).map(_.split(','))
|
||||
.getOrElse(new Array[String](0))
|
||||
.map(new java.io.File(_).getAbsolutePath)
|
||||
sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
|
||||
sparkContext
|
||||
}
|
||||
|
||||
|
@ -850,6 +853,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
|
|||
printWelcome()
|
||||
echo("Initializing interpreter...")
|
||||
|
||||
// Add JARS specified in Spark's ADD_JARS variable to classpath
|
||||
val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0))
|
||||
jars.foreach(settings.classpath.append(_))
|
||||
|
||||
this.settings = settings
|
||||
createInterpreter()
|
||||
|
||||
|
|
|
@ -28,24 +28,25 @@ class ReplSuite extends FunSuite {
|
|||
val separator = System.getProperty("path.separator")
|
||||
interp.process(Array("-classpath", paths.mkString(separator)))
|
||||
spark.repl.Main.interp = null
|
||||
if (interp.sparkContext != null)
|
||||
if (interp.sparkContext != null) {
|
||||
interp.sparkContext.stop()
|
||||
}
|
||||
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
|
||||
System.clearProperty("spark.driver.port")
|
||||
System.clearProperty("spark.hostPort")
|
||||
return out.toString
|
||||
}
|
||||
|
||||
|
||||
def assertContains(message: String, output: String) {
|
||||
assert(output contains message,
|
||||
assert(output.contains(message),
|
||||
"Interpreter output did not contain '" + message + "':\n" + output)
|
||||
}
|
||||
|
||||
|
||||
def assertDoesNotContain(message: String, output: String) {
|
||||
assert(!(output contains message),
|
||||
assert(!output.contains(message),
|
||||
"Interpreter output contained '" + message + "':\n" + output)
|
||||
}
|
||||
|
||||
|
||||
test ("simple foreach with accumulator") {
|
||||
val output = runInterpreter("local", """
|
||||
val accum = sc.accumulator(0)
|
||||
|
@ -56,7 +57,7 @@ class ReplSuite extends FunSuite {
|
|||
assertDoesNotContain("Exception", output)
|
||||
assertContains("res1: Int = 55", output)
|
||||
}
|
||||
|
||||
|
||||
test ("external vars") {
|
||||
val output = runInterpreter("local", """
|
||||
var v = 7
|
||||
|
@ -105,7 +106,7 @@ class ReplSuite extends FunSuite {
|
|||
assertContains("res0: Int = 70", output)
|
||||
assertContains("res1: Int = 100", output)
|
||||
}
|
||||
|
||||
|
||||
test ("broadcast vars") {
|
||||
// Test that the value that a broadcast var had when it was created is used,
|
||||
// even if that variable is then modified in the driver program
|
||||
|
@ -143,6 +144,27 @@ class ReplSuite extends FunSuite {
|
|||
assertContains("res2: Long = 3", output)
|
||||
}
|
||||
|
||||
test ("local-cluster mode") {
|
||||
val output = runInterpreter("local-cluster[1,1,512]", """
|
||||
var v = 7
|
||||
def getV() = v
|
||||
sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
|
||||
v = 10
|
||||
sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
|
||||
var array = new Array[Int](5)
|
||||
val broadcastArray = sc.broadcast(array)
|
||||
sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
|
||||
array(0) = 5
|
||||
sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
|
||||
""")
|
||||
assertDoesNotContain("error:", output)
|
||||
assertDoesNotContain("Exception", output)
|
||||
assertContains("res0: Int = 70", output)
|
||||
assertContains("res1: Int = 100", output)
|
||||
assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
|
||||
assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
|
||||
}
|
||||
|
||||
if (System.getenv("MESOS_NATIVE_LIBRARY") != null) {
|
||||
test ("running on Mesos") {
|
||||
val output = runInterpreter("localquiet", """
|
||||
|
|
100
run
100
run
|
@ -23,29 +23,38 @@ fi
|
|||
if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then
|
||||
SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m}
|
||||
SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true"
|
||||
SPARK_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default
|
||||
# Do not overwrite SPARK_JAVA_OPTS environment variable in this script
|
||||
OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS" # Empty by default
|
||||
else
|
||||
OUR_JAVA_OPTS="$SPARK_JAVA_OPTS"
|
||||
fi
|
||||
|
||||
|
||||
# Add java opts for master, worker, executor. The opts maybe null
|
||||
case "$1" in
|
||||
'spark.deploy.master.Master')
|
||||
SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_MASTER_OPTS"
|
||||
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS"
|
||||
;;
|
||||
'spark.deploy.worker.Worker')
|
||||
SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_WORKER_OPTS"
|
||||
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS"
|
||||
;;
|
||||
'spark.executor.StandaloneExecutorBackend')
|
||||
SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
|
||||
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
|
||||
;;
|
||||
'spark.executor.MesosExecutorBackend')
|
||||
SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
|
||||
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
|
||||
;;
|
||||
'spark.repl.Main')
|
||||
SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_REPL_OPTS"
|
||||
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS"
|
||||
;;
|
||||
esac
|
||||
|
||||
# Figure out whether to run our class with java or with the scala launcher.
|
||||
# In most cases, we'd prefer to execute our process with java because scala
|
||||
# creates a shell script as the parent of its Java process, which makes it
|
||||
# hard to kill the child with stuff like Process.destroy(). However, for
|
||||
# the Spark shell, the wrapper is necessary to properly reset the terminal
|
||||
# when we exit, so we allow it to set a variable to launch with scala.
|
||||
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
|
||||
if [ "$SCALA_HOME" ]; then
|
||||
RUNNER="${SCALA_HOME}/bin/scala"
|
||||
|
@ -58,14 +67,15 @@ if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
|
|||
fi
|
||||
fi
|
||||
else
|
||||
if [ `command -v java` ]; then
|
||||
RUNNER="java"
|
||||
if [ -n "${JAVA_HOME}" ]; then
|
||||
RUNNER="${JAVA_HOME}/bin/java"
|
||||
else
|
||||
if [ -z "$JAVA_HOME" ]; then
|
||||
if [ `command -v java` ]; then
|
||||
RUNNER="java"
|
||||
else
|
||||
echo "JAVA_HOME is not set" >&2
|
||||
exit 1
|
||||
fi
|
||||
RUNNER="${JAVA_HOME}/bin/java"
|
||||
fi
|
||||
if [ -z "$SCALA_LIBRARY_PATH" ]; then
|
||||
if [ -z "$SCALA_HOME" ]; then
|
||||
|
@ -84,7 +94,7 @@ fi
|
|||
export SPARK_MEM
|
||||
|
||||
# Set JAVA_OPTS to be able to load native libraries and to set heap size
|
||||
JAVA_OPTS="$SPARK_JAVA_OPTS"
|
||||
JAVA_OPTS="$OUR_JAVA_OPTS"
|
||||
JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH"
|
||||
JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM"
|
||||
# Load extra JAVA_OPTS from conf/java-opts, if it exists
|
||||
|
@ -92,15 +102,11 @@ if [ -e $FWDIR/conf/java-opts ] ; then
|
|||
JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`"
|
||||
fi
|
||||
export JAVA_OPTS
|
||||
# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala!
|
||||
|
||||
CORE_DIR="$FWDIR/core"
|
||||
REPL_DIR="$FWDIR/repl"
|
||||
REPL_BIN_DIR="$FWDIR/repl-bin"
|
||||
EXAMPLES_DIR="$FWDIR/examples"
|
||||
BAGEL_DIR="$FWDIR/bagel"
|
||||
GRAPH_DIR="$FWDIR/graph"
|
||||
STREAMING_DIR="$FWDIR/streaming"
|
||||
PYSPARK_DIR="$FWDIR/python"
|
||||
REPL_DIR="$FWDIR/repl"
|
||||
|
||||
# Exit if the user hasn't compiled Spark
|
||||
if [ ! -e "$CORE_DIR/target" ]; then
|
||||
|
@ -115,36 +121,9 @@ if [[ "$@" = *repl* && ! -e "$REPL_DIR/target" ]]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
# Build up classpath
|
||||
CLASSPATH="$SPARK_CLASSPATH"
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/conf"
|
||||
CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
if [ -n "$SPARK_TESTING" ] ; then
|
||||
CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
|
||||
CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes"
|
||||
fi
|
||||
CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources"
|
||||
CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar
|
||||
if [ -e "$FWDIR/lib_managed" ]; then
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*"
|
||||
CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*"
|
||||
fi
|
||||
CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*"
|
||||
if [ -e $REPL_BIN_DIR/target ]; then
|
||||
for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
|
||||
CLASSPATH="$CLASSPATH:$jar"
|
||||
done
|
||||
fi
|
||||
|
||||
CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
CLASSPATH="$CLASSPATH:$GRAPH_DIR/target/scala-$SCALA_VERSION/classes"
|
||||
|
||||
for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
|
||||
CLASSPATH="$CLASSPATH:$jar"
|
||||
done
|
||||
# Compute classpath using external script
|
||||
CLASSPATH=`$FWDIR/bin/compute-classpath.sh`
|
||||
export CLASSPATH
|
||||
|
||||
# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack
|
||||
# to avoid the -sources and -doc packages that are built by publish-local.
|
||||
|
@ -152,37 +131,16 @@ if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ];
|
|||
# Use the JAR from the SBT build
|
||||
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar`
|
||||
fi
|
||||
if [ -e "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar ]; then
|
||||
if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then
|
||||
# Use the JAR from the Maven build
|
||||
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar`
|
||||
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
|
||||
fi
|
||||
|
||||
# Add hadoop conf dir - else FileSystem.*, etc fail !
|
||||
# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
|
||||
# the configurtion files.
|
||||
if [ "x" != "x$HADOOP_CONF_DIR" ]; then
|
||||
CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR"
|
||||
fi
|
||||
if [ "x" != "x$YARN_CONF_DIR" ]; then
|
||||
CLASSPATH="$CLASSPATH:$YARN_CONF_DIR"
|
||||
fi
|
||||
|
||||
|
||||
# Figure out whether to run our class with java or with the scala launcher.
|
||||
# In most cases, we'd prefer to execute our process with java because scala
|
||||
# creates a shell script as the parent of its Java process, which makes it
|
||||
# hard to kill the child with stuff like Process.destroy(). However, for
|
||||
# the Spark shell, the wrapper is necessary to properly reset the terminal
|
||||
# when we exit, so we allow it to set a variable to launch with scala.
|
||||
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
|
||||
EXTRA_ARGS="" # Java options will be passed to scala as JAVA_OPTS
|
||||
else
|
||||
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar"
|
||||
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar"
|
||||
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar"
|
||||
# The JVM doesn't read JAVA_OPTS by default so we need to pass it in
|
||||
EXTRA_ARGS="$JAVA_OPTS"
|
||||
fi
|
||||
|
||||
export CLASSPATH # Needed for spark-shell
|
||||
exec "$RUNNER" -cp "$CLASSPATH" $EXTRA_ARGS "$@"
|
||||
exec "$RUNNER" -cp "$CLASSPATH" $EXTRA_ARGS "$@"
|
50
run2.cmd
50
run2.cmd
|
@ -23,7 +23,9 @@ if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1
|
|||
if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m
|
||||
set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true
|
||||
if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY%
|
||||
if "%RUNNING_DAEMON%"=="1" set SPARK_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
|
||||
rem Do not overwrite SPARK_JAVA_OPTS environment variable in this script
|
||||
if "%RUNNING_DAEMON%"=="0" set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS%
|
||||
if "%RUNNING_DAEMON%"=="1" set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
|
||||
|
||||
rem Check that SCALA_HOME has been specified
|
||||
if not "x%SCALA_HOME%"=="x" goto scala_exists
|
||||
|
@ -31,52 +33,22 @@ if not "x%SCALA_HOME%"=="x" goto scala_exists
|
|||
goto exit
|
||||
:scala_exists
|
||||
|
||||
rem If the user specifies a Mesos JAR, put it before our included one on the classpath
|
||||
set MESOS_CLASSPATH=
|
||||
if not "x%MESOS_JAR%"=="x" set MESOS_CLASSPATH=%MESOS_JAR%
|
||||
|
||||
rem Figure out how much memory to use per executor and set it as an environment
|
||||
rem variable so that our process sees it and can report it to Mesos
|
||||
if "x%SPARK_MEM%"=="x" set SPARK_MEM=512m
|
||||
|
||||
rem Set JAVA_OPTS to be able to load native libraries and to set heap size
|
||||
set JAVA_OPTS=%SPARK_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM%
|
||||
rem Load extra JAVA_OPTS from conf/java-opts, if it exists
|
||||
if exist "%FWDIR%conf\java-opts.cmd" call "%FWDIR%conf\java-opts.cmd"
|
||||
set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM%
|
||||
rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala!
|
||||
|
||||
set CORE_DIR=%FWDIR%core
|
||||
set REPL_DIR=%FWDIR%repl
|
||||
set EXAMPLES_DIR=%FWDIR%examples
|
||||
set BAGEL_DIR=%FWDIR%bagel
|
||||
set GRAPH_DIR=%FWDIR%graph
|
||||
set STREAMING_DIR=%FWDIR%streaming
|
||||
set PYSPARK_DIR=%FWDIR%python
|
||||
|
||||
rem Build up classpath
|
||||
set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
|
||||
set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources
|
||||
set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes
|
||||
set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\*
|
||||
set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes
|
||||
set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\*
|
||||
set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\*
|
||||
set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\*
|
||||
set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\*
|
||||
set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
|
||||
set CLASSPATH=%CLASSPATH%;%GRAPH_DIR%\target\scala-%SCALA_VERSION%\classes
|
||||
|
||||
rem Add hadoop conf dir - else FileSystem.*, etc fail
|
||||
rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
|
||||
rem the configurtion files.
|
||||
if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
|
||||
set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
|
||||
:no_hadoop_conf_dir
|
||||
|
||||
if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
|
||||
set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
|
||||
:no_yarn_conf_dir
|
||||
|
||||
set REPL_DIR=%FWDIR%repl
|
||||
|
||||
rem Compute classpath using external script
|
||||
set DONT_PRINT_CLASSPATH=1
|
||||
call "%FWDIR%bin\compute-classpath.cmd"
|
||||
set DONT_PRINT_CLASSPATH=0
|
||||
|
||||
rem Figure out the JAR file that our examples were packaged into.
|
||||
rem First search in the build path from SBT:
|
||||
|
@ -108,4 +80,4 @@ if "%SPARK_LAUNCH_WITH_SCALA%" NEQ 1 goto java_runner
|
|||
:run_spark
|
||||
|
||||
"%RUNNER%" -cp "%CLASSPATH%" %EXTRA_ARGS% %*
|
||||
:exit
|
||||
:exit
|
|
@ -441,7 +441,12 @@ abstract class DStream[T: ClassManifest] (
|
|||
* Return a new DStream in which each RDD has a single element generated by counting each RDD
|
||||
* of this DStream.
|
||||
*/
|
||||
def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _)
|
||||
def count(): DStream[Long] = {
|
||||
this.map(_ => (null, 1L))
|
||||
.transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1)))
|
||||
.reduceByKey(_ + _)
|
||||
.map(_._2)
|
||||
}
|
||||
|
||||
/**
|
||||
* Return a new DStream in which each RDD contains the counts of each distinct value in
|
||||
|
@ -457,7 +462,7 @@ abstract class DStream[T: ClassManifest] (
|
|||
* this DStream will be registered as an output stream and therefore materialized.
|
||||
*/
|
||||
def foreach(foreachFunc: RDD[T] => Unit) {
|
||||
foreach((r: RDD[T], t: Time) => foreachFunc(r))
|
||||
this.foreach((r: RDD[T], t: Time) => foreachFunc(r))
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -28,6 +28,7 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
|
|||
import org.apache.hadoop.fs.Path
|
||||
import twitter4j.Status
|
||||
|
||||
|
||||
/**
|
||||
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
|
||||
* information (such as, cluster URL and job name) to internally create a SparkContext, it provides
|
||||
|
@ -186,10 +187,11 @@ class StreamingContext private (
|
|||
* should be same.
|
||||
*/
|
||||
def actorStream[T: ClassManifest](
|
||||
props: Props,
|
||||
name: String,
|
||||
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
|
||||
supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy): DStream[T] = {
|
||||
props: Props,
|
||||
name: String,
|
||||
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
|
||||
supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy
|
||||
): DStream[T] = {
|
||||
networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy))
|
||||
}
|
||||
|
||||
|
@ -197,9 +199,10 @@ class StreamingContext private (
|
|||
* Create an input stream that receives messages pushed by a zeromq publisher.
|
||||
* @param publisherUrl Url of remote zeromq publisher
|
||||
* @param subscribe topic to subscribe to
|
||||
* @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence
|
||||
* of byte thus it needs the converter(which might be deserializer of bytes)
|
||||
* to translate from sequence of sequence of bytes, where sequence refer to a frame
|
||||
* @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic
|
||||
* and each frame has sequence of byte thus it needs the converter
|
||||
* (which might be deserializer of bytes) to translate from sequence
|
||||
* of sequence of bytes, where sequence refer to a frame
|
||||
* and sub sequence refer to its payload.
|
||||
* @param storageLevel RDD storage level. Defaults to memory-only.
|
||||
*/
|
||||
|
@ -215,24 +218,39 @@ class StreamingContext private (
|
|||
}
|
||||
|
||||
/**
|
||||
* Create an input stream that pulls messages form a Kafka Broker.
|
||||
* Create an input stream that pulls messages from a Kafka Broker.
|
||||
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
|
||||
* @param groupId The group id for this consumer.
|
||||
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
|
||||
* in its own thread.
|
||||
* @param initialOffsets Optional initial offsets for each of the partitions to consume.
|
||||
* By default the value is pulled from zookeper.
|
||||
* in its own thread.
|
||||
* @param storageLevel Storage level to use for storing the received objects
|
||||
* (default: StorageLevel.MEMORY_AND_DISK_SER_2)
|
||||
*/
|
||||
def kafkaStream[T: ClassManifest](
|
||||
def kafkaStream(
|
||||
zkQuorum: String,
|
||||
groupId: String,
|
||||
topics: Map[String, Int],
|
||||
initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](),
|
||||
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
|
||||
): DStream[String] = {
|
||||
val kafkaParams = Map[String, String](
|
||||
"zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000")
|
||||
kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an input stream that pulls messages from a Kafka Broker.
|
||||
* @param kafkaParams Map of kafka configuration paramaters.
|
||||
* See: http://kafka.apache.org/configuration.html
|
||||
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
|
||||
* in its own thread.
|
||||
* @param storageLevel Storage level to use for storing the received objects
|
||||
*/
|
||||
def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest](
|
||||
kafkaParams: Map[String, String],
|
||||
topics: Map[String, Int],
|
||||
storageLevel: StorageLevel
|
||||
): DStream[T] = {
|
||||
val inputStream = new KafkaInputDStream[T](this, zkQuorum, groupId, topics, initialOffsets, storageLevel)
|
||||
val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel)
|
||||
registerInputStream(inputStream)
|
||||
inputStream
|
||||
}
|
||||
|
@ -397,7 +415,8 @@ class StreamingContext private (
|
|||
* it will process either one or all of the RDDs returned by the queue.
|
||||
* @param queue Queue of RDDs
|
||||
* @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
|
||||
* @param defaultRDD Default RDD is returned by the DStream when the queue is empty. Set as null if no RDD should be returned when empty
|
||||
* @param defaultRDD Default RDD is returned by the DStream when the queue is empty.
|
||||
* Set as null if no RDD should be returned when empty
|
||||
* @tparam T Type of objects in the RDD
|
||||
*/
|
||||
def queueStream[T: ClassManifest](
|
||||
|
|
|
@ -121,14 +121,15 @@ class JavaStreamingContext(val ssc: StreamingContext) {
|
|||
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
|
||||
* in its own thread.
|
||||
*/
|
||||
def kafkaStream[T](
|
||||
def kafkaStream(
|
||||
zkQuorum: String,
|
||||
groupId: String,
|
||||
topics: JMap[String, JInt])
|
||||
: JavaDStream[T] = {
|
||||
implicit val cmt: ClassManifest[T] =
|
||||
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
|
||||
ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*))
|
||||
: JavaDStream[String] = {
|
||||
implicit val cmt: ClassManifest[String] =
|
||||
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
|
||||
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
|
||||
StorageLevel.MEMORY_ONLY_SER_2)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -136,49 +137,45 @@ class JavaStreamingContext(val ssc: StreamingContext) {
|
|||
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
|
||||
* @param groupId The group id for this consumer.
|
||||
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
|
||||
* in its own thread.
|
||||
* @param initialOffsets Optional initial offsets for each of the partitions to consume.
|
||||
* By default the value is pulled from zookeper.
|
||||
* in its own thread.
|
||||
* @param storageLevel RDD storage level. Defaults to memory-only
|
||||
*
|
||||
*/
|
||||
def kafkaStream[T](
|
||||
def kafkaStream(
|
||||
zkQuorum: String,
|
||||
groupId: String,
|
||||
topics: JMap[String, JInt],
|
||||
initialOffsets: JMap[KafkaPartitionKey, JLong])
|
||||
: JavaDStream[T] = {
|
||||
implicit val cmt: ClassManifest[T] =
|
||||
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
|
||||
ssc.kafkaStream[T](
|
||||
zkQuorum,
|
||||
groupId,
|
||||
Map(topics.mapValues(_.intValue()).toSeq: _*),
|
||||
Map(initialOffsets.mapValues(_.longValue()).toSeq: _*))
|
||||
storageLevel: StorageLevel)
|
||||
: JavaDStream[String] = {
|
||||
implicit val cmt: ClassManifest[String] =
|
||||
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
|
||||
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
|
||||
storageLevel)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an input stream that pulls messages form a Kafka Broker.
|
||||
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
|
||||
* @param groupId The group id for this consumer.
|
||||
* @param typeClass Type of RDD
|
||||
* @param decoderClass Type of kafka decoder
|
||||
* @param kafkaParams Map of kafka configuration paramaters.
|
||||
* See: http://kafka.apache.org/configuration.html
|
||||
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
|
||||
* in its own thread.
|
||||
* @param initialOffsets Optional initial offsets for each of the partitions to consume.
|
||||
* By default the value is pulled from zookeper.
|
||||
* @param storageLevel RDD storage level. Defaults to memory-only
|
||||
*/
|
||||
def kafkaStream[T](
|
||||
zkQuorum: String,
|
||||
groupId: String,
|
||||
def kafkaStream[T, D <: kafka.serializer.Decoder[_]](
|
||||
typeClass: Class[T],
|
||||
decoderClass: Class[D],
|
||||
kafkaParams: JMap[String, String],
|
||||
topics: JMap[String, JInt],
|
||||
initialOffsets: JMap[KafkaPartitionKey, JLong],
|
||||
storageLevel: StorageLevel)
|
||||
: JavaDStream[T] = {
|
||||
implicit val cmt: ClassManifest[T] =
|
||||
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
|
||||
ssc.kafkaStream[T](
|
||||
zkQuorum,
|
||||
groupId,
|
||||
implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]]
|
||||
ssc.kafkaStream[T, D](
|
||||
kafkaParams.toMap,
|
||||
Map(topics.mapValues(_.intValue()).toSeq: _*),
|
||||
Map(initialOffsets.mapValues(_.longValue()).toSeq: _*),
|
||||
storageLevel)
|
||||
}
|
||||
|
||||
|
|
|
@ -9,58 +9,51 @@ import java.util.concurrent.Executors
|
|||
|
||||
import kafka.consumer._
|
||||
import kafka.message.{Message, MessageSet, MessageAndMetadata}
|
||||
import kafka.serializer.StringDecoder
|
||||
import kafka.serializer.Decoder
|
||||
import kafka.utils.{Utils, ZKGroupTopicDirs}
|
||||
import kafka.utils.ZkUtils._
|
||||
import kafka.utils.ZKStringSerializer
|
||||
import org.I0Itec.zkclient._
|
||||
|
||||
import scala.collection.Map
|
||||
import scala.collection.mutable.HashMap
|
||||
import scala.collection.JavaConversions._
|
||||
|
||||
|
||||
// Key for a specific Kafka Partition: (broker, topic, group, part)
|
||||
case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int)
|
||||
|
||||
/**
|
||||
* Input stream that pulls messages from a Kafka Broker.
|
||||
*
|
||||
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
|
||||
* @param groupId The group id for this consumer.
|
||||
* @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html
|
||||
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
|
||||
* in its own thread.
|
||||
* @param initialOffsets Optional initial offsets for each of the partitions to consume.
|
||||
* By default the value is pulled from zookeper.
|
||||
* @param storageLevel RDD storage level.
|
||||
*/
|
||||
private[streaming]
|
||||
class KafkaInputDStream[T: ClassManifest](
|
||||
class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest](
|
||||
@transient ssc_ : StreamingContext,
|
||||
zkQuorum: String,
|
||||
groupId: String,
|
||||
kafkaParams: Map[String, String],
|
||||
topics: Map[String, Int],
|
||||
initialOffsets: Map[KafkaPartitionKey, Long],
|
||||
storageLevel: StorageLevel
|
||||
) extends NetworkInputDStream[T](ssc_ ) with Logging {
|
||||
|
||||
|
||||
def getReceiver(): NetworkReceiver[T] = {
|
||||
new KafkaReceiver(zkQuorum, groupId, topics, initialOffsets, storageLevel)
|
||||
new KafkaReceiver[T, D](kafkaParams, topics, storageLevel)
|
||||
.asInstanceOf[NetworkReceiver[T]]
|
||||
}
|
||||
}
|
||||
|
||||
private[streaming]
|
||||
class KafkaReceiver(zkQuorum: String, groupId: String,
|
||||
topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long],
|
||||
storageLevel: StorageLevel) extends NetworkReceiver[Any] {
|
||||
|
||||
// Timeout for establishing a connection to Zookeper in ms.
|
||||
val ZK_TIMEOUT = 10000
|
||||
class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
|
||||
kafkaParams: Map[String, String],
|
||||
topics: Map[String, Int],
|
||||
storageLevel: StorageLevel
|
||||
) extends NetworkReceiver[Any] {
|
||||
|
||||
// Handles pushing data into the BlockManager
|
||||
lazy protected val blockGenerator = new BlockGenerator(storageLevel)
|
||||
// Connection to Kafka
|
||||
var consumerConnector : ZookeeperConsumerConnector = null
|
||||
var consumerConnector : ConsumerConnector = null
|
||||
|
||||
def onStop() {
|
||||
blockGenerator.stop()
|
||||
|
@ -73,54 +66,59 @@ class KafkaReceiver(zkQuorum: String, groupId: String,
|
|||
// In case we are using multiple Threads to handle Kafka Messages
|
||||
val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
|
||||
|
||||
logInfo("Starting Kafka Consumer Stream with group: " + groupId)
|
||||
logInfo("Initial offsets: " + initialOffsets.toString)
|
||||
logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid"))
|
||||
|
||||
// Zookeper connection properties
|
||||
// Kafka connection properties
|
||||
val props = new Properties()
|
||||
props.put("zk.connect", zkQuorum)
|
||||
props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString)
|
||||
props.put("groupid", groupId)
|
||||
kafkaParams.foreach(param => props.put(param._1, param._2))
|
||||
|
||||
// Create the connection to the cluster
|
||||
logInfo("Connecting to Zookeper: " + zkQuorum)
|
||||
logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect"))
|
||||
val consumerConfig = new ConsumerConfig(props)
|
||||
consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector]
|
||||
logInfo("Connected to " + zkQuorum)
|
||||
consumerConnector = Consumer.create(consumerConfig)
|
||||
logInfo("Connected to " + kafkaParams("zk.connect"))
|
||||
|
||||
// If specified, set the topic offset
|
||||
setOffsets(initialOffsets)
|
||||
// When autooffset.reset is defined, it is our responsibility to try and whack the
|
||||
// consumer group zk node.
|
||||
if (kafkaParams.contains("autooffset.reset")) {
|
||||
tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid"))
|
||||
}
|
||||
|
||||
// Create Threads for each Topic/Message Stream we are listening
|
||||
val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder())
|
||||
val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]]
|
||||
val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder)
|
||||
|
||||
// Start the messages handler for each partition
|
||||
topicMessageStreams.values.foreach { streams =>
|
||||
streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) }
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Overwrites the offets in Zookeper.
|
||||
private def setOffsets(offsets: Map[KafkaPartitionKey, Long]) {
|
||||
offsets.foreach { case(key, offset) =>
|
||||
val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic)
|
||||
val partitionName = key.brokerId + "-" + key.partId
|
||||
updatePersistentPath(consumerConnector.zkClient,
|
||||
topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString)
|
||||
}
|
||||
}
|
||||
|
||||
// Handles Kafka Messages
|
||||
private class MessageHandler(stream: KafkaStream[String]) extends Runnable {
|
||||
private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable {
|
||||
def run() {
|
||||
logInfo("Starting MessageHandler.")
|
||||
stream.takeWhile { msgAndMetadata =>
|
||||
for (msgAndMetadata <- stream) {
|
||||
blockGenerator += msgAndMetadata.message
|
||||
// Keep on handling messages
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// It is our responsibility to delete the consumer group when specifying autooffset.reset. This is because
|
||||
// Kafka 0.7.2 only honors this param when the group is not in zookeeper.
|
||||
//
|
||||
// The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas'
|
||||
// ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest'/'largest':
|
||||
// https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala
|
||||
private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) {
|
||||
try {
|
||||
val dir = "/consumers/" + groupId
|
||||
logInfo("Cleaning up temporary zookeeper data under " + dir + ".")
|
||||
val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer)
|
||||
zk.deleteRecursive(dir)
|
||||
zk.close()
|
||||
} catch {
|
||||
case _ => // swallow
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -198,7 +198,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
|
|||
case class Block(id: String, iterator: Iterator[T], metadata: Any = null)
|
||||
|
||||
val clock = new SystemClock()
|
||||
val blockInterval = 200L
|
||||
val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong
|
||||
val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
|
||||
val blockStorageLevel = storageLevel
|
||||
val blocksForPushing = new ArrayBlockingQueue[Block](1000)
|
||||
|
|
|
@ -4,6 +4,7 @@ import com.google.common.base.Optional;
|
|||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Maps;
|
||||
import com.google.common.io.Files;
|
||||
import kafka.serializer.StringDecoder;
|
||||
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
|
||||
import org.junit.After;
|
||||
import org.junit.Assert;
|
||||
|
@ -23,7 +24,6 @@ import spark.streaming.api.java.JavaPairDStream;
|
|||
import spark.streaming.api.java.JavaStreamingContext;
|
||||
import spark.streaming.JavaTestUtils;
|
||||
import spark.streaming.JavaCheckpointTestUtils;
|
||||
import spark.streaming.dstream.KafkaPartitionKey;
|
||||
import spark.streaming.InputStreamsSuite;
|
||||
|
||||
import java.io.*;
|
||||
|
@ -1203,10 +1203,14 @@ public class JavaAPISuite implements Serializable {
|
|||
@Test
|
||||
public void testKafkaStream() {
|
||||
HashMap<String, Integer> topics = Maps.newHashMap();
|
||||
HashMap<KafkaPartitionKey, Long> offsets = Maps.newHashMap();
|
||||
JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics);
|
||||
JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets);
|
||||
JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets,
|
||||
JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics,
|
||||
StorageLevel.MEMORY_AND_DISK());
|
||||
|
||||
HashMap<String, String> kafkaParams = Maps.newHashMap();
|
||||
kafkaParams.put("zk.connect","localhost:12345");
|
||||
kafkaParams.put("groupid","consumer-group");
|
||||
JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics,
|
||||
StorageLevel.MEMORY_AND_DISK());
|
||||
}
|
||||
|
||||
|
|
|
@ -93,9 +93,9 @@ class BasicOperationsSuite extends TestSuiteBase {
|
|||
|
||||
test("count") {
|
||||
testOperation(
|
||||
Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4),
|
||||
Seq(Seq(), 1 to 1, 1 to 2, 1 to 3, 1 to 4),
|
||||
(s: DStream[Int]) => s.count(),
|
||||
Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L))
|
||||
Seq(Seq(0L), Seq(1L), Seq(2L), Seq(3L), Seq(4L))
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -243,6 +243,17 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
|
|||
assert(output(i) === expectedOutput(i))
|
||||
}
|
||||
}
|
||||
|
||||
test("kafka input stream") {
|
||||
val ssc = new StreamingContext(master, framework, batchDuration)
|
||||
val topics = Map("my-topic" -> 1)
|
||||
val test1 = ssc.kafkaStream("localhost:12345", "group", topics)
|
||||
val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK)
|
||||
|
||||
// Test specifying decoder
|
||||
val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group")
|
||||
val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue