Merge branch 'master' of github.com:mesos/spark into graph

Conflicts:
	run
	run2.cmd
This commit is contained in:
Reynold Xin 2013-06-29 15:30:21 -07:00
commit 564d902d79
84 changed files with 4030 additions and 1761 deletions

1
.gitignore vendored
View file

@ -36,4 +36,5 @@ streaming-tests.log
dependency-reduced-pom.xml
.ensime
.ensime_lucene
checkpoint
derby.log

52
bin/compute-classpath.cmd Normal file
View file

@ -0,0 +1,52 @@
@echo off
rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
rem script and the ExecutorRunner in standalone cluster mode.
set SCALA_VERSION=2.9.3
rem Figure out where the Spark framework is installed
set FWDIR=%~dp0..\
rem Load environment variables from conf\spark-env.cmd, if it exists
if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
set CORE_DIR=%FWDIR%core
set REPL_DIR=%FWDIR%repl
set EXAMPLES_DIR=%FWDIR%examples
set BAGEL_DIR=%FWDIR%bagel
set STREAMING_DIR=%FWDIR%streaming
set PYSPARK_DIR=%FWDIR%python
rem Build up classpath
set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources
set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes
set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\*
set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\*
set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\*
set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\*
set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\*
set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
rem Add hadoop conf dir - else FileSystem.*, etc fail
rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
rem the configurtion files.
if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
:no_hadoop_conf_dir
if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
:no_yarn_conf_dir
rem Add Scala standard library
set CLASSPATH=%CLASSPATH%;%SCALA_HOME%\lib\scala-library.jar;%SCALA_HOME%\lib\scala-compiler.jar;%SCALA_HOME%\lib\jline.jar
rem A bit of a hack to allow calling this script within run2.cmd without seeing output
if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
echo %CLASSPATH%
:exit

89
bin/compute-classpath.sh Executable file
View file

@ -0,0 +1,89 @@
#!/bin/bash
# This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
# script and the ExecutorRunner in standalone cluster mode.
SCALA_VERSION=2.9.3
# Figure out where Spark is installed
FWDIR="$(cd `dirname $0`/..; pwd)"
# Load environment variables from conf/spark-env.sh, if it exists
if [ -e $FWDIR/conf/spark-env.sh ] ; then
. $FWDIR/conf/spark-env.sh
fi
CORE_DIR="$FWDIR/core"
REPL_DIR="$FWDIR/repl"
REPL_BIN_DIR="$FWDIR/repl-bin"
EXAMPLES_DIR="$FWDIR/examples"
BAGEL_DIR="$FWDIR/bagel"
STREAMING_DIR="$FWDIR/streaming"
PYSPARK_DIR="$FWDIR/python"
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH"
CLASSPATH="$CLASSPATH:$FWDIR/conf"
CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes"
if [ -n "$SPARK_TESTING" ] ; then
CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes"
fi
CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources"
CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar
if [ -e "$FWDIR/lib_managed" ]; then
CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*"
CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*"
fi
CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*"
# Add the shaded JAR for Maven builds
if [ -e $REPL_BIN_DIR/target ]; then
for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
CLASSPATH="$CLASSPATH:$jar"
done
# The shaded JAR doesn't contain examples, so include those separately
EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
CLASSPATH+=":$EXAMPLES_JAR"
fi
CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
CLASSPATH="$CLASSPATH:$jar"
done
# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack
# to avoid the -sources and -doc packages that are built by publish-local.
if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ]; then
# Use the JAR from the SBT build
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar`
fi
if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then
# Use the JAR from the Maven build
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
fi
# Add hadoop conf dir - else FileSystem.*, etc fail !
# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
# the configurtion files.
if [ "x" != "x$HADOOP_CONF_DIR" ]; then
CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR"
fi
if [ "x" != "x$YARN_CONF_DIR" ]; then
CLASSPATH="$CLASSPATH:$YARN_CONF_DIR"
fi
# Add Scala standard library
if [ -z "$SCALA_LIBRARY_PATH" ]; then
if [ -z "$SCALA_HOME" ]; then
echo "SCALA_HOME is not set" >&2
exit 1
fi
SCALA_LIBRARY_PATH="$SCALA_HOME/lib"
fi
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar"
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar"
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar"
echo "$CLASSPATH"

View file

@ -32,8 +32,8 @@
<artifactId>compress-lzf</artifactId>
</dependency>
<dependency>
<groupId>asm</groupId>
<artifactId>asm-all</artifactId>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>

View file

@ -8,15 +8,20 @@ import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioSocketChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
class FileClient {
private Logger LOG = LoggerFactory.getLogger(this.getClass().getName());
private FileClientHandler handler = null;
private Channel channel = null;
private Bootstrap bootstrap = null;
private int connectTimeout = 60*1000; // 1 min
public FileClient(FileClientHandler handler) {
public FileClient(FileClientHandler handler, int connectTimeout) {
this.handler = handler;
this.connectTimeout = connectTimeout;
}
public void init() {
@ -25,25 +30,10 @@ class FileClient {
.channel(OioSocketChannel.class)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeout)
.handler(new FileClientChannelInitializer(handler));
}
public static final class ChannelCloseListener implements ChannelFutureListener {
private FileClient fc = null;
public ChannelCloseListener(FileClient fc){
this.fc = fc;
}
@Override
public void operationComplete(ChannelFuture future) {
if (fc.bootstrap!=null){
fc.bootstrap.shutdown();
fc.bootstrap = null;
}
}
}
public void connect(String host, int port) {
try {
// Start the connection attempt.
@ -58,8 +48,8 @@ class FileClient {
public void waitForClose() {
try {
channel.closeFuture().sync();
} catch (InterruptedException e){
e.printStackTrace();
} catch (InterruptedException e) {
LOG.warn("FileClient interrupted", e);
}
}

View file

@ -9,7 +9,14 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
private FileHeader currentHeader = null;
private volatile boolean handlerCalled = false;
public boolean isComplete() {
return handlerCalled;
}
public abstract void handle(ChannelHandlerContext ctx, ByteBuf in, FileHeader header);
public abstract void handleError(String blockId);
@Override
public ByteBuf newInboundBuffer(ChannelHandlerContext ctx) {
@ -26,6 +33,7 @@ abstract class FileClientHandler extends ChannelInboundByteHandlerAdapter {
// get file
if(in.readableBytes() >= currentHeader.fileLen()) {
handle(ctx, in, currentHeader);
handlerCalled = true;
currentHeader = null;
ctx.close();
}

View file

@ -58,6 +58,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead

View file

@ -5,8 +5,7 @@ import java.lang.reflect.Field
import scala.collection.mutable.Map
import scala.collection.mutable.Set
import org.objectweb.asm.{ClassReader, MethodVisitor, Type}
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.{ClassReader, ClassVisitor, MethodVisitor, Type}
import org.objectweb.asm.Opcodes._
import java.io.{InputStream, IOException, ByteArrayOutputStream, ByteArrayInputStream, BufferedInputStream}
@ -162,10 +161,10 @@ private[spark] object ClosureCleaner extends Logging {
}
}
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends EmptyVisitor {
private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) extends ClassVisitor(ASM4) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
return new EmptyVisitor {
return new MethodVisitor(ASM4) {
override def visitFieldInsn(op: Int, owner: String, name: String, desc: String) {
if (op == GETFIELD) {
for (cl <- output.keys if cl.getName == owner.replace('/', '.')) {
@ -188,7 +187,7 @@ private[spark] class FieldAccessFinder(output: Map[Class[_], Set[String]]) exten
}
}
private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisitor {
private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends ClassVisitor(ASM4) {
var myName: String = null
override def visit(version: Int, access: Int, name: String, sig: String,
@ -198,7 +197,7 @@ private[spark] class InnerClosureFinder(output: Set[Class[_]]) extends EmptyVisi
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
return new EmptyVisitor {
return new MethodVisitor(ASM4) {
override def visitMethodInsn(op: Int, owner: String, name: String,
desc: String) {
val argTypes = Type.getArgumentTypes(desc)

View file

@ -1,5 +1,6 @@
package spark
import java.nio.ByteBuffer
import java.util.{Date, HashMap => JHashMap}
import java.text.SimpleDateFormat
@ -10,6 +11,8 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.mapred.FileOutputFormat
import org.apache.hadoop.mapred.HadoopWriter
@ -17,7 +20,7 @@ import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil, TaskAttemptID, TaskAttemptContext}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, Job => NewAPIHadoopJob, HadoopMapReduceUtil}
import spark.partial.BoundedDouble
import spark.partial.PartialResult
@ -62,8 +65,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
val aggregator =
new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey(_), true)
} else if (mapSideCombine) {
@ -95,7 +97,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* list concatenation, 0 for addition, or 1 for multiplication.).
*/
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
combineByKey[V]({v: V => func(zeroValue, v)}, func, func, partitioner)
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)
// When deserializing, use a lazy val to create just one instance of the serializer per task
lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
def createZero() = cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
}
/**
@ -185,11 +196,13 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
* partitioning of the resulting key-value pair RDD by passing a Partitioner.
*/
def groupByKey(partitioner: Partitioner): RDD[(K, Seq[V])] = {
// groupByKey shouldn't use map side combine because map side combine does not
// reduce the amount of data shuffled and requires all map side data be inserted
// into a hash table, leading to more objects in the old gen.
def createCombiner(v: V) = ArrayBuffer(v)
def mergeValue(buf: ArrayBuffer[V], v: V) = buf += v
def mergeCombiners(b1: ArrayBuffer[V], b2: ArrayBuffer[V]) = b1 ++= b2
val bufs = combineByKey[ArrayBuffer[V]](
createCombiner _, mergeValue _, mergeCombiners _, partitioner)
createCombiner _, mergeValue _, null, partitioner, mapSideCombine=false)
bufs.asInstanceOf[RDD[(K, Seq[V])]]
}
@ -515,6 +528,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
}
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD. Compress the result with the
* supplied codec.
*/
def saveAsHadoopFile[F <: OutputFormat[K, V]](
path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
}
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
@ -574,6 +597,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
jobCommitter.cleanupJob(jobTaskContext)
}
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD. Compress with the supplied codec.
*/
def saveAsHadoopFile(
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
codec: Class[_ <: CompressionCodec]) {
saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass,
new JobConf(self.context.hadoopConfiguration), Some(codec))
}
/**
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
@ -583,11 +620,19 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[_ <: OutputFormat[_, _]],
conf: JobConf = new JobConf(self.context.hadoopConfiguration)) {
conf: JobConf = new JobConf(self.context.hadoopConfiguration),
codec: Option[Class[_ <: CompressionCodec]] = None) {
conf.setOutputKeyClass(keyClass)
conf.setOutputValueClass(valueClass)
// conf.setOutputFormat(outputFormatClass) // Doesn't work in Scala 2.9 due to what may be a generics bug
conf.set("mapred.output.format.class", outputFormatClass.getName)
for (c <- codec) {
conf.setCompressMapOutput(true)
conf.set("mapred.output.compress", "true")
conf.setMapOutputCompressorClass(c)
conf.set("mapred.output.compression.codec", c.getCanonicalName)
conf.set("mapred.output.compression.type", CompressionType.BLOCK.toString)
}
conf.setOutputCommitter(classOf[FileOutputCommitter])
FileOutputFormat.setOutputPath(conf, HadoopWriter.createPathFromString(path, conf))
saveAsHadoopDataset(conf)

View file

@ -7,12 +7,14 @@ import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.Text
import org.apache.hadoop.mapred.TextOutputFormat
import it.unimi.dsi.fastutil.objects.{Object2LongOpenHashMap => OLMap}
import spark.broadcast.Broadcast
import spark.Partitioner._
import spark.partial.BoundedDouble
import spark.partial.CountEvaluator
@ -35,6 +37,7 @@ import spark.rdd.ZippedPartitionsRDD2
import spark.rdd.ZippedPartitionsRDD3
import spark.rdd.ZippedPartitionsRDD4
import spark.storage.StorageLevel
import spark.util.BoundedPriorityQueue
import SparkContext._
@ -114,6 +117,14 @@ abstract class RDD[T: ClassManifest](
this
}
/** User-defined generator of this RDD*/
var generator = Utils.getCallSiteInfo.firstUserClass
/** Reset generator*/
def setGenerator(_generator: String) = {
generator = _generator
}
/**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. This can only be used to assign a new storage level if the RDD does not
@ -352,13 +363,36 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD created by piping elements to a forked external process.
*/
def pipe(command: Seq[String]): RDD[String] = new PipedRDD(this, command)
def pipe(command: String, env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
/**
* Return an RDD created by piping elements to a forked external process.
* The print behavior can be customized by providing two functions.
*
* @param command command to run in forked process.
* @param env environment variables to set.
* @param printPipeContext Before piping elements, this function is called as an oppotunity
* to pipe context data. Print line function (like out.println) will be
* passed as printPipeContext's parameter.
* @param printRDDElement Use this function to customize how to pipe elements. This function
* will be called with each RDD element as the 1st parameter, and the
* print line function (like out.println()) as the 2nd parameter.
* An example of pipe the RDD data of groupBy() in a streaming way,
* instead of constructing a huge String to concat all the elements:
* def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
* for (e <- record._2){f(e)}
* @return the result RDD
*/
def pipe(command: Seq[String], env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
def pipe(
command: Seq[String],
env: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
printRDDElement: (T, String => Unit) => Unit = null): RDD[String] =
new PipedRDD(this, command, env,
if (printPipeContext ne null) sc.clean(printPipeContext) else null,
if (printRDDElement ne null) sc.clean(printRDDElement) else null)
/**
* Return a new RDD by applying a function to each partition of this RDD.
@ -722,6 +756,24 @@ abstract class RDD[T: ClassManifest](
case _ => throw new UnsupportedOperationException("empty collection")
}
/**
* Returns the top K elements from this RDD as defined by
* the specified implicit Ordering[T].
* @param num the number of top elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = {
mapPartitions { items =>
val queue = new BoundedPriorityQueue[T](num)
queue ++= items
Iterator.single(queue)
}.reduce { (queue1, queue2) =>
queue1 ++= queue2
queue1
}.toArray
}
/**
* Save this RDD as a text file, using string representations of elements.
*/
@ -730,6 +782,14 @@ abstract class RDD[T: ClassManifest](
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
}
/**
* Save this RDD as a compressed text file, using string representations of elements.
*/
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
this.map(x => (NullWritable.get(), new Text(x.toString)))
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
}
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
@ -788,7 +848,7 @@ abstract class RDD[T: ClassManifest](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
private[spark] val origin = Utils.getSparkCallSite
private[spark] val origin = Utils.formatSparkCallSite
private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]

View file

@ -18,6 +18,7 @@ import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.mapred.OutputCommitter
import org.apache.hadoop.mapred.FileOutputCommitter
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.Writable
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.io.BytesWritable
@ -62,7 +63,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
* byte arrays to BytesWritable, and Strings to Text. The `path` can be on any Hadoop-supported
* file system.
*/
def saveAsSequenceFile(path: String) {
def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
def anyToWritable[U <% Writable](u: U): Writable = u
val keyClass = getWritableClass[K]
@ -72,14 +73,18 @@ class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : Cla
logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," + valueClass.getSimpleName + ")" )
val format = classOf[SequenceFileOutputFormat[Writable, Writable]]
val jobConf = new JobConf(self.context.hadoopConfiguration)
if (!convertKey && !convertValue) {
self.saveAsHadoopFile(path, keyClass, valueClass, format)
self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec)
} else if (!convertKey && convertValue) {
self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(
path, keyClass, valueClass, format, jobConf, codec)
} else if (convertKey && !convertValue) {
self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(path, keyClass, valueClass, format)
self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(
path, keyClass, valueClass, format, jobConf, codec)
} else if (convertKey && convertValue) {
self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(path, keyClass, valueClass, format)
self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(
path, keyClass, valueClass, format, jobConf, codec)
}
}
}

View file

@ -49,7 +49,6 @@ import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend
import spark.storage.{BlockManagerUI, StorageStatus, StorageUtils, RDDInfo}
import spark.util.{MetadataCleaner, TimeStampedHashMap}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
@ -630,7 +629,7 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
val callSite = Utils.getSparkCallSite
val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runJob(rdd, func, partitions, callSite, allowLocal, resultHandler, localProperties.value)
@ -713,7 +712,7 @@ class SparkContext(
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
val callSite = Utils.getSparkCallSite
val callSite = Utils.formatSparkCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout, localProperties.value)

View file

@ -1,5 +1,8 @@
package spark
import collection.mutable
import serializer.Serializer
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
import akka.remote.RemoteActorRefProvider
@ -9,6 +12,7 @@ import spark.storage.BlockManagerMaster
import spark.network.ConnectionManager
import spark.serializer.{Serializer, SerializerManager}
import spark.util.AkkaUtils
import spark.api.python.PythonWorkerFactory
/**
@ -37,7 +41,10 @@ class SparkEnv (
// If executorId is NOT found, return defaultHostPort
var executorIdToHostPort: Option[(String, String) => String]) {
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
def stop() {
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
shuffleFetcher.stop()
@ -50,6 +57,11 @@ class SparkEnv (
actorSystem.awaitTermination()
}
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
synchronized {
pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create()
}
}
def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
val env = SparkEnv.get

View file

@ -116,8 +116,8 @@ private object Utils extends Logging {
while (dir == null) {
attempts += 1
if (attempts > maxAttempts) {
throw new IOException("Failed to create a temp directory after " + maxAttempts +
" attempts!")
throw new IOException("Failed to create a temp directory (under " + root + ") after " +
maxAttempts + " attempts!")
}
try {
dir = new File(root, "spark-" + UUID.randomUUID.toString)
@ -522,13 +522,45 @@ private object Utils extends Logging {
execute(command, new File("."))
}
/**
* Execute a command and get its output, throwing an exception if it yields a code other than 0.
*/
def executeAndGetOutput(command: Seq[String], workingDir: File = new File(".")): String = {
val process = new ProcessBuilder(command: _*)
.directory(workingDir)
.start()
new Thread("read stderr for " + command(0)) {
override def run() {
for (line <- Source.fromInputStream(process.getErrorStream).getLines) {
System.err.println(line)
}
}
}.start()
val output = new StringBuffer
val stdoutThread = new Thread("read stdout for " + command(0)) {
override def run() {
for (line <- Source.fromInputStream(process.getInputStream).getLines) {
output.append(line)
}
}
}
stdoutThread.start()
val exitCode = process.waitFor()
stdoutThread.join() // Wait for it to finish reading output
if (exitCode != 0) {
throw new SparkException("Process " + command + " exited with code " + exitCode)
}
output.toString
}
private[spark] class CallSiteInfo(val lastSparkMethod: String, val firstUserFile: String,
val firstUserLine: Int, val firstUserClass: String)
/**
* When called inside a class in the spark package, returns the name of the user code class
* (outside the spark package) that called into Spark, as well as which Spark method they called.
* This is used, for example, to tell users where in their code each RDD got created.
*/
def getSparkCallSite: String = {
def getCallSiteInfo: CallSiteInfo = {
val trace = Thread.currentThread.getStackTrace().filter( el =>
(!el.getMethodName.contains("getStackTrace")))
@ -540,6 +572,7 @@ private object Utils extends Logging {
var firstUserFile = "<unknown>"
var firstUserLine = 0
var finished = false
var firstUserClass = "<unknown>"
for (el <- trace) {
if (!finished) {
@ -554,13 +587,19 @@ private object Utils extends Logging {
else {
firstUserLine = el.getLineNumber
firstUserFile = el.getFileName
firstUserClass = el.getClassName
finished = true
}
}
}
"%s at %s:%s".format(lastSparkMethod, firstUserFile, firstUserLine)
new CallSiteInfo(lastSparkMethod, firstUserFile, firstUserLine, firstUserClass)
}
def formatSparkCallSite = {
val callSiteInfo = getCallSiteInfo
"%s at %s:%s".format(callSiteInfo.lastSparkMethod, callSiteInfo.firstUserFile,
callSiteInfo.firstUserLine)
}
/**
* Try to find a free port to bind to on the local host. This should ideally never be needed,
* except that, unfortunately, some of the networking libraries we currently rely on (e.g. Spray)
@ -602,4 +641,67 @@ private object Utils extends Logging {
}
return false
}
def isSpace(c: Char): Boolean = {
" \t\r\n".indexOf(c) != -1
}
/**
* Split a string of potentially quoted arguments from the command line the way that a shell
* would do it to determine arguments to a command. For example, if the string is 'a "b c" d',
* then it would be parsed as three arguments: 'a', 'b c' and 'd'.
*/
def splitCommandString(s: String): Seq[String] = {
val buf = new ArrayBuffer[String]
var inWord = false
var inSingleQuote = false
var inDoubleQuote = false
var curWord = new StringBuilder
def endWord() {
buf += curWord.toString
curWord.clear()
}
var i = 0
while (i < s.length) {
var nextChar = s.charAt(i)
if (inDoubleQuote) {
if (nextChar == '"') {
inDoubleQuote = false
} else if (nextChar == '\\') {
if (i < s.length - 1) {
// Append the next character directly, because only " and \ may be escaped in
// double quotes after the shell's own expansion
curWord.append(s.charAt(i + 1))
i += 1
}
} else {
curWord.append(nextChar)
}
} else if (inSingleQuote) {
if (nextChar == '\'') {
inSingleQuote = false
} else {
curWord.append(nextChar)
}
// Backslashes are not treated specially in single quotes
} else if (nextChar == '"') {
inWord = true
inDoubleQuote = true
} else if (nextChar == '\'') {
inWord = true
inSingleQuote = true
} else if (!isSpace(nextChar)) {
curWord.append(nextChar)
inWord = true
} else if (inWord && isSpace(nextChar)) {
endWord()
inWord = false
}
i += 1
}
if (inWord || inDoubleQuote || inSingleQuote) {
endWord()
}
return buf
}
}

View file

@ -6,6 +6,7 @@ import java.util.Comparator
import scala.Tuple2
import scala.collection.JavaConversions._
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.OutputFormat
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
@ -459,6 +460,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass)
}
/** Output the RDD to any Hadoop-supported file system, compressing with the supplied codec. */
def saveAsHadoopFile[F <: OutputFormat[_, _]](
path: String,
keyClass: Class[_],
valueClass: Class[_],
outputFormatClass: Class[F],
codec: Class[_ <: CompressionCodec]) {
rdd.saveAsHadoopFile(path, keyClass, valueClass, outputFormatClass, codec)
}
/** Output the RDD to any Hadoop-supported file system. */
def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
path: String,

View file

@ -86,7 +86,6 @@ JavaRDDLike[T, JavaRDD[T]] {
*/
def subtract(other: JavaRDD[T], p: Partitioner): JavaRDD[T] =
wrapRDD(rdd.subtract(other, p))
}
object JavaRDD {

View file

@ -1,9 +1,10 @@
package spark.api.java
import java.util.{List => JList}
import java.util.{List => JList, Comparator}
import scala.Tuple2
import scala.collection.JavaConversions._
import org.apache.hadoop.io.compress.CompressionCodec
import spark.{SparkContext, Partition, RDD, TaskContext}
import spark.api.java.JavaPairRDD._
import spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
@ -310,6 +311,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def saveAsTextFile(path: String) = rdd.saveAsTextFile(path)
/**
* Save this RDD as a compressed text file, using string representations of elements.
*/
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) =
rdd.saveAsTextFile(path, codec)
/**
* Save this RDD as a SequenceFile of serialized objects.
*/
@ -351,4 +359,29 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def toDebugString(): String = {
rdd.toDebugString
}
/**
* Returns the top K elements from this RDD as defined by
* the specified Comparator[T].
* @param num the number of top elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
def top(num: Int, comp: Comparator[T]): JList[T] = {
import scala.collection.JavaConversions._
val topElems = rdd.top(num)(Ordering.comparatorToOrdering(comp))
val arr: java.util.Collection[T] = topElems.toSeq
new java.util.ArrayList(arr)
}
/**
* Returns the top K elements from this RDD using the
* natural ordering for T.
* @param num the number of top elements to return
* @return an array of top elements
*/
def top(num: Int): JList[T] = {
val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[T]]
top(num, comp)
}
}

View file

@ -2,10 +2,9 @@ package spark.api.python
import java.io._
import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Collections}
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
import scala.io.Source
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import spark.broadcast.Broadcast
@ -16,7 +15,7 @@ import spark.rdd.PipedRDD
private[spark] class PythonRDD[T: ClassManifest](
parent: RDD[T],
command: Seq[String],
envVars: java.util.Map[String, String],
envVars: JMap[String, String],
preservePartitoning: Boolean,
pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest](
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
preservePartitoning: Boolean, pythonExec: String,
broadcastVars: JList[Broadcast[Array[Byte]]],
accumulator: Accumulator[JList[Array[Byte]]]) =
@ -36,35 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest](
override val partitioner = if (preservePartitoning) parent.partitioner else None
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
// Add the environmental variables to the process.
val currentEnvVars = pb.environment()
for ((variable, value) <- envVars) {
currentEnvVars.put(variable, value)
}
val proc = pb.start()
val startTime = System.currentTimeMillis
val env = SparkEnv.get
// Start a thread to print the process's stderr to ours
new Thread("stderr reader for " + pythonExec) {
override def run() {
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
System.err.println(line)
}
}
}.start()
val worker = env.createPythonWorker(pythonExec, envVars.toMap)
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
val dOut = new DataOutputStream(proc.getOutputStream)
val out = new PrintWriter(worker.getOutputStream)
val dOut = new DataOutputStream(worker.getOutputStream)
// Partition index
dOut.writeInt(split.index)
// sparkFilesDir
@ -88,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest](
}
dOut.flush()
out.flush()
proc.getOutputStream.close()
worker.shutdownOutput()
}
}.start()
// Return an iterator that read lines from the process's stdout
val stream = new DataInputStream(proc.getInputStream)
val stream = new DataInputStream(worker.getInputStream)
return new Iterator[Array[Byte]] {
def next(): Array[Byte] = {
val obj = _nextObj
_nextObj = read()
if (hasNext) {
// FIXME: can deadlock if worker is waiting for us to
// respond to current message (currently irrelevant because
// output is shutdown before we read any input)
_nextObj = read()
}
obj
}
@ -108,6 +95,17 @@ private[spark] class PythonRDD[T: ClassManifest](
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
case -3 =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
val finishTime = stream.readLong()
val boot = bootTime - startTime
val init = initTime - bootTime
val finish = finishTime - initTime
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
read
case -2 =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
@ -115,23 +113,21 @@ private[spark] class PythonRDD[T: ClassManifest](
stream.readFully(obj)
throw new PythonException(new String(obj))
case -1 =>
// We've finished the data section of the output, but we can still read some
// accumulator updates; let's do that, breaking when we get EOFException
while (true) {
val len2 = stream.readInt()
// We've finished the data section of the output, but we can still
// read some accumulator updates; let's do that, breaking when we
// get a negative length record.
var len2 = stream.readInt()
while (len2 >= 0) {
val update = new Array[Byte](len2)
stream.readFully(update)
accumulator += Collections.singletonList(update)
len2 = stream.readInt()
}
new Array[Byte](0)
}
} catch {
case eof: EOFException => {
val exitStatus = proc.waitFor()
if (exitStatus != 0) {
throw new Exception("Subprocess exited with status " + exitStatus)
}
new Array[Byte](0)
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
case e => throw e
}
@ -159,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (a, b)
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
}
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
}
@ -215,7 +211,7 @@ private[spark] object PythonRDD {
dOut.write(s)
dOut.writeByte(Pickle.STOP)
} else {
throw new Exception("Unexpected RDD type")
throw new SparkException("Unexpected RDD type")
}
}

View file

@ -0,0 +1,95 @@
package spark.api.python
import java.io.{DataInputStream, IOException}
import java.net.{Socket, SocketException, InetAddress}
import scala.collection.JavaConversions._
import spark._
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
extends Logging {
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
def create(): Socket = {
synchronized {
// Start the daemon if it hasn't been started
startDaemon()
// Attempt to connect, restart and retry once if it fails
try {
new Socket(daemonHost, daemonPort)
} catch {
case exc: SocketException => {
logWarning("Python daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
new Socket(daemonHost, daemonPort)
}
case e => throw e
}
}
}
def stop() {
stopDaemon()
}
private def startDaemon() {
synchronized {
// Is it already running?
if (daemon != null) {
return
}
try {
// Create and start the daemon
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
val workerEnv = pb.environment()
workerEnv.putAll(envVars)
daemon = pb.start()
daemonPort = new DataInputStream(daemon.getInputStream).readInt()
// Redirect the stderr to ours
new Thread("stderr reader for " + pythonExec) {
override def run() {
scala.util.control.Exception.ignoring(classOf[IOException]) {
// FIXME HACK: We copy the stream on the level of bytes to
// attempt to dodge encoding problems.
val in = daemon.getErrorStream
var buf = new Array[Byte](1024)
var len = in.read(buf)
while (len != -1) {
System.err.write(buf, 0, len)
len = in.read(buf)
}
}
}
}.start()
} catch {
case e => {
stopDaemon()
throw e
}
}
// Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
// detect our disappearance.
}
}
private def stopDaemon() {
synchronized {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}
daemon = null
daemonPort = 0
}
}
}

View file

@ -1,6 +1,7 @@
package spark.deploy.worker
import java.io._
import java.lang.System.getenv
import spark.deploy.{ExecutorState, ExecutorStateChanged, ApplicationDescription}
import akka.actor.ActorRef
import spark.{Utils, Logging}
@ -40,7 +41,7 @@ private[spark] class ExecutorRunner(
workerThread.start()
// Shutdown hook that kills actors on shutdown.
shutdownHook = new Thread() {
shutdownHook = new Thread() {
override def run() {
if (process != null) {
logInfo("Shutdown hook killing child process.")
@ -77,9 +78,29 @@ private[spark] class ExecutorRunner(
def buildCommandSeq(): Seq[String] = {
val command = appDesc.command
val script = if (System.getProperty("os.name").startsWith("Windows")) "run.cmd" else "run"
val runScript = new File(sparkHome, script).getCanonicalPath
Seq(runScript, command.mainClass) ++ (command.arguments ++ Seq(appId)).map(substituteVariables)
val runner = Option(getenv("JAVA_HOME")).map(_ + "/bin/java").getOrElse("java")
// SPARK-698: do not call the run.cmd script, as process.destroy()
// fails to kill a process tree on Windows
Seq(runner) ++ buildJavaOpts() ++ Seq(command.mainClass) ++
command.arguments.map(substituteVariables)
}
/**
* Attention: this must always be aligned with the environment variables in the run scripts and
* the way the JAVA_OPTS are assembled there.
*/
def buildJavaOpts(): Seq[String] = {
val libraryOpts = Option(getenv("SPARK_LIBRARY_PATH"))
.map(p => List("-Djava.library.path=" + p))
.getOrElse(Nil)
val userOpts = Option(getenv("SPARK_JAVA_OPTS")).map(Utils.splitCommandString).getOrElse(Nil)
val memoryOpts = Seq("-Xms" + memory + "M", "-Xmx" + memory + "M")
// Figure out our classpath with the external compute-classpath script
val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
val classPath = Utils.executeAndGetOutput(Seq(sparkHome + "/bin/compute-classpath" + ext))
Seq("-cp", classPath) ++ libraryOpts ++ userOpts ++ memoryOpts
}
/** Spawn a thread that will redirect a given stream to a file */
@ -115,7 +136,6 @@ private[spark] class ExecutorRunner(
for ((key, value) <- appDesc.command.environment) {
env.put(key, value)
}
env.put("SPARK_MEM", memory.toString + "m")
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
env.put("SPARK_LAUNCH_WITH_SCALA", "0")

View file

@ -42,7 +42,8 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
// Create our ClassLoader and set it on this thread
private val urlClassLoader = createClassLoader()
Thread.currentThread.setContextClassLoader(urlClassLoader)
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
@ -88,7 +89,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
override def run() {
val startTime = System.currentTimeMillis()
SparkEnv.set(env)
Thread.currentThread.setContextClassLoader(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo("Running task ID " + taskId)
context.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
@ -104,6 +105,7 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val value = task.run(taskId.toInt)
val taskFinish = System.currentTimeMillis()
task.metrics.foreach{ m =>
m.hostname = Utils.localHostName
m.executorDeserializeTime = (taskStart - startTime).toInt
m.executorRunTime = (taskFinish - taskStart).toInt
}
@ -152,26 +154,31 @@ private[spark] class Executor(executorId: String, slaveHostname: String, propert
val urls = currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
}.toArray
loader = new URLClassLoader(urls, loader)
new ExecutorURLClassLoader(urls, loader)
}
// If the REPL is in use, add another ClassLoader that will read
// new classes defined by the REPL as the user types code
/**
* If the REPL is in use, add another ClassLoader that will read
* new classes defined by the REPL as the user types code
*/
private def addReplClassLoaderIfNeeded(parent: ClassLoader): ClassLoader = {
val classUri = System.getProperty("spark.repl.class.uri")
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
loader = {
try {
val klass = Class.forName("spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
constructor.newInstance(classUri, loader)
} catch {
case _: ClassNotFoundException => loader
}
try {
val klass = Class.forName("spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
val constructor = klass.getConstructor(classOf[String], classOf[ClassLoader])
return constructor.newInstance(classUri, parent)
} catch {
case _: ClassNotFoundException =>
logError("Could not find spark.repl.ExecutorClassLoader on classpath!")
System.exit(1)
null
}
} else {
return parent
}
return new ExecutorURLClassLoader(Array(), loader)
}
/**

View file

@ -1,6 +1,11 @@
package spark.executor
class TaskMetrics extends Serializable {
/**
* Host's name the task runs on
*/
var hostname: String = _
/**
* Time taken on the executor to deserialize this task
*/
@ -33,10 +38,15 @@ object TaskMetrics {
class ShuffleReadMetrics extends Serializable {
/**
* Time when shuffle finishs
*/
var shuffleFinishTime: Long = _
/**
* Total number of blocks fetched in a shuffle (remote or local)
*/
var totalBlocksFetched : Int = _
var totalBlocksFetched: Int = _
/**
* Number of remote blocks fetched in a shuffle

View file

@ -9,19 +9,36 @@ import io.netty.util.CharsetUtil
import spark.Logging
import spark.network.ConnectionManagerId
import scala.collection.JavaConverters._
private[spark] class ShuffleCopier extends Logging {
def getBlock(cmId: ConnectionManagerId, blockId: String,
def getBlock(host: String, port: Int, blockId: String,
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback)
val fc = new FileClient(handler)
fc.init()
fc.connect(cmId.host, cmId.port)
fc.sendRequest(blockId)
fc.waitForClose()
fc.close()
val connectTimeout = System.getProperty("spark.shuffle.netty.connect.timeout", "60000").toInt
val fc = new FileClient(handler, connectTimeout)
try {
fc.init()
fc.connect(host, port)
fc.sendRequest(blockId)
fc.waitForClose()
fc.close()
} catch {
// Handle any socket-related exceptions in FileClient
case e: Exception => {
logError("Shuffle copy of block " + blockId + " from " + host + ":" + port + " failed", e)
handler.handleError(blockId)
}
}
}
def getBlock(cmId: ConnectionManagerId, blockId: String,
resultCollectCallback: (String, Long, ByteBuf) => Unit) {
getBlock(cmId.host, cmId.port, blockId, resultCollectCallback)
}
def getBlocks(cmId: ConnectionManagerId,
@ -44,20 +61,18 @@ private[spark] object ShuffleCopier extends Logging {
logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
override def handleError(blockId: String) {
if (!isComplete) {
resultCollectCallBack(blockId, -1, null)
}
}
}
def echoResultCollectCallBack(blockId: String, size: Long, content: ByteBuf) {
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
}
def runGetBlock(host:String, port:Int, file:String){
val handler = new ShuffleClientHandler(echoResultCollectCallBack)
val fc = new FileClient(handler)
fc.init();
fc.connect(host, port)
fc.sendRequest(file)
fc.waitForClose();
fc.close()
if (size != -1) {
logInfo("File: " + blockId + " content is : \" " + content.toString(CharsetUtil.UTF_8) + "\"")
}
}
def main(args: Array[String]) {
@ -71,14 +86,16 @@ private[spark] object ShuffleCopier extends Logging {
val threads = if (args.length > 3) args(3).toInt else 10
val copiers = Executors.newFixedThreadPool(80)
for (i <- Range(0, threads)) {
val runnable = new Runnable() {
val tasks = (for (i <- Range(0, threads)) yield {
Executors.callable(new Runnable() {
def run() {
runGetBlock(host, port, file)
val copier = new ShuffleCopier()
copier.getBlock(host, port, file, echoResultCollectCallBack)
}
}
copiers.execute(runnable)
}
})
}).asJava
copiers.invokeAll(tasks)
copiers.shutdown
System.exit(0)
}
}

View file

@ -6,7 +6,7 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer
import spark.{Aggregator, Logging, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Aggregator, Partition, Partitioner, RDD, SparkEnv, TaskContext}
import spark.{Dependency, OneToOneDependency, ShuffleDependency}
@ -49,12 +49,16 @@ private[spark] class CoGroupAggregator
*
* @param rdds parent RDDs.
* @param part partitioner used to partition the shuffle output.
* @param mapSideCombine flag indicating whether to merge values before shuffle step.
* @param mapSideCombine flag indicating whether to merge values before shuffle step. If the flag
* is on, Spark does an extra pass over the data on the map side to merge
* all values belonging to the same key together. This can reduce the amount
* of data shuffled if and only if the number of distinct keys is very small,
* and the ratio of key size to value size is also very small.
*/
class CoGroupedRDD[K](
@transient var rdds: Seq[RDD[(K, _)]],
part: Partitioner,
val mapSideCombine: Boolean = true,
val mapSideCombine: Boolean = false,
val serializerClass: String = null)
extends RDD[(K, Seq[Seq[_]])](rdds.head.context, Nil) {

View file

@ -9,6 +9,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import spark.{RDD, SparkEnv, Partition, TaskContext}
import spark.broadcast.Broadcast
/**
@ -18,14 +19,21 @@ import spark.{RDD, SparkEnv, Partition, TaskContext}
class PipedRDD[T: ClassManifest](
prev: RDD[T],
command: Seq[String],
envVars: Map[String, String])
envVars: Map[String, String],
printPipeContext: (String => Unit) => Unit,
printRDDElement: (T, String => Unit) => Unit)
extends RDD[String](prev) {
def this(prev: RDD[T], command: Seq[String]) = this(prev, command, Map())
// Similar to Runtime.exec(), if we are given a single string, split it into words
// using a standard StringTokenizer (i.e. by spaces)
def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command))
def this(
prev: RDD[T],
command: String,
envVars: Map[String, String] = Map(),
printPipeContext: (String => Unit) => Unit = null,
printRDDElement: (T, String => Unit) => Unit = null) =
this(prev, PipedRDD.tokenize(command), envVars, printPipeContext, printRDDElement)
override def getPartitions: Array[Partition] = firstParent[T].partitions
@ -52,8 +60,17 @@ class PipedRDD[T: ClassManifest](
override def run() {
SparkEnv.set(env)
val out = new PrintWriter(proc.getOutputStream)
// input the pipe context firstly
if (printPipeContext != null) {
printPipeContext(out.println(_))
}
for (elem <- firstParent[T].iterator(split, context)) {
out.println(elem)
if (printRDDElement != null) {
printRDDElement(elem, out.println(_))
} else {
out.println(elem)
}
}
out.close()
}

View file

@ -53,14 +53,10 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
val exactMatchLocations = exactMatchPreferredLocations.reduce((x, y) => x.intersect(y))
// Remove exact match and then do host local match.
val otherNodePreferredLocations = rddSplitZip.map(x => {
x._1.preferredLocations(x._2).map(hostPort => {
val host = Utils.parseHostPort(hostPort)._1
if (exactMatchLocations.contains(host)) null else host
}).filter(_ != null)
})
val otherNodeLocalLocations = otherNodePreferredLocations.reduce((x, y) => x.intersect(y))
val exactMatchHosts = exactMatchLocations.map(Utils.parseHostPort(_)._1)
val matchPreferredHosts = exactMatchPreferredLocations.map(locs => locs.map(Utils.parseHostPort(_)._1))
.reduce((x, y) => x.intersect(y))
val otherNodeLocalLocations = matchPreferredHosts.filter { s => !exactMatchHosts.contains(s) }
otherNodeLocalLocations ++ exactMatchLocations
}

View file

@ -298,6 +298,7 @@ class DAGScheduler(
// Compute very short actions like first() or take() with no parent stages locally.
runLocally(job)
} else {
sparkListeners.foreach(_.onJobStart(SparkListenerJobStart(job, properties)))
idToActiveJob(runId) = job
activeJobs += job
resultStageToJob(finalStage) = job
@ -311,6 +312,8 @@ class DAGScheduler(
handleExecutorLost(execId)
case completion: CompletionEvent =>
sparkListeners.foreach(_.onTaskEnd(SparkListenerTaskEnd(completion.task,
completion.reason, completion.taskInfo, completion.taskMetrics)))
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
@ -321,6 +324,7 @@ class DAGScheduler(
for (job <- activeJobs) {
val error = new SparkException("Job cancelled because SparkContext was shut down")
job.listener.jobFailed(error)
sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
}
return true
}
@ -468,6 +472,7 @@ class DAGScheduler(
}
}
if (tasks.size > 0) {
sparkListeners.foreach(_.onStageSubmitted(SparkListenerStageSubmitted(stage, tasks.size)))
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
myPending ++= tasks
logDebug("New pending tasks: " + myPending)
@ -522,6 +527,7 @@ class DAGScheduler(
activeJobs -= job
resultStageToJob -= stage
markStageAsFinished(stage)
sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobSucceeded)))
}
job.listener.taskSucceeded(rt.outputId, event.result)
}
@ -662,7 +668,9 @@ class DAGScheduler(
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
job.listener.jobFailed(new SparkException("Job failed: " + reason))
val error = new SparkException("Job failed: " + reason)
job.listener.jobFailed(error)
sparkListeners.foreach(_.onJobEnd(SparkListenerJobEnd(job, JobFailed(error))))
activeJobs -= job
resultStageToJob -= resultStage
}

View file

@ -0,0 +1,306 @@
package spark.scheduler
import java.io.PrintWriter
import java.io.File
import java.io.FileNotFoundException
import java.text.SimpleDateFormat
import java.util.{Date, Properties}
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{Map, HashMap, ListBuffer}
import scala.io.Source
import spark._
import spark.executor.TaskMetrics
import spark.scheduler.cluster.TaskInfo
// Used to record runtime information for each job, including RDD graph
// tasks' start/stop shuffle information and information from outside
class JobLogger(val logDirName: String) extends SparkListener with Logging {
private val logDir =
if (System.getenv("SPARK_LOG_DIR") != null)
System.getenv("SPARK_LOG_DIR")
else
"/tmp/spark"
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
private val stageIDToJobID = new HashMap[Int, Int]
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
createLogDir()
def this() = this(String.valueOf(System.currentTimeMillis()))
def getLogDir = logDir
def getJobIDtoPrintWriter = jobIDToPrintWriter
def getStageIDToJobID = stageIDToJobID
def getJobIDToStages = jobIDToStages
def getEventQueue = eventQueue
new Thread("JobLogger") {
setDaemon(true)
override def run() {
while (true) {
val event = eventQueue.take
logDebug("Got event of type " + event.getClass.getName)
event match {
case SparkListenerJobStart(job, properties) =>
processJobStartEvent(job, properties)
case SparkListenerStageSubmitted(stage, taskSize) =>
processStageSubmittedEvent(stage, taskSize)
case StageCompleted(stageInfo) =>
processStageCompletedEvent(stageInfo)
case SparkListenerJobEnd(job, result) =>
processJobEndEvent(job, result)
case SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics) =>
processTaskEndEvent(task, reason, taskInfo, taskMetrics)
case _ =>
}
}
}
}.start()
// Create a folder for log files, the folder's name is the creation time of the jobLogger
protected def createLogDir() {
val dir = new File(logDir + "/" + logDirName + "/")
if (dir.exists()) {
return
}
if (dir.mkdirs() == false) {
logError("create log directory error:" + logDir + "/" + logDirName + "/")
}
}
// Create a log file for one job, the file name is the jobID
protected def createLogWriter(jobID: Int) {
try{
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
jobIDToPrintWriter += (jobID -> fileWriter)
} catch {
case e: FileNotFoundException => e.printStackTrace()
}
}
// Close log file, and clean the stage relationship in stageIDToJobID
protected def closeLogWriter(jobID: Int) =
jobIDToPrintWriter.get(jobID).foreach { fileWriter =>
fileWriter.close()
jobIDToStages.get(jobID).foreach(_.foreach{ stage =>
stageIDToJobID -= stage.id
})
jobIDToPrintWriter -= jobID
jobIDToStages -= jobID
}
// Write log information to log file, withTime parameter controls whether to recored
// time stamp for the information
protected def jobLogInfo(jobID: Int, info: String, withTime: Boolean = true) {
var writeInfo = info
if (withTime) {
val date = new Date(System.currentTimeMillis())
writeInfo = DATE_FORMAT.format(date) + ": " +info
}
jobIDToPrintWriter.get(jobID).foreach(_.println(writeInfo))
}
protected def stageLogInfo(stageID: Int, info: String, withTime: Boolean = true) =
stageIDToJobID.get(stageID).foreach(jobID => jobLogInfo(jobID, info, withTime))
protected def buildJobDep(jobID: Int, stage: Stage) {
if (stage.priority == jobID) {
jobIDToStages.get(jobID) match {
case Some(stageList) => stageList += stage
case None => val stageList = new ListBuffer[Stage]
stageList += stage
jobIDToStages += (jobID -> stageList)
}
stageIDToJobID += (stage.id -> jobID)
stage.parents.foreach(buildJobDep(jobID, _))
}
}
protected def recordStageDep(jobID: Int) {
def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
var rddList = new ListBuffer[RDD[_]]
rddList += rdd
rdd.dependencies.foreach{ dep => dep match {
case shufDep: ShuffleDependency[_,_] =>
case _ => rddList ++= getRddsInStage(dep.rdd)
}
}
rddList
}
jobIDToStages.get(jobID).foreach {_.foreach { stage =>
var depRddDesc: String = ""
getRddsInStage(stage.rdd).foreach { rdd =>
depRddDesc += rdd.id + ","
}
var depStageDesc: String = ""
stage.parents.foreach { stage =>
depStageDesc += "(" + stage.id + "," + stage.shuffleDep.get.shuffleId + ")"
}
jobLogInfo(jobID, "STAGE_ID=" + stage.id + " RDD_DEP=(" +
depRddDesc.substring(0, depRddDesc.length - 1) + ")" +
" STAGE_DEP=" + depStageDesc, false)
}
}
}
// Generate indents and convert to String
protected def indentString(indent: Int) = {
val sb = new StringBuilder()
for (i <- 1 to indent) {
sb.append(" ")
}
sb.toString()
}
protected def getRddName(rdd: RDD[_]) = {
var rddName = rdd.getClass.getName
if (rdd.name != null) {
rddName = rdd.name
}
rddName
}
protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
rdd.dependencies.foreach{ dep => dep match {
case shufDep: ShuffleDependency[_,_] =>
val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
}
}
protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
var stageInfo: String = ""
if (stage.isShuffleMap) {
stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
stage.shuffleDep.get.shuffleId
}else{
stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
}
if (stage.priority == jobID) {
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
recordRddInStageGraph(jobID, stage.rdd, indent)
stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
} else
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.priority, false)
}
// Record task metrics into job log files
protected def recordTaskMetrics(stageID: Int, status: String,
taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageID +
" START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
val readMetrics =
taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
" REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
case None => ""
}
val writeMetrics =
taskMetrics.shuffleWriteMetrics match {
case Some(metrics) =>
" SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
case None => ""
}
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
eventQueue.put(stageSubmitted)
}
protected def processStageSubmittedEvent(stage: Stage, taskSize: Int) {
stageLogInfo(stage.id, "STAGE_ID=" + stage.id + " STATUS=SUBMITTED" + " TASK_SIZE=" + taskSize)
}
override def onStageCompleted(stageCompleted: StageCompleted) {
eventQueue.put(stageCompleted)
}
protected def processStageCompletedEvent(stageInfo: StageInfo) {
stageLogInfo(stageInfo.stage.id, "STAGE_ID=" +
stageInfo.stage.id + " STATUS=COMPLETED")
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
eventQueue.put(taskEnd)
}
protected def processTaskEndEvent(task: Task[_], reason: TaskEndReason,
taskInfo: TaskInfo, taskMetrics: TaskMetrics) {
var taskStatus = ""
task match {
case resultTask: ResultTask[_, _] => taskStatus = "TASK_TYPE=RESULT_TASK"
case shuffleMapTask: ShuffleMapTask => taskStatus = "TASK_TYPE=SHUFFLE_MAP_TASK"
}
reason match {
case Success => taskStatus += " STATUS=SUCCESS"
recordTaskMetrics(task.stageId, taskStatus, taskInfo, taskMetrics)
case Resubmitted =>
taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
" STAGE_ID=" + task.stageId
stageLogInfo(task.stageId, taskStatus)
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
task.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
stageLogInfo(task.stageId, taskStatus)
case OtherFailure(message) =>
taskStatus += " STATUS=FAILURE TID=" + taskInfo.taskId +
" STAGE_ID=" + task.stageId + " INFO=" + message
stageLogInfo(task.stageId, taskStatus)
case _ =>
}
}
override def onJobEnd(jobEnd: SparkListenerJobEnd) {
eventQueue.put(jobEnd)
}
protected def processJobEndEvent(job: ActiveJob, reason: JobResult) {
var info = "JOB_ID=" + job.runId
reason match {
case JobSucceeded => info += " STATUS=SUCCESS"
case JobFailed(exception) =>
info += " STATUS=FAILED REASON="
exception.getMessage.split("\\s+").foreach(info += _ + "_")
case _ =>
}
jobLogInfo(job.runId, info.substring(0, info.length - 1).toUpperCase)
closeLogWriter(job.runId)
}
protected def recordJobProperties(jobID: Int, properties: Properties) {
if(properties != null) {
val annotation = properties.getProperty("spark.job.annotation", "")
jobLogInfo(jobID, annotation, false)
}
}
override def onJobStart(jobStart: SparkListenerJobStart) {
eventQueue.put(jobStart)
}
protected def processJobStartEvent(job: ActiveJob, properties: Properties) {
createLogWriter(job.runId)
recordJobProperties(job.runId, properties)
buildJobDep(job.runId, job.finalStage)
recordStageDep(job.runId)
recordStageDepGraph(job.runId, job.finalStage)
jobLogInfo(job.runId, "JOB_ID=" + job.runId + " STATUS=STARTED")
}
}

View file

@ -1,27 +1,59 @@
package spark.scheduler
import java.util.Properties
import spark.scheduler.cluster.TaskInfo
import spark.util.Distribution
import spark.{Utils, Logging}
import spark.{Logging, SparkContext, TaskEndReason, Utils}
import spark.executor.TaskMetrics
trait SparkListener {
/**
* called when a stage is completed, with information on the completed stage
*/
def onStageCompleted(stageCompleted: StageCompleted)
}
sealed trait SparkListenerEvents
case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int) extends SparkListenerEvents
case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
taskMetrics: TaskMetrics) extends SparkListenerEvents
case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null)
extends SparkListenerEvents
case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
extends SparkListenerEvents
trait SparkListener {
/**
* Called when a stage is completed, with information on the completed stage
*/
def onStageCompleted(stageCompleted: StageCompleted) { }
/**
* Called when a stage is submitted
*/
def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { }
/**
* Called when a task ends
*/
def onTaskEnd(taskEnd: SparkListenerTaskEnd) { }
/**
* Called when a job starts
*/
def onJobStart(jobStart: SparkListenerJobStart) { }
/**
* Called when a job ends
*/
def onJobEnd(jobEnd: SparkListenerJobEnd) { }
}
/**
* Simple SparkListener that logs a few summary statistics when each stage completes
*/
class StatsReportListener extends SparkListener with Logging {
def onStageCompleted(stageCompleted: StageCompleted) {
override def onStageCompleted(stageCompleted: StageCompleted) {
import spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
this.logInfo("Finished stage: " + stageCompleted.stageInfo)

View file

@ -177,7 +177,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
val manager = new TaskSetManager(this, taskSet)
val manager = new ClusterTaskSetManager(this, taskSet)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
taskSetTaskIds(taskSet.id) = new HashSet[Long]()

View file

@ -0,0 +1,747 @@
package spark.scheduler.cluster
import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
import spark._
import spark.scheduler._
import spark.TaskState.TaskState
import java.nio.ByteBuffer
private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
// process local is expected to be used ONLY within tasksetmanager for now.
val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
type TaskLocality = Value
def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
// Must not be the constraint.
assert (constraint != TaskLocality.PROCESS_LOCAL)
constraint match {
case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL
case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
// For anything else, allow
case _ => true
}
}
def parse(str: String): TaskLocality = {
// better way to do this ?
try {
val retval = TaskLocality.withName(str)
// Must not specify PROCESS_LOCAL !
assert (retval != TaskLocality.PROCESS_LOCAL)
retval
} catch {
case nEx: NoSuchElementException => {
logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL");
// default to preserve earlier behavior
NODE_LOCAL
}
}
}
}
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
private[spark] class ClusterTaskSetManager(
sched: ClusterScheduler,
val taskSet: TaskSet)
extends TaskSetManager
with Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
// Quantile of tasks at which to start speculation
val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
// Serializer for closures and tasks.
val ser = SparkEnv.get.closureSerializer.newInstance()
val tasks = taskSet.tasks
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
var tasksFinished = 0
var weight = 1
var minShare = 0
var runningTasks = 0
var priority = taskSet.priority
var stageId = taskSet.stageId
var name = "TaskSet_"+taskSet.stageId.toString
var parent:Schedulable = null
// Last time when we launched a preferred task (for delay scheduling)
var lastPreferredLaunchTime = System.currentTimeMillis
// List of pending tasks for each node (process local to container). These collections are actually
// treated as stacks, in which new tasks are added to the end of the
// ArrayBuffer and removed from the end. This makes it faster to detect
// tasks that repeatedly fail because whenever a task failed, it is put
// back at the head of the stack. They are also only cleaned up lazily;
// when a task is launched, it remains in all the pending lists except
// the one that it was launched from, but gets removed from them later.
private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
// List of pending tasks for each node.
// Essentially, similar to pendingTasksForHostPort, except at host level
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List of pending tasks for each node based on rack locality.
// Essentially, similar to pendingTasksForHost, except at rack level
private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List containing pending tasks with no locality preferences
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
// List containing all pending tasks (also used as a stack, as above)
val allPendingTasks = new ArrayBuffer[Int]
// Tasks that can be speculated. Since these will be a small fraction of total
// tasks, we'll just hold them in a HashSet.
val speculatableTasks = new HashSet[Int]
// Task index, start and finish time for each task attempt (indexed by task ID)
val taskInfos = new HashMap[Long, TaskInfo]
// Did the job fail?
var failed = false
var causeOfFailure = ""
// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
// Map of recent exceptions (identified by string representation and
// top stack frame) to duplicate count (how many times the same
// exception has appeared) and time the full exception was
// printed. This should ideally be an LRU map that can drop old
// exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
// Figure out the current map output tracker generation and set it on all tasks
val generation = sched.mapOutputTracker.getGeneration
logDebug("Generation for " + taskSet.id + ": " + generation)
for (t <- tasks) {
t.generation = generation
}
// Add all our tasks to the pending lists. We do this in reverse order
// of task index so that tasks with low indices get launched first.
for (i <- (0 until numTasks).reverse) {
addPendingTask(i)
}
// Note that it follows the hierarchy.
// if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
// if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler,
taskLocality: TaskLocality.TaskLocality): HashSet[String] = {
if (TaskLocality.PROCESS_LOCAL == taskLocality) {
// straight forward comparison ! Special case it.
val retval = new HashSet[String]()
scheduler.synchronized {
for (location <- _taskPreferredLocations) {
if (scheduler.isExecutorAliveOnHostPort(location)) {
retval += location
}
}
}
return retval
}
val taskPreferredLocations =
if (TaskLocality.NODE_LOCAL == taskLocality) {
_taskPreferredLocations
} else {
assert (TaskLocality.RACK_LOCAL == taskLocality)
// Expand set to include all 'seen' rack local hosts.
// This works since container allocation/management happens within master - so any rack locality information is updated in msater.
// Best case effort, and maybe sort of kludge for now ... rework it later ?
val hosts = new HashSet[String]
_taskPreferredLocations.foreach(h => {
val rackOpt = scheduler.getRackForHost(h)
if (rackOpt.isDefined) {
val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
if (hostsOpt.isDefined) {
hosts ++= hostsOpt.get
}
}
// Ensure that irrespective of what scheduler says, host is always added !
hosts += h
})
hosts
}
val retval = new HashSet[String]
scheduler.synchronized {
for (prefLocation <- taskPreferredLocations) {
val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
if (aliveLocationsOpt.isDefined) {
retval ++= aliveLocationsOpt.get
}
}
}
retval
}
// Add a task to all the pending-task lists that it should be on.
private def addPendingTask(index: Int) {
// We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
// hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL)
val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
if (rackLocalLocations.size == 0) {
// Current impl ensures this.
assert (processLocalLocations.size == 0)
assert (hostLocalLocations.size == 0)
pendingTasksWithNoPrefs += index
} else {
// process local locality
for (hostPort <- processLocalLocations) {
// DEBUG Code
Utils.checkHostPort(hostPort)
val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
hostPortList += index
}
// host locality (includes process local)
for (hostPort <- hostLocalLocations) {
// DEBUG Code
Utils.checkHostPort(hostPort)
val host = Utils.parseHostPort(hostPort)._1
val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
hostList += index
}
// rack locality (includes process local and host local)
for (rackLocalHostPort <- rackLocalLocations) {
// DEBUG Code
Utils.checkHostPort(rackLocalHostPort)
val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
list += index
}
}
allPendingTasks += index
}
// Return the pending tasks list for a given host port (process local), or an empty list if
// there is no map entry for that host
private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
// DEBUG Code
Utils.checkHostPort(hostPort)
pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
}
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
val host = Utils.parseHostPort(hostPort)._1
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
// Return the pending tasks (rack level) list for a given host, or an empty list if
// there is no map entry for that host
private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
val host = Utils.parseHostPort(hostPort)._1
pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
}
// Number of pending tasks for a given host Port (which would be process local)
def numPendingTasksForHostPort(hostPort: String): Int = {
getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Number of pending tasks for a given host (which would be data local)
def numPendingTasksForHost(hostPort: String): Int = {
getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Number of pending rack local tasks for a given host
def numRackLocalPendingTasksForHost(hostPort: String): Int = {
getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Dequeue a pending task from the given list and return its index.
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily.
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
if (copiesRunning(index) == 0 && !finished(index)) {
return Some(index)
}
}
return None
}
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if locality is set, the
// task must have a preference for this host/rack/no preferred locations at all.
private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
if (speculatableTasks.size > 0) {
val localTask = speculatableTasks.find {
index =>
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
val attemptLocs = taskAttempts(index).map(_.hostPort)
(locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
}
if (localTask != None) {
speculatableTasks -= localTask.get
return localTask
}
// check for rack locality
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
val rackTask = speculatableTasks.find {
index =>
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
val attemptLocs = taskAttempts(index).map(_.hostPort)
locations.contains(hostPort) && !attemptLocs.contains(hostPort)
}
if (rackTask != None) {
speculatableTasks -= rackTask.get
return rackTask
}
}
// Any task ...
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
// Check for attemptLocs also ?
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
if (nonLocalTask != None) {
speculatableTasks -= nonLocalTask.get
return nonLocalTask
}
}
}
return None
}
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
if (processLocalTask != None) {
return processLocalTask
}
val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
if (localTask != None) {
return localTask
}
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
if (rackLocalTask != None) {
return rackLocalTask
}
}
// Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
// TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
if (noPrefTask != None) {
return noPrefTask
}
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
val nonLocalTask = findTaskFromList(allPendingTasks)
if (nonLocalTask != None) {
return nonLocalTask
}
}
// Finally, if all else has failed, find a speculative task
return findSpeculativeTask(hostPort, locality)
}
private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
Utils.checkHostPort(hostPort)
val locs = task.preferredLocations
locs.contains(hostPort)
}
private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
// If no preference, consider it as host local
if (locs.isEmpty) return true
val host = Utils.parseHostPort(hostPort)._1
locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
}
// Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
// This is true if either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
val preferredRacks = new HashSet[String]()
for (preferredHost <- locs) {
val rack = sched.getRackForHost(preferredHost)
if (None != rack) preferredRacks += rack.get
}
if (preferredRacks.isEmpty) return false
val hostRack = sched.getRackForHost(hostPort)
return None != hostRack && preferredRacks.contains(hostRack.get)
}
// Respond to an offer of a single slave from the scheduler by finding a task
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
// If explicitly specified, use that
val locality = if (overrideLocality != null) overrideLocality else {
// expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
val time = System.currentTimeMillis
if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY
}
findTask(hostPort, locality) match {
case Some(index) => {
// Found a task; do some bookkeeping and return a Mesos task for it
val task = tasks(index)
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val taskLocality =
if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else
if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else
if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else
TaskLocality.ANY
val prefStr = taskLocality.toString
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
taskSet.id, index, taskId, execId, hostPort, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
val time = System.currentTimeMillis
val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (TaskLocality.NODE_LOCAL == taskLocality) {
lastPreferredLaunchTime = time
}
// Serialize and return the task
val startTime = System.currentTimeMillis
val serializedTask = Task.serializeWithDependencies(
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
val timeTaken = System.currentTimeMillis - startTime
increaseRunningTasks(1)
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
}
case _ =>
}
}
return None
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
state match {
case TaskState.FINISHED =>
taskFinished(tid, state, serializedData)
case TaskState.LOST =>
taskLost(tid, state, serializedData)
case TaskState.FAILED =>
taskLost(tid, state, serializedData)
case TaskState.KILLED =>
taskLost(tid, state, serializedData)
case _ =>
}
}
def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
if (info.failed) {
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
// or even from Mesos itself when acks get delayed.
return
}
val index = info.index
info.markSuccessful()
decreaseRunningTasks(1)
if (!finished(index)) {
tasksFinished += 1
logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
tid, info.duration, tasksFinished, numTasks))
// Deserialize task result and pass it to the scheduler
try {
val result = ser.deserialize[TaskResult[_]](serializedData)
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
} catch {
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread().getContextClassLoader
throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
case ex => throw ex
}
// Mark finished and stop if we've finished all the tasks
finished(index) = true
if (tasksFinished == numTasks) {
sched.taskSetFinished(this)
}
} else {
logInfo("Ignoring task-finished event for TID " + tid +
" because task " + index + " is already finished")
}
}
def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
if (info.failed) {
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
// or even from Mesos itself when acks get delayed.
return
}
val index = info.index
info.markFailed()
decreaseRunningTasks(1)
if (!finished(index)) {
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (serializedData != null && serializedData.limit() > 0) {
val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
finished(index) = true
tasksFinished += 1
sched.taskSetFinished(this)
decreaseRunningTasks(runningTasks)
return
case taskResultTooBig: TaskResultTooBigFailure =>
logInfo("Loss was due to task %s result exceeding Akka frame size; " +
"aborting job".format(tid))
abort("Task %s result exceeded Akka frame size".format(tid))
return
case ef: ExceptionFailure =>
val key = ef.description
val now = System.currentTimeMillis
val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
} else {
recentExceptions(key) = (dupCount + 1, printTime)
(false, dupCount + 1)
}
} else {
recentExceptions(key) = (0, now)
(true, 0)
}
}
if (printFull) {
val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
logInfo("Loss was due to %s\n%s\n%s".format(
ef.className, ef.description, locs.mkString("\n")))
} else {
logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
}
case _ => {}
}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
// Count failed attempts only on FAILED and LOST state (not on KILLED)
if (state == TaskState.FAILED || state == TaskState.LOST) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
logInfo("Ignoring task-lost event for TID " + tid +
" because task " + index + " is already finished")
}
}
def error(message: String) {
// Save the error message
abort("Error: " + message)
}
def abort(message: String) {
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.listener.taskSetFailed(taskSet, message)
decreaseRunningTasks(runningTasks)
sched.taskSetFinished(this)
}
override def increaseRunningTasks(taskNum: Int) {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
override def decreaseRunningTasks(taskNum: Int) {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
}
}
//TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed
override def getSchedulableByName(name: String): Schedulable = {
return null
}
override def addSchedulable(schedulable:Schedulable) {
//nothing
}
override def removeSchedulable(schedulable:Schedulable) {
//nothing
}
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
override def executorLost(execId: String, hostPort: String) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
// If some task has preferred locations only on hostname, and there are no more executors there,
// put it in the no-prefs list to avoid the wait from delay scheduling
// host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to
// no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc.
// Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if
// there is no host local node for the task (not if there is no process local node for the task)
for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
// val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
if (newLocs.isEmpty) {
pendingTasksWithNoPrefs += index
}
}
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (finished(index)) {
finished(index) = false
copiesRunning(index) -= 1
tasksFinished -= 1
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
taskLost(tid, TaskState.KILLED, null)
}
}
/**
* Check for tasks to be speculated and return true if there are any. This is called periodically
* by the ClusterScheduler.
*
* TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
* we don't scan the whole task set. It might also help to make this sorted by launch time.
*/
override def checkSpeculatableTasks(): Boolean = {
// Can't speculate if we only have one task, or if all tasks have finished.
if (numTasks == 1 || tasksFinished == numTasks) {
return false
}
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
if (tasksFinished >= minFinishedForSpeculation) {
val time = System.currentTimeMillis()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
logDebug("Task length threshold for speculation: " + threshold)
for ((tid, info) <- taskInfos) {
val index = info.index
if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
taskSet.id, index, info.hostPort, threshold))
speculatableTasks += index
foundTasks = true
}
}
}
return foundTasks
}
override def hasPendingTasks(): Boolean = {
numTasks > 0 && tasksFinished < numTasks
}
}

View file

@ -1,747 +1,17 @@
package spark.scheduler.cluster
import java.util.{HashMap => JHashMap, NoSuchElementException, Arrays}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.max
import scala.math.min
import spark._
import spark.scheduler._
import spark.TaskState.TaskState
import java.nio.ByteBuffer
private[spark] object TaskLocality extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") with Logging {
// process local is expected to be used ONLY within tasksetmanager for now.
val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
type TaskLocality = Value
def isAllowed(constraint: TaskLocality, condition: TaskLocality): Boolean = {
// Must not be the constraint.
assert (constraint != TaskLocality.PROCESS_LOCAL)
constraint match {
case TaskLocality.NODE_LOCAL => condition == TaskLocality.NODE_LOCAL
case TaskLocality.RACK_LOCAL => condition == TaskLocality.NODE_LOCAL || condition == TaskLocality.RACK_LOCAL
// For anything else, allow
case _ => true
}
}
def parse(str: String): TaskLocality = {
// better way to do this ?
try {
val retval = TaskLocality.withName(str)
// Must not specify PROCESS_LOCAL !
assert (retval != TaskLocality.PROCESS_LOCAL)
retval
} catch {
case nEx: NoSuchElementException => {
logWarning("Invalid task locality specified '" + str + "', defaulting to NODE_LOCAL");
// default to preserve earlier behavior
NODE_LOCAL
}
}
}
}
/**
* Schedules the tasks within a single TaskSet in the ClusterScheduler.
*/
private[spark] class TaskSetManager(
sched: ClusterScheduler,
val taskSet: TaskSet)
extends Schedulable
with Logging {
// Maximum time to wait to run a task in a preferred location (in ms)
val LOCALITY_WAIT = System.getProperty("spark.locality.wait", "3000").toLong
// CPUs to request per task
val CPUS_PER_TASK = System.getProperty("spark.task.cpus", "1").toDouble
// Maximum times a task is allowed to fail before failing the job
val MAX_TASK_FAILURES = 4
// Quantile of tasks at which to start speculation
val SPECULATION_QUANTILE = System.getProperty("spark.speculation.quantile", "0.75").toDouble
val SPECULATION_MULTIPLIER = System.getProperty("spark.speculation.multiplier", "1.5").toDouble
// Serializer for closures and tasks.
val ser = SparkEnv.get.closureSerializer.newInstance()
val tasks = taskSet.tasks
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
var tasksFinished = 0
var weight = 1
var minShare = 0
var runningTasks = 0
var priority = taskSet.priority
var stageId = taskSet.stageId
var name = "TaskSet_"+taskSet.stageId.toString
var parent:Schedulable = null
// Last time when we launched a preferred task (for delay scheduling)
var lastPreferredLaunchTime = System.currentTimeMillis
// List of pending tasks for each node (process local to container). These collections are actually
// treated as stacks, in which new tasks are added to the end of the
// ArrayBuffer and removed from the end. This makes it faster to detect
// tasks that repeatedly fail because whenever a task failed, it is put
// back at the head of the stack. They are also only cleaned up lazily;
// when a task is launched, it remains in all the pending lists except
// the one that it was launched from, but gets removed from them later.
private val pendingTasksForHostPort = new HashMap[String, ArrayBuffer[Int]]
// List of pending tasks for each node.
// Essentially, similar to pendingTasksForHostPort, except at host level
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List of pending tasks for each node based on rack locality.
// Essentially, similar to pendingTasksForHost, except at rack level
private val pendingRackLocalTasksForHost = new HashMap[String, ArrayBuffer[Int]]
// List containing pending tasks with no locality preferences
val pendingTasksWithNoPrefs = new ArrayBuffer[Int]
// List containing all pending tasks (also used as a stack, as above)
val allPendingTasks = new ArrayBuffer[Int]
// Tasks that can be speculated. Since these will be a small fraction of total
// tasks, we'll just hold them in a HashSet.
val speculatableTasks = new HashSet[Int]
// Task index, start and finish time for each task attempt (indexed by task ID)
val taskInfos = new HashMap[Long, TaskInfo]
// Did the job fail?
var failed = false
var causeOfFailure = ""
// How frequently to reprint duplicate exceptions in full, in milliseconds
val EXCEPTION_PRINT_INTERVAL =
System.getProperty("spark.logging.exceptionPrintInterval", "10000").toLong
// Map of recent exceptions (identified by string representation and
// top stack frame) to duplicate count (how many times the same
// exception has appeared) and time the full exception was
// printed. This should ideally be an LRU map that can drop old
// exceptions automatically.
val recentExceptions = HashMap[String, (Int, Long)]()
// Figure out the current map output tracker generation and set it on all tasks
val generation = sched.mapOutputTracker.getGeneration
logDebug("Generation for " + taskSet.id + ": " + generation)
for (t <- tasks) {
t.generation = generation
}
// Add all our tasks to the pending lists. We do this in reverse order
// of task index so that tasks with low indices get launched first.
for (i <- (0 until numTasks).reverse) {
addPendingTask(i)
}
// Note that it follows the hierarchy.
// if we search for NODE_LOCAL, the output will include PROCESS_LOCAL and
// if we search for RACK_LOCAL, it will include PROCESS_LOCAL & NODE_LOCAL
private def findPreferredLocations(_taskPreferredLocations: Seq[String], scheduler: ClusterScheduler,
taskLocality: TaskLocality.TaskLocality): HashSet[String] = {
if (TaskLocality.PROCESS_LOCAL == taskLocality) {
// straight forward comparison ! Special case it.
val retval = new HashSet[String]()
scheduler.synchronized {
for (location <- _taskPreferredLocations) {
if (scheduler.isExecutorAliveOnHostPort(location)) {
retval += location
}
}
}
return retval
}
val taskPreferredLocations =
if (TaskLocality.NODE_LOCAL == taskLocality) {
_taskPreferredLocations
} else {
assert (TaskLocality.RACK_LOCAL == taskLocality)
// Expand set to include all 'seen' rack local hosts.
// This works since container allocation/management happens within master - so any rack locality information is updated in msater.
// Best case effort, and maybe sort of kludge for now ... rework it later ?
val hosts = new HashSet[String]
_taskPreferredLocations.foreach(h => {
val rackOpt = scheduler.getRackForHost(h)
if (rackOpt.isDefined) {
val hostsOpt = scheduler.getCachedHostsForRack(rackOpt.get)
if (hostsOpt.isDefined) {
hosts ++= hostsOpt.get
}
}
// Ensure that irrespective of what scheduler says, host is always added !
hosts += h
})
hosts
}
val retval = new HashSet[String]
scheduler.synchronized {
for (prefLocation <- taskPreferredLocations) {
val aliveLocationsOpt = scheduler.getExecutorsAliveOnHost(Utils.parseHostPort(prefLocation)._1)
if (aliveLocationsOpt.isDefined) {
retval ++= aliveLocationsOpt.get
}
}
}
retval
}
// Add a task to all the pending-task lists that it should be on.
private def addPendingTask(index: Int) {
// We can infer hostLocalLocations from rackLocalLocations by joining it against tasks(index).preferredLocations (with appropriate
// hostPort <-> host conversion). But not doing it for simplicity sake. If this becomes a performance issue, modify it.
val processLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.PROCESS_LOCAL)
val hostLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
val rackLocalLocations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
if (rackLocalLocations.size == 0) {
// Current impl ensures this.
assert (processLocalLocations.size == 0)
assert (hostLocalLocations.size == 0)
pendingTasksWithNoPrefs += index
} else {
// process local locality
for (hostPort <- processLocalLocations) {
// DEBUG Code
Utils.checkHostPort(hostPort)
val hostPortList = pendingTasksForHostPort.getOrElseUpdate(hostPort, ArrayBuffer())
hostPortList += index
}
// host locality (includes process local)
for (hostPort <- hostLocalLocations) {
// DEBUG Code
Utils.checkHostPort(hostPort)
val host = Utils.parseHostPort(hostPort)._1
val hostList = pendingTasksForHost.getOrElseUpdate(host, ArrayBuffer())
hostList += index
}
// rack locality (includes process local and host local)
for (rackLocalHostPort <- rackLocalLocations) {
// DEBUG Code
Utils.checkHostPort(rackLocalHostPort)
val rackLocalHost = Utils.parseHostPort(rackLocalHostPort)._1
val list = pendingRackLocalTasksForHost.getOrElseUpdate(rackLocalHost, ArrayBuffer())
list += index
}
}
allPendingTasks += index
}
// Return the pending tasks list for a given host port (process local), or an empty list if
// there is no map entry for that host
private def getPendingTasksForHostPort(hostPort: String): ArrayBuffer[Int] = {
// DEBUG Code
Utils.checkHostPort(hostPort)
pendingTasksForHostPort.getOrElse(hostPort, ArrayBuffer())
}
// Return the pending tasks list for a given host, or an empty list if
// there is no map entry for that host
private def getPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
val host = Utils.parseHostPort(hostPort)._1
pendingTasksForHost.getOrElse(host, ArrayBuffer())
}
// Return the pending tasks (rack level) list for a given host, or an empty list if
// there is no map entry for that host
private def getRackLocalPendingTasksForHost(hostPort: String): ArrayBuffer[Int] = {
val host = Utils.parseHostPort(hostPort)._1
pendingRackLocalTasksForHost.getOrElse(host, ArrayBuffer())
}
// Number of pending tasks for a given host Port (which would be process local)
def numPendingTasksForHostPort(hostPort: String): Int = {
getPendingTasksForHostPort(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Number of pending tasks for a given host (which would be data local)
def numPendingTasksForHost(hostPort: String): Int = {
getPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Number of pending rack local tasks for a given host
def numRackLocalPendingTasksForHost(hostPort: String): Int = {
getRackLocalPendingTasksForHost(hostPort).count( index => copiesRunning(index) == 0 && !finished(index) )
}
// Dequeue a pending task from the given list and return its index.
// Return None if the list is empty.
// This method also cleans up any tasks in the list that have already
// been launched, since we want that to happen lazily.
private def findTaskFromList(list: ArrayBuffer[Int]): Option[Int] = {
while (!list.isEmpty) {
val index = list.last
list.trimEnd(1)
if (copiesRunning(index) == 0 && !finished(index)) {
return Some(index)
}
}
return None
}
// Return a speculative task for a given host if any are available. The task should not have an
// attempt running on this host, in case the host is slow. In addition, if locality is set, the
// task must have a preference for this host/rack/no preferred locations at all.
private def findSpeculativeTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
assert (TaskLocality.isAllowed(locality, TaskLocality.NODE_LOCAL))
speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
if (speculatableTasks.size > 0) {
val localTask = speculatableTasks.find {
index =>
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
val attemptLocs = taskAttempts(index).map(_.hostPort)
(locations.size == 0 || locations.contains(hostPort)) && !attemptLocs.contains(hostPort)
}
if (localTask != None) {
speculatableTasks -= localTask.get
return localTask
}
// check for rack locality
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
val rackTask = speculatableTasks.find {
index =>
val locations = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
val attemptLocs = taskAttempts(index).map(_.hostPort)
locations.contains(hostPort) && !attemptLocs.contains(hostPort)
}
if (rackTask != None) {
speculatableTasks -= rackTask.get
return rackTask
}
}
// Any task ...
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
// Check for attemptLocs also ?
val nonLocalTask = speculatableTasks.find(i => !taskAttempts(i).map(_.hostPort).contains(hostPort))
if (nonLocalTask != None) {
speculatableTasks -= nonLocalTask.get
return nonLocalTask
}
}
}
return None
}
// Dequeue a pending task for a given node and return its index.
// If localOnly is set to false, allow non-local tasks as well.
private def findTask(hostPort: String, locality: TaskLocality.TaskLocality): Option[Int] = {
val processLocalTask = findTaskFromList(getPendingTasksForHostPort(hostPort))
if (processLocalTask != None) {
return processLocalTask
}
val localTask = findTaskFromList(getPendingTasksForHost(hostPort))
if (localTask != None) {
return localTask
}
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
val rackLocalTask = findTaskFromList(getRackLocalPendingTasksForHost(hostPort))
if (rackLocalTask != None) {
return rackLocalTask
}
}
// Look for no pref tasks AFTER rack local tasks - this has side effect that we will get to failed tasks later rather than sooner.
// TODO: That code path needs to be revisited (adding to no prefs list when host:port goes down).
val noPrefTask = findTaskFromList(pendingTasksWithNoPrefs)
if (noPrefTask != None) {
return noPrefTask
}
if (TaskLocality.isAllowed(locality, TaskLocality.ANY)) {
val nonLocalTask = findTaskFromList(allPendingTasks)
if (nonLocalTask != None) {
return nonLocalTask
}
}
// Finally, if all else has failed, find a speculative task
return findSpeculativeTask(hostPort, locality)
}
private def isProcessLocalLocation(task: Task[_], hostPort: String): Boolean = {
Utils.checkHostPort(hostPort)
val locs = task.preferredLocations
locs.contains(hostPort)
}
private def isHostLocalLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
// If no preference, consider it as host local
if (locs.isEmpty) return true
val host = Utils.parseHostPort(hostPort)._1
locs.find(h => Utils.parseHostPort(h)._1 == host).isDefined
}
// Does a host count as a rack local preferred location for a task? (assumes host is NOT preferred location).
// This is true if either the task has preferred locations and this host is one, or it has
// no preferred locations (in which we still count the launch as preferred).
private def isRackLocalLocation(task: Task[_], hostPort: String): Boolean = {
val locs = task.preferredLocations
val preferredRacks = new HashSet[String]()
for (preferredHost <- locs) {
val rack = sched.getRackForHost(preferredHost)
if (None != rack) preferredRacks += rack.get
}
if (preferredRacks.isEmpty) return false
val hostRack = sched.getRackForHost(hostPort)
return None != hostRack && preferredRacks.contains(hostRack.get)
}
// Respond to an offer of a single slave from the scheduler by finding a task
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
// If explicitly specified, use that
val locality = if (overrideLocality != null) overrideLocality else {
// expand only if we have waited for more than LOCALITY_WAIT for a host local task ...
val time = System.currentTimeMillis
if (time - lastPreferredLaunchTime < LOCALITY_WAIT) TaskLocality.NODE_LOCAL else TaskLocality.ANY
}
findTask(hostPort, locality) match {
case Some(index) => {
// Found a task; do some bookkeeping and return a Mesos task for it
val task = tasks(index)
val taskId = sched.newTaskId()
// Figure out whether this should count as a preferred launch
val taskLocality =
if (isProcessLocalLocation(task, hostPort)) TaskLocality.PROCESS_LOCAL else
if (isHostLocalLocation(task, hostPort)) TaskLocality.NODE_LOCAL else
if (isRackLocalLocation(task, hostPort)) TaskLocality.RACK_LOCAL else
TaskLocality.ANY
val prefStr = taskLocality.toString
logInfo("Starting task %s:%d as TID %s on slave %s: %s (%s)".format(
taskSet.id, index, taskId, execId, hostPort, prefStr))
// Do various bookkeeping
copiesRunning(index) += 1
val time = System.currentTimeMillis
val info = new TaskInfo(taskId, index, time, execId, hostPort, taskLocality)
taskInfos(taskId) = info
taskAttempts(index) = info :: taskAttempts(index)
if (TaskLocality.NODE_LOCAL == taskLocality) {
lastPreferredLaunchTime = time
}
// Serialize and return the task
val startTime = System.currentTimeMillis
val serializedTask = Task.serializeWithDependencies(
task, sched.sc.addedFiles, sched.sc.addedJars, ser)
val timeTaken = System.currentTimeMillis - startTime
increaseRunningTasks(1)
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
return Some(new TaskDescription(taskId, execId, taskName, serializedTask))
}
case _ =>
}
}
return None
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
state match {
case TaskState.FINISHED =>
taskFinished(tid, state, serializedData)
case TaskState.LOST =>
taskLost(tid, state, serializedData)
case TaskState.FAILED =>
taskLost(tid, state, serializedData)
case TaskState.KILLED =>
taskLost(tid, state, serializedData)
case _ =>
}
}
def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
if (info.failed) {
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
// or even from Mesos itself when acks get delayed.
return
}
val index = info.index
info.markSuccessful()
decreaseRunningTasks(1)
if (!finished(index)) {
tasksFinished += 1
logInfo("Finished TID %s in %d ms (progress: %d/%d)".format(
tid, info.duration, tasksFinished, numTasks))
// Deserialize task result and pass it to the scheduler
try {
val result = ser.deserialize[TaskResult[_]](serializedData)
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
} catch {
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread().getContextClassLoader
throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
case ex => throw ex
}
// Mark finished and stop if we've finished all the tasks
finished(index) = true
if (tasksFinished == numTasks) {
sched.taskSetFinished(this)
}
} else {
logInfo("Ignoring task-finished event for TID " + tid +
" because task " + index + " is already finished")
}
}
def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
if (info.failed) {
// We might get two task-lost messages for the same task in coarse-grained Mesos mode,
// or even from Mesos itself when acks get delayed.
return
}
val index = info.index
info.markFailed()
decreaseRunningTasks(1)
if (!finished(index)) {
logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
copiesRunning(index) -= 1
// Check if the problem is a map output fetch failure. In that case, this
// task will never succeed on any node, so tell the scheduler about it.
if (serializedData != null && serializedData.limit() > 0) {
val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
reason match {
case fetchFailed: FetchFailed =>
logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
finished(index) = true
tasksFinished += 1
sched.taskSetFinished(this)
decreaseRunningTasks(runningTasks)
return
case taskResultTooBig: TaskResultTooBigFailure =>
logInfo("Loss was due to task %s result exceeding Akka frame size;" +
"aborting job".format(tid))
abort("Task %s result exceeded Akka frame size".format(tid))
return
case ef: ExceptionFailure =>
val key = ef.description
val now = System.currentTimeMillis
val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
recentExceptions(key) = (0, now)
(true, 0)
} else {
recentExceptions(key) = (dupCount + 1, printTime)
(false, dupCount + 1)
}
} else {
recentExceptions(key) = (0, now)
(true, 0)
}
}
if (printFull) {
val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
logInfo("Loss was due to %s\n%s\n%s".format(
ef.className, ef.description, locs.mkString("\n")))
} else {
logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
}
case _ => {}
}
}
// On non-fetch failures, re-enqueue the task as pending for a max number of retries
addPendingTask(index)
// Count failed attempts only on FAILED and LOST state (not on KILLED)
if (state == TaskState.FAILED || state == TaskState.LOST) {
numFailures(index) += 1
if (numFailures(index) > MAX_TASK_FAILURES) {
logError("Task %s:%d failed more than %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
logInfo("Ignoring task-lost event for TID " + tid +
" because task " + index + " is already finished")
}
}
def error(message: String) {
// Save the error message
abort("Error: " + message)
}
def abort(message: String) {
failed = true
causeOfFailure = message
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.listener.taskSetFailed(taskSet, message)
decreaseRunningTasks(runningTasks)
sched.taskSetFinished(this)
}
override def increaseRunningTasks(taskNum: Int) {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
override def decreaseRunningTasks(taskNum: Int) {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
}
}
//TODO: for now we just find Pool not TaskSetManager, we can extend this function in future if needed
override def getSchedulableByName(name: String): Schedulable = {
return null
}
override def addSchedulable(schedulable:Schedulable) {
//nothing
}
override def removeSchedulable(schedulable:Schedulable) {
//nothing
}
override def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
override def executorLost(execId: String, hostPort: String) {
logInfo("Re-queueing tasks for " + execId + " from TaskSet " + taskSet.id)
// If some task has preferred locations only on hostname, and there are no more executors there,
// put it in the no-prefs list to avoid the wait from delay scheduling
// host local tasks - should we push this to rack local or no pref list ? For now, preserving behavior and moving to
// no prefs list. Note, this was done due to impliations related to 'waiting' for data local tasks, etc.
// Note: NOT checking process local list - since host local list is super set of that. We need to ad to no prefs only if
// there is no host local node for the task (not if there is no process local node for the task)
for (index <- getPendingTasksForHost(Utils.parseHostPort(hostPort)._1)) {
// val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.RACK_LOCAL)
val newLocs = findPreferredLocations(tasks(index).preferredLocations, sched, TaskLocality.NODE_LOCAL)
if (newLocs.isEmpty) {
pendingTasksWithNoPrefs += index
}
}
// Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage
if (tasks(0).isInstanceOf[ShuffleMapTask]) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (finished(index)) {
finished(index) = false
copiesRunning(index) -= 1
tasksFinished -= 1
addPendingTask(index)
// Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
// stage finishes when a total of tasks.size tasks finish.
sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
}
}
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
taskLost(tid, TaskState.KILLED, null)
}
}
/**
* Check for tasks to be speculated and return true if there are any. This is called periodically
* by the ClusterScheduler.
*
* TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that
* we don't scan the whole task set. It might also help to make this sorted by launch time.
*/
override def checkSpeculatableTasks(): Boolean = {
// Can't speculate if we only have one task, or if all tasks have finished.
if (numTasks == 1 || tasksFinished == numTasks) {
return false
}
var foundTasks = false
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
if (tasksFinished >= minFinishedForSpeculation) {
val time = System.currentTimeMillis()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * numTasks).round.toInt, durations.size - 1))
val threshold = max(SPECULATION_MULTIPLIER * medianDuration, 100)
// TODO: Threshold should also look at standard deviation of task durations and have a lower
// bound based on that.
logDebug("Task length threshold for speculation: " + threshold)
for ((tid, info) <- taskInfos) {
val index = info.index
if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
!speculatableTasks.contains(index)) {
logInfo(
"Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
taskSet.id, index, info.hostPort, threshold))
speculatableTasks += index
foundTasks = true
}
}
}
return foundTasks
}
override def hasPendingTasks(): Boolean = {
numTasks > 0 && tasksFinished < numTasks
}
private[spark] trait TaskSetManager extends Schedulable {
def taskSet: TaskSet
def slaveOffer(execId: String, hostPort: String, availableCpus: Double,
overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription]
def numPendingTasksForHostPort(hostPort: String): Int
def numRackLocalPendingTasksForHost(hostPort :String): Int
def numPendingTasksForHost(hostPort: String): Int
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
def error(message: String)
}

View file

@ -2,19 +2,50 @@ package spark.scheduler.local
import java.io.File
import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import spark._
import spark.TaskState.TaskState
import spark.executor.ExecutorURLClassLoader
import spark.scheduler._
import spark.scheduler.cluster.{TaskLocality, TaskInfo}
import spark.scheduler.cluster._
import akka.actor._
/**
* A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* A FIFO or Fair TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
* the scheduler also allows each task to fail up to maxFailures times, which is useful for
* testing fault recovery.
*/
private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkContext)
private[spark] case class LocalReviveOffers()
private[spark] case class LocalStatusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer)
private[spark] class LocalActor(localScheduler: LocalScheduler, var freeCores: Int) extends Actor with Logging {
def receive = {
case LocalReviveOffers =>
launchTask(localScheduler.resourceOffer(freeCores))
case LocalStatusUpdate(taskId, state, serializeData) =>
freeCores += 1
localScheduler.statusUpdate(taskId, state, serializeData)
launchTask(localScheduler.resourceOffer(freeCores))
}
def launchTask(tasks : Seq[TaskDescription]) {
for (task <- tasks) {
freeCores -= 1
localScheduler.threadPool.submit(new Runnable {
def run() {
localScheduler.runTask(task.taskId,task.serializedTask)
}
})
}
}
}
private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with Logging {
@ -30,89 +61,127 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
val classLoader = new ExecutorURLClassLoader(Array(), Thread.currentThread.getContextClassLoader)
// TODO: Need to take into account stage priority in scheduling
var schedulableBuilder: SchedulableBuilder = null
var rootPool: Pool = null
val activeTaskSets = new HashMap[String, TaskSetManager]
val taskIdToTaskSetId = new HashMap[Long, String]
val taskSetTaskIds = new HashMap[String, HashSet[Long]]
override def start() { }
var localActor: ActorRef = null
override def start() {
//default scheduler is FIFO
val schedulingMode = System.getProperty("spark.cluster.schedulingmode", "FIFO")
//temporarily set rootPool name to empty
rootPool = new Pool("", SchedulingMode.withName(schedulingMode), 0, 0)
schedulableBuilder = {
schedulingMode match {
case "FIFO" =>
new FIFOSchedulableBuilder(rootPool)
case "FAIR" =>
new FairSchedulableBuilder(rootPool)
}
}
schedulableBuilder.buildPools()
localActor = env.actorSystem.actorOf(Props(new LocalActor(this, threads)), "Test")
}
override def setListener(listener: TaskSchedulerListener) {
this.listener = listener
}
override def submitTasks(taskSet: TaskSet) {
val tasks = taskSet.tasks
val failCount = new Array[Int](tasks.size)
def submitTask(task: Task[_], idInJob: Int) {
val myAttemptId = attemptId.getAndIncrement()
threadPool.submit(new Runnable {
def run() {
runTask(task, idInJob, myAttemptId)
}
})
synchronized {
var manager = new LocalTaskSetManager(this, taskSet)
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
activeTaskSets(taskSet.id) = manager
taskSetTaskIds(taskSet.id) = new HashSet[Long]()
localActor ! LocalReviveOffers
}
}
def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
logInfo("Running " + task)
val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
try {
Accumulators.clear()
Thread.currentThread().setContextClassLoader(classLoader)
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
val ser = SparkEnv.get.closureSerializer.newInstance()
val bytes = Task.serializeWithDependencies(task, sc.addedFiles, sc.addedJars, ser)
logInfo("Size of task " + idInJob + " is " + bytes.limit + " bytes")
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
val deserStart = System.currentTimeMillis()
val deserializedTask = ser.deserialize[Task[_]](
taskBytes, Thread.currentThread.getContextClassLoader)
val deserTime = System.currentTimeMillis() - deserStart
// Run it
val result: Any = deserializedTask.run(attemptId)
// Serialize and deserialize the result to emulate what the Mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val serResult = ser.serialize(result)
deserializedTask.metrics.get.resultSize = serResult.limit()
val resultToReturn = ser.deserialize[Any](serResult)
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished " + task)
info.markSuccessful()
deserializedTask.metrics.get.executorRunTime = info.duration.toInt //close enough
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
// If the threadpool has not already been shutdown, notify DAGScheduler
if (!Thread.currentThread().isInterrupted)
listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, deserializedTask.metrics.getOrElse(null))
} catch {
case t: Throwable => {
logError("Exception in task " + idInJob, t)
failCount.synchronized {
failCount(idInJob) += 1
if (failCount(idInJob) <= maxFailures) {
submitTask(task, idInJob)
} else {
// TODO: Do something nicer here to return all the way to the user
if (!Thread.currentThread().isInterrupted) {
val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
listener.taskEnded(task, failure, null, null, info, null)
}
}
}
}
def resourceOffer(freeCores: Int): Seq[TaskDescription] = {
synchronized {
var freeCpuCores = freeCores
val tasks = new ArrayBuffer[TaskDescription](freeCores)
val sortedTaskSetQueue = rootPool.getSortedTaskSetQueue()
for (manager <- sortedTaskSetQueue) {
logDebug("parentName:%s,name:%s,runningTasks:%s".format(manager.parent.name, manager.name, manager.runningTasks))
}
}
for ((task, i) <- tasks.zipWithIndex) {
submitTask(task, i)
var launchTask = false
for (manager <- sortedTaskSetQueue) {
do {
launchTask = false
manager.slaveOffer(null,null,freeCpuCores) match {
case Some(task) =>
tasks += task
taskIdToTaskSetId(task.taskId) = manager.taskSet.id
taskSetTaskIds(manager.taskSet.id) += task.taskId
freeCpuCores -= 1
launchTask = true
case None => {}
}
} while(launchTask)
}
return tasks
}
}
def taskSetFinished(manager: TaskSetManager) {
synchronized {
activeTaskSets -= manager.taskSet.id
manager.parent.removeSchedulable(manager)
logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
taskSetTaskIds -= manager.taskSet.id
}
}
def runTask(taskId: Long, bytes: ByteBuffer) {
logInfo("Running " + taskId)
val info = new TaskInfo(taskId, 0, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
// Set the Spark execution environment for the worker thread
SparkEnv.set(env)
val ser = SparkEnv.get.closureSerializer.newInstance()
try {
Accumulators.clear()
Thread.currentThread().setContextClassLoader(classLoader)
// Serialize and deserialize the task so that accumulators are changed to thread-local ones;
// this adds a bit of unnecessary overhead but matches how the Mesos Executor works.
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(bytes)
updateDependencies(taskFiles, taskJars) // Download any files added with addFile
val deserStart = System.currentTimeMillis()
val deserializedTask = ser.deserialize[Task[_]](
taskBytes, Thread.currentThread.getContextClassLoader)
val deserTime = System.currentTimeMillis() - deserStart
// Run it
val result: Any = deserializedTask.run(taskId)
// Serialize and deserialize the result to emulate what the Mesos
// executor does. This is useful to catch serialization errors early
// on in development (so when users move their local Spark programs
// to the cluster, they don't get surprised by serialization errors).
val serResult = ser.serialize(result)
deserializedTask.metrics.get.resultSize = serResult.limit()
val resultToReturn = ser.deserialize[Any](serResult)
val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
ser.serialize(Accumulators.values))
logInfo("Finished " + taskId)
deserializedTask.metrics.get.executorRunTime = deserTime.toInt//info.duration.toInt //close enough
deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
val serializedResult = ser.serialize(taskResult)
localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
} catch {
case t: Throwable => {
val failure = new ExceptionFailure(t.getClass.getName, t.toString, t.getStackTrace)
localActor ! LocalStatusUpdate(taskId, TaskState.FAILED, ser.serialize(failure))
}
}
}
@ -128,6 +197,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
Utils.fetchFile(name, new File(SparkFiles.getRootDirectory))
@ -143,7 +213,16 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
}
}
override def stop() {
def statusUpdate(taskId :Long, state: TaskState, serializedData: ByteBuffer) {
synchronized {
val taskSetId = taskIdToTaskSetId(taskId)
val taskSetManager = activeTaskSets(taskSetId)
taskSetTaskIds(taskSetId) -= taskId
taskSetManager.statusUpdate(taskId, state, serializedData)
}
}
override def stop() {
threadPool.shutdownNow()
}

View file

@ -0,0 +1,172 @@
package spark.scheduler.local
import java.io.File
import java.util.concurrent.atomic.AtomicInteger
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import spark._
import spark.TaskState.TaskState
import spark.scheduler._
import spark.scheduler.cluster._
private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet) extends TaskSetManager with Logging {
var parent: Schedulable = null
var weight: Int = 1
var minShare: Int = 0
var runningTasks: Int = 0
var priority: Int = taskSet.priority
var stageId: Int = taskSet.stageId
var name: String = "TaskSet_"+taskSet.stageId.toString
var failCount = new Array[Int](taskSet.tasks.size)
val taskInfos = new HashMap[Long, TaskInfo]
val numTasks = taskSet.tasks.size
var numFinished = 0
val ser = SparkEnv.get.closureSerializer.newInstance()
val copiesRunning = new Array[Int](numTasks)
val finished = new Array[Boolean](numTasks)
val numFailures = new Array[Int](numTasks)
val MAX_TASK_FAILURES = sched.maxFailures
def increaseRunningTasks(taskNum: Int): Unit = {
runningTasks += taskNum
if (parent != null) {
parent.increaseRunningTasks(taskNum)
}
}
def decreaseRunningTasks(taskNum: Int): Unit = {
runningTasks -= taskNum
if (parent != null) {
parent.decreaseRunningTasks(taskNum)
}
}
def addSchedulable(schedulable: Schedulable): Unit = {
//nothing
}
def removeSchedulable(schedulable: Schedulable): Unit = {
//nothing
}
def getSchedulableByName(name: String): Schedulable = {
return null
}
def executorLost(executorId: String, host: String): Unit = {
//nothing
}
def checkSpeculatableTasks(): Boolean = {
return true
}
def getSortedTaskSetQueue(): ArrayBuffer[TaskSetManager] = {
var sortedTaskSetQueue = new ArrayBuffer[TaskSetManager]
sortedTaskSetQueue += this
return sortedTaskSetQueue
}
def hasPendingTasks(): Boolean = {
return true
}
def findTask(): Option[Int] = {
for (i <- 0 to numTasks-1) {
if (copiesRunning(i) == 0 && !finished(i)) {
return Some(i)
}
}
return None
}
def slaveOffer(execId: String, hostPort: String, availableCpus: Double, overrideLocality: TaskLocality.TaskLocality = null): Option[TaskDescription] = {
SparkEnv.set(sched.env)
logDebug("availableCpus:%d,numFinished:%d,numTasks:%d".format(availableCpus.toInt, numFinished, numTasks))
if (availableCpus > 0 && numFinished < numTasks) {
findTask() match {
case Some(index) =>
val taskId = sched.attemptId.getAndIncrement()
val task = taskSet.tasks(index)
val info = new TaskInfo(taskId, index, System.currentTimeMillis(), "local", "local:1", TaskLocality.NODE_LOCAL)
taskInfos(taskId) = info
val bytes = Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
logInfo("Size of task " + taskId + " is " + bytes.limit + " bytes")
val taskName = "task %s:%d".format(taskSet.id, index)
copiesRunning(index) += 1
increaseRunningTasks(1)
return Some(new TaskDescription(taskId, null, taskName, bytes))
case None => {}
}
}
return None
}
def numPendingTasksForHostPort(hostPort: String): Int = {
return 0
}
def numRackLocalPendingTasksForHost(hostPort :String): Int = {
return 0
}
def numPendingTasksForHost(hostPort: String): Int = {
return 0
}
def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
state match {
case TaskState.FINISHED =>
taskEnded(tid, state, serializedData)
case TaskState.FAILED =>
taskFailed(tid, state, serializedData)
case _ => {}
}
}
def taskEnded(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
val index = info.index
val task = taskSet.tasks(index)
info.markSuccessful()
val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
result.metrics.resultSize = serializedData.limit()
sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
numFinished += 1
decreaseRunningTasks(1)
finished(index) = true
if (numFinished == numTasks) {
sched.taskSetFinished(this)
}
}
def taskFailed(tid: Long, state: TaskState, serializedData: ByteBuffer) {
val info = taskInfos(tid)
val index = info.index
val task = taskSet.tasks(index)
info.markFailed()
decreaseRunningTasks(1)
val reason: ExceptionFailure = ser.deserialize[ExceptionFailure](serializedData, getClass.getClassLoader)
if (!finished(index)) {
copiesRunning(index) -= 1
numFailures(index) += 1
val locs = reason.stackTrace.map(loc => "\tat %s".format(loc.toString))
logInfo("Loss was due to %s\n%s\n%s".format(reason.className, reason.description, locs.mkString("\n")))
if (numFailures(index) > MAX_TASK_FAILURES) {
val errorMessage = "Task %s:%d failed more than %d times; aborting job %s".format(taskSet.id, index, 4, reason.description)
decreaseRunningTasks(runningTasks)
sched.listener.taskSetFailed(taskSet, errorMessage)
// need to delete failed Taskset from schedule queue
sched.taskSetFinished(this)
}
}
}
def error(message: String) {
}
}

View file

@ -67,11 +67,20 @@ object BlockFetcherIterator {
throw new IllegalArgumentException("BlocksByAddress is null")
}
protected var _totalBlocks = blocksByAddress.map(_._2.size).sum
logDebug("Getting " + _totalBlocks + " blocks")
// Total number blocks fetched (local + remote). Also number of FetchResults expected
protected var _numBlocksToFetch = 0
protected var startTime = System.currentTimeMillis
protected val localBlockIds = new ArrayBuffer[String]()
protected val remoteBlockIds = new HashSet[String]()
// This represents the number of local blocks, also counting zero-sized blocks
private var numLocal = 0
// BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
protected val localBlocksToFetch = new ArrayBuffer[String]()
// This represents the number of remote blocks, also counting zero-sized blocks
private var numRemote = 0
// BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
protected val remoteBlocksToFetch = new HashSet[String]()
// A queue to hold our results.
protected val results = new LinkedBlockingQueue[FetchResult]
@ -124,13 +133,15 @@ object BlockFetcherIterator {
protected def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
val originalTotalBlocks = _totalBlocks
val remoteRequests = new ArrayBuffer[FetchRequest]
for ((address, blockInfos) <- blocksByAddress) {
if (address == blockManagerId) {
localBlockIds ++= blockInfos.map(_._1)
numLocal = blockInfos.size
// Filter out zero-sized blocks
localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
_numBlocksToFetch += localBlocksToFetch.size
} else {
remoteBlockIds ++= blockInfos.map(_._1)
numRemote += blockInfos.size
// Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
@ -144,10 +155,10 @@ object BlockFetcherIterator {
// Skip empty blocks
if (size > 0) {
curBlocks += ((blockId, size))
remoteBlocksToFetch += blockId
_numBlocksToFetch += 1
curRequestSize += size
} else if (size == 0) {
_totalBlocks -= 1
} else {
} else if (size < 0) {
throw new BlockException(blockId, "Negative block size " + size)
}
if (curRequestSize >= minRequestSize) {
@ -163,8 +174,8 @@ object BlockFetcherIterator {
}
}
}
logInfo("Getting " + _totalBlocks + " non-zero-bytes blocks out of " +
originalTotalBlocks + " blocks")
logInfo("Getting " + _numBlocksToFetch + " non-zero-bytes blocks out of " +
totalBlocks + " blocks")
remoteRequests
}
@ -172,7 +183,7 @@ object BlockFetcherIterator {
// Get the local blocks while remote blocks are being fetched. Note that it's okay to do
// these all at once because they will just memory-map some files, so they won't consume
// any memory that might exceed our maxBytesInFlight
for (id <- localBlockIds) {
for (id <- localBlocksToFetch) {
getLocalFromDisk(id, serializer) match {
case Some(iter) => {
// Pass 0 as size since it's not in flight
@ -198,7 +209,7 @@ object BlockFetcherIterator {
sendRequest(fetchRequests.dequeue())
}
val numGets = remoteBlockIds.size - fetchRequests.size
val numGets = remoteRequests.size - fetchRequests.size
logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
// Get Local Blocks
@ -210,7 +221,7 @@ object BlockFetcherIterator {
//an iterator that will read fetched blocks off the queue as they arrive.
@volatile protected var resultsGotten = 0
override def hasNext: Boolean = resultsGotten < _totalBlocks
override def hasNext: Boolean = resultsGotten < _numBlocksToFetch
override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
@ -227,9 +238,9 @@ object BlockFetcherIterator {
}
// Implementing BlockFetchTracker trait.
override def totalBlocks: Int = _totalBlocks
override def numLocalBlocks: Int = localBlockIds.size
override def numRemoteBlocks: Int = remoteBlockIds.size
override def totalBlocks: Int = numLocal + numRemote
override def numLocalBlocks: Int = numLocal
override def numRemoteBlocks: Int = numRemote
override def remoteFetchTime: Long = _remoteFetchTime
override def fetchWaitTime: Long = _fetchWaitTime
override def remoteBytesRead: Long = _remoteBytesRead
@ -265,7 +276,7 @@ object BlockFetcherIterator {
}).toList
}
//keep this to interrupt the threads when necessary
// keep this to interrupt the threads when necessary
private def stopCopiers() {
for (copier <- copiers) {
copier.interrupt()
@ -291,7 +302,7 @@ object BlockFetcherIterator {
private var copiers: List[_ <: Thread] = null
override def initialize() {
// Split Local Remote Blocks and adjust totalBlocks to include only the non 0-byte blocks
// Split Local Remote Blocks and set numBlocksToFetch
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
for (request <- Utils.randomize(remoteRequests)) {
@ -311,10 +322,7 @@ object BlockFetcherIterator {
override def next(): (String, Option[Iterator[Any]]) = {
resultsGotten += 1
val result = results.take()
// if all the results has been retrieved, shutdown the copiers
if (resultsGotten == _totalBlocks && copiers != null) {
stopCopiers()
}
// If all the results has been retrieved, copiers will exit automatically
(result.blockId, if (result.failed) None else Some(result.deserialize()))
}
}

View file

@ -35,21 +35,25 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private var bs: OutputStream = null
private var objOut: SerializationStream = null
private var lastValidPosition = 0L
private var initialized = false
override def open(): DiskBlockObjectWriter = {
val fos = new FileOutputStream(f, true)
channel = fos.getChannel()
bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos))
bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
initialized = true
this
}
override def close() {
objOut.close()
bs.close()
channel = null
bs = null
objOut = null
if (initialized) {
objOut.close()
bs.close()
channel = null
bs = null
objOut = null
}
// Invoke the close callback handler.
super.close()
}
@ -59,38 +63,48 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
// Flush the partial writes, and set valid length to be the length of the entire file.
// Return the number of bytes written for this commit.
override def commit(): Long = {
// NOTE: Flush the serializer first and then the compressed/buffered output stream
objOut.flush()
bs.flush()
val prevPos = lastValidPosition
lastValidPosition = channel.position()
lastValidPosition - prevPos
if (initialized) {
// NOTE: Flush the serializer first and then the compressed/buffered output stream
objOut.flush()
bs.flush()
val prevPos = lastValidPosition
lastValidPosition = channel.position()
lastValidPosition - prevPos
} else {
// lastValidPosition is zero if stream is uninitialized
lastValidPosition
}
}
override def revertPartialWrites() {
// Discard current writes. We do this by flushing the outstanding writes and
// truncate the file to the last valid position.
objOut.flush()
bs.flush()
channel.truncate(lastValidPosition)
if (initialized) {
// Discard current writes. We do this by flushing the outstanding writes and
// truncate the file to the last valid position.
objOut.flush()
bs.flush()
channel.truncate(lastValidPosition)
}
}
override def write(value: Any) {
if (!initialized) {
open()
}
objOut.writeObject(value)
}
override def size(): Long = lastValidPosition
}
val MAX_DIR_CREATION_ATTEMPTS: Int = 10
val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
var shuffleSender : ShuffleSender = null
private var shuffleSender : ShuffleSender = null
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
// having really large inodes at the top level.
val localDirs = createLocalDirs()
val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
private val localDirs: Array[File] = createLocalDirs()
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
addShutdownHook()
@ -99,7 +113,6 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
}
override def getSize(blockId: String): Long = {
getFile(blockId).length()
}
@ -197,7 +210,10 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private def createFile(blockId: String, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
if (!allowAppendExisting && file.exists()) {
throw new Exception("File for block " + blockId + " already exists on disk: " + file)
// NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
// was rescheduled on the same machine as the old task.
logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
file.delete()
}
file
}
@ -232,8 +248,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
rootDirs.split(",").map(rootDir => {
var foundLocalDir: Boolean = false
rootDirs.split(",").map { rootDir =>
var foundLocalDir = false
var localDir: File = null
var localDirId: String = null
var tries = 0
@ -248,7 +264,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir failed", e)
logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
}
}
if (!foundLocalDir) {
@ -258,7 +274,7 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
logInfo("Created local directory at " + localDir)
localDir
})
}
}
private def addShutdownHook() {
@ -266,15 +282,16 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
try {
localDirs.foreach { localDir =>
localDirs.foreach { localDir =>
try {
if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
} catch {
case t: Throwable =>
logError("Exception while deleting local spark dir: " + localDir, t)
}
if (shuffleSender != null) {
shuffleSender.stop
}
} catch {
case t: Throwable => logError("Exception while deleting local spark dirs", t)
}
if (shuffleSender != null) {
shuffleSender.stop
}
}
})

View file

@ -24,7 +24,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockManager.blockId(shuffleId, bucketId, mapId)
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize).open()
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
}
new ShuffleWriterGroup(mapId, writers)
}

View file

@ -0,0 +1,45 @@
package spark.util
import java.io.Serializable
import java.util.{PriorityQueue => JPriorityQueue}
import scala.collection.generic.Growable
import scala.collection.JavaConverters._
/**
* Bounded priority queue. This class wraps the original PriorityQueue
* class and modifies it such that only the top K elements are retained.
* The top K elements are defined by an implicit Ordering[A].
*/
class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
extends Iterable[A] with Growable[A] with Serializable {
private val underlying = new JPriorityQueue[A](maxSize, ord)
override def iterator: Iterator[A] = underlying.iterator.asScala
override def ++=(xs: TraversableOnce[A]): this.type = {
xs.foreach { this += _ }
this
}
override def +=(elem: A): this.type = {
if (size < maxSize) underlying.offer(elem)
else maybeReplaceLowest(elem)
this
}
override def +=(elem1: A, elem2: A, elems: A*): this.type = {
this += elem1 += elem2 ++= elems
}
override def clear() { underlying.clear() }
private def maybeReplaceLowest(a: A): Boolean = {
val head = underlying.peek()
if (head != null && ord.gt(a, head)) {
underlying.poll()
underlying.offer(a)
} else false
}
}

View file

@ -37,17 +37,23 @@ class StatCounter(values: TraversableOnce[Double]) extends Serializable {
if (other == this) {
merge(other.copy()) // Avoid overwriting fields in a weird order
} else {
val delta = other.mu - mu
if (other.n * 10 < n) {
mu = mu + (delta * other.n) / (n + other.n)
} else if (n * 10 < other.n) {
mu = other.mu - (delta * n) / (n + other.n)
} else {
mu = (mu * n + other.mu * other.n) / (n + other.n)
if (n == 0) {
mu = other.mu
m2 = other.m2
n = other.n
} else if (other.n != 0) {
val delta = other.mu - mu
if (other.n * 10 < n) {
mu = mu + (delta * other.n) / (n + other.n)
} else if (n * 10 < other.n) {
mu = other.mu - (delta * n) / (n + other.n)
} else {
mu = (mu * n + other.mu * other.n) / (n + other.n)
}
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
n += other.n
}
m2 += other.m2 + (delta * delta * n * other.n) / (n + other.n)
n += other.n
this
this
}
}

View file

@ -27,6 +27,16 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
}
test("basic checkpointing") {
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
flatMappedRDD.checkpoint()
assert(flatMappedRDD.dependencies.head.rdd == parCollection)
val result = flatMappedRDD.collect()
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)
}
test("RDDs with one-to-one dependencies") {
testCheckpointing(_.map(x => x.toString))
testCheckpointing(_.flatMap(x => 1 to x))

View file

@ -7,6 +7,8 @@ import scala.io.Source
import com.google.common.io.Files
import org.scalatest.FunSuite
import org.apache.hadoop.io._
import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodec, GzipCodec}
import SparkContext._
@ -26,6 +28,28 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(sc.textFile(outputDir).collect().toList === List("1", "2", "3", "4"))
}
test("text files (compressed)") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val normalDir = new File(tempDir, "output_normal").getAbsolutePath
val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
val codec = new DefaultCodec()
val data = sc.parallelize("a" * 10000, 1)
data.saveAsTextFile(normalDir)
data.saveAsTextFile(compressedOutputDir, classOf[DefaultCodec])
val normalFile = new File(normalDir, "part-00000")
val normalContent = sc.textFile(normalDir).collect
assert(normalContent === Array.fill(10000)("a"))
val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
val compressedContent = sc.textFile(compressedOutputDir).collect
assert(compressedContent === Array.fill(10000)("a"))
assert(compressedFile.length < normalFile.length)
}
test("SequenceFiles") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
@ -37,6 +61,28 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
}
test("SequenceFile (compressed)") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()
val normalDir = new File(tempDir, "output_normal").getAbsolutePath
val compressedOutputDir = new File(tempDir, "output_compressed").getAbsolutePath
val codec = new DefaultCodec()
val data = sc.parallelize(Seq.fill(100)("abc"), 1).map(x => (x, x))
data.saveAsSequenceFile(normalDir)
data.saveAsSequenceFile(compressedOutputDir, Some(classOf[DefaultCodec]))
val normalFile = new File(normalDir, "part-00000")
val normalContent = sc.sequenceFile[String, String](normalDir).collect
assert(normalContent === Array.fill(100)("abc", "abc"))
val compressedFile = new File(compressedOutputDir, "part-00000" + codec.getDefaultExtension)
val compressedContent = sc.sequenceFile[String, String](compressedOutputDir).collect
assert(compressedContent === Array.fill(100)("abc", "abc"))
assert(compressedFile.length < normalFile.length)
}
test("SequenceFile with writable key") {
sc = new SparkContext("local", "test")
val tempDir = Files.createTempDir()

View file

@ -8,6 +8,7 @@ import java.util.*;
import scala.Tuple2;
import com.google.common.base.Charsets;
import org.apache.hadoop.io.compress.DefaultCodec;
import com.google.common.io.Files;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.Text;
@ -473,6 +474,19 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(expected, readRDD.collect());
}
@Test
public void textFilesCompressed() throws IOException {
File tempDir = Files.createTempDir();
String outputDir = new File(tempDir, "output").getAbsolutePath();
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4));
rdd.saveAsTextFile(outputDir, DefaultCodec.class);
// Try reading it in as a text file RDD
List<String> expected = Arrays.asList("1", "2", "3", "4");
JavaRDD<String> readRDD = sc.textFile(outputDir);
Assert.assertEquals(expected, readRDD.collect());
}
@Test
public void sequenceFile() {
File tempDir = Files.createTempDir();
@ -619,6 +633,37 @@ public class JavaAPISuite implements Serializable {
}).collect().toString());
}
@Test
public void hadoopFileCompressed() {
File tempDir = Files.createTempDir();
String outputDir = new File(tempDir, "output_compressed").getAbsolutePath();
List<Tuple2<Integer, String>> pairs = Arrays.asList(
new Tuple2<Integer, String>(1, "a"),
new Tuple2<Integer, String>(2, "aa"),
new Tuple2<Integer, String>(3, "aaa")
);
JavaPairRDD<Integer, String> rdd = sc.parallelizePairs(pairs);
rdd.map(new PairFunction<Tuple2<Integer, String>, IntWritable, Text>() {
@Override
public Tuple2<IntWritable, Text> call(Tuple2<Integer, String> pair) {
return new Tuple2<IntWritable, Text>(new IntWritable(pair._1()), new Text(pair._2()));
}
}).saveAsHadoopFile(outputDir, IntWritable.class, Text.class, SequenceFileOutputFormat.class,
DefaultCodec.class);
JavaPairRDD<IntWritable, Text> output = sc.hadoopFile(outputDir,
SequenceFileInputFormat.class, IntWritable.class, Text.class);
Assert.assertEquals(pairs.toString(), output.map(new Function<Tuple2<IntWritable, Text>,
String>() {
@Override
public String call(Tuple2<IntWritable, Text> x) {
return x.toString();
}
}).collect().toString());
}
@Test
public void zip() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));

View file

@ -0,0 +1,287 @@
package spark
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
import org.scalatest.prop.Checkers
import org.scalacheck.Arbitrary._
import org.scalacheck.Gen
import org.scalacheck.Prop._
import com.google.common.io.Files
import spark.rdd.ShuffledRDD
import spark.SparkContext._
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
test("groupByKey") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
val valuesFor1 = groups.find(_._1 == 1).get._2
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("groupByKey with duplicates") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
val valuesFor1 = groups.find(_._1 == 1).get._2
assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("groupByKey with negative key hash codes") {
val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
val valuesForMinus1 = groups.find(_._1 == -1).get._2
assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("groupByKey with many output partitions") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey(10).collect()
assert(groups.size === 2)
val valuesFor1 = groups.find(_._1 == 1).get._2
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("reduceByKey") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
test("reduceByKey with collectAsMap") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collectAsMap()
assert(sums.size === 2)
assert(sums(1) === 7)
assert(sums(2) === 1)
}
test("reduceByKey with many output partitons") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
test("reduceByKey with partitioner") {
val p = new Partitioner() {
def numPartitions = 2
def getPartition(key: Any) = key.asInstanceOf[Int]
}
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
val sums = pairs.reduceByKey(_+_)
assert(sums.collect().toSet === Set((1, 4), (0, 1)))
assert(sums.partitioner === Some(p))
// count the dependencies to make sure there is only 1 ShuffledRDD
val deps = new HashSet[RDD[_]]()
def visit(r: RDD[_]) {
for (dep <- r.dependencies) {
deps += dep.rdd
visit(dep.rdd)
}
}
visit(sums)
assert(deps.size === 2) // ShuffledRDD, ParallelCollection
}
test("join") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 4)
assert(joined.toSet === Set(
(1, (1, 'x')),
(1, (2, 'x')),
(2, (1, 'y')),
(2, (1, 'z'))
))
}
test("join all-to-all") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 6)
assert(joined.toSet === Set(
(1, (1, 'x')),
(1, (1, 'y')),
(1, (2, 'x')),
(1, (2, 'y')),
(1, (3, 'x')),
(1, (3, 'y'))
))
}
test("leftOuterJoin") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.leftOuterJoin(rdd2).collect()
assert(joined.size === 5)
assert(joined.toSet === Set(
(1, (1, Some('x'))),
(1, (2, Some('x'))),
(2, (1, Some('y'))),
(2, (1, Some('z'))),
(3, (1, None))
))
}
test("rightOuterJoin") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.rightOuterJoin(rdd2).collect()
assert(joined.size === 5)
assert(joined.toSet === Set(
(1, (Some(1), 'x')),
(1, (Some(2), 'x')),
(2, (Some(1), 'y')),
(2, (Some(1), 'z')),
(4, (None, 'w'))
))
}
test("join with no matches") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 0)
}
test("join with many output partitions") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.join(rdd2, 10).collect()
assert(joined.size === 4)
assert(joined.toSet === Set(
(1, (1, 'x')),
(1, (2, 'x')),
(2, (1, 'y')),
(2, (1, 'z'))
))
}
test("groupWith") {
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.groupWith(rdd2).collect()
assert(joined.size === 4)
assert(joined.toSet === Set(
(1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
(2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
(3, (ArrayBuffer(1), ArrayBuffer())),
(4, (ArrayBuffer(), ArrayBuffer('w')))
))
}
test("zero-partition RDD") {
val emptyDir = Files.createTempDir()
val file = sc.textFile(emptyDir.getAbsolutePath)
assert(file.partitions.size == 0)
assert(file.collect().toList === Nil)
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
}
test("keys and values") {
val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
assert(rdd.keys.collect().toList === List(1, 2))
assert(rdd.values.collect().toList === List("a", "b"))
}
test("default partitioner uses partition size") {
// specify 2000 partitions
val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
// do a map, which loses the partitioner
val b = a.map(a => (a, (a * 2).toString))
// then a group by, and see we didn't revert to 2 partitions
val c = b.groupByKey()
assert(c.partitions.size === 2000)
}
test("default partitioner uses largest partitioner") {
val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
val c = a.join(b)
assert(c.partitions.size === 2000)
}
test("subtract") {
val a = sc.parallelize(Array(1, 2, 3), 2)
val b = sc.parallelize(Array(2, 3, 4), 4)
val c = a.subtract(b)
assert(c.collect().toSet === Set(1))
assert(c.partitions.size === a.partitions.size)
}
test("subtract with narrow dependency") {
// use a deterministic partitioner
val p = new Partitioner() {
def numPartitions = 5
def getPartition(key: Any) = key.asInstanceOf[Int]
}
// partitionBy so we have a narrow dependency
val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
// more partitions/no partitioner so a shuffle dependency
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
val c = a.subtract(b)
assert(c.collect().toSet === Set((1, "a"), (3, "c")))
// Ideally we could keep the original partitioner...
assert(c.partitioner === None)
}
test("subtractByKey") {
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
val c = a.subtractByKey(b)
assert(c.collect().toSet === Set((1, "a"), (1, "a")))
assert(c.partitions.size === a.partitions.size)
}
test("subtractByKey with narrow dependency") {
// use a deterministic partitioner
val p = new Partitioner() {
def numPartitions = 5
def getPartition(key: Any) = key.asInstanceOf[Int]
}
// partitionBy so we have a narrow dependency
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
// more partitions/no partitioner so a shuffle dependency
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
val c = a.subtractByKey(b)
assert(c.collect().toSet === Set((1, "a"), (1, "a")))
assert(c.partitioner.get === p)
}
test("foldByKey") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.foldByKey(0)(_+_).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
test("foldByKey with mutable result type") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val bufs = pairs.mapValues(v => ArrayBuffer(v)).cache()
// Fold the values using in-place mutation
val sums = bufs.foldByKey(new ArrayBuffer[Int])(_ ++= _).collect()
assert(sums.toSet === Set((1, ArrayBuffer(1, 2, 3, 1)), (2, ArrayBuffer(1))))
// Check that the mutable objects in the original RDD were not changed
assert(bufs.collect().toSet === Set(
(1, ArrayBuffer(1)),
(1, ArrayBuffer(2)),
(1, ArrayBuffer(3)),
(1, ArrayBuffer(1)),
(2, ArrayBuffer(1))))
}
}

View file

@ -1,13 +1,13 @@
package spark
import org.scalatest.FunSuite
import scala.collection.mutable.ArrayBuffer
import SparkContext._
import spark.util.StatCounter
import scala.math.abs
class PartitioningSuite extends FunSuite with SharedSparkContext {
class PartitioningSuite extends FunSuite with LocalSparkContext {
test("HashPartitioner equality") {
val p2 = new HashPartitioner(2)
val p4 = new HashPartitioner(4)
@ -21,8 +21,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("RangePartitioner equality") {
sc = new SparkContext("local", "test")
// Make an RDD where all the elements are the same so that the partition range bounds
// are deterministically all the same.
val rdd = sc.parallelize(Seq(1, 1, 1, 1)).map(x => (x, x))
@ -50,7 +48,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("HashPartitioner not equal to RangePartitioner") {
sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10).map(x => (x, x))
val rangeP2 = new RangePartitioner(2, rdd)
val hashP2 = new HashPartitioner(2)
@ -61,8 +58,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("partitioner preservation") {
sc = new SparkContext("local", "test")
val rdd = sc.parallelize(1 to 10, 4).map(x => (x, x))
val grouped2 = rdd.groupByKey(2)
@ -101,7 +96,6 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
}
test("partitioning Java arrays should fail") {
sc = new SparkContext("local", "test")
val arrs: RDD[Array[Int]] = sc.parallelize(Array(1, 2, 3, 4), 2).map(x => Array(x))
val arrPairs: RDD[(Array[Int], Int)] =
sc.parallelize(Array(1, 2, 3, 4), 2).map(x => (Array(x), x))
@ -120,4 +114,20 @@ class PartitioningSuite extends FunSuite with LocalSparkContext {
assert(intercept[SparkException]{ arrPairs.reduceByKeyLocally(_ + _) }.getMessage.contains("array"))
assert(intercept[SparkException]{ arrPairs.reduceByKey(_ + _) }.getMessage.contains("array"))
}
test("zero-length partitions should be correctly handled") {
// Create RDD with some consecutive empty partitions (including the "first" one)
val rdd: RDD[Double] = sc
.parallelize(Array(-1.0, -1.0, -1.0, -1.0, 2.0, 4.0, -1.0, -1.0), 8)
.filter(_ >= 0.0)
// Run the partitions, including the consecutive empty ones, through StatCounter
val stats: StatCounter = rdd.stats();
assert(abs(6.0 - stats.sum) < 0.01);
assert(abs(6.0/2 - rdd.mean) < 0.01);
assert(abs(1.0 - rdd.variance) < 0.01);
assert(abs(1.0 - rdd.stdev) < 0.01);
// Add other tests here for classes that should be able to handle empty partitions correctly
}
}

View file

@ -3,10 +3,9 @@ package spark
import org.scalatest.FunSuite
import SparkContext._
class PipedRDDSuite extends FunSuite with LocalSparkContext {
class PipedRDDSuite extends FunSuite with SharedSparkContext {
test("basic pipe") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("cat"))
@ -19,8 +18,45 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
assert(c(3) === "4")
}
test("advanced pipe") {
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val bl = sc.broadcast(List("0"))
val piped = nums.pipe(Seq("cat"),
Map[String, String](),
(f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
(i:Int, f: String=> Unit) => f(i + "_"))
val c = piped.collect()
assert(c.size === 8)
assert(c(0) === "0")
assert(c(1) === "\u0001")
assert(c(2) === "1_")
assert(c(3) === "2_")
assert(c(4) === "0")
assert(c(5) === "\u0001")
assert(c(6) === "3_")
assert(c(7) === "4_")
val nums1 = sc.makeRDD(Array("a\t1", "b\t2", "a\t3", "b\t4"), 2)
val d = nums1.groupBy(str=>str.split("\t")(0)).
pipe(Seq("cat"),
Map[String, String](),
(f: String => Unit) => {bl.value.map(f(_));f("\u0001")},
(i:Tuple2[String, Seq[String]], f: String=> Unit) => {for (e <- i._2){ f(e + "_")}}).collect()
assert(d.size === 8)
assert(d(0) === "0")
assert(d(1) === "\u0001")
assert(d(2) === "b\t2_")
assert(d(3) === "b\t4_")
assert(d(4) === "0")
assert(d(5) === "\u0001")
assert(d(6) === "a\t1_")
assert(d(7) === "a\t3_")
}
test("pipe with env variable") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe(Seq("printenv", "MY_TEST_ENV"), Map("MY_TEST_ENV" -> "LALALA"))
val c = piped.collect()
@ -30,7 +66,6 @@ class PipedRDDSuite extends FunSuite with LocalSparkContext {
}
test("pipe with non-zero exit status") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val piped = nums.pipe("cat nonexistent_file")
intercept[SparkException] {

View file

@ -7,10 +7,9 @@ import org.scalatest.time.{Span, Millis}
import spark.SparkContext._
import spark.rdd.{CoalescedRDD, CoGroupedRDD, EmptyRDD, PartitionPruningRDD, ShuffledRDD}
class RDDSuite extends FunSuite with LocalSparkContext {
class RDDSuite extends FunSuite with SharedSparkContext {
test("basic operations") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(nums.collect().toList === List(1, 2, 3, 4))
val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
@ -46,7 +45,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("SparkContext.union") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
assert(sc.union(nums).collect().toList === List(1, 2, 3, 4))
assert(sc.union(nums, nums).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
@ -55,7 +53,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("aggregate") {
sc = new SparkContext("local", "test")
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
type StringMap = HashMap[String, Int]
val emptyMap = new StringMap {
@ -75,57 +72,14 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5)))
}
test("basic checkpointing") {
import java.io.File
val checkpointDir = File.createTempFile("temp", "")
checkpointDir.delete()
sc = new SparkContext("local", "test")
sc.setCheckpointDir(checkpointDir.toString)
val parCollection = sc.makeRDD(1 to 4)
val flatMappedRDD = parCollection.flatMap(x => 1 to x)
flatMappedRDD.checkpoint()
assert(flatMappedRDD.dependencies.head.rdd == parCollection)
val result = flatMappedRDD.collect()
Thread.sleep(1000)
assert(flatMappedRDD.dependencies.head.rdd != parCollection)
assert(flatMappedRDD.collect() === result)
checkpointDir.deleteOnExit()
}
test("basic caching") {
sc = new SparkContext("local", "test")
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
assert(rdd.collect().toList === List(1, 2, 3, 4))
assert(rdd.collect().toList === List(1, 2, 3, 4))
assert(rdd.collect().toList === List(1, 2, 3, 4))
}
test("unpersist RDD") {
sc = new SparkContext("local", "test")
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
rdd.count
assert(sc.persistentRdds.isEmpty === false)
rdd.unpersist()
assert(sc.persistentRdds.isEmpty === true)
failAfter(Span(3000, Millis)) {
try {
while (! sc.getRDDStorageInfo.isEmpty) {
Thread.sleep(200)
}
} catch {
case _ => { Thread.sleep(10) }
// Do nothing. We might see exceptions because block manager
// is racing this thread to remove entries from the driver.
}
}
assert(sc.getRDDStorageInfo.isEmpty === true)
}
test("caching with failures") {
sc = new SparkContext("local", "test")
val onlySplit = new Partition { override def index: Int = 0 }
var shouldFail = true
val rdd = new RDD[Int](sc, Nil) {
@ -148,7 +102,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("empty RDD") {
sc = new SparkContext("local", "test")
val empty = new EmptyRDD[Int](sc)
assert(empty.count === 0)
assert(empty.collect().size === 0)
@ -168,37 +121,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("cogrouped RDDs") {
sc = new SparkContext("local", "test")
val rdd1 = sc.makeRDD(Array((1, "one"), (1, "another one"), (2, "two"), (3, "three")), 2)
val rdd2 = sc.makeRDD(Array((1, "one1"), (1, "another one1"), (2, "two1")), 2)
// Use cogroup function
val cogrouped = rdd1.cogroup(rdd2).collectAsMap()
assert(cogrouped(1) === (Seq("one", "another one"), Seq("one1", "another one1")))
assert(cogrouped(2) === (Seq("two"), Seq("two1")))
assert(cogrouped(3) === (Seq("three"), Seq()))
// Construct CoGroupedRDD directly, with map side combine enabled
val cogrouped1 = new CoGroupedRDD[Int](
Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
new HashPartitioner(3),
true).collectAsMap()
assert(cogrouped1(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
assert(cogrouped1(2).toSeq === Seq(Seq("two"), Seq("two1")))
assert(cogrouped1(3).toSeq === Seq(Seq("three"), Seq()))
// Construct CoGroupedRDD directly, with map side combine disabled
val cogrouped2 = new CoGroupedRDD[Int](
Seq(rdd1.asInstanceOf[RDD[(Int, Any)]], rdd2.asInstanceOf[RDD[(Int, Any)]]),
new HashPartitioner(3),
false).collectAsMap()
assert(cogrouped2(1).toSeq === Seq(Seq("one", "another one"), Seq("one1", "another one1")))
assert(cogrouped2(2).toSeq === Seq(Seq("two"), Seq("two1")))
assert(cogrouped2(3).toSeq === Seq(Seq("three"), Seq()))
}
test("coalesced RDDs") {
sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
val coalesced1 = data.coalesce(2)
@ -236,7 +158,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("zipped RDDs") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val zipped = nums.zip(nums.map(_ + 1.0))
assert(zipped.glom().map(_.toList).collect().toList ===
@ -248,7 +169,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
}
test("partition pruning") {
sc = new SparkContext("local", "test")
val data = sc.parallelize(1 to 10, 10)
// Note that split number starts from 0, so > 8 means only 10th partition left.
val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8)
@ -260,7 +180,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("mapWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.mapWith(
(index: Int) => new Random(index + 42))
@ -279,7 +198,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("flatMapWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2)
val randoms = ones.flatMapWith(
(index: Int) => new Random(index + 42))
@ -301,7 +219,6 @@ class RDDSuite extends FunSuite with LocalSparkContext {
test("filterWith") {
import java.util.Random
sc = new SparkContext("local", "test")
val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2)
val sample = ints.filterWith(
(index: Int) => new Random(index + 42))
@ -317,4 +234,21 @@ class RDDSuite extends FunSuite with LocalSparkContext {
assert(sample.size === checkSample.size)
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
test("top with predefined ordering") {
val nums = Array.range(1, 100000)
val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)
val topK = ints.top(5)
assert(topK.size === 5)
assert(topK.sorted === nums.sorted.takeRight(5))
}
test("top with custom ordering") {
val words = Vector("a", "b", "c", "d")
implicit val ord = implicitly[Ordering[String]].reverse
val rdd = sc.makeRDD(words, 2)
val topK = rdd.top(2)
assert(topK.size === 2)
assert(topK.sorted === Array("b", "a"))
}
}

View file

@ -0,0 +1,25 @@
package spark
import org.scalatest.Suite
import org.scalatest.BeforeAndAfterAll
/** Shares a local `SparkContext` between all tests in a suite and closes it at the end */
trait SharedSparkContext extends BeforeAndAfterAll { self: Suite =>
@transient private var _sc: SparkContext = _
def sc: SparkContext = _sc
override def beforeAll() {
_sc = new SparkContext("local", "test")
super.beforeAll()
}
override def afterAll() {
if (_sc != null) {
LocalSparkContext.stop(_sc)
_sc = null
}
super.afterAll()
}
}

View file

@ -16,54 +16,9 @@ import spark.rdd.ShuffledRDD
import spark.SparkContext._
class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
test("groupByKey") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
val valuesFor1 = groups.find(_._1 == 1).get._2
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("groupByKey with duplicates") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
val valuesFor1 = groups.find(_._1 == 1).get._2
assert(valuesFor1.toList.sorted === List(1, 1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("groupByKey with negative key hash codes") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((-1, 1), (-1, 2), (-1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
assert(groups.size === 2)
val valuesForMinus1 = groups.find(_._1 == -1).get._2
assert(valuesForMinus1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("groupByKey with many output partitions") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey(10).collect()
assert(groups.size === 2)
val valuesFor1 = groups.find(_._1 == 1).get._2
assert(valuesFor1.toList.sorted === List(1, 2, 3))
val valuesFor2 = groups.find(_._1 == 2).get._2
assert(valuesFor2.toList.sorted === List(1))
}
test("groupByKey with compression") {
try {
System.setProperty("spark.blockManager.compress", "true")
System.setProperty("spark.shuffle.compress", "true")
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4)
val groups = pairs.groupByKey(4).collect()
@ -77,234 +32,6 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
}
}
test("reduceByKey") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
test("reduceByKey with collectAsMap") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_).collectAsMap()
assert(sums.size === 2)
assert(sums(1) === 7)
assert(sums(2) === 1)
}
test("reduceByKey with many output partitons") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (1, 1), (2, 1)))
val sums = pairs.reduceByKey(_+_, 10).collect()
assert(sums.toSet === Set((1, 7), (2, 1)))
}
test("reduceByKey with partitioner") {
sc = new SparkContext("local", "test")
val p = new Partitioner() {
def numPartitions = 2
def getPartition(key: Any) = key.asInstanceOf[Int]
}
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 1), (0, 1))).partitionBy(p)
val sums = pairs.reduceByKey(_+_)
assert(sums.collect().toSet === Set((1, 4), (0, 1)))
assert(sums.partitioner === Some(p))
// count the dependencies to make sure there is only 1 ShuffledRDD
val deps = new HashSet[RDD[_]]()
def visit(r: RDD[_]) {
for (dep <- r.dependencies) {
deps += dep.rdd
visit(dep.rdd)
}
}
visit(sums)
assert(deps.size === 2) // ShuffledRDD, ParallelCollection
}
test("join") {
sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 4)
assert(joined.toSet === Set(
(1, (1, 'x')),
(1, (2, 'x')),
(2, (1, 'y')),
(2, (1, 'z'))
))
}
test("join all-to-all") {
sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (1, 3)))
val rdd2 = sc.parallelize(Array((1, 'x'), (1, 'y')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 6)
assert(joined.toSet === Set(
(1, (1, 'x')),
(1, (1, 'y')),
(1, (2, 'x')),
(1, (2, 'y')),
(1, (3, 'x')),
(1, (3, 'y'))
))
}
test("leftOuterJoin") {
sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.leftOuterJoin(rdd2).collect()
assert(joined.size === 5)
assert(joined.toSet === Set(
(1, (1, Some('x'))),
(1, (2, Some('x'))),
(2, (1, Some('y'))),
(2, (1, Some('z'))),
(3, (1, None))
))
}
test("rightOuterJoin") {
sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.rightOuterJoin(rdd2).collect()
assert(joined.size === 5)
assert(joined.toSet === Set(
(1, (Some(1), 'x')),
(1, (Some(2), 'x')),
(2, (Some(1), 'y')),
(2, (Some(1), 'z')),
(4, (None, 'w'))
))
}
test("join with no matches") {
sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((4, 'x'), (5, 'y'), (5, 'z'), (6, 'w')))
val joined = rdd1.join(rdd2).collect()
assert(joined.size === 0)
}
test("join with many output partitions") {
sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.join(rdd2, 10).collect()
assert(joined.size === 4)
assert(joined.toSet === Set(
(1, (1, 'x')),
(1, (2, 'x')),
(2, (1, 'y')),
(2, (1, 'z'))
))
}
test("groupWith") {
sc = new SparkContext("local", "test")
val rdd1 = sc.parallelize(Array((1, 1), (1, 2), (2, 1), (3, 1)))
val rdd2 = sc.parallelize(Array((1, 'x'), (2, 'y'), (2, 'z'), (4, 'w')))
val joined = rdd1.groupWith(rdd2).collect()
assert(joined.size === 4)
assert(joined.toSet === Set(
(1, (ArrayBuffer(1, 2), ArrayBuffer('x'))),
(2, (ArrayBuffer(1), ArrayBuffer('y', 'z'))),
(3, (ArrayBuffer(1), ArrayBuffer())),
(4, (ArrayBuffer(), ArrayBuffer('w')))
))
}
test("zero-partition RDD") {
sc = new SparkContext("local", "test")
val emptyDir = Files.createTempDir()
val file = sc.textFile(emptyDir.getAbsolutePath)
assert(file.partitions.size == 0)
assert(file.collect().toList === Nil)
// Test that a shuffle on the file works, because this used to be a bug
assert(file.map(line => (line, 1)).reduceByKey(_ + _).collect().toList === Nil)
}
test("keys and values") {
sc = new SparkContext("local", "test")
val rdd = sc.parallelize(Array((1, "a"), (2, "b")))
assert(rdd.keys.collect().toList === List(1, 2))
assert(rdd.values.collect().toList === List("a", "b"))
}
test("default partitioner uses partition size") {
sc = new SparkContext("local", "test")
// specify 2000 partitions
val a = sc.makeRDD(Array(1, 2, 3, 4), 2000)
// do a map, which loses the partitioner
val b = a.map(a => (a, (a * 2).toString))
// then a group by, and see we didn't revert to 2 partitions
val c = b.groupByKey()
assert(c.partitions.size === 2000)
}
test("default partitioner uses largest partitioner") {
sc = new SparkContext("local", "test")
val a = sc.makeRDD(Array((1, "a"), (2, "b")), 2)
val b = sc.makeRDD(Array((1, "a"), (2, "b")), 2000)
val c = a.join(b)
assert(c.partitions.size === 2000)
}
test("subtract") {
sc = new SparkContext("local", "test")
val a = sc.parallelize(Array(1, 2, 3), 2)
val b = sc.parallelize(Array(2, 3, 4), 4)
val c = a.subtract(b)
assert(c.collect().toSet === Set(1))
assert(c.partitions.size === a.partitions.size)
}
test("subtract with narrow dependency") {
sc = new SparkContext("local", "test")
// use a deterministic partitioner
val p = new Partitioner() {
def numPartitions = 5
def getPartition(key: Any) = key.asInstanceOf[Int]
}
// partitionBy so we have a narrow dependency
val a = sc.parallelize(Array((1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
// more partitions/no partitioner so a shuffle dependency
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
val c = a.subtract(b)
assert(c.collect().toSet === Set((1, "a"), (3, "c")))
// Ideally we could keep the original partitioner...
assert(c.partitioner === None)
}
test("subtractByKey") {
sc = new SparkContext("local", "test")
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 2)
val b = sc.parallelize(Array((2, 20), (3, 30), (4, 40)), 4)
val c = a.subtractByKey(b)
assert(c.collect().toSet === Set((1, "a"), (1, "a")))
assert(c.partitions.size === a.partitions.size)
}
test("subtractByKey with narrow dependency") {
sc = new SparkContext("local", "test")
// use a deterministic partitioner
val p = new Partitioner() {
def numPartitions = 5
def getPartition(key: Any) = key.asInstanceOf[Int]
}
// partitionBy so we have a narrow dependency
val a = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c"))).partitionBy(p)
// more partitions/no partitioner so a shuffle dependency
val b = sc.parallelize(Array((2, "b"), (3, "cc"), (4, "d")), 4)
val c = a.subtractByKey(b)
assert(c.collect().toSet === Set((1, "a"), (1, "a")))
assert(c.partitioner.get === p)
}
test("shuffle non-zero block size") {
sc = new SparkContext("local-cluster[2,1,512]", "test")
val NUM_BLOCKS = 3
@ -367,6 +94,30 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
assert(nonEmptyBlocks.size <= 4)
}
test("zero sized blocks without kryo") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
sc = new SparkContext("local-cluster[2,1,512]", "test")
// 10 partitions from 4 keys
val NUM_BLOCKS = 10
val a = sc.parallelize(1 to 4, NUM_BLOCKS)
val b = a.map(x => (x, x*2))
// NOTE: The default Java serializer should create zero-sized blocks
val c = new ShuffledRDD(b, new HashPartitioner(10))
val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id)
statuses.map(x => x._2)
}
val nonEmptyBlocks = blockSizes.filter(x => x > 0)
// We should have at most 4 non-zero sized partitions
assert(nonEmptyBlocks.size <= 4)
}
}
object ShuffleSuite {

View file

@ -35,7 +35,7 @@ class SizeEstimatorSuite
var oldOops: String = _
override def beforeAll() {
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
oldArch = System.setProperty("os.arch", "amd64")
oldOops = System.setProperty("spark.test.useCompressedOops", "true")
}
@ -46,54 +46,54 @@ class SizeEstimatorSuite
}
test("simple classes") {
expect(16)(SizeEstimator.estimate(new DummyClass1))
expect(16)(SizeEstimator.estimate(new DummyClass2))
expect(24)(SizeEstimator.estimate(new DummyClass3))
expect(24)(SizeEstimator.estimate(new DummyClass4(null)))
expect(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
assert(SizeEstimator.estimate(new DummyClass1) === 16)
assert(SizeEstimator.estimate(new DummyClass2) === 16)
assert(SizeEstimator.estimate(new DummyClass3) === 24)
assert(SizeEstimator.estimate(new DummyClass4(null)) === 24)
assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48)
}
// NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
// (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("strings") {
expect(40)(SizeEstimator.estimate(DummyString("")))
expect(48)(SizeEstimator.estimate(DummyString("a")))
expect(48)(SizeEstimator.estimate(DummyString("ab")))
expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
assert(SizeEstimator.estimate(DummyString("")) === 40)
assert(SizeEstimator.estimate(DummyString("a")) === 48)
assert(SizeEstimator.estimate(DummyString("ab")) === 48)
assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
}
test("primitive arrays") {
expect(32)(SizeEstimator.estimate(new Array[Byte](10)))
expect(40)(SizeEstimator.estimate(new Array[Char](10)))
expect(40)(SizeEstimator.estimate(new Array[Short](10)))
expect(56)(SizeEstimator.estimate(new Array[Int](10)))
expect(96)(SizeEstimator.estimate(new Array[Long](10)))
expect(56)(SizeEstimator.estimate(new Array[Float](10)))
expect(96)(SizeEstimator.estimate(new Array[Double](10)))
expect(4016)(SizeEstimator.estimate(new Array[Int](1000)))
expect(8016)(SizeEstimator.estimate(new Array[Long](1000)))
assert(SizeEstimator.estimate(new Array[Byte](10)) === 32)
assert(SizeEstimator.estimate(new Array[Char](10)) === 40)
assert(SizeEstimator.estimate(new Array[Short](10)) === 40)
assert(SizeEstimator.estimate(new Array[Int](10)) === 56)
assert(SizeEstimator.estimate(new Array[Long](10)) === 96)
assert(SizeEstimator.estimate(new Array[Float](10)) === 56)
assert(SizeEstimator.estimate(new Array[Double](10)) === 96)
assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016)
assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016)
}
test("object arrays") {
// Arrays containing nulls should just have one pointer per element
expect(56)(SizeEstimator.estimate(new Array[String](10)))
expect(56)(SizeEstimator.estimate(new Array[AnyRef](10)))
assert(SizeEstimator.estimate(new Array[String](10)) === 56)
assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56)
// For object arrays with non-null elements, each object should take one pointer plus
// however many bytes that class takes. (Note that Array.fill calls the code in its
// second parameter separately for each object, so we get distinct objects.)
expect(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)))
expect(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)))
expect(296)(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)))
expect(56)(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)))
assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216)
assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216)
assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296)
assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56)
// Past size 100, our samples 100 elements, but we should still get the right size.
expect(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)))
assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016)
// If an array contains the *same* element many times, we should only count it once.
val d1 = new DummyClass1
expect(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object
expect(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object
assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object
assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object
// Same thing with huge array containing the same element many times. Note that this won't
// return exactly 4032 because it can't tell that *all* the elements will equal the first
@ -111,10 +111,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
expect(40)(SizeEstimator.estimate(DummyString("")))
expect(48)(SizeEstimator.estimate(DummyString("a")))
expect(48)(SizeEstimator.estimate(DummyString("ab")))
expect(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
assert(SizeEstimator.estimate(DummyString("")) === 40)
assert(SizeEstimator.estimate(DummyString("a")) === 48)
assert(SizeEstimator.estimate(DummyString("ab")) === 48)
assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
resetOrClear("os.arch", arch)
}
@ -128,10 +128,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
expect(56)(SizeEstimator.estimate(DummyString("")))
expect(64)(SizeEstimator.estimate(DummyString("a")))
expect(64)(SizeEstimator.estimate(DummyString("ab")))
expect(72)(SizeEstimator.estimate(DummyString("abcdefgh")))
assert(SizeEstimator.estimate(DummyString("")) === 56)
assert(SizeEstimator.estimate(DummyString("a")) === 64)
assert(SizeEstimator.estimate(DummyString("ab")) === 64)
assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72)
resetOrClear("os.arch", arch)
resetOrClear("spark.test.useCompressedOops", oops)

View file

@ -5,16 +5,14 @@ import org.scalatest.BeforeAndAfter
import org.scalatest.matchers.ShouldMatchers
import SparkContext._
class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers with Logging {
class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging {
test("sortByKey") {
sc = new SparkContext("local", "test")
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
assert(pairs.sortByKey().collect() === Array((0,0), (1,0), (2,0), (3,0)))
}
test("large array") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@ -24,7 +22,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("large array with one split") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@ -32,9 +29,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
assert(sorted.partitions.size === 1)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
test("large array with many partitions") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@ -42,9 +38,8 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
assert(sorted.partitions.size === 20)
assert(sorted.collect() === pairArr.sortBy(_._1))
}
test("sort descending") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@ -52,15 +47,13 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("sort descending with one split") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 1)
assert(pairs.sortByKey(false, 1).collect() === pairArr.sortWith((x, y) => x._1 > y._1))
}
test("sort descending with many partitions") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(1000) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 2)
@ -68,7 +61,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("more partitions than elements") {
sc = new SparkContext("local", "test")
val rand = new scala.util.Random()
val pairArr = Array.fill(10) { (rand.nextInt(), rand.nextInt()) }
val pairs = sc.parallelize(pairArr, 30)
@ -76,14 +68,12 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("empty RDD") {
sc = new SparkContext("local", "test")
val pairArr = new Array[(Int, Int)](0)
val pairs = sc.parallelize(pairArr, 2)
assert(pairs.sortByKey().collect() === pairArr.sortBy(_._1))
}
test("partition balancing") {
sc = new SparkContext("local", "test")
val pairArr = (1 to 1000).map(x => (x, x)).toArray
val sorted = sc.parallelize(pairArr, 4).sortByKey()
assert(sorted.collect() === pairArr.sortBy(_._1))
@ -99,7 +89,6 @@ class SortingSuite extends FunSuite with LocalSparkContext with ShouldMatchers w
}
test("partition balancing for descending sort") {
sc = new SparkContext("local", "test")
val pairArr = (1 to 1000).map(x => (x, x)).toArray
val sorted = sc.parallelize(pairArr, 4).sortByKey(false)
assert(sorted.collect() === pairArr.sortBy(_._1).reverse)

View file

@ -0,0 +1,30 @@
package spark
import org.scalatest.FunSuite
import org.scalatest.concurrent.Timeouts._
import org.scalatest.time.{Span, Millis}
import spark.SparkContext._
class UnpersistSuite extends FunSuite with LocalSparkContext {
test("unpersist RDD") {
sc = new SparkContext("local", "test")
val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache()
rdd.count
assert(sc.persistentRdds.isEmpty === false)
rdd.unpersist()
assert(sc.persistentRdds.isEmpty === true)
failAfter(Span(3000, Millis)) {
try {
while (! sc.getRDDStorageInfo.isEmpty) {
Thread.sleep(200)
}
} catch {
case _ => { Thread.sleep(10) }
// Do nothing. We might see exceptions because block manager
// is racing this thread to remove entries from the driver.
}
}
assert(sc.getRDDStorageInfo.isEmpty === true)
}
}

View file

@ -27,24 +27,49 @@ class UtilsSuite extends FunSuite {
assert(os.toByteArray.toList.equals(bytes.toList))
}
test("memoryStringToMb"){
assert(Utils.memoryStringToMb("1") == 0)
assert(Utils.memoryStringToMb("1048575") == 0)
assert(Utils.memoryStringToMb("3145728") == 3)
test("memoryStringToMb") {
assert(Utils.memoryStringToMb("1") === 0)
assert(Utils.memoryStringToMb("1048575") === 0)
assert(Utils.memoryStringToMb("3145728") === 3)
assert(Utils.memoryStringToMb("1024k") == 1)
assert(Utils.memoryStringToMb("5000k") == 4)
assert(Utils.memoryStringToMb("4024k") == Utils.memoryStringToMb("4024K"))
assert(Utils.memoryStringToMb("1024k") === 1)
assert(Utils.memoryStringToMb("5000k") === 4)
assert(Utils.memoryStringToMb("4024k") === Utils.memoryStringToMb("4024K"))
assert(Utils.memoryStringToMb("1024m") == 1024)
assert(Utils.memoryStringToMb("5000m") == 5000)
assert(Utils.memoryStringToMb("4024m") == Utils.memoryStringToMb("4024M"))
assert(Utils.memoryStringToMb("1024m") === 1024)
assert(Utils.memoryStringToMb("5000m") === 5000)
assert(Utils.memoryStringToMb("4024m") === Utils.memoryStringToMb("4024M"))
assert(Utils.memoryStringToMb("2g") == 2048)
assert(Utils.memoryStringToMb("3g") == Utils.memoryStringToMb("3G"))
assert(Utils.memoryStringToMb("2g") === 2048)
assert(Utils.memoryStringToMb("3g") === Utils.memoryStringToMb("3G"))
assert(Utils.memoryStringToMb("2t") == 2097152)
assert(Utils.memoryStringToMb("3t") == Utils.memoryStringToMb("3T"))
assert(Utils.memoryStringToMb("2t") === 2097152)
assert(Utils.memoryStringToMb("3t") === Utils.memoryStringToMb("3T"))
}
test("splitCommandString") {
assert(Utils.splitCommandString("") === Seq())
assert(Utils.splitCommandString("a") === Seq("a"))
assert(Utils.splitCommandString("aaa") === Seq("aaa"))
assert(Utils.splitCommandString("a b c") === Seq("a", "b", "c"))
assert(Utils.splitCommandString(" a b\t c ") === Seq("a", "b", "c"))
assert(Utils.splitCommandString("a 'b c'") === Seq("a", "b c"))
assert(Utils.splitCommandString("a 'b c' d") === Seq("a", "b c", "d"))
assert(Utils.splitCommandString("'b c'") === Seq("b c"))
assert(Utils.splitCommandString("a \"b c\"") === Seq("a", "b c"))
assert(Utils.splitCommandString("a \"b c\" d") === Seq("a", "b c", "d"))
assert(Utils.splitCommandString("\"b c\"") === Seq("b c"))
assert(Utils.splitCommandString("a 'b\" c' \"d' e\"") === Seq("a", "b\" c", "d' e"))
assert(Utils.splitCommandString("a\t'b\nc'\nd") === Seq("a", "b\nc", "d"))
assert(Utils.splitCommandString("a \"b\\\\c\"") === Seq("a", "b\\c"))
assert(Utils.splitCommandString("a \"b\\\"c\"") === Seq("a", "b\"c"))
assert(Utils.splitCommandString("a 'b\\\"c'") === Seq("a", "b\\\"c"))
assert(Utils.splitCommandString("'a'b") === Seq("ab"))
assert(Utils.splitCommandString("'a''b'") === Seq("ab"))
assert(Utils.splitCommandString("\"a\"b") === Seq("ab"))
assert(Utils.splitCommandString("\"a\"\"b\"") === Seq("ab"))
assert(Utils.splitCommandString("''") === Seq(""))
assert(Utils.splitCommandString("\"\"") === Seq(""))
}
}

View file

@ -17,9 +17,8 @@ object ZippedPartitionsSuite {
}
}
class ZippedPartitionsSuite extends FunSuite with LocalSparkContext {
class ZippedPartitionsSuite extends FunSuite with SharedSparkContext {
test("print sizes") {
sc = new SparkContext("local", "test")
val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2)
val data3 = sc.makeRDD(Array(1.0, 2.0), 2)

View file

@ -16,7 +16,7 @@ class DummyTaskSetManager(
initNumTasks: Int,
clusterScheduler: ClusterScheduler,
taskSet: TaskSet)
extends TaskSetManager(clusterScheduler,taskSet) {
extends ClusterTaskSetManager(clusterScheduler,taskSet) {
parent = null
weight = 1

View file

@ -0,0 +1,104 @@
package spark.scheduler
import java.util.Properties
import java.util.concurrent.LinkedBlockingQueue
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import scala.collection.mutable
import spark._
import spark.SparkContext._
class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
test("inner method") {
sc = new SparkContext("local", "joblogger")
val joblogger = new JobLogger {
def createLogWriterTest(jobID: Int) = createLogWriter(jobID)
def closeLogWriterTest(jobID: Int) = closeLogWriter(jobID)
def getRddNameTest(rdd: RDD[_]) = getRddName(rdd)
def buildJobDepTest(jobID: Int, stage: Stage) = buildJobDep(jobID, stage)
}
type MyRDD = RDD[(Int, Int)]
def makeRdd(
numPartitions: Int,
dependencies: List[Dependency[_]]
): MyRDD = {
val maxPartition = numPartitions - 1
return new MyRDD(sc, dependencies) {
override def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] =
throw new RuntimeException("should not be reached")
override def getPartitions = (0 to maxPartition).map(i => new Partition {
override def index = i
}).toArray
}
}
val jobID = 5
val parentRdd = makeRdd(4, Nil)
val shuffleDep = new ShuffleDependency(parentRdd, null)
val rootRdd = makeRdd(4, List(shuffleDep))
val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID)
val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID)
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4))
joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")
joblogger.createLogWriterTest(jobID)
joblogger.getJobIDtoPrintWriter.size should be (1)
joblogger.buildJobDepTest(jobID, rootStage)
joblogger.getJobIDToStages.get(jobID).get.size should be (2)
joblogger.getStageIDToJobID.get(0) should be (Some(jobID))
joblogger.getStageIDToJobID.get(1) should be (Some(jobID))
joblogger.closeLogWriterTest(jobID)
joblogger.getStageIDToJobID.size should be (0)
joblogger.getJobIDToStages.size should be (0)
joblogger.getJobIDtoPrintWriter.size should be (0)
}
test("inner variables") {
sc = new SparkContext("local[4]", "joblogger")
val joblogger = new JobLogger {
override protected def closeLogWriter(jobID: Int) =
getJobIDtoPrintWriter.get(jobID).foreach { fileWriter =>
fileWriter.close()
}
}
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
joblogger.getLogDir should be ("/tmp/spark")
joblogger.getJobIDtoPrintWriter.size should be (1)
joblogger.getStageIDToJobID.size should be (2)
joblogger.getStageIDToJobID.get(0) should be (Some(0))
joblogger.getStageIDToJobID.get(1) should be (Some(0))
joblogger.getJobIDToStages.size should be (1)
}
test("interface functions") {
sc = new SparkContext("local[4]", "joblogger")
val joblogger = new JobLogger {
var onTaskEndCount = 0
var onJobEndCount = 0
var onJobStartCount = 0
var onStageCompletedCount = 0
var onStageSubmittedCount = 0
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = onTaskEndCount += 1
override def onJobEnd(jobEnd: SparkListenerJobEnd) = onJobEndCount += 1
override def onJobStart(jobStart: SparkListenerJobStart) = onJobStartCount += 1
override def onStageCompleted(stageCompleted: StageCompleted) = onStageCompletedCount += 1
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) = onStageSubmittedCount += 1
}
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
joblogger.onJobStartCount should be (1)
joblogger.onJobEndCount should be (1)
joblogger.onTaskEndCount should be (8)
joblogger.onStageSubmittedCount should be (2)
joblogger.onStageCompletedCount should be (2)
}
}

View file

@ -0,0 +1,206 @@
package spark.scheduler
import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import spark._
import spark.scheduler._
import spark.scheduler.cluster._
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ConcurrentMap, HashMap}
import java.util.concurrent.Semaphore
import java.util.concurrent.CountDownLatch
import java.util.Properties
class Lock() {
var finished = false
def jobWait() = {
synchronized {
while(!finished) {
this.wait()
}
}
}
def jobFinished() = {
synchronized {
finished = true
this.notifyAll()
}
}
}
object TaskThreadInfo {
val threadToLock = HashMap[Int, Lock]()
val threadToRunning = HashMap[Int, Boolean]()
val threadToStarted = HashMap[Int, CountDownLatch]()
}
/*
* 1. each thread contains one job.
* 2. each job contains one stage.
* 3. each stage only contains one task.
* 4. each task(launched) must be lanched orderly(using threadToStarted) to make sure
* it will get cpu core resource, and will wait to finished after user manually
* release "Lock" and then cluster will contain another free cpu cores.
* 5. each task(pending) must use "sleep" to make sure it has been added to taskSetManager queue,
* thus it will be scheduled later when cluster has free cpu cores.
*/
class LocalSchedulerSuite extends FunSuite with LocalSparkContext {
def createThread(threadIndex: Int, poolName: String, sc: SparkContext, sem: Semaphore) {
TaskThreadInfo.threadToRunning(threadIndex) = false
val nums = sc.parallelize(threadIndex to threadIndex, 1)
TaskThreadInfo.threadToLock(threadIndex) = new Lock()
TaskThreadInfo.threadToStarted(threadIndex) = new CountDownLatch(1)
new Thread {
if (poolName != null) {
sc.addLocalProperties("spark.scheduler.cluster.fair.pool",poolName)
}
override def run() {
val ans = nums.map(number => {
TaskThreadInfo.threadToRunning(number) = true
TaskThreadInfo.threadToStarted(number).countDown()
TaskThreadInfo.threadToLock(number).jobWait()
TaskThreadInfo.threadToRunning(number) = false
number
}).collect()
assert(ans.toList === List(threadIndex))
sem.release()
}
}.start()
}
test("Local FIFO scheduler end-to-end test") {
System.setProperty("spark.cluster.schedulingmode", "FIFO")
sc = new SparkContext("local[4]", "test")
val sem = new Semaphore(0)
createThread(1,null,sc,sem)
TaskThreadInfo.threadToStarted(1).await()
createThread(2,null,sc,sem)
TaskThreadInfo.threadToStarted(2).await()
createThread(3,null,sc,sem)
TaskThreadInfo.threadToStarted(3).await()
createThread(4,null,sc,sem)
TaskThreadInfo.threadToStarted(4).await()
// thread 5 and 6 (stage pending)must meet following two points
// 1. stages (taskSetManager) of jobs in thread 5 and 6 should be add to taskSetManager
// queue before executing TaskThreadInfo.threadToLock(1).jobFinished()
// 2. priority of stage in thread 5 should be prior to priority of stage in thread 6
// So I just use "sleep" 1s here for each thread.
// TODO: any better solution?
createThread(5,null,sc,sem)
Thread.sleep(1000)
createThread(6,null,sc,sem)
Thread.sleep(1000)
assert(TaskThreadInfo.threadToRunning(1) === true)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === true)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === false)
assert(TaskThreadInfo.threadToRunning(6) === false)
TaskThreadInfo.threadToLock(1).jobFinished()
TaskThreadInfo.threadToStarted(5).await()
assert(TaskThreadInfo.threadToRunning(1) === false)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === true)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === true)
assert(TaskThreadInfo.threadToRunning(6) === false)
TaskThreadInfo.threadToLock(3).jobFinished()
TaskThreadInfo.threadToStarted(6).await()
assert(TaskThreadInfo.threadToRunning(1) === false)
assert(TaskThreadInfo.threadToRunning(2) === true)
assert(TaskThreadInfo.threadToRunning(3) === false)
assert(TaskThreadInfo.threadToRunning(4) === true)
assert(TaskThreadInfo.threadToRunning(5) === true)
assert(TaskThreadInfo.threadToRunning(6) === true)
TaskThreadInfo.threadToLock(2).jobFinished()
TaskThreadInfo.threadToLock(4).jobFinished()
TaskThreadInfo.threadToLock(5).jobFinished()
TaskThreadInfo.threadToLock(6).jobFinished()
sem.acquire(6)
}
test("Local fair scheduler end-to-end test") {
sc = new SparkContext("local[8]", "LocalSchedulerSuite")
val sem = new Semaphore(0)
System.setProperty("spark.cluster.schedulingmode", "FAIR")
val xmlPath = getClass.getClassLoader.getResource("fairscheduler.xml").getFile()
System.setProperty("spark.fairscheduler.allocation.file", xmlPath)
createThread(10,"1",sc,sem)
TaskThreadInfo.threadToStarted(10).await()
createThread(20,"2",sc,sem)
TaskThreadInfo.threadToStarted(20).await()
createThread(30,"3",sc,sem)
TaskThreadInfo.threadToStarted(30).await()
assert(TaskThreadInfo.threadToRunning(10) === true)
assert(TaskThreadInfo.threadToRunning(20) === true)
assert(TaskThreadInfo.threadToRunning(30) === true)
createThread(11,"1",sc,sem)
TaskThreadInfo.threadToStarted(11).await()
createThread(21,"2",sc,sem)
TaskThreadInfo.threadToStarted(21).await()
createThread(31,"3",sc,sem)
TaskThreadInfo.threadToStarted(31).await()
assert(TaskThreadInfo.threadToRunning(11) === true)
assert(TaskThreadInfo.threadToRunning(21) === true)
assert(TaskThreadInfo.threadToRunning(31) === true)
createThread(12,"1",sc,sem)
TaskThreadInfo.threadToStarted(12).await()
createThread(22,"2",sc,sem)
TaskThreadInfo.threadToStarted(22).await()
createThread(32,"3",sc,sem)
assert(TaskThreadInfo.threadToRunning(12) === true)
assert(TaskThreadInfo.threadToRunning(22) === true)
assert(TaskThreadInfo.threadToRunning(32) === false)
TaskThreadInfo.threadToLock(10).jobFinished()
TaskThreadInfo.threadToStarted(32).await()
assert(TaskThreadInfo.threadToRunning(32) === true)
//1. Similar with above scenario, sleep 1s for stage of 23 and 33 to be added to taskSetManager
// queue so that cluster will assign free cpu core to stage 23 after stage 11 finished.
//2. priority of 23 and 33 will be meaningless as using fair scheduler here.
createThread(23,"2",sc,sem)
createThread(33,"3",sc,sem)
Thread.sleep(1000)
TaskThreadInfo.threadToLock(11).jobFinished()
TaskThreadInfo.threadToStarted(23).await()
assert(TaskThreadInfo.threadToRunning(23) === true)
assert(TaskThreadInfo.threadToRunning(33) === false)
TaskThreadInfo.threadToLock(12).jobFinished()
TaskThreadInfo.threadToStarted(33).await()
assert(TaskThreadInfo.threadToRunning(33) === true)
TaskThreadInfo.threadToLock(20).jobFinished()
TaskThreadInfo.threadToLock(21).jobFinished()
TaskThreadInfo.threadToLock(22).jobFinished()
TaskThreadInfo.threadToLock(23).jobFinished()
TaskThreadInfo.threadToLock(30).jobFinished()
TaskThreadInfo.threadToLock(31).jobFinished()
TaskThreadInfo.threadToLock(32).jobFinished()
TaskThreadInfo.threadToLock(33).jobFinished()
sem.acquire(11)
}
}

View file

@ -77,7 +77,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
class SaveStageInfo extends SparkListener {
val stageInfos = mutable.Buffer[StageInfo]()
def onStageCompleted(stage: StageCompleted) {
override def onStageCompleted(stage: StageCompleted) {
stageInfos += stage.stageInfo
}
}

View file

@ -260,6 +260,13 @@ Apart from these, the following properties are also available, and may be useful
applications). Note that any RDD that persists in memory for more than this duration will be cleared as well.
</td>
</tr>
<tr>
<td>spark.streaming.blockInterval</td>
<td>200</td>
<td>
Duration (milliseconds) of how long to batch new objects coming from network receivers.
</td>
</tr>
</table>

View file

@ -27,14 +27,14 @@ Short functions can be passed to RDD methods using Python's [`lambda`](http://ww
{% highlight python %}
logData = sc.textFile(logFile).cache()
errors = logData.filter(lambda s: 'ERROR' in s.split())
errors = logData.filter(lambda line: "ERROR" in line)
{% endhighlight %}
You can also pass functions that are defined using the `def` keyword; this is useful for more complicated functions that cannot be expressed using `lambda`:
{% highlight python %}
def is_error(line):
return 'ERROR' in line.split()
return "ERROR" in line
errors = logData.filter(is_error)
{% endhighlight %}
@ -43,8 +43,7 @@ Functions can access objects in enclosing scopes, although modifications to thos
{% highlight python %}
error_keywords = ["Exception", "Error"]
def is_error(line):
words = line.split()
return any(keyword in words for keyword in error_keywords)
return any(keyword in line for keyword in error_keywords)
errors = logData.filter(is_error)
{% endhighlight %}

View file

@ -43,12 +43,18 @@ new SparkContext(master, appName, [sparkHome], [jars])
The `master` parameter is a string specifying a [Spark or Mesos cluster URL](#master-urls) to connect to, or a special "local" string to run in local mode, as described below. `appName` is a name for your application, which will be shown in the cluster web UI. Finally, the last two parameters are needed to deploy your code to a cluster if running in distributed mode, as described later.
In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable. For example, to run on four cores, use
In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the variable called `sc`. Making your own SparkContext will not work. You can set which master the context connects to using the `MASTER` environment variable, and you can add JARs to the classpath with the `ADD_JARS` variable. For example, to run `spark-shell` on four cores, use
{% highlight bash %}
$ MASTER=local[4] ./spark-shell
{% endhighlight %}
Or, to also add `code.jar` to its classpath, use:
{% highlight bash %}
$ MASTER=local[4] ADD_JARS=code.jar ./spark-shell
{% endhighlight %}
### Master URLs
The master URL passed to Spark can be in one of the following formats:
@ -78,7 +84,7 @@ If you want to run your job on a cluster, you will need to specify the two optio
* `sparkHome`: The path at which Spark is installed on your worker machines (it should be the same on all of them).
* `jars`: A list of JAR files on the local machine containing your job's code and any dependencies, which Spark will deploy to all the worker nodes. You'll need to package your job into a set of JARs using your build system. For example, if you're using SBT, the [sbt-assembly](https://github.com/sbt/sbt-assembly) plugin is a good way to make a single JAR with your code and dependencies.
If you run `spark-shell` on a cluster, any classes you define in the shell will automatically be distributed.
If you run `spark-shell` on a cluster, you can add JARs to it by specifying the `ADD_JARS` environment variable before you launch it. This variable should contain a comma-separated list of JARs. For example, `ADD_JARS=a.jar,b.jar ./spark-shell` will launch a shell with `a.jar` and `b.jar` on its classpath. In addition, any new classes you define in the shell will automatically be distributed.
# Resilient Distributed Datasets (RDDs)

View file

@ -34,6 +34,41 @@
<artifactId>scalacheck_${scala.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.cassandra</groupId>
<artifactId>cassandra-all</artifactId>
<version>1.2.5</version>
<exclusions>
<exclusion>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
</exclusion>
<exclusion>
<groupId>com.googlecode.concurrentlinkedhashmap</groupId>
<artifactId>concurrentlinkedhashmap-lru</artifactId>
</exclusion>
<exclusion>
<groupId>com.ning</groupId>
<artifactId>compress-lzf</artifactId>
</exclusion>
<exclusion>
<groupId>io.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
<exclusion>
<groupId>jline</groupId>
<artifactId>jline</artifactId>
</exclusion>
<exclusion>
<groupId>log4j</groupId>
<artifactId>log4j</artifactId>
</exclusion>
<exclusion>
<groupId>org.apache.cassandra.deps</groupId>
<artifactId>avro</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
@ -67,6 +102,11 @@
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase</artifactId>
<version>0.94.6</version>
</dependency>
</dependencies>
<build>
<plugins>
@ -105,6 +145,11 @@
<artifactId>hadoop-client</artifactId>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.apache.hbase</groupId>
<artifactId>hbase</artifactId>
<version>0.94.6</version>
</dependency>
</dependencies>
<build>
<plugins>

View file

@ -0,0 +1,196 @@
package spark.examples
import org.apache.hadoop.mapreduce.Job
import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat
import org.apache.cassandra.hadoop.ConfigHelper
import org.apache.cassandra.hadoop.ColumnFamilyInputFormat
import org.apache.cassandra.thrift._
import spark.SparkContext
import spark.SparkContext._
import java.nio.ByteBuffer
import java.util.SortedMap
import org.apache.cassandra.db.IColumn
import org.apache.cassandra.utils.ByteBufferUtil
import scala.collection.JavaConversions._
/*
* This example demonstrates using Spark with Cassandra with the New Hadoop API and Cassandra
* support for Hadoop.
*
* To run this example, run this file with the following command params -
* <spark_master> <cassandra_node> <cassandra_port>
*
* So if you want to run this on localhost this will be,
* local[3] localhost 9160
*
* The example makes some assumptions:
* 1. You have already created a keyspace called casDemo and it has a column family named Words
* 2. There are column family has a column named "para" which has test content.
*
* You can create the content by running the following script at the bottom of this file with
* cassandra-cli.
*
*/
object CassandraTest {
def main(args: Array[String]) {
// Get a SparkContext
val sc = new SparkContext(args(0), "casDemo")
// Build the job configuration with ConfigHelper provided by Cassandra
val job = new Job()
job.setInputFormatClass(classOf[ColumnFamilyInputFormat])
val host: String = args(1)
val port: String = args(2)
ConfigHelper.setInputInitialAddress(job.getConfiguration(), host)
ConfigHelper.setInputRpcPort(job.getConfiguration(), port)
ConfigHelper.setOutputInitialAddress(job.getConfiguration(), host)
ConfigHelper.setOutputRpcPort(job.getConfiguration(), port)
ConfigHelper.setInputColumnFamily(job.getConfiguration(), "casDemo", "Words")
ConfigHelper.setOutputColumnFamily(job.getConfiguration(), "casDemo", "WordCount")
val predicate = new SlicePredicate()
val sliceRange = new SliceRange()
sliceRange.setStart(Array.empty[Byte])
sliceRange.setFinish(Array.empty[Byte])
predicate.setSlice_range(sliceRange)
ConfigHelper.setInputSlicePredicate(job.getConfiguration(), predicate)
ConfigHelper.setInputPartitioner(job.getConfiguration(), "Murmur3Partitioner")
ConfigHelper.setOutputPartitioner(job.getConfiguration(), "Murmur3Partitioner")
// Make a new Hadoop RDD
val casRdd = sc.newAPIHadoopRDD(
job.getConfiguration(),
classOf[ColumnFamilyInputFormat],
classOf[ByteBuffer],
classOf[SortedMap[ByteBuffer, IColumn]])
// Let us first get all the paragraphs from the retrieved rows
val paraRdd = casRdd.map {
case (key, value) => {
ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value())
}
}
// Lets get the word count in paras
val counts = paraRdd.flatMap(p => p.split(" ")).map(word => (word, 1)).reduceByKey(_ + _)
counts.collect().foreach {
case (word, count) => println(word + ":" + count)
}
counts.map {
case (word, count) => {
val colWord = new org.apache.cassandra.thrift.Column()
colWord.setName(ByteBufferUtil.bytes("word"))
colWord.setValue(ByteBufferUtil.bytes(word))
colWord.setTimestamp(System.currentTimeMillis)
val colCount = new org.apache.cassandra.thrift.Column()
colCount.setName(ByteBufferUtil.bytes("wcount"))
colCount.setValue(ByteBufferUtil.bytes(count.toLong))
colCount.setTimestamp(System.currentTimeMillis)
val outputkey = ByteBufferUtil.bytes(word + "-COUNT-" + System.currentTimeMillis)
val mutations: java.util.List[Mutation] = new Mutation() :: new Mutation() :: Nil
mutations.get(0).setColumn_or_supercolumn(new ColumnOrSuperColumn())
mutations.get(0).column_or_supercolumn.setColumn(colWord)
mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn())
mutations.get(1).column_or_supercolumn.setColumn(colCount)
(outputkey, mutations)
}
}.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]],
classOf[ColumnFamilyOutputFormat], job.getConfiguration)
}
}
/*
create keyspace casDemo;
use casDemo;
create column family WordCount with comparator = UTF8Type;
update column family WordCount with column_metadata =
[{column_name: word, validation_class: UTF8Type},
{column_name: wcount, validation_class: LongType}];
create column family Words with comparator = UTF8Type;
update column family Words with column_metadata =
[{column_name: book, validation_class: UTF8Type},
{column_name: para, validation_class: UTF8Type}];
assume Words keys as utf8;
set Words['3musk001']['book'] = 'The Three Musketeers';
set Words['3musk001']['para'] = 'On the first Monday of the month of April, 1625, the market
town of Meung, in which the author of ROMANCE OF THE ROSE was born, appeared to
be in as perfect a state of revolution as if the Huguenots had just made
a second La Rochelle of it. Many citizens, seeing the women flying
toward the High Street, leaving their children crying at the open doors,
hastened to don the cuirass, and supporting their somewhat uncertain
courage with a musket or a partisan, directed their steps toward the
hostelry of the Jolly Miller, before which was gathered, increasing
every minute, a compact group, vociferous and full of curiosity.';
set Words['3musk002']['book'] = 'The Three Musketeers';
set Words['3musk002']['para'] = 'In those times panics were common, and few days passed without
some city or other registering in its archives an event of this kind. There were
nobles, who made war against each other; there was the king, who made
war against the cardinal; there was Spain, which made war against the
king. Then, in addition to these concealed or public, secret or open
wars, there were robbers, mendicants, Huguenots, wolves, and scoundrels,
who made war upon everybody. The citizens always took up arms readily
against thieves, wolves or scoundrels, often against nobles or
Huguenots, sometimes against the king, but never against cardinal or
Spain. It resulted, then, from this habit that on the said first Monday
of April, 1625, the citizens, on hearing the clamor, and seeing neither
the red-and-yellow standard nor the livery of the Duc de Richelieu,
rushed toward the hostel of the Jolly Miller. When arrived there, the
cause of the hubbub was apparent to all';
set Words['3musk003']['book'] = 'The Three Musketeers';
set Words['3musk003']['para'] = 'You ought, I say, then, to husband the means you have, however
large the sum may be; but you ought also to endeavor to perfect yourself in
the exercises becoming a gentleman. I will write a letter today to the
Director of the Royal Academy, and tomorrow he will admit you without
any expense to yourself. Do not refuse this little service. Our
best-born and richest gentlemen sometimes solicit it without being able
to obtain it. You will learn horsemanship, swordsmanship in all its
branches, and dancing. You will make some desirable acquaintances; and
from time to time you can call upon me, just to tell me how you are
getting on, and to say whether I can be of further service to you.';
set Words['thelostworld001']['book'] = 'The Lost World';
set Words['thelostworld001']['para'] = 'She sat with that proud, delicate profile of hers outlined
against the red curtain. How beautiful she was! And yet how aloof! We had been
friends, quite good friends; but never could I get beyond the same
comradeship which I might have established with one of my
fellow-reporters upon the Gazette,--perfectly frank, perfectly kindly,
and perfectly unsexual. My instincts are all against a woman being too
frank and at her ease with me. It is no compliment to a man. Where
the real sex feeling begins, timidity and distrust are its companions,
heritage from old wicked days when love and violence went often hand in
hand. The bent head, the averted eye, the faltering voice, the wincing
figure--these, and not the unshrinking gaze and frank reply, are the
true signals of passion. Even in my short life I had learned as much
as that--or had inherited it in that race memory which we call instinct.';
set Words['thelostworld002']['book'] = 'The Lost World';
set Words['thelostworld002']['para'] = 'I always liked McArdle, the crabbed, old, round-backed,
red-headed news editor, and I rather hoped that he liked me. Of course, Beaumont was
the real boss; but he lived in the rarefied atmosphere of some Olympian
height from which he could distinguish nothing smaller than an
international crisis or a split in the Cabinet. Sometimes we saw him
passing in lonely majesty to his inner sanctum, with his eyes staring
vaguely and his mind hovering over the Balkans or the Persian Gulf. He
was above and beyond us. But McArdle was his first lieutenant, and it
was he that we knew. The old man nodded as I entered the room, and he
pushed his spectacles far up on his bald forehead.';
*/

View file

@ -0,0 +1,35 @@
package spark.examples
import spark._
import spark.rdd.NewHadoopRDD
import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor}
import org.apache.hadoop.hbase.client.HBaseAdmin
import org.apache.hadoop.hbase.mapreduce.TableInputFormat
object HBaseTest {
def main(args: Array[String]) {
val sc = new SparkContext(args(0), "HBaseTest",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
val conf = HBaseConfiguration.create()
// Other options for configuring scan behavior are available. More information available at
// http://hbase.apache.org/apidocs/org/apache/hadoop/hbase/mapreduce/TableInputFormat.html
conf.set(TableInputFormat.INPUT_TABLE, args(1))
// Initialize hBase table if necessary
val admin = new HBaseAdmin(conf)
if(!admin.isTableAvailable(args(1))) {
val tableDesc = new HTableDescriptor(args(1))
admin.createTable(tableDesc)
}
val hBaseRDD = sc.newAPIHadoopRDD(conf, classOf[TableInputFormat],
classOf[org.apache.hadoop.hbase.io.ImmutableBytesWritable],
classOf[org.apache.hadoop.hbase.client.Result])
hBaseRDD.count()
System.exit(0)
}
}

View file

@ -37,7 +37,7 @@ object KafkaWordCount {
ssc.checkpoint("checkpoint")
val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
val lines = ssc.kafkaStream[String](zkQuorum, group, topicpMap)
val lines = ssc.kafkaStream(zkQuorum, group, topicpMap)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
wordCounts.print()

14
pom.xml
View file

@ -61,7 +61,7 @@
<cdh.version>4.1.2</cdh.version>
<log4j.version>1.2.17</log4j.version>
<PermGen>0m</PermGen>
<PermGen>64m</PermGen>
<MaxPermGen>512m</MaxPermGen>
</properties>
@ -191,9 +191,9 @@
<version>0.8.4</version>
</dependency>
<dependency>
<groupId>asm</groupId>
<artifactId>asm-all</artifactId>
<version>3.3.1</version>
<groupId>org.ow2.asm</groupId>
<artifactId>asm</artifactId>
<version>4.0</version>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
@ -396,10 +396,8 @@
<jvmArgs>
<jvmArg>-Xms64m</jvmArg>
<jvmArg>-Xmx1024m</jvmArg>
<jvmArg>-XX:PermSize</jvmArg>
<jvmArg>${PermGen}</jvmArg>
<jvmArg>-XX:MaxPermSize</jvmArg>
<jvmArg>${MaxPermGen}</jvmArg>
<jvmArg>-XX:PermSize=${PermGen}</jvmArg>
<jvmArg>-XX:MaxPermSize=${MaxPermGen}</jvmArg>
</jvmArgs>
<javacArgs>
<javacArg>-source</javacArg>

View file

@ -56,7 +56,7 @@ object SparkBuild extends Build {
// Fork new JVMs for tests and set Java options for those
fork := true,
javaOptions += "-Xmx2g",
javaOptions += "-Xmx2500m",
// Only allow one test at a time, even across projects, since they run in the same JVM
concurrentRestrictions in Global += Tags.limit(Tags.Test, 1),
@ -127,12 +127,13 @@ object SparkBuild extends Build {
publishMavenStyle in MavenCompile := true,
publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal),
publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn
)
) ++ net.virtualvoid.sbt.graph.Plugin.graphSettings
val slf4jVersion = "1.6.1"
val slf4jVersion = "1.7.2"
val excludeJackson = ExclusionRule(organization = "org.codehaus.jackson")
val excludeNetty = ExclusionRule(organization = "org.jboss.netty")
val excludeAsm = ExclusionRule(organization = "asm")
def coreSettings = sharedSettings ++ Seq(
name := "spark-core",
@ -150,7 +151,7 @@ object SparkBuild extends Build {
"org.slf4j" % "slf4j-log4j12" % slf4jVersion,
"commons-daemon" % "commons-daemon" % "1.0.10",
"com.ning" % "compress-lzf" % "0.8.4",
"asm" % "asm-all" % "3.3.1",
"org.ow2.asm" % "asm" % "4.0",
"com.google.protobuf" % "protobuf-java" % "2.4.1",
"de.javakaffee" % "kryo-serializers" % "0.22",
"com.typesafe.akka" % "akka-actor" % "2.0.3" excludeAll(excludeNetty),
@ -203,7 +204,20 @@ object SparkBuild extends Build {
def examplesSettings = sharedSettings ++ Seq(
name := "spark-examples",
libraryDependencies ++= Seq("com.twitter" % "algebird-core_2.9.2" % "0.1.11")
libraryDependencies ++= Seq(
"com.twitter" % "algebird-core_2.9.2" % "0.1.11",
"org.apache.hbase" % "hbase" % "0.94.6" excludeAll(excludeNetty, excludeAsm),
"org.apache.cassandra" % "cassandra-all" % "1.2.5"
exclude("com.google.guava", "guava")
exclude("com.googlecode.concurrentlinkedhashmap", "concurrentlinkedhashmap-lru")
exclude("com.ning","compress-lzf")
exclude("io.netty", "netty")
exclude("jline","jline")
exclude("log4j","log4j")
exclude("org.apache.cassandra.deps", "avro")
)
)
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
@ -212,9 +226,12 @@ object SparkBuild extends Build {
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
resolvers ++= Seq(
"Akka Repository" at "http://repo.akka.io/releases/"
),
libraryDependencies ++= Seq(
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty),
"com.github.sgroschupf" % "zkclient" % "0.1",
"com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty),
"org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
"com.typesafe.akka" % "akka-zeromq" % "2.0.3" excludeAll(excludeNetty)
)
@ -223,7 +240,7 @@ object SparkBuild extends Build {
def extraAssemblySettings() = Seq(test in assembly := {}) ++ Seq(
mergeStrategy in assembly := {
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
case m if m.toLowerCase.matches("meta-inf/.*\\.sf$") => MergeStrategy.discard
case m if m.toLowerCase.matches("meta-inf.*\\.sf$") => MergeStrategy.discard
case "reference.conf" => MergeStrategy.concat
case _ => MergeStrategy.first
}

View file

@ -16,3 +16,5 @@ addSbtPlugin("io.spray" %% "sbt-twirl" % "0.6.1")
//resolvers += Resolver.url("sbt-plugin-releases", new URL("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases/"))(Resolver.ivyStylePatterns)
//addSbtPlugin("com.jsuereth" % "xsbt-gpg-plugin" % "0.6")
addSbtPlugin("net.virtual-void" % "sbt-dependency-graph" % "0.7.3")

158
python/pyspark/daemon.py Normal file
View file

@ -0,0 +1,158 @@
import os
import sys
import multiprocessing
from ctypes import c_bool
from errno import EINTR, ECHILD
from socket import socket, AF_INET, SOCK_STREAM, SOMAXCONN
from signal import signal, SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main
from pyspark.serializers import write_int
try:
POOLSIZE = multiprocessing.cpu_count()
except NotImplementedError:
POOLSIZE = 4
exit_flag = multiprocessing.Value(c_bool, False)
def should_exit():
global exit_flag
return exit_flag.value
def compute_real_exit_code(exit_code):
# SystemExit's code can be integer or string, but os._exit only accepts integers
import numbers
if isinstance(exit_code, numbers.Integral):
return exit_code
else:
return 1
def worker(listen_sock):
# Redirect stdout to stderr
os.dup2(2, 1)
# Manager sends SIGHUP to request termination of workers in the pool
def handle_sighup(*args):
assert should_exit()
signal(SIGHUP, handle_sighup)
# Cleanup zombie children
def handle_sigchld(*args):
pid = status = None
try:
while (pid, status) != (0, 0):
pid, status = os.waitpid(0, os.WNOHANG)
except EnvironmentError as err:
if err.errno == EINTR:
# retry
handle_sigchld()
elif err.errno != ECHILD:
raise
signal(SIGCHLD, handle_sigchld)
# Handle clients
while not should_exit():
# Wait until a client arrives or we have to exit
sock = None
while not should_exit() and sock is None:
try:
sock, addr = listen_sock.accept()
except EnvironmentError as err:
if err.errno != EINTR:
raise
if sock is not None:
# Fork a child to handle the client.
# The client is handled in the child so that the manager
# never receives SIGCHLD unless a worker crashes.
if os.fork() == 0:
# Leave the worker pool
signal(SIGHUP, SIG_DFL)
listen_sock.close()
# Handle the client then exit
sockfile = sock.makefile()
exit_code = 0
try:
worker_main(sockfile, sockfile)
except SystemExit as exc:
exit_code = exc.code
finally:
sockfile.close()
sock.close()
os._exit(compute_real_exit_code(exit_code))
else:
sock.close()
def launch_worker(listen_sock):
if os.fork() == 0:
try:
worker(listen_sock)
except Exception as err:
import traceback
traceback.print_exc()
os._exit(1)
else:
assert should_exit()
os._exit(0)
def manager():
# Create a new process group to corral our children
os.setpgid(0, 0)
# Create a listening socket on the AF_INET loopback interface
listen_sock = socket(AF_INET, SOCK_STREAM)
listen_sock.bind(('127.0.0.1', 0))
listen_sock.listen(max(1024, 2 * POOLSIZE, SOMAXCONN))
listen_host, listen_port = listen_sock.getsockname()
write_int(listen_port, sys.stdout)
# Launch initial worker pool
for idx in range(POOLSIZE):
launch_worker(listen_sock)
listen_sock.close()
def shutdown():
global exit_flag
exit_flag.value = True
# Gracefully exit on SIGTERM, don't die on SIGHUP
signal(SIGTERM, lambda signum, frame: shutdown())
signal(SIGHUP, SIG_IGN)
# Cleanup zombie children
def handle_sigchld(*args):
try:
pid, status = os.waitpid(0, os.WNOHANG)
if status != 0 and not should_exit():
raise RuntimeError("worker crashed: %s, %s" % (pid, status))
except EnvironmentError as err:
if err.errno not in (ECHILD, EINTR):
raise
signal(SIGCHLD, handle_sigchld)
# Initialization complete
sys.stdout.close()
try:
while not should_exit():
try:
# Spark tells us to exit by closing stdin
if os.read(0, 512) == '':
shutdown()
except EnvironmentError as err:
if err.errno != EINTR:
shutdown()
raise
finally:
signal(SIGTERM, SIG_DFL)
exit_flag.value = True
# Send SIGHUP to notify workers of shutdown
os.kill(0, SIGHUP)
if __name__ == '__main__':
manager()

View file

@ -46,6 +46,10 @@ def read_long(stream):
return struct.unpack("!q", length)[0]
def write_long(value, stream):
stream.write(struct.pack("!q", value))
def read_int(stream):
length = stream.read(4)
if length == "":

View file

@ -12,6 +12,7 @@ import unittest
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.java_gateway import SPARK_HOME
from pyspark.serializers import read_int
class PySparkTestCase(unittest.TestCase):
@ -117,5 +118,47 @@ class TestIO(PySparkTestCase):
self.sc.parallelize([1]).foreach(func)
class TestDaemon(unittest.TestCase):
def connect(self, port):
from socket import socket, AF_INET, SOCK_STREAM
sock = socket(AF_INET, SOCK_STREAM)
sock.connect(('127.0.0.1', port))
# send a split index of -1 to shutdown the worker
sock.send("\xFF\xFF\xFF\xFF")
sock.close()
return True
def do_termination_test(self, terminator):
from subprocess import Popen, PIPE
from errno import ECONNREFUSED
# start daemon
daemon_path = os.path.join(os.path.dirname(__file__), "daemon.py")
daemon = Popen([sys.executable, daemon_path], stdin=PIPE, stdout=PIPE)
# read the port number
port = read_int(daemon.stdout)
# daemon should accept connections
self.assertTrue(self.connect(port))
# request shutdown
terminator(daemon)
time.sleep(1)
# daemon should no longer accept connections
with self.assertRaises(EnvironmentError) as trap:
self.connect(port)
self.assertEqual(trap.exception.errno, ECONNREFUSED)
def test_termination_stdin(self):
"""Ensure that daemon and workers terminate when stdin is closed."""
self.do_termination_test(lambda daemon: daemon.stdin.close())
def test_termination_sigterm(self):
"""Ensure that daemon and workers terminate on SIGTERM."""
from signal import SIGTERM
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
if __name__ == "__main__":
unittest.main()

View file

@ -3,6 +3,7 @@ Worker that receives input from Piped RDD.
"""
import os
import sys
import time
import traceback
from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
@ -12,48 +13,60 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, read_with_length, write_int, \
read_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
# Redirect stdout to stderr so that users must return values from functions.
old_stdout = os.fdopen(os.dup(1), 'w')
os.dup2(2, 1)
def load_obj(infile):
return load_pickle(standard_b64decode(infile.readline().strip()))
def load_obj():
return load_pickle(standard_b64decode(sys.stdin.readline().strip()))
def report_times(outfile, boot, init, finish):
write_int(-3, outfile)
write_long(1000 * boot, outfile)
write_long(1000 * init, outfile)
write_long(1000 * finish, outfile)
def main():
split_index = read_int(sys.stdin)
spark_files_dir = load_pickle(read_with_length(sys.stdin))
def main(infile, outfile):
boot_time = time.time()
split_index = read_int(infile)
if split_index == -1: # for unit tests
return
spark_files_dir = load_pickle(read_with_length(infile))
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
sys.path.append(spark_files_dir)
num_broadcast_variables = read_int(sys.stdin)
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(sys.stdin)
value = read_with_length(sys.stdin)
bid = read_long(infile)
value = read_with_length(infile)
_broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
func = load_obj()
bypassSerializer = load_obj()
func = load_obj(infile)
bypassSerializer = load_obj(infile)
if bypassSerializer:
dumps = lambda x: x
else:
dumps = dump_pickle
iterator = read_from_pickle_file(sys.stdin)
init_time = time.time()
iterator = read_from_pickle_file(infile)
try:
for obj in func(split_index, iterator):
write_with_length(dumps(obj), old_stdout)
write_with_length(dumps(obj), outfile)
except Exception as e:
write_int(-2, old_stdout)
write_with_length(traceback.format_exc(), old_stdout)
write_int(-2, outfile)
write_with_length(traceback.format_exc(), outfile)
sys.exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
# Mark the beginning of the accumulators section of the output
write_int(-1, old_stdout)
write_int(-1, outfile)
for aid, accum in _accumulatorRegistry.items():
write_with_length(dump_pickle((aid, accum._value)), old_stdout)
write_with_length(dump_pickle((aid, accum._value)), outfile)
write_int(-1, outfile)
if __name__ == '__main__':
main()
# Redirect stdout to stderr so that users must return values from functions.
old_stdout = os.fdopen(os.dup(1), 'w')
os.dup2(2, 1)
main(sys.stdin, old_stdout)

View file

@ -8,7 +8,6 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.objectweb.asm._
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._
@ -83,7 +82,7 @@ extends ClassLoader(parent) {
}
class ConstructorCleaner(className: String, cv: ClassVisitor)
extends ClassAdapter(cv) {
extends ClassVisitor(ASM4, cv) {
override def visitMethod(access: Int, name: String, desc: String,
sig: String, exceptions: Array[String]): MethodVisitor = {
val mv = cv.visitMethod(access, name, desc, sig, exceptions)

View file

@ -822,7 +822,7 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
spark.repl.Main.interp.out.println("Spark context available as sc.");
spark.repl.Main.interp.out.flush();
""")
command("import spark.SparkContext._");
command("import spark.SparkContext._")
}
echo("Type in expressions to have them evaluated.")
echo("Type :help for more information.")
@ -838,7 +838,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
if (prop != null) prop else "local"
}
}
sparkContext = new SparkContext(master, "Spark shell")
val jars = Option(System.getenv("ADD_JARS")).map(_.split(','))
.getOrElse(new Array[String](0))
.map(new java.io.File(_).getAbsolutePath)
sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
sparkContext
}
@ -850,6 +853,10 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
printWelcome()
echo("Initializing interpreter...")
// Add JARS specified in Spark's ADD_JARS variable to classpath
val jars = Option(System.getenv("ADD_JARS")).map(_.split(',')).getOrElse(new Array[String](0))
jars.foreach(settings.classpath.append(_))
this.settings = settings
createInterpreter()

View file

@ -28,24 +28,25 @@ class ReplSuite extends FunSuite {
val separator = System.getProperty("path.separator")
interp.process(Array("-classpath", paths.mkString(separator)))
spark.repl.Main.interp = null
if (interp.sparkContext != null)
if (interp.sparkContext != null) {
interp.sparkContext.stop()
}
// To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
System.clearProperty("spark.driver.port")
System.clearProperty("spark.hostPort")
return out.toString
}
def assertContains(message: String, output: String) {
assert(output contains message,
assert(output.contains(message),
"Interpreter output did not contain '" + message + "':\n" + output)
}
def assertDoesNotContain(message: String, output: String) {
assert(!(output contains message),
assert(!output.contains(message),
"Interpreter output contained '" + message + "':\n" + output)
}
test ("simple foreach with accumulator") {
val output = runInterpreter("local", """
val accum = sc.accumulator(0)
@ -56,7 +57,7 @@ class ReplSuite extends FunSuite {
assertDoesNotContain("Exception", output)
assertContains("res1: Int = 55", output)
}
test ("external vars") {
val output = runInterpreter("local", """
var v = 7
@ -105,7 +106,7 @@ class ReplSuite extends FunSuite {
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
}
test ("broadcast vars") {
// Test that the value that a broadcast var had when it was created is used,
// even if that variable is then modified in the driver program
@ -143,6 +144,27 @@ class ReplSuite extends FunSuite {
assertContains("res2: Long = 3", output)
}
test ("local-cluster mode") {
val output = runInterpreter("local-cluster[1,1,512]", """
var v = 7
def getV() = v
sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
v = 10
sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
var array = new Array[Int](5)
val broadcastArray = sc.broadcast(array)
sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
array(0) = 5
sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
""")
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertContains("res0: Int = 70", output)
assertContains("res1: Int = 100", output)
assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
}
if (System.getenv("MESOS_NATIVE_LIBRARY") != null) {
test ("running on Mesos") {
val output = runInterpreter("localquiet", """

100
run
View file

@ -23,29 +23,38 @@ fi
if [ "$1" = "spark.deploy.master.Master" -o "$1" = "spark.deploy.worker.Worker" ]; then
SPARK_MEM=${SPARK_DAEMON_MEMORY:-512m}
SPARK_DAEMON_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS -Dspark.akka.logLifecycleEvents=true"
SPARK_JAVA_OPTS=$SPARK_DAEMON_JAVA_OPTS # Empty by default
# Do not overwrite SPARK_JAVA_OPTS environment variable in this script
OUR_JAVA_OPTS="$SPARK_DAEMON_JAVA_OPTS" # Empty by default
else
OUR_JAVA_OPTS="$SPARK_JAVA_OPTS"
fi
# Add java opts for master, worker, executor. The opts maybe null
case "$1" in
'spark.deploy.master.Master')
SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_MASTER_OPTS"
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_MASTER_OPTS"
;;
'spark.deploy.worker.Worker')
SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_WORKER_OPTS"
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_WORKER_OPTS"
;;
'spark.executor.StandaloneExecutorBackend')
SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
;;
'spark.executor.MesosExecutorBackend')
SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
;;
'spark.repl.Main')
SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_REPL_OPTS"
OUR_JAVA_OPTS="$OUR_JAVA_OPTS $SPARK_REPL_OPTS"
;;
esac
# Figure out whether to run our class with java or with the scala launcher.
# In most cases, we'd prefer to execute our process with java because scala
# creates a shell script as the parent of its Java process, which makes it
# hard to kill the child with stuff like Process.destroy(). However, for
# the Spark shell, the wrapper is necessary to properly reset the terminal
# when we exit, so we allow it to set a variable to launch with scala.
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
if [ "$SCALA_HOME" ]; then
RUNNER="${SCALA_HOME}/bin/scala"
@ -58,14 +67,15 @@ if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
fi
fi
else
if [ `command -v java` ]; then
RUNNER="java"
if [ -n "${JAVA_HOME}" ]; then
RUNNER="${JAVA_HOME}/bin/java"
else
if [ -z "$JAVA_HOME" ]; then
if [ `command -v java` ]; then
RUNNER="java"
else
echo "JAVA_HOME is not set" >&2
exit 1
fi
RUNNER="${JAVA_HOME}/bin/java"
fi
if [ -z "$SCALA_LIBRARY_PATH" ]; then
if [ -z "$SCALA_HOME" ]; then
@ -84,7 +94,7 @@ fi
export SPARK_MEM
# Set JAVA_OPTS to be able to load native libraries and to set heap size
JAVA_OPTS="$SPARK_JAVA_OPTS"
JAVA_OPTS="$OUR_JAVA_OPTS"
JAVA_OPTS="$JAVA_OPTS -Djava.library.path=$SPARK_LIBRARY_PATH"
JAVA_OPTS="$JAVA_OPTS -Xms$SPARK_MEM -Xmx$SPARK_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
@ -92,15 +102,11 @@ if [ -e $FWDIR/conf/java-opts ] ; then
JAVA_OPTS="$JAVA_OPTS `cat $FWDIR/conf/java-opts`"
fi
export JAVA_OPTS
# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala!
CORE_DIR="$FWDIR/core"
REPL_DIR="$FWDIR/repl"
REPL_BIN_DIR="$FWDIR/repl-bin"
EXAMPLES_DIR="$FWDIR/examples"
BAGEL_DIR="$FWDIR/bagel"
GRAPH_DIR="$FWDIR/graph"
STREAMING_DIR="$FWDIR/streaming"
PYSPARK_DIR="$FWDIR/python"
REPL_DIR="$FWDIR/repl"
# Exit if the user hasn't compiled Spark
if [ ! -e "$CORE_DIR/target" ]; then
@ -115,36 +121,9 @@ if [[ "$@" = *repl* && ! -e "$REPL_DIR/target" ]]; then
exit 1
fi
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH"
CLASSPATH="$CLASSPATH:$FWDIR/conf"
CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/classes"
if [ -n "$SPARK_TESTING" ] ; then
CLASSPATH="$CLASSPATH:$CORE_DIR/target/scala-$SCALA_VERSION/test-classes"
CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/test-classes"
fi
CLASSPATH="$CLASSPATH:$CORE_DIR/src/main/resources"
CLASSPATH="$CLASSPATH:$REPL_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$EXAMPLES_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$STREAMING_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$STREAMING_DIR/lib/org/apache/kafka/kafka/0.7.2-spark/*" # <-- our in-project Kafka Jar
if [ -e "$FWDIR/lib_managed" ]; then
CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/jars/*"
CLASSPATH="$CLASSPATH:$FWDIR/lib_managed/bundles/*"
fi
CLASSPATH="$CLASSPATH:$REPL_DIR/lib/*"
if [ -e $REPL_BIN_DIR/target ]; then
for jar in `find "$REPL_BIN_DIR/target" -name 'spark-repl-*-shaded-hadoop*.jar'`; do
CLASSPATH="$CLASSPATH:$jar"
done
fi
CLASSPATH="$CLASSPATH:$BAGEL_DIR/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$GRAPH_DIR/target/scala-$SCALA_VERSION/classes"
for jar in `find $PYSPARK_DIR/lib -name '*jar'`; do
CLASSPATH="$CLASSPATH:$jar"
done
# Compute classpath using external script
CLASSPATH=`$FWDIR/bin/compute-classpath.sh`
export CLASSPATH
# Figure out the JAR file that our examples were packaged into. This includes a bit of a hack
# to avoid the -sources and -doc packages that are built by publish-local.
@ -152,37 +131,16 @@ if [ -e "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar ];
# Use the JAR from the SBT build
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/scala-$SCALA_VERSION/spark-examples"*[0-9T].jar`
fi
if [ -e "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar ]; then
if [ -e "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar ]; then
# Use the JAR from the Maven build
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples-"*hadoop[12].jar`
export SPARK_EXAMPLES_JAR=`ls "$EXAMPLES_DIR/target/spark-examples"*[0-9T].jar`
fi
# Add hadoop conf dir - else FileSystem.*, etc fail !
# Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
# the configurtion files.
if [ "x" != "x$HADOOP_CONF_DIR" ]; then
CLASSPATH="$CLASSPATH:$HADOOP_CONF_DIR"
fi
if [ "x" != "x$YARN_CONF_DIR" ]; then
CLASSPATH="$CLASSPATH:$YARN_CONF_DIR"
fi
# Figure out whether to run our class with java or with the scala launcher.
# In most cases, we'd prefer to execute our process with java because scala
# creates a shell script as the parent of its Java process, which makes it
# hard to kill the child with stuff like Process.destroy(). However, for
# the Spark shell, the wrapper is necessary to properly reset the terminal
# when we exit, so we allow it to set a variable to launch with scala.
if [ "$SPARK_LAUNCH_WITH_SCALA" == "1" ]; then
EXTRA_ARGS="" # Java options will be passed to scala as JAVA_OPTS
else
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-library.jar"
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/scala-compiler.jar"
CLASSPATH="$CLASSPATH:$SCALA_LIBRARY_PATH/jline.jar"
# The JVM doesn't read JAVA_OPTS by default so we need to pass it in
EXTRA_ARGS="$JAVA_OPTS"
fi
export CLASSPATH # Needed for spark-shell
exec "$RUNNER" -cp "$CLASSPATH" $EXTRA_ARGS "$@"
exec "$RUNNER" -cp "$CLASSPATH" $EXTRA_ARGS "$@"

View file

@ -23,7 +23,9 @@ if "%1"=="spark.deploy.worker.Worker" set RUNNING_DAEMON=1
if "x%SPARK_DAEMON_MEMORY%" == "x" set SPARK_DAEMON_MEMORY=512m
set SPARK_DAEMON_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS% -Dspark.akka.logLifecycleEvents=true
if "%RUNNING_DAEMON%"=="1" set SPARK_MEM=%SPARK_DAEMON_MEMORY%
if "%RUNNING_DAEMON%"=="1" set SPARK_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
rem Do not overwrite SPARK_JAVA_OPTS environment variable in this script
if "%RUNNING_DAEMON%"=="0" set OUR_JAVA_OPTS=%SPARK_JAVA_OPTS%
if "%RUNNING_DAEMON%"=="1" set OUR_JAVA_OPTS=%SPARK_DAEMON_JAVA_OPTS%
rem Check that SCALA_HOME has been specified
if not "x%SCALA_HOME%"=="x" goto scala_exists
@ -31,52 +33,22 @@ if not "x%SCALA_HOME%"=="x" goto scala_exists
goto exit
:scala_exists
rem If the user specifies a Mesos JAR, put it before our included one on the classpath
set MESOS_CLASSPATH=
if not "x%MESOS_JAR%"=="x" set MESOS_CLASSPATH=%MESOS_JAR%
rem Figure out how much memory to use per executor and set it as an environment
rem variable so that our process sees it and can report it to Mesos
if "x%SPARK_MEM%"=="x" set SPARK_MEM=512m
rem Set JAVA_OPTS to be able to load native libraries and to set heap size
set JAVA_OPTS=%SPARK_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM%
rem Load extra JAVA_OPTS from conf/java-opts, if it exists
if exist "%FWDIR%conf\java-opts.cmd" call "%FWDIR%conf\java-opts.cmd"
set JAVA_OPTS=%OUR_JAVA_OPTS% -Djava.library.path=%SPARK_LIBRARY_PATH% -Xms%SPARK_MEM% -Xmx%SPARK_MEM%
rem Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in ExecutorRunner.scala!
set CORE_DIR=%FWDIR%core
set REPL_DIR=%FWDIR%repl
set EXAMPLES_DIR=%FWDIR%examples
set BAGEL_DIR=%FWDIR%bagel
set GRAPH_DIR=%FWDIR%graph
set STREAMING_DIR=%FWDIR%streaming
set PYSPARK_DIR=%FWDIR%python
rem Build up classpath
set CLASSPATH=%SPARK_CLASSPATH%;%MESOS_CLASSPATH%;%FWDIR%conf;%CORE_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%CORE_DIR%\target\scala-%SCALA_VERSION%\test-classes;%CORE_DIR%\src\main\resources
set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\classes;%STREAMING_DIR%\target\scala-%SCALA_VERSION%\test-classes
set CLASSPATH=%CLASSPATH%;%STREAMING_DIR%\lib\org\apache\kafka\kafka\0.7.2-spark\*
set CLASSPATH=%CLASSPATH%;%REPL_DIR%\target\scala-%SCALA_VERSION%\classes;%EXAMPLES_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\jars\*
set CLASSPATH=%CLASSPATH%;%FWDIR%lib_managed\bundles\*
set CLASSPATH=%CLASSPATH%;%FWDIR%repl\lib\*
set CLASSPATH=%CLASSPATH%;%FWDIR%python\lib\*
set CLASSPATH=%CLASSPATH%;%BAGEL_DIR%\target\scala-%SCALA_VERSION%\classes
set CLASSPATH=%CLASSPATH%;%GRAPH_DIR%\target\scala-%SCALA_VERSION%\classes
rem Add hadoop conf dir - else FileSystem.*, etc fail
rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
rem the configurtion files.
if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
:no_hadoop_conf_dir
if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
:no_yarn_conf_dir
set REPL_DIR=%FWDIR%repl
rem Compute classpath using external script
set DONT_PRINT_CLASSPATH=1
call "%FWDIR%bin\compute-classpath.cmd"
set DONT_PRINT_CLASSPATH=0
rem Figure out the JAR file that our examples were packaged into.
rem First search in the build path from SBT:
@ -108,4 +80,4 @@ if "%SPARK_LAUNCH_WITH_SCALA%" NEQ 1 goto java_runner
:run_spark
"%RUNNER%" -cp "%CLASSPATH%" %EXTRA_ARGS% %*
:exit
:exit

View file

@ -441,7 +441,12 @@ abstract class DStream[T: ClassManifest] (
* Return a new DStream in which each RDD has a single element generated by counting each RDD
* of this DStream.
*/
def count(): DStream[Long] = this.map(_ => 1L).reduce(_ + _)
def count(): DStream[Long] = {
this.map(_ => (null, 1L))
.transform(_.union(context.sparkContext.makeRDD(Seq((null, 0L)), 1)))
.reduceByKey(_ + _)
.map(_._2)
}
/**
* Return a new DStream in which each RDD contains the counts of each distinct value in
@ -457,7 +462,7 @@ abstract class DStream[T: ClassManifest] (
* this DStream will be registered as an output stream and therefore materialized.
*/
def foreach(foreachFunc: RDD[T] => Unit) {
foreach((r: RDD[T], t: Time) => foreachFunc(r))
this.foreach((r: RDD[T], t: Time) => foreachFunc(r))
}
/**

View file

@ -28,6 +28,7 @@ import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
import org.apache.hadoop.fs.Path
import twitter4j.Status
/**
* A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic
* information (such as, cluster URL and job name) to internally create a SparkContext, it provides
@ -186,10 +187,11 @@ class StreamingContext private (
* should be same.
*/
def actorStream[T: ClassManifest](
props: Props,
name: String,
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy): DStream[T] = {
props: Props,
name: String,
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2,
supervisorStrategy: SupervisorStrategy = ReceiverSupervisorStrategy.defaultStrategy
): DStream[T] = {
networkStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy))
}
@ -197,9 +199,10 @@ class StreamingContext private (
* Create an input stream that receives messages pushed by a zeromq publisher.
* @param publisherUrl Url of remote zeromq publisher
* @param subscribe topic to subscribe to
* @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic and each frame has sequence
* of byte thus it needs the converter(which might be deserializer of bytes)
* to translate from sequence of sequence of bytes, where sequence refer to a frame
* @param bytesToObjects A zeroMQ stream publishes sequence of frames for each topic
* and each frame has sequence of byte thus it needs the converter
* (which might be deserializer of bytes) to translate from sequence
* of sequence of bytes, where sequence refer to a frame
* and sub sequence refer to its payload.
* @param storageLevel RDD storage level. Defaults to memory-only.
*/
@ -215,24 +218,39 @@ class StreamingContext private (
}
/**
* Create an input stream that pulls messages form a Kafka Broker.
* Create an input stream that pulls messages from a Kafka Broker.
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param initialOffsets Optional initial offsets for each of the partitions to consume.
* By default the value is pulled from zookeper.
* in its own thread.
* @param storageLevel Storage level to use for storing the received objects
* (default: StorageLevel.MEMORY_AND_DISK_SER_2)
*/
def kafkaStream[T: ClassManifest](
def kafkaStream(
zkQuorum: String,
groupId: String,
topics: Map[String, Int],
initialOffsets: Map[KafkaPartitionKey, Long] = Map[KafkaPartitionKey, Long](),
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
): DStream[String] = {
val kafkaParams = Map[String, String](
"zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000")
kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel)
}
/**
* Create an input stream that pulls messages from a Kafka Broker.
* @param kafkaParams Map of kafka configuration paramaters.
* See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param storageLevel Storage level to use for storing the received objects
*/
def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest](
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
): DStream[T] = {
val inputStream = new KafkaInputDStream[T](this, zkQuorum, groupId, topics, initialOffsets, storageLevel)
val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel)
registerInputStream(inputStream)
inputStream
}
@ -397,7 +415,8 @@ class StreamingContext private (
* it will process either one or all of the RDDs returned by the queue.
* @param queue Queue of RDDs
* @param oneAtATime Whether only one RDD should be consumed from the queue in every interval
* @param defaultRDD Default RDD is returned by the DStream when the queue is empty. Set as null if no RDD should be returned when empty
* @param defaultRDD Default RDD is returned by the DStream when the queue is empty.
* Set as null if no RDD should be returned when empty
* @tparam T Type of objects in the RDD
*/
def queueStream[T: ClassManifest](

View file

@ -121,14 +121,15 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
*/
def kafkaStream[T](
def kafkaStream(
zkQuorum: String,
groupId: String,
topics: JMap[String, JInt])
: JavaDStream[T] = {
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
ssc.kafkaStream[T](zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*))
: JavaDStream[String] = {
implicit val cmt: ClassManifest[String] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
StorageLevel.MEMORY_ONLY_SER_2)
}
/**
@ -136,49 +137,45 @@ class JavaStreamingContext(val ssc: StreamingContext) {
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param initialOffsets Optional initial offsets for each of the partitions to consume.
* By default the value is pulled from zookeper.
* in its own thread.
* @param storageLevel RDD storage level. Defaults to memory-only
*
*/
def kafkaStream[T](
def kafkaStream(
zkQuorum: String,
groupId: String,
topics: JMap[String, JInt],
initialOffsets: JMap[KafkaPartitionKey, JLong])
: JavaDStream[T] = {
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
ssc.kafkaStream[T](
zkQuorum,
groupId,
Map(topics.mapValues(_.intValue()).toSeq: _*),
Map(initialOffsets.mapValues(_.longValue()).toSeq: _*))
storageLevel: StorageLevel)
: JavaDStream[String] = {
implicit val cmt: ClassManifest[String] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
storageLevel)
}
/**
* Create an input stream that pulls messages form a Kafka Broker.
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param typeClass Type of RDD
* @param decoderClass Type of kafka decoder
* @param kafkaParams Map of kafka configuration paramaters.
* See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param initialOffsets Optional initial offsets for each of the partitions to consume.
* By default the value is pulled from zookeper.
* @param storageLevel RDD storage level. Defaults to memory-only
*/
def kafkaStream[T](
zkQuorum: String,
groupId: String,
def kafkaStream[T, D <: kafka.serializer.Decoder[_]](
typeClass: Class[T],
decoderClass: Class[D],
kafkaParams: JMap[String, String],
topics: JMap[String, JInt],
initialOffsets: JMap[KafkaPartitionKey, JLong],
storageLevel: StorageLevel)
: JavaDStream[T] = {
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
ssc.kafkaStream[T](
zkQuorum,
groupId,
implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]]
ssc.kafkaStream[T, D](
kafkaParams.toMap,
Map(topics.mapValues(_.intValue()).toSeq: _*),
Map(initialOffsets.mapValues(_.longValue()).toSeq: _*),
storageLevel)
}

View file

@ -9,58 +9,51 @@ import java.util.concurrent.Executors
import kafka.consumer._
import kafka.message.{Message, MessageSet, MessageAndMetadata}
import kafka.serializer.StringDecoder
import kafka.serializer.Decoder
import kafka.utils.{Utils, ZKGroupTopicDirs}
import kafka.utils.ZkUtils._
import kafka.utils.ZKStringSerializer
import org.I0Itec.zkclient._
import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
// Key for a specific Kafka Partition: (broker, topic, group, part)
case class KafkaPartitionKey(brokerId: Int, topic: String, groupId: String, partId: Int)
/**
* Input stream that pulls messages from a Kafka Broker.
*
* @param zkQuorum Zookeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
* @param kafkaParams Map of kafka configuration paramaters. See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param initialOffsets Optional initial offsets for each of the partitions to consume.
* By default the value is pulled from zookeper.
* @param storageLevel RDD storage level.
*/
private[streaming]
class KafkaInputDStream[T: ClassManifest](
class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest](
@transient ssc_ : StreamingContext,
zkQuorum: String,
groupId: String,
kafkaParams: Map[String, String],
topics: Map[String, Int],
initialOffsets: Map[KafkaPartitionKey, Long],
storageLevel: StorageLevel
) extends NetworkInputDStream[T](ssc_ ) with Logging {
def getReceiver(): NetworkReceiver[T] = {
new KafkaReceiver(zkQuorum, groupId, topics, initialOffsets, storageLevel)
new KafkaReceiver[T, D](kafkaParams, topics, storageLevel)
.asInstanceOf[NetworkReceiver[T]]
}
}
private[streaming]
class KafkaReceiver(zkQuorum: String, groupId: String,
topics: Map[String, Int], initialOffsets: Map[KafkaPartitionKey, Long],
storageLevel: StorageLevel) extends NetworkReceiver[Any] {
// Timeout for establishing a connection to Zookeper in ms.
val ZK_TIMEOUT = 10000
class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
) extends NetworkReceiver[Any] {
// Handles pushing data into the BlockManager
lazy protected val blockGenerator = new BlockGenerator(storageLevel)
// Connection to Kafka
var consumerConnector : ZookeeperConsumerConnector = null
var consumerConnector : ConsumerConnector = null
def onStop() {
blockGenerator.stop()
@ -73,54 +66,59 @@ class KafkaReceiver(zkQuorum: String, groupId: String,
// In case we are using multiple Threads to handle Kafka Messages
val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
logInfo("Starting Kafka Consumer Stream with group: " + groupId)
logInfo("Initial offsets: " + initialOffsets.toString)
logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid"))
// Zookeper connection properties
// Kafka connection properties
val props = new Properties()
props.put("zk.connect", zkQuorum)
props.put("zk.connectiontimeout.ms", ZK_TIMEOUT.toString)
props.put("groupid", groupId)
kafkaParams.foreach(param => props.put(param._1, param._2))
// Create the connection to the cluster
logInfo("Connecting to Zookeper: " + zkQuorum)
logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect"))
val consumerConfig = new ConsumerConfig(props)
consumerConnector = Consumer.create(consumerConfig).asInstanceOf[ZookeeperConsumerConnector]
logInfo("Connected to " + zkQuorum)
consumerConnector = Consumer.create(consumerConfig)
logInfo("Connected to " + kafkaParams("zk.connect"))
// If specified, set the topic offset
setOffsets(initialOffsets)
// When autooffset.reset is defined, it is our responsibility to try and whack the
// consumer group zk node.
if (kafkaParams.contains("autooffset.reset")) {
tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid"))
}
// Create Threads for each Topic/Message Stream we are listening
val topicMessageStreams = consumerConnector.createMessageStreams(topics, new StringDecoder())
val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]]
val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder)
// Start the messages handler for each partition
topicMessageStreams.values.foreach { streams =>
streams.foreach { stream => executorPool.submit(new MessageHandler(stream)) }
}
}
// Overwrites the offets in Zookeper.
private def setOffsets(offsets: Map[KafkaPartitionKey, Long]) {
offsets.foreach { case(key, offset) =>
val topicDirs = new ZKGroupTopicDirs(key.groupId, key.topic)
val partitionName = key.brokerId + "-" + key.partId
updatePersistentPath(consumerConnector.zkClient,
topicDirs.consumerOffsetDir + "/" + partitionName, offset.toString)
}
}
// Handles Kafka Messages
private class MessageHandler(stream: KafkaStream[String]) extends Runnable {
private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable {
def run() {
logInfo("Starting MessageHandler.")
stream.takeWhile { msgAndMetadata =>
for (msgAndMetadata <- stream) {
blockGenerator += msgAndMetadata.message
// Keep on handling messages
true
}
}
}
// It is our responsibility to delete the consumer group when specifying autooffset.reset. This is because
// Kafka 0.7.2 only honors this param when the group is not in zookeeper.
//
// The kafka high level consumer doesn't expose setting offsets currently, this is a trick copied from Kafkas'
// ConsoleConsumer. See code related to 'autooffset.reset' when it is set to 'smallest'/'largest':
// https://github.com/apache/kafka/blob/0.7.2/core/src/main/scala/kafka/consumer/ConsoleConsumer.scala
private def tryZookeeperConsumerGroupCleanup(zkUrl: String, groupId: String) {
try {
val dir = "/consumers/" + groupId
logInfo("Cleaning up temporary zookeeper data under " + dir + ".")
val zk = new ZkClient(zkUrl, 30*1000, 30*1000, ZKStringSerializer)
zk.deleteRecursive(dir)
zk.close()
} catch {
case _ => // swallow
}
}
}

View file

@ -198,7 +198,7 @@ abstract class NetworkReceiver[T: ClassManifest]() extends Serializable with Log
case class Block(id: String, iterator: Iterator[T], metadata: Any = null)
val clock = new SystemClock()
val blockInterval = 200L
val blockInterval = System.getProperty("spark.streaming.blockInterval", "200").toLong
val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer)
val blockStorageLevel = storageLevel
val blocksForPushing = new ArrayBlockingQueue[Block](1000)

View file

@ -4,6 +4,7 @@ import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.Files;
import kafka.serializer.StringDecoder;
import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat;
import org.junit.After;
import org.junit.Assert;
@ -23,7 +24,6 @@ import spark.streaming.api.java.JavaPairDStream;
import spark.streaming.api.java.JavaStreamingContext;
import spark.streaming.JavaTestUtils;
import spark.streaming.JavaCheckpointTestUtils;
import spark.streaming.dstream.KafkaPartitionKey;
import spark.streaming.InputStreamsSuite;
import java.io.*;
@ -1203,10 +1203,14 @@ public class JavaAPISuite implements Serializable {
@Test
public void testKafkaStream() {
HashMap<String, Integer> topics = Maps.newHashMap();
HashMap<KafkaPartitionKey, Long> offsets = Maps.newHashMap();
JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics);
JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics, offsets);
JavaDStream test3 = ssc.kafkaStream("localhost:12345", "group", topics, offsets,
JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics,
StorageLevel.MEMORY_AND_DISK());
HashMap<String, String> kafkaParams = Maps.newHashMap();
kafkaParams.put("zk.connect","localhost:12345");
kafkaParams.put("groupid","consumer-group");
JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics,
StorageLevel.MEMORY_AND_DISK());
}

View file

@ -93,9 +93,9 @@ class BasicOperationsSuite extends TestSuiteBase {
test("count") {
testOperation(
Seq(1 to 1, 1 to 2, 1 to 3, 1 to 4),
Seq(Seq(), 1 to 1, 1 to 2, 1 to 3, 1 to 4),
(s: DStream[Int]) => s.count(),
Seq(Seq(1L), Seq(2L), Seq(3L), Seq(4L))
Seq(Seq(0L), Seq(1L), Seq(2L), Seq(3L), Seq(4L))
)
}

View file

@ -243,6 +243,17 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
assert(output(i) === expectedOutput(i))
}
}
test("kafka input stream") {
val ssc = new StreamingContext(master, framework, batchDuration)
val topics = Map("my-topic" -> 1)
val test1 = ssc.kafkaStream("localhost:12345", "group", topics)
val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK)
// Test specifying decoder
val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group")
val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
}
}