Merge remote-tracking branch 'upstream/master' into ui_new_state

Conflicts:
	core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
This commit is contained in:
Kay Ousterhout 2013-10-23 16:06:28 -07:00
commit a5f8f54ecd
70 changed files with 1426 additions and 810 deletions

View file

@ -25,6 +25,7 @@ import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.FileSegment;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@ -37,40 +38,34 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
BlockId blockId = BlockId.apply(blockIdString);
String path = pResolver.getAbsolutePath(blockId.name());
// if getFilePath returns null, close the channel
if (path == null) {
FileSegment fileSegment = pResolver.getBlockLocation(blockId);
// if getBlockLocation returns null, close the channel
if (fileSegment == null) {
//ctx.close();
return;
}
File file = new File(path);
File file = fileSegment.file();
if (file.exists()) {
if (!file.isFile()) {
//logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
long length = file.length();
long length = fileSegment.length();
if (length > Integer.MAX_VALUE || length <= 0) {
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
//logger.info("Sending block "+blockId+" filelen = "+len);
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
.getChannel(), 0, file.length()));
.getChannel(), fileSegment.offset(), fileSegment.length()));
} catch (Exception e) {
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
//logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();

View file

@ -17,13 +17,10 @@
package org.apache.spark.network.netty;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.FileSegment;
public interface PathResolver {
/**
* Get the absolute path of the file
*
* @param fileId
* @return the absolute path of file
*/
public String getAbsolutePath(String fileId);
/** Get the file segment in which the given block resides. */
public FileSegment getBlockLocation(BlockId blockId);
}

View file

@ -17,20 +17,29 @@
package org.apache.hadoop.mapred
private[apache]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext");
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID])
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
"org.apache.hadoop.mapred.JobContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf],
classOf[org.apache.hadoop.mapreduce.JobID])
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = {
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext")
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
"org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
def newTaskAttemptID(
jtIdentifier: String,
jobId: Int,
isMap: Boolean,
taskId: Int,
attemptId: Int) = {
new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId)
}

View file

@ -17,9 +17,10 @@
package org.apache.hadoop.mapreduce
import org.apache.hadoop.conf.Configuration
import java.lang.{Integer => JInteger, Boolean => JBoolean}
import org.apache.hadoop.conf.Configuration
private[apache]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil {
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID");
def newTaskAttemptID(
jtIdentifier: String,
jobId: Int,
isMap: Boolean,
taskId: Int,
attemptId: Int) = {
val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
// first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN)
// First, attempt to use the old-style constructor that takes a boolean isMap
// (not available in YARN)
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
classOf[Int], classOf[Int])
ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new
JInteger(attemptId)).asInstanceOf[TaskAttemptID]
classOf[Int], classOf[Int])
ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
} catch {
case exc: NoSuchMethodException => {
// failed, look for the new ctor that takes a TaskType (not available in 1.x)
val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]]
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE")
// If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
.asInstanceOf[Class[Enum[_]]]
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
taskTypeClass, if(isMap) "MAP" else "REDUCE")
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
classOf[Int], classOf[Int])
ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new
JInteger(attemptId)).asInstanceOf[TaskAttemptID]
ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
}
}
}

View file

@ -20,13 +20,11 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
@ -40,11 +38,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
extends Actor with Logging {
def receive = {
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
sender ! tracker.getSerializedLocations(shuffleId)
sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@ -60,22 +59,19 @@ private[spark] class MapOutputTracker extends Logging {
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var epoch: Long = 0
private val epochLock = new java.lang.Object
protected var epoch: Long = 0
protected val epochLock = new java.lang.Object
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
private def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
@ -86,50 +82,12 @@ private[spark] class MapOutputTracker extends Logging {
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
def communicate(message: Any) {
private def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
}
def registerMapOutputs(
shuffleId: Int,
statuses: Array[MapStatus],
changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
incrementEpoch()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
var array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
}
incrementEpoch()
} else {
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
@ -168,7 +126,7 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
fetchedStatuses = deserializeStatuses(fetchedBytes)
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
@ -194,9 +152,8 @@ private[spark] class MapOutputTracker extends Logging {
}
}
private def cleanup(cleanupTime: Long) {
protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
@ -206,15 +163,7 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}
// Called on master to increment the epoch number
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
logDebug("Increasing epoch to " + epoch)
}
}
// Called on master or workers to get current epoch number
// Called to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
@ -228,14 +177,62 @@ private[spark] class MapOutputTracker extends Logging {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
mapStatuses.clear()
epoch = newEpoch
mapStatuses.clear()
}
}
}
}
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
// Cache a serialized version of the output statuses for each shuffle to send them out faster
private var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
val array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
}
def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
incrementEpoch()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
val arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
val array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
}
incrementEpoch()
} else {
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
logDebug("Increasing epoch to " + epoch)
}
}
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
@ -253,7 +250,7 @@ private[spark] class MapOutputTracker extends Logging {
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeStatuses(statuses)
val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
@ -261,13 +258,31 @@ private[spark] class MapOutputTracker extends Logging {
cachedSerializedStatuses(shuffleId) = bytes
}
}
return bytes
bytes
}
protected override def cleanup(cleanupTime: Long) {
super.cleanup(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
override def stop() {
super.stop()
cachedSerializedStatuses.clear()
}
override def updateEpoch(newEpoch: Long) {
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
}
}
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
@ -278,18 +293,11 @@ private[spark] class MapOutputTracker extends Logging {
out.toByteArray
}
// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
// Opposite of serializeMapStatuses.
def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
objIn.readObject().
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
// comment this out - nulls could be due to missing location ?
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
objIn.readObject().asInstanceOf[Array[MapStatus]]
}
}
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),

View file

@ -156,8 +156,8 @@ class SparkContext(
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """spark://(.*)""".r
//Regular expression for connection to Mesos cluster
val MESOS_REGEX = """(mesos://.*)""".r
// Regular expression for connection to Mesos cluster
val MESOS_REGEX = """mesos://(.*)""".r
master match {
case "local" =>

View file

@ -187,10 +187,14 @@ object SparkEnv extends Logging {
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
val mapOutputTracker = new MapOutputTracker()
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster()
} else {
new MapOutputTracker()
}
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerActor(mapOutputTracker))
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")

View file

@ -17,14 +17,14 @@
package org.apache.hadoop.mapred
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import java.io.IOException
import java.text.SimpleDateFormat
import java.text.NumberFormat
import java.io.IOException
import java.util.Date
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
import org.apache.spark.SerializableWritable
@ -36,6 +36,7 @@ import org.apache.spark.SerializableWritable
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
private[apache]
class SparkHadoopWriter(@transient jobConf: JobConf)
extends Logging
with SparkHadoopMapRedUtil
@ -86,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
}
getOutputCommitter().setupTask(getTaskContext())
writer = getOutputFormat().getRecordWriter(
fs, conf.value, outputName, Reporter.NULL)
writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
}
def write(key: AnyRef, value: AnyRef) {
if (writer!=null) {
//println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
if (writer != null) {
writer.write(key, value)
} else {
throw new IOException("Writer is null, open() has not been called")
@ -182,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
}
}
private[apache]
object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")

View file

@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
@ -27,11 +25,7 @@ import java.io.Serializable;
*/
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>>
implements Serializable {
public abstract Iterable<Double> call(T t);
@Override
public final Iterable<Double> apply(T t) { return call(t); }
// Intentionally left blank
}

View file

@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
@ -29,6 +27,5 @@ import java.io.Serializable;
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
implements Serializable {
public abstract Double call(T t) throws Exception;
// Intentionally left blank
}

View file

@ -21,8 +21,5 @@ package org.apache.spark.api.java.function
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
@throws(classOf[Exception])
def call(x: T) : java.lang.Iterable[R]
def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]]
}

View file

@ -21,8 +21,5 @@ package org.apache.spark.api.java.function
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
@throws(classOf[Exception])
def call(a: A, b:B) : java.lang.Iterable[C]
def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
}

View file

@ -19,7 +19,6 @@ package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@ -30,8 +29,6 @@ import java.io.Serializable;
* when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
public abstract R call(T t) throws Exception;
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}

View file

@ -19,7 +19,6 @@ package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction2;
import java.io.Serializable;
@ -29,8 +28,6 @@ import java.io.Serializable;
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {
public abstract R call(T1 t1, T2 t2) throws Exception;
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}

View file

@ -20,7 +20,6 @@ package org.apache.spark.api.java.function;
import scala.Tuple2;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@ -34,8 +33,6 @@ public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
public ClassManifest<K> keyType() {
return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
}

View file

@ -20,7 +20,6 @@ package org.apache.spark.api.java.function;
import scala.Tuple2;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@ -29,12 +28,9 @@ import java.io.Serializable;
*/
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
public abstract class PairFunction<T, K, V>
extends WrappedFunction1<T, Tuple2<K, V>>
public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
public abstract Tuple2<K, V> call(T t) throws Exception;
public ClassManifest<K> keyType() {
return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
}

View file

@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")

View file

@ -17,26 +17,31 @@
package org.apache.spark.deploy
import com.google.common.collect.MapMaker
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
import com.google.common.collect.MapMaker
/**
* Contains util methods to interact with Hadoop from spark.
* Contains util methods to interact with Hadoop from Spark.
*/
private[spark]
class SparkHadoopUtil {
// A general, soft-reference map for metadata needed during HadoopRDD split computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop
// subsystems
/**
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
* subsystems.
*/
def newConfiguration(): Configuration = new Configuration()
// Add any user credentials to the job conf which are necessary for running on a secure Hadoop
// cluster
/**
* Add any user credentials to the job conf which are necessary for running on a secure Hadoop
* cluster.
*/
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }

View file

@ -21,7 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
import org.apache.spark.storage.BlockId
import org.apache.spark.storage.{BlockId, FileSegment}
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@ -54,8 +54,7 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
override def getAbsolutePath(blockIdString: String): String = {
val blockId = BlockId(blockIdString)
override def getBlockLocation(blockId: BlockId): FileSegment = {
if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
@ -65,7 +64,7 @@ private[spark] object ShuffleSender {
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId.name)
return file.getAbsolutePath
return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)

View file

@ -52,13 +52,14 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
private[spark]
class DAGScheduler(
taskSched: TaskScheduler,
mapOutputTracker: MapOutputTracker,
mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends Logging {
def this(taskSched: TaskScheduler) {
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
SparkEnv.get.blockManager.master, SparkEnv.get)
}
taskSched.setDAGScheduler(this)
@ -187,7 +188,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
@ -200,6 +201,7 @@ class DAGScheduler(
*/
private def newStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int,
callSite: Option[String] = None)
@ -212,9 +214,10 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
val stage =
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
stageToInfos(stage) = StageInfo(stage)
stageToInfos(stage) = new StageInfo(stage)
stage
}
@ -366,7 +369,7 @@ class DAGScheduler(
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
val finalStage = newStage(rdd, None, jobId, Some(callSite))
val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@ -592,7 +595,7 @@ class DAGScheduler(
// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
@ -613,9 +616,7 @@ class DAGScheduler(
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
if (!stage.submissionTime.isDefined) {
stage.submissionTime = Some(System.currentTimeMillis())
}
stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@ -637,12 +638,12 @@ class DAGScheduler(
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
val serviceTime = stage.submissionTime match {
val serviceTime = stageToInfos(stage).submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
case _ => "Unkown"
case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stage.completionTime = Some(System.currentTimeMillis)
stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(StageCompleted(stageToInfos(stage)))
running -= stage
}
@ -812,7 +813,7 @@ class DAGScheduler(
*/
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
failedStage.completionTime = Some(System.currentTimeMillis())
stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
val error = new SparkException("Job aborted: " + reason)

View file

@ -24,56 +24,54 @@ import java.text.SimpleDateFormat
import java.util.{Date, Properties}
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{Map, HashMap, ListBuffer}
import scala.io.Source
import scala.collection.mutable.{HashMap, ListBuffer}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
// Used to record runtime information for each job, including RDD graph
// tasks' start/stop shuffle information and information from outside
/**
* A logger class to record runtime information for jobs in Spark. This class outputs one log file
* per Spark job with information such as RDD graph, tasks start/stop, shuffle information.
*
* @param logDirName The base directory for the log files.
*/
class JobLogger(val logDirName: String) extends SparkListener with Logging {
private val logDir =
if (System.getenv("SPARK_LOG_DIR") != null)
System.getenv("SPARK_LOG_DIR")
else
"/tmp/spark"
private val logDir = Option(System.getenv("SPARK_LOG_DIR")).getOrElse("/tmp/spark")
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
private val stageIDToJobID = new HashMap[Int, Int]
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
createLogDir()
def this() = this(String.valueOf(System.currentTimeMillis()))
def getLogDir = logDir
def getJobIDtoPrintWriter = jobIDToPrintWriter
def getStageIDToJobID = stageIDToJobID
def getJobIDToStages = jobIDToStages
def getEventQueue = eventQueue
// The following 5 functions are used only in testing.
private[scheduler] def getLogDir = logDir
private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
private[scheduler] def getStageIDToJobID = stageIDToJobID
private[scheduler] def getJobIDToStages = jobIDToStages
private[scheduler] def getEventQueue = eventQueue
// Create a folder for log files, the folder's name is the creation time of the jobLogger
protected def createLogDir() {
val dir = new File(logDir + "/" + logDirName + "/")
if (dir.exists()) {
return
}
if (dir.mkdirs() == false) {
logError("create log directory error:" + logDir + "/" + logDirName + "/")
if (!dir.exists() && !dir.mkdirs()) {
logError("Error creating log directory: " + logDir + "/" + logDirName + "/")
}
}
// Create a log file for one job, the file name is the jobID
protected def createLogWriter(jobID: Int) {
try{
try {
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
jobIDToPrintWriter += (jobID -> fileWriter)
} catch {
case e: FileNotFoundException => e.printStackTrace()
}
} catch {
case e: FileNotFoundException => e.printStackTrace()
}
}
// Close log file, and clean the stage relationship in stageIDToJobID
@ -118,10 +116,9 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
var rddList = new ListBuffer[RDD[_]]
rddList += rdd
rdd.dependencies.foreach{ dep => dep match {
case shufDep: ShuffleDependency[_,_] =>
case _ => rddList ++= getRddsInStage(dep.rdd)
}
rdd.dependencies.foreach {
case shufDep: ShuffleDependency[_, _] =>
case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
}
rddList
}
@ -161,29 +158,27 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
rdd.dependencies.foreach{ dep => dep match {
case shufDep: ShuffleDependency[_,_] =>
val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
rdd.dependencies.foreach {
case shufDep: ShuffleDependency[_, _] =>
val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
}
protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
var stageInfo: String = ""
if (stage.isShuffleMap) {
stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
stage.shuffleDep.get.shuffleId
}else{
stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
val stageInfo = if (stage.isShuffleMap) {
"STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
} else {
"STAGE_ID=" + stage.id + " RESULT_STAGE"
}
if (stage.jobId == jobID) {
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
recordRddInStageGraph(jobID, stage.rdd, indent)
stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
} else
} else {
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
}
}
// Record task metrics into job log files
@ -193,39 +188,32 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
" START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
val readMetrics =
taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
" REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
case None => ""
}
val writeMetrics =
taskMetrics.shuffleWriteMetrics match {
case Some(metrics) =>
" SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
case None => ""
}
val readMetrics = taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
" REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
case None => ""
}
val writeMetrics = taskMetrics.shuffleWriteMetrics match {
case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
case None => ""
}
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
stageLogInfo(
stageSubmitted.stage.id,
"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.id, stageSubmitted.taskSize))
stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
}
override def onStageCompleted(stageCompleted: StageCompleted) {
stageLogInfo(
stageCompleted.stageInfo.stage.id,
"STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
stageCompleted.stage.stageId))
}
override def onTaskStart(taskStart: SparkListenerTaskStart) { }

View file

@ -167,8 +167,7 @@ private[spark] class ShuffleMapTask(
var totalTime = 0L
val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
writer.commit()
writer.close()
val size = writer.size()
val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
@ -191,6 +190,7 @@ private[spark] class ShuffleMapTask(
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && buckets != null) {
buckets.writers.foreach(_.close())
shuffle.releaseWriters(buckets)
}
// Execute the callbacks on task completion.

View file

@ -24,10 +24,10 @@ import org.apache.spark.executor.TaskMetrics
sealed trait SparkListenerEvents
case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties)
case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents
case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
@ -89,7 +89,7 @@ class StatsReportListener extends SparkListener with Logging {
override def onStageCompleted(stageCompleted: StageCompleted) {
import org.apache.spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
this.logInfo("Finished stage: " + stageCompleted.stageInfo)
this.logInfo("Finished stage: " + stageCompleted.stage)
showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
//shuffle write
@ -102,7 +102,7 @@ class StatsReportListener extends SparkListener with Logging {
//runtime breakdown
val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
val runtimePcts = stageCompleted.stage.taskInfos.map{
case (info, metrics) => RuntimePercentage(info.duration, metrics)
}
showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
@ -120,7 +120,7 @@ object StatsReportListener extends Logging {
val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
Distribution(stage.stageInfo.taskInfos.flatMap{
Distribution(stage.stage.taskInfos.flatMap {
case ((info,metric)) => getMetric(info, metric)})
}

View file

@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
@ -49,11 +50,6 @@ private[spark] class Stage(
val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
/** When first task was submitted to scheduler. */
var submissionTime: Option[Long] = None
var completionTime: Option[Long] = None
private var nextAttemptId = 0
def isAvailable: Boolean = {

View file

@ -21,9 +21,18 @@ import scala.collection._
import org.apache.spark.executor.TaskMetrics
case class StageInfo(
val stage: Stage,
class StageInfo(
stage: Stage,
val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
) {
override def toString = stage.rdd.toString
val stageId = stage.id
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
var submissionTime: Option[Long] = None
var completionTime: Option[Long] = None
val rddName = stage.rdd.name
val name = stage.name
val numPartitions = stage.numPartitions
val numTasks = stage.numTasks
override def toString = rddName
}

View file

@ -45,7 +45,7 @@ import org.apache.spark.util.ByteBufferInputStream
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
def run(attemptId: Long): T = {
final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
if (_killed) {
kill()

View file

@ -28,7 +28,7 @@ import akka.dispatch.{Await, Future}
import akka.util.Duration
import akka.util.duration._
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
@ -102,18 +102,19 @@ private[spark] class BlockManager(
}
val shuffleBlockManager = new ShuffleBlockManager(this)
val diskBlockManager = new DiskBlockManager(
System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore: DiskStore =
new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private[storage] val diskStore = new DiskStore(this, diskBlockManager)
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
val connectionManager = new ConnectionManager(0)
@ -512,16 +513,20 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
* The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
writer.registerCloseEventHandler(() => {
diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
blockInfo.put(blockId, myInfo)
myInfo.markReady(writer.size())
myInfo.markReady(writer.fileSegment().length)
})
writer
}
@ -862,13 +867,24 @@ private[spark] class BlockManager(
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
/** Serializes into a stream. */
def dataSerializeStream(
blockId: BlockId,
outputStream: OutputStream,
values: Iterator[Any],
serializer: Serializer = defaultSerializer) {
val byteStream = new FastBufferedOutputStream(outputStream)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
}
/** Serializes into a byte buffer. */
def dataSerialize(
blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
dataSerializeStream(blockId, byteStream, values, serializer)
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}

View file

@ -26,6 +26,7 @@ import org.apache.spark.storage.BlockManagerMessages._
* An actor to take commands from the master to execute options. For example,
* this is used to remove blocks from the slave's BlockManager.
*/
private[storage]
class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
override def receive = {

View file

@ -17,6 +17,13 @@
package org.apache.spark.storage
import java.io.{FileOutputStream, File, OutputStream}
import java.nio.channels.FileChannel
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.Logging
import org.apache.spark.serializer.{SerializationStream, Serializer}
/**
* An interface for writing JVM objects to some underlying storage. This interface allows
@ -59,12 +66,129 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
def write(value: Any)
/**
* Size of the valid writes, in bytes.
* Returns the file segment of committed data that this Writer has written.
*/
def size(): Long
def fileSegment(): FileSegment
/**
* Cumulative time spent performing blocking writes, in ns.
*/
def timeWriting(): Long
}
/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
class DiskBlockObjectWriter(
blockId: BlockId,
file: File,
serializer: Serializer,
bufferSize: Int,
compressStream: OutputStream => OutputStream)
extends BlockObjectWriter(blockId)
with Logging
{
/** Intercepts write calls and tracks total time spent writing. Not thread safe. */
private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
def timeWriting = _timeWriting
private var _timeWriting = 0L
private def callWithTiming(f: => Unit) = {
val start = System.nanoTime()
f
_timeWriting += (System.nanoTime() - start)
}
def write(i: Int): Unit = callWithTiming(out.write(i))
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
}
private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
private var bs: OutputStream = null
private var fos: FileOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
private var initialPosition = 0L
private var lastValidPosition = 0L
private var initialized = false
private var _timeWriting = 0L
override def open(): BlockObjectWriter = {
fos = new FileOutputStream(file, true)
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
initialPosition = channel.position
lastValidPosition = initialPosition
bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
initialized = true
this
}
override def close() {
if (initialized) {
if (syncWrites) {
// Force outstanding writes to disk and track how long it takes
objOut.flush()
val start = System.nanoTime()
fos.getFD.sync()
_timeWriting += System.nanoTime() - start
}
objOut.close()
_timeWriting += ts.timeWriting
channel = null
bs = null
fos = null
ts = null
objOut = null
}
// Invoke the close callback handler.
super.close()
}
override def isOpen: Boolean = objOut != null
override def commit(): Long = {
if (initialized) {
// NOTE: Flush the serializer first and then the compressed/buffered output stream
objOut.flush()
bs.flush()
val prevPos = lastValidPosition
lastValidPosition = channel.position()
lastValidPosition - prevPos
} else {
// lastValidPosition is zero if stream is uninitialized
lastValidPosition
}
}
override def revertPartialWrites() {
if (initialized) {
// Discard current writes. We do this by flushing the outstanding writes and
// truncate the file to the last valid position.
objOut.flush()
bs.flush()
channel.truncate(lastValidPosition)
}
}
override def write(value: Any) {
if (!initialized) {
open()
}
objOut.writeObject(value)
}
override def fileSegment(): FileSegment = {
val bytesWritten = lastValidPosition - initialPosition
new FileSegment(file, initialPosition, bytesWritten)
}
// Only valid if called after close()
override def timeWriting() = _timeWriting
}

View file

@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import java.io.File
import java.text.SimpleDateFormat
import java.util.{Date, Random}
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.Logging
import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
/**
* Creates and maintains the logical mapping between logical blocks and physical on-disk
* locations. By default, one block is mapped to one file with a name given by its BlockId.
* However, it is also possible to have a block map to only a segment of a file, by calling
* mapBlockToFileSegment().
*
* @param rootDirs The directories to use for storing block files. Data will be hashed among these.
*/
private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
// having really large inodes at the top level.
private val localDirs: Array[File] = createLocalDirs()
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
private var shuffleSender : ShuffleSender = null
// Stores only Blocks which have been specifically mapped to segments of files
// (rather than the default, which maps a Block to a whole file).
// This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks.
private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment]
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup)
addShutdownHook()
/**
* Creates a logical mapping from the given BlockId to a segment of a file.
* This will cause any accesses of the logical BlockId to be directed to the specified
* physical location.
*/
def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) {
blockToFileSegmentMap.put(blockId, fileSegment)
}
/**
* Returns the phyiscal file segment in which the given BlockId is located.
* If the BlockId has been mapped to a specific FileSegment, that will be returned.
* Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
*/
def getBlockLocation(blockId: BlockId): FileSegment = {
if (blockToFileSegmentMap.internalMap.containsKey(blockId)) {
blockToFileSegmentMap.get(blockId).get
} else {
val file = getFile(blockId.name)
new FileSegment(file, 0, file.length())
}
}
/**
* Simply returns a File to place the given Block into. This does not physically create the file.
* If filename is given, that file will be used. Otherwise, we will use the BlockId to get
* a unique filename.
*/
def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = {
val actualFilename = if (filename == "") blockId.name else filename
val file = getFile(actualFilename)
if (!allowAppending && file.exists()) {
throw new IllegalStateException(
"Attempted to create file that already exists: " + actualFilename)
}
file
}
private def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
// Create the subdirectory if it doesn't already exist
var subDir = subDirs(dirId)(subDirId)
if (subDir == null) {
subDir = subDirs(dirId).synchronized {
val old = subDirs(dirId)(subDirId)
if (old != null) {
old
} else {
val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
newDir.mkdir()
subDirs(dirId)(subDirId) = newDir
newDir
}
}
}
new File(subDir, filename)
}
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
rootDirs.split(",").map { rootDir =>
var foundLocalDir = false
var localDir: File = null
var localDirId: String = null
var tries = 0
val rand = new Random()
while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
tries += 1
try {
localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
localDir = new File(rootDir, "spark-local-" + localDirId)
if (!localDir.exists) {
foundLocalDir = localDir.mkdirs()
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
}
}
if (!foundLocalDir) {
logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
" attempts to create local dir in " + rootDir)
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
}
logInfo("Created local directory at " + localDir)
localDir
}
}
private def cleanup(cleanupTime: Long) {
blockToFileSegmentMap.clearOldValues(cleanupTime)
}
private def addShutdownHook() {
localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
localDirs.foreach { localDir =>
try {
if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
} catch {
case t: Throwable =>
logError("Exception while deleting local spark dir: " + localDir, t)
}
}
if (shuffleSender != null) {
shuffleSender.stop()
}
}
})
}
private[storage] def startShuffleBlockSender(port: Int): Int = {
shuffleSender = new ShuffleSender(port, this)
logInfo("Created ShuffleSender binding to port : " + shuffleSender.port)
shuffleSender.port
}
}

View file

@ -17,158 +17,25 @@
package org.apache.spark.storage
import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
import java.io.{FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
import java.util.{Random, Date}
import java.text.SimpleDateFormat
import scala.collection.mutable.ArrayBuffer
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.serializer.{Serializer, SerializationStream}
import org.apache.spark.Logging
import org.apache.spark.network.netty.ShuffleSender
import org.apache.spark.network.netty.PathResolver
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
/**
* Stores BlockManager blocks on disk.
*/
private class DiskStore(blockManager: BlockManager, rootDirs: String)
private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
class DiskBlockObjectWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
extends BlockObjectWriter(blockId) {
/** Intercepts write calls and tracks total time spent writing. Not thread safe. */
private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
def timeWriting = _timeWriting
private var _timeWriting = 0L
private def callWithTiming(f: => Unit) = {
val start = System.nanoTime()
f
_timeWriting += (System.nanoTime() - start)
}
def write(i: Int): Unit = callWithTiming(out.write(i))
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
}
private val f: File = createFile(blockId /*, allowAppendExisting */)
private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
// The file channel, used for repositioning / truncating the file.
private var channel: FileChannel = null
private var bs: OutputStream = null
private var fos: FileOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
private var lastValidPosition = 0L
private var initialized = false
private var _timeWriting = 0L
override def open(): DiskBlockObjectWriter = {
fos = new FileOutputStream(f, true)
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
initialized = true
this
}
override def close() {
if (initialized) {
if (syncWrites) {
// Force outstanding writes to disk and track how long it takes
objOut.flush()
val start = System.nanoTime()
fos.getFD.sync()
_timeWriting += System.nanoTime() - start
objOut.close()
} else {
objOut.close()
}
_timeWriting += ts.timeWriting
channel = null
bs = null
fos = null
ts = null
objOut = null
}
// Invoke the close callback handler.
super.close()
}
override def isOpen: Boolean = objOut != null
// Flush the partial writes, and set valid length to be the length of the entire file.
// Return the number of bytes written for this commit.
override def commit(): Long = {
if (initialized) {
// NOTE: Flush the serializer first and then the compressed/buffered output stream
objOut.flush()
bs.flush()
val prevPos = lastValidPosition
lastValidPosition = channel.position()
lastValidPosition - prevPos
} else {
// lastValidPosition is zero if stream is uninitialized
lastValidPosition
}
}
override def revertPartialWrites() {
if (initialized) {
// Discard current writes. We do this by flushing the outstanding writes and
// truncate the file to the last valid position.
objOut.flush()
bs.flush()
channel.truncate(lastValidPosition)
}
}
override def write(value: Any) {
if (!initialized) {
open()
}
objOut.writeObject(value)
}
override def size(): Long = lastValidPosition
// Only valid if called after close()
override def timeWriting = _timeWriting
}
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
private var shuffleSender : ShuffleSender = null
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
// having really large inodes at the top level.
private val localDirs: Array[File] = createLocalDirs()
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
addShutdownHook()
def getBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
}
override def getSize(blockId: BlockId): Long = {
getFile(blockId).length()
diskManager.getBlockLocation(blockId).length
}
override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
@ -177,27 +44,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
val file = createFile(blockId)
val channel = new RandomAccessFile(file, "rw").getChannel()
val file = diskManager.createBlockFile(blockId, allowAppending = false)
val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
}
channel.close()
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
private def getFileBytes(file: File): ByteBuffer = {
val length = file.length()
val channel = new RandomAccessFile(file, "r").getChannel()
val buffer = try {
channel.map(MapMode.READ_ONLY, 0, length)
} finally {
channel.close()
}
buffer
file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
override def putValues(
@ -209,21 +64,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
val file = createFile(blockId)
val fileOut = blockManager.wrapForCompression(blockId,
new FastBufferedOutputStream(new FileOutputStream(file)))
val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
objOut.writeAll(values.iterator)
objOut.close()
val length = file.length()
val file = diskManager.createBlockFile(blockId, allowAppending = false)
val outputStream = new FileOutputStream(file)
blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
val length = file.length
val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
blockId, Utils.bytesToString(length), timeTaken))
file.getName, Utils.bytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
val buffer = getFileBytes(file)
val buffer = getBytes(blockId).get
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
@ -231,13 +83,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val file = getFile(blockId)
val bytes = getFileBytes(file)
Some(bytes)
val segment = diskManager.getBlockLocation(blockId)
val channel = new RandomAccessFile(segment.file, "r").getChannel()
val buffer = try {
channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
} finally {
channel.close()
}
Some(buffer)
}
override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
}
/**
@ -249,118 +106,20 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def remove(blockId: BlockId): Boolean = {
val file = getFile(blockId)
if (file.exists()) {
val fileSegment = diskManager.getBlockLocation(blockId)
val file = fileSegment.file
if (file.exists() && file.length() == fileSegment.length) {
file.delete()
} else {
if (fileSegment.length < file.length()) {
logWarning("Could not delete block associated with only a part of a file: " + blockId)
}
false
}
}
override def contains(blockId: BlockId): Boolean = {
getFile(blockId).exists()
}
private def createFile(blockId: BlockId, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
if (!allowAppendExisting && file.exists()) {
// NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
// was rescheduled on the same machine as the old task.
logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
file.delete()
}
file
}
private def getFile(blockId: BlockId): File = {
logDebug("Getting file for block " + blockId)
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(blockId)
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
// Create the subdirectory if it doesn't already exist
var subDir = subDirs(dirId)(subDirId)
if (subDir == null) {
subDir = subDirs(dirId).synchronized {
val old = subDirs(dirId)(subDirId)
if (old != null) {
old
} else {
val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
newDir.mkdir()
subDirs(dirId)(subDirId) = newDir
newDir
}
}
}
new File(subDir, blockId.name)
}
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
rootDirs.split(",").map { rootDir =>
var foundLocalDir = false
var localDir: File = null
var localDirId: String = null
var tries = 0
val rand = new Random()
while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
tries += 1
try {
localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
localDir = new File(rootDir, "spark-local-" + localDirId)
if (!localDir.exists) {
foundLocalDir = localDir.mkdirs()
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
}
}
if (!foundLocalDir) {
logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
" attempts to create local dir in " + rootDir)
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
}
logInfo("Created local directory at " + localDir)
localDir
}
}
private def addShutdownHook() {
localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
localDirs.foreach { localDir =>
try {
if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
} catch {
case t: Throwable =>
logError("Exception while deleting local spark dir: " + localDir, t)
}
}
if (shuffleSender != null) {
shuffleSender.stop()
}
}
})
}
private[storage] def startShuffleBlockSender(port: Int): Int = {
val pResolver = new PathResolver {
override def getAbsolutePath(blockIdString: String): String = {
val blockId = BlockId(blockIdString)
if (!blockId.isShuffle) null
else DiskStore.this.getFile(blockId).getAbsolutePath
}
}
shuffleSender = new ShuffleSender(port, pResolver)
logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
shuffleSender.port
val file = diskManager.getBlockLocation(blockId).file
file.exists()
}
}

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import java.io.File
/**
* References a particular segment of a file (potentially the entire file),
* based off an offset and a length.
*/
private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) {
override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
}

View file

@ -17,12 +17,13 @@
package org.apache.spark.storage
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark.serializer.Serializer
private[spark]
class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter])
private[spark]
trait ShuffleBlocks {
@ -30,24 +31,61 @@ trait ShuffleBlocks {
def releaseWriters(group: ShuffleWriterGroup)
}
/**
* Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer
* per reducer.
*
* As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
* blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
* per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files,
* it releases them for another task.
* Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
* - shuffleId: The unique id given to the entire shuffle stage.
* - bucketId: The id of the output partition (i.e., reducer id)
* - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
* time owns a particular fileId, and this id is returned to a pool when the task finishes.
*/
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
// TODO: Remove this once the shuffle file consolidation feature is stable.
val consolidateShuffleFiles =
System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
var nextFileId = new AtomicInteger(0)
val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]()
def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = {
new ShuffleBlocks {
// Get a group of writers for a map task.
override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val fileId = getUnusedFileId()
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
val filename = physicalFileName(shuffleId, bucketId, fileId)
blockManager.getDiskWriter(blockId, filename, serializer, bufferSize)
}
new ShuffleWriterGroup(mapId, writers)
new ShuffleWriterGroup(mapId, fileId, writers)
}
override def releaseWriters(group: ShuffleWriterGroup) = {
// Nothing really to release here.
override def releaseWriters(group: ShuffleWriterGroup) {
recycleFileId(group.fileId)
}
}
}
private def getUnusedFileId(): Int = {
val fileId = unusedFileIds.poll()
if (fileId == null) nextFileId.getAndIncrement() else fileId
}
private def recycleFileId(fileId: Int) {
if (!consolidateShuffleFiles) { return } // ensures we always generate new file id
unusedFileIds.add(fileId)
}
private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
"merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
}
}

View file

@ -7,9 +7,11 @@ import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.SparkContext
import org.apache.spark.util.Utils
/** Utility for micro-benchmarking shuffle write performance.
*
* Writes simulated shuffle output from several threads and records the observed throughput*/
/**
* Utility for micro-benchmarking shuffle write performance.
*
* Writes simulated shuffle output from several threads and records the observed throughput.
*/
object StoragePerfTester {
def main(args: Array[String]) = {
/** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */
@ -44,7 +46,7 @@ object StoragePerfTester {
}
buckets.writers.map {w =>
w.commit()
total.addAndGet(w.size())
total.addAndGet(w.fileSegment().length)
w.close()
}

View file

@ -38,7 +38,7 @@ private[spark] class IndexPage(parent: JobProgressUI) {
val now = System.currentTimeMillis()
var activeTime = 0L
for (tasks <- listener.stageToTasksActive.values; t <- tasks) {
for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) {
activeTime += t.timeRunning(now)
}

View file

@ -36,52 +36,52 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt
val DEFAULT_POOL_NAME = "default"
val stageToPool = new HashMap[Stage, String]()
val stageToDescription = new HashMap[Stage, String]()
val poolToActiveStages = new HashMap[String, HashSet[Stage]]()
val stageIdToPool = new HashMap[Int, String]()
val stageIdToDescription = new HashMap[Int, String]()
val poolToActiveStages = new HashMap[String, HashSet[StageInfo]]()
val activeStages = HashSet[Stage]()
val completedStages = ListBuffer[Stage]()
val failedStages = ListBuffer[Stage]()
val activeStages = HashSet[StageInfo]()
val completedStages = ListBuffer[StageInfo]()
val failedStages = ListBuffer[StageInfo]()
// Total metrics reflect metrics only for completed tasks
var totalTime = 0L
var totalShuffleRead = 0L
var totalShuffleWrite = 0L
val stageToTime = HashMap[Int, Long]()
val stageToShuffleRead = HashMap[Int, Long]()
val stageToShuffleWrite = HashMap[Int, Long]()
val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
val stageToTasksComplete = HashMap[Int, Int]()
val stageToTasksFailed = HashMap[Int, Int]()
val stageToTaskInfos =
val stageIdToTime = HashMap[Int, Long]()
val stageIdToShuffleRead = HashMap[Int, Long]()
val stageIdToShuffleWrite = HashMap[Int, Long]()
val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
val stageIdToTasksComplete = HashMap[Int, Int]()
val stageIdToTasksFailed = HashMap[Int, Int]()
val stageIdToTaskInfos =
HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
override def onJobStart(jobStart: SparkListenerJobStart) {}
override def onStageCompleted(stageCompleted: StageCompleted) = synchronized {
val stage = stageCompleted.stageInfo.stage
poolToActiveStages(stageToPool(stage)) -= stage
val stage = stageCompleted.stage
poolToActiveStages(stageIdToPool(stage.stageId)) -= stage
activeStages -= stage
completedStages += stage
trimIfNecessary(completedStages)
}
/** If stages is too large, remove and garbage collect old stages */
def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized {
def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > RETAINED_STAGES) {
val toRemove = RETAINED_STAGES / 10
stages.takeRight(toRemove).foreach( s => {
stageToTaskInfos.remove(s.id)
stageToTime.remove(s.id)
stageToShuffleRead.remove(s.id)
stageToShuffleWrite.remove(s.id)
stageToTasksActive.remove(s.id)
stageToTasksComplete.remove(s.id)
stageToTasksFailed.remove(s.id)
stageToPool.remove(s)
if (stageToDescription.contains(s)) {stageToDescription.remove(s)}
stageIdToTaskInfos.remove(s.stageId)
stageIdToTime.remove(s.stageId)
stageIdToShuffleRead.remove(s.stageId)
stageIdToShuffleWrite.remove(s.stageId)
stageIdToTasksActive.remove(s.stageId)
stageIdToTasksComplete.remove(s.stageId)
stageIdToTasksFailed.remove(s.stageId)
stageIdToPool.remove(s.stageId)
if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)}
})
stages.trimEnd(toRemove)
}
@ -95,25 +95,25 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val poolName = Option(stageSubmitted.properties).map {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
stageToPool(stage) = poolName
stageIdToPool(stage.stageId) = poolName
val description = Option(stageSubmitted.properties).flatMap {
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
}
description.map(d => stageToDescription(stage) = d)
description.map(d => stageIdToDescription(stage.stageId) = d)
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]())
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]())
stages += stage
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
val sid = taskStart.task.stageId
val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive += taskStart.taskInfo
val taskList = stageToTaskInfos.getOrElse(
val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList += ((taskStart.taskInfo, None, None))
stageToTaskInfos(sid) = taskList
stageIdToTaskInfos(sid) = taskList
}
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult)
@ -124,40 +124,40 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive -= taskEnd.taskInfo
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
case e: ExceptionFailure =>
stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1
stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1
(Some(e), e.metrics)
case _ =>
stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1
stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}
stageToTime.getOrElseUpdate(sid, 0L)
stageIdToTime.getOrElseUpdate(sid, 0L)
val time = metrics.map(m => m.executorRunTime).getOrElse(0)
stageToTime(sid) += time
stageIdToTime(sid) += time
totalTime += time
stageToShuffleRead.getOrElseUpdate(sid, 0L)
stageIdToShuffleRead.getOrElseUpdate(sid, 0L)
val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s =>
s.remoteBytesRead).getOrElse(0L)
stageToShuffleRead(sid) += shuffleRead
stageIdToShuffleRead(sid) += shuffleRead
totalShuffleRead += shuffleRead
stageToShuffleWrite.getOrElseUpdate(sid, 0L)
stageIdToShuffleWrite.getOrElseUpdate(sid, 0L)
val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s =>
s.shuffleBytesWritten).getOrElse(0L)
stageToShuffleWrite(sid) += shuffleWrite
stageIdToShuffleWrite(sid) += shuffleWrite
totalShuffleWrite += shuffleWrite
val taskList = stageToTaskInfos.getOrElse(
val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList -= ((taskEnd.taskInfo, None, None))
taskList += ((taskEnd.taskInfo, metrics, failureInfo))
stageToTaskInfos(sid) = taskList
stageIdToTaskInfos(sid) = taskList
}
override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
@ -165,10 +165,15 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
case end: SparkListenerJobEnd =>
end.jobResult match {
case JobFailed(ex, Some(stage)) =>
activeStages -= stage
poolToActiveStages(stageToPool(stage)) -= stage
failedStages += stage
trimIfNecessary(failedStages)
/* If two jobs share a stage we could get this failure message twice. So we first
* check whether we've already retired this stage. */
val stageInfo = activeStages.filter(s => s.stageId == stage.id).headOption
stageInfo.foreach {s =>
activeStages -= s
poolToActiveStages(stageIdToPool(stage.id)) -= s
failedStages += s
trimIfNecessary(failedStages)
}
case _ =>
}
case _ =>

View file

@ -21,13 +21,13 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.xml.Node
import org.apache.spark.scheduler.{Schedulable, Stage}
import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) {
var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages
var poolToActiveStages: HashMap[String, HashSet[StageInfo]] = listener.poolToActiveStages
def toNodeSeq(): Seq[Node] = {
listener.synchronized {
@ -35,7 +35,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
}
}
private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node],
private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[StageInfo]]) => Seq[Node],
rows: Seq[Schedulable]
): Seq[Node] = {
<table class="table table-bordered table-striped table-condensed sortable table-fixed">
@ -53,7 +53,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
</table>
}
private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]])
private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[StageInfo]])
: Seq[Node] = {
val activeStages = poolToActiveStages.get(p.name) match {
case Some(stages) => stages.size

View file

@ -40,7 +40,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val stageId = request.getParameter("id").toInt
val now = System.currentTimeMillis()
if (!listener.stageToTaskInfos.contains(stageId)) {
if (!listener.stageIdToTaskInfos.contains(stageId)) {
val content =
<div>
<h4>Summary Metrics</h4> No tasks have started yet
@ -49,23 +49,23 @@ private[spark] class StagePage(parent: JobProgressUI) {
return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages)
}
val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
val tasks = listener.stageIdToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
val numCompleted = tasks.count(_._1.finished)
val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L)
val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L)
val hasShuffleRead = shuffleReadBytes > 0
val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L)
val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L)
val hasShuffleWrite = shuffleWriteBytes > 0
var activeTime = 0L
listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
val summary =
<div>
<ul class="unstyled">
<li>
<strong>CPU time: </strong>
{parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)}
{parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
</li>
{if (hasShuffleRead)
<li>

View file

@ -22,13 +22,13 @@ import java.util.Date
import scala.xml.Node
import scala.collection.mutable.HashSet
import org.apache.spark.scheduler.{SchedulingMode, Stage, TaskInfo}
import org.apache.spark.scheduler.{SchedulingMode, StageInfo, TaskInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
/** Page showing list of all ongoing and recently finished stages */
private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) {
private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgressUI) {
val listener = parent.listener
val dateFmt = parent.dateFmt
@ -73,40 +73,40 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU
}
private def stageRow(s: Stage): Seq[Node] = {
private def stageRow(s: StageInfo): Seq[Node] = {
val submissionTime = s.submissionTime match {
case Some(t) => dateFmt.format(new Date(t))
case None => "Unknown"
}
val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match {
val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match {
val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size
val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0)
val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match {
val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashSet[TaskInfo]()).size
val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0)
val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match {
case f if f > 0 => "(%s failed)".format(f)
case _ => ""
}
val totalTasks = s.numPartitions
val totalTasks = s.numTasks
val poolName = listener.stageToPool.get(s)
val poolName = listener.stageIdToPool.get(s.stageId)
val nameLink =
<a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.id)}>{s.name}</a>
val description = listener.stageToDescription.get(s)
<a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.stageId)}>{s.toString}</a>
val description = listener.stageIdToDescription.get(s.stageId)
.map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink)
val finishTime = s.completionTime.getOrElse(System.currentTimeMillis())
val duration = s.submissionTime.map(t => finishTime - t)
<tr>
<td>{s.id}</td>
<td>{s.stageId}</td>
{if (isFairScheduler) {
<td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),poolName.get)}>
{poolName.get}</a></td>}

View file

@ -56,9 +56,10 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea
}
object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
"ShuffleMapTask", "BlockManager", "BroadcastVars") {
"ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, BROADCAST_VARS = Value
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value

View file

@ -48,15 +48,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
val tracker = new MapOutputTrackerMaster()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.stop()
}
test("master register and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
val tracker = new MapOutputTrackerMaster()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@ -74,19 +74,17 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
val tracker = new MapOutputTrackerMaster()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simulatenous fetch failures
// As if we had two simultaneous fetch failures
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
@ -102,9 +100,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
val masterTracker = new MapOutputTracker()
val masterTracker = new MapOutputTrackerMaster()
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
val slaveTracker = new MapOutputTracker()

View file

@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark.LocalSparkContext
import org.apache.spark.MapOutputTracker
import org.apache.spark.MapOutputTrackerMaster
import org.apache.spark.SparkContext
import org.apache.spark.Partition
import org.apache.spark.TaskContext
@ -64,7 +64,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def defaultParallelism() = 2
}
var mapOutputTracker: MapOutputTracker = null
var mapOutputTracker: MapOutputTrackerMaster = null
var scheduler: DAGScheduler = null
/**
@ -99,7 +99,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSets.clear()
cacheLocations.clear()
results.clear()
mapOutputTracker = new MapOutputTracker()
mapOutputTracker = new MapOutputTrackerMaster()
scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
override def runLocally(job: ActiveJob) {
// don't bother with the thread while unit testing

View file

@ -58,10 +58,13 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val parentRdd = makeRdd(4, Nil)
val shuffleDep = new ShuffleDependency(parentRdd, null)
val rootRdd = makeRdd(4, List(shuffleDep))
val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None)
val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None)
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null))
val shuffleMapStage =
new Stage(1, parentRdd, parentRdd.partitions.size, Some(shuffleDep), Nil, jobID, None)
val rootStage =
new Stage(0, rootRdd, rootRdd.partitions.size, None, List(shuffleMapStage), jobID, None)
val rootStageInfo = new StageInfo(rootStage)
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")

View file

@ -19,17 +19,57 @@ package org.apache.spark.scheduler
import scala.collection.mutable.{Buffer, HashSet}
import org.scalatest.FunSuite
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.{SparkContext, LocalSparkContext}
import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
with BeforeAndAfter {
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
before {
sc = new SparkContext("local", "DAGSchedulerSuite")
}
test("basic creation of StageInfo") {
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
val rdd2 = rdd1.map(x => x.toString)
rdd2.setName("Target RDD")
rdd2.count
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be {1}
val first = listener.stageInfos.head
first.rddName should be {"Target RDD"}
first.numTasks should be {4}
first.numPartitions should be {4}
first.submissionTime should be ('defined)
first.completionTime should be ('defined)
first.taskInfos.length should be {4}
}
test("StageInfo with fewer tasks than partitions") {
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
val rdd2 = rdd1.map(x => x.toString)
sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true)
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be {1}
val first = listener.stageInfos.head
first.numTasks should be {2}
first.numPartitions should be {4}
}
test("local metrics") {
sc = new SparkContext("local[4]", "test")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@ -66,7 +106,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong},
stageInfo + " executorDeserializeTime")
if (stageInfo.stage.rdd.name == d4.name) {
if (stageInfo.rddName == d4.name) {
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime},
stageInfo + " fetchWaitTime")
@ -74,11 +114,11 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
stageInfo.taskInfos.foreach { case (taskInfo, taskMetrics) =>
taskMetrics.resultSize should be > (0l)
if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
if (stageInfo.rddName == d2.name || stageInfo.rddName == d3.name) {
taskMetrics.shuffleWriteMetrics should be ('defined)
taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
}
if (stageInfo.stage.rdd.name == d4.name) {
if (stageInfo.rddName == d4.name) {
taskMetrics.shuffleReadMetrics should be ('defined)
val sm = taskMetrics.shuffleReadMetrics.get
sm.totalBlocksFetched should be > (0)
@ -136,15 +176,10 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
assert(m.sum / m.size.toDouble > 0.0, msg)
}
def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
!names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
}
class SaveStageInfo extends SparkListener {
val stageInfos = Buffer[StageInfo]()
override def onStageCompleted(stage: StageCompleted) {
stageInfos += stage.stageInfo
stageInfos += stage.stage
}
}

View file

@ -131,6 +131,17 @@ sc = SparkContext("local", "App Name", pyFiles=['MyFile.py', 'lib.zip', 'app.egg
Files listed here will be added to the `PYTHONPATH` and shipped to remote worker machines.
Code dependencies can be added to an existing SparkContext using its `addPyFile()` method.
You can set [system properties](configuration.html#system-properties)
using `SparkContext.setSystemProperty()` class method *before*
instantiating SparkContext. For example, to set the amount of memory
per executor process:
{% highlight python %}
from pyspark import SparkContext
SparkContext.setSystemProperty('spark.executor.memory', '2g')
sc = SparkContext("local", "App Name")
{% endhighlight %}
# API Docs
[API documentation](api/pyspark/index.html) for PySpark is available as Epydoc.

View file

@ -142,7 +142,7 @@ All transformations in Spark are <i>lazy</i>, in that they do not compute their
By default, each transformed RDD is recomputed each time you run an action on it. However, you may also *persist* an RDD in memory using the `persist` (or `cache`) method, in which case Spark will keep the elements around on the cluster for much faster access the next time you query it. There is also support for persisting datasets on disk, or replicated across the cluster. The next section in this document describes these options.
The following tables list the transformations and actions currently supported (see also the [RDD API doc](api/core/index.html#org.apache.spark.RDD) for details):
The following tables list the transformations and actions currently supported (see also the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD) for details):
### Transformations
@ -211,7 +211,7 @@ The following tables list the transformations and actions currently supported (s
</tr>
</table>
A complete list of transformations is available in the [RDD API doc](api/core/index.html#org.apache.spark.RDD).
A complete list of transformations is available in the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD).
### Actions
@ -259,7 +259,7 @@ A complete list of transformations is available in the [RDD API doc](api/core/in
</tr>
</table>
A complete list of actions is available in the [RDD API doc](api/core/index.html#org.apache.spark.RDD).
A complete list of actions is available in the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD).
## RDD Persistence

View file

@ -32,13 +32,20 @@
<url>http://spark.incubator.apache.org/</url>
<repositories>
<!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
<repository>
<id>lib</id>
<url>file://${project.basedir}/lib</url>
<id>apache-repo</id>
<name>Apache Repository</name>
<url>https://repository.apache.org/content/repositories/releases</url>
<releases>
<enabled>true</enabled>
</releases>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
@ -81,9 +88,18 @@
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka</artifactId>
<version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
<scope>provided</scope>
<artifactId>kafka_2.9.2</artifactId>
<version>0.8.0-beta1</version>
<exclusions>
<exclusion>
<groupId>com.sun.jmx</groupId>
<artifactId>jmxri</artifactId>
</exclusion>
<exclusion>
<groupId>com.sun.jdmk</groupId>
<artifactId>jmxtools</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>

View file

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.examples;
import java.util.Map;
import java.util.HashMap;
import com.google.common.collect.Lists;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import scala.Tuple2;
/**
* Consumes messages from one or more topics in Kafka and does wordcount.
* Usage: JavaKafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>
* <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
* <zkQuorum> is a list of one or more zookeeper servers that make quorum
* <group> is the name of kafka consumer group
* <topics> is a list of one or more kafka topics to consume from
* <numThreads> is the number of threads the kafka consumer should use
*
* Example:
* `./run-example org.apache.spark.streaming.examples.JavaKafkaWordCount local[2] zoo01,zoo02,
* zoo03 my-consumer-group topic1,topic2 1`
*/
public class JavaKafkaWordCount {
public static void main(String[] args) {
if (args.length < 5) {
System.err.println("Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>");
System.exit(1);
}
// Create the context with a 1 second batch size
JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount",
new Duration(2000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
int numThreads = Integer.parseInt(args[4]);
Map<String, Integer> topicMap = new HashMap<String, Integer>();
String[] topics = args[3].split(",");
for (String topic: topics) {
topicMap.put(topic, numThreads);
}
JavaPairDStream<String, String> messages = ssc.kafkaStream(args[1], args[2], topicMap);
JavaDStream<String> lines = messages.map(new Function<Tuple2<String, String>, String>() {
@Override
public String call(Tuple2<String, String> tuple2) throws Exception {
return tuple2._2();
}
});
JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
@Override
public Iterable<String> call(String x) {
return Lists.newArrayList(x.split(" "));
}
});
JavaPairDStream<String, Integer> wordCounts = words.map(
new PairFunction<String, String, Integer>() {
@Override
public Tuple2<String, Integer> call(String s) throws Exception {
return new Tuple2<String, Integer>(s, 1);
}
}).reduceByKey(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer i1, Integer i2) throws Exception {
return i1 + i2;
}
});
wordCounts.print();
ssc.start();
}
}

View file

@ -18,13 +18,11 @@
package org.apache.spark.streaming.examples
import java.util.Properties
import kafka.message.Message
import kafka.producer.SyncProducerConfig
import kafka.producer._
import org.apache.spark.SparkContext
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.util.RawTextHelper._
/**
@ -54,9 +52,10 @@ object KafkaWordCount {
ssc.checkpoint("checkpoint")
val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
val lines = ssc.kafkaStream(zkQuorum, group, topicpMap)
val lines = ssc.kafkaStream(zkQuorum, group, topicpMap).map(_._2)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
val wordCounts = words.map(x => (x, 1l))
.reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
wordCounts.print()
ssc.start()
@ -68,15 +67,16 @@ object KafkaWordCountProducer {
def main(args: Array[String]) {
if (args.length < 2) {
System.err.println("Usage: KafkaWordCountProducer <zkQuorum> <topic> <messagesPerSec> <wordsPerMessage>")
System.err.println("Usage: KafkaWordCountProducer <metadataBrokerList> <topic> " +
"<messagesPerSec> <wordsPerMessage>")
System.exit(1)
}
val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args
val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args
// Zookeper connection properties
val props = new Properties()
props.put("zk.connect", zkQuorum)
props.put("metadata.broker.list", brokers)
props.put("serializer.class", "kafka.serializer.StringEncoder")
val config = new ProducerConfig(props)
@ -85,11 +85,13 @@ object KafkaWordCountProducer {
// Send some messages
while(true) {
val messages = (1 to messagesPerSec.toInt).map { messageNum =>
(1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ")
val str = (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString)
.mkString(" ")
new KeyedMessage[String, String](topic, str)
}.toArray
println(messages.mkString(","))
val data = new ProducerData[String, String](topic, messages)
producer.send(data)
producer.send(messages: _*)
Thread.sleep(100)
}
}

View file

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.examples
import org.apache.spark.streaming.{ Seconds, StreamingContext }
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.MQTTReceiver
import org.apache.spark.storage.StorageLevel
import org.eclipse.paho.client.mqttv3.MqttClient
import org.eclipse.paho.client.mqttv3.MqttClientPersistence
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
import org.eclipse.paho.client.mqttv3.MqttException
import org.eclipse.paho.client.mqttv3.MqttMessage
import org.eclipse.paho.client.mqttv3.MqttTopic
/**
* A simple Mqtt publisher for demonstration purposes, repeatedly publishes
* Space separated String Message "hello mqtt demo for spark streaming"
*/
object MQTTPublisher {
var client: MqttClient = _
def main(args: Array[String]) {
if (args.length < 2) {
System.err.println("Usage: MQTTPublisher <MqttBrokerUrl> <topic>")
System.exit(1)
}
val Seq(brokerUrl, topic) = args.toSeq
try {
var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp")
client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance)
} catch {
case e: MqttException => println("Exception Caught: " + e)
}
client.connect()
val msgtopic: MqttTopic = client.getTopic(topic);
val msg: String = "hello mqtt demo for spark streaming"
while (true) {
val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes())
msgtopic.publish(message);
println("Published data. topic: " + msgtopic.getName() + " Message: " + message)
}
client.disconnect()
}
}
/**
* A sample wordcount with MqttStream stream
*
* To work with Mqtt, Mqtt Message broker/server required.
* Mosquitto (http://mosquitto.org/) is an open source Mqtt Broker
* In ubuntu mosquitto can be installed using the command `$ sudo apt-get install mosquitto`
* Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/
* Example Java code for Mqtt Publisher and Subscriber can be found here https://bitbucket.org/mkjinesh/mqttclient
* Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>
* In local mode, <master> should be 'local[n]' with n > 1
* <MqttbrokerUrl> and <topic> describe where Mqtt publisher is running.
*
* To run this example locally, you may run publisher as
* `$ ./run-example org.apache.spark.streaming.examples.MQTTPublisher tcp://localhost:1883 foo`
* and run the example as
* `$ ./run-example org.apache.spark.streaming.examples.MQTTWordCount local[2] tcp://localhost:1883 foo`
*/
object MQTTWordCount {
def main(args: Array[String]) {
if (args.length < 3) {
System.err.println(
"Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>" +
" In local mode, <master> should be 'local[n]' with n > 1")
System.exit(1)
}
val Seq(master, brokerUrl, topic) = args.toSeq
val ssc = new StreamingContext(master, "MqttWordCount", Seconds(2), System.getenv("SPARK_HOME"),
Seq(System.getenv("SPARK_EXAMPLES_JAR")))
val lines = ssc.mqttStream(brokerUrl, topic, StorageLevel.MEMORY_ONLY)
val words = lines.flatMap(x => x.toString.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
wordCounts.print()
ssc.start()
}
}

View file

@ -108,7 +108,10 @@ object SparkBuild extends Build {
// Shared between both core and streaming.
resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"),
// For Sonatype publishing
// Shared between both examples and streaming.
resolvers ++= Seq("Mqtt Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/"),
// For Sonatype publishing
resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"),
@ -278,13 +281,17 @@ object SparkBuild extends Build {
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
resolvers ++= Seq(
"Akka Repository" at "http://repo.akka.io/releases/"
"Akka Repository" at "http://repo.akka.io/releases/",
"Apache repo" at "https://repository.apache.org/content/repositories/releases"
),
libraryDependencies ++= Seq(
"org.eclipse.paho" % "mqtt-client" % "0.4.0",
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy),
"com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty),
"org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
"com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty)
"com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty),
"org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1"
exclude("com.sun.jdmk", "jmxtools")
exclude("com.sun.jmx", "jmxri")
)
)

View file

@ -49,6 +49,7 @@ class SparkContext(object):
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
"""
@ -66,19 +67,18 @@ class SparkContext(object):
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
>>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
"""
with SparkContext._lock:
if SparkContext._active_spark_context:
raise ValueError("Cannot run multiple SparkContexts at once")
else:
SparkContext._active_spark_context = self
if not SparkContext._gateway:
SparkContext._gateway = launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeIteratorToPickleFile = \
SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
SparkContext._takePartition = \
SparkContext._jvm.PythonRDD.takePartition
SparkContext._ensure_initialized(self)
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
@ -119,6 +119,32 @@ class SparkContext(object):
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
@classmethod
def _ensure_initialized(cls, instance=None):
with SparkContext._lock:
if not SparkContext._gateway:
SparkContext._gateway = launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeIteratorToPickleFile = \
SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
SparkContext._takePartition = \
SparkContext._jvm.PythonRDD.takePartition
if instance:
if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
raise ValueError("Cannot run multiple SparkContexts at once")
else:
SparkContext._active_spark_context = instance
@classmethod
def setSystemProperty(cls, key, value):
"""
Set a system property, such as spark.executor.memory. This must be
invoked before instantiating SparkContext.
"""
SparkContext._ensure_initialized()
SparkContext._jvm.java.lang.System.setProperty(key, value)
@property
def defaultParallelism(self):
"""

View file

@ -1 +0,0 @@
18876b8bc2e4cef28b6d191aa49d963f

View file

@ -1 +0,0 @@
06b27270ffa52250a2c08703b397c99127b72060

View file

@ -1,9 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
<modelVersion>4.0.0</modelVersion>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka</artifactId>
<version>0.7.2-spark</version>
<description>POM was created from install:install-file</description>
</project>

View file

@ -1 +0,0 @@
7bc4322266e6032bdf9ef6eebdd8097d

View file

@ -1 +0,0 @@
d0f79e8eff0db43ca7bcf7dce2c8cd2972685c9d

View file

@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<metadata>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka</artifactId>
<versioning>
<release>0.7.2-spark</release>
<versions>
<version>0.7.2-spark</version>
</versions>
<lastUpdated>20130121015225</lastUpdated>
</versioning>
</metadata>

View file

@ -1 +0,0 @@
e2b9c7c5f6370dd1d21a0aae5e8dcd77

View file

@ -1 +0,0 @@
2a4341da936b6c07a09383d17ffb185ac558ee91

View file

@ -32,10 +32,16 @@
<url>http://spark.incubator.apache.org/</url>
<repositories>
<!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
<repository>
<id>lib</id>
<url>file://${project.basedir}/lib</url>
<id>apache-repo</id>
<name>Apache Repository</name>
<url>https://repository.apache.org/content/repositories/releases</url>
<releases>
<enabled>true</enabled>
</releases>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
</repositories>
@ -56,9 +62,18 @@
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka</artifactId>
<version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
<scope>provided</scope>
<artifactId>kafka_2.9.2</artifactId>
<version>0.8.0-beta1</version>
<exclusions>
<exclusion>
<groupId>com.sun.jmx</groupId>
<artifactId>jmxri</artifactId>
</exclusion>
<exclusion>
<groupId>com.sun.jdmk</groupId>
<artifactId>jmxtools</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.flume</groupId>
@ -75,17 +90,6 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.github.sgroschupf</groupId>
<artifactId>zkclient</artifactId>
<version>0.1</version>
<exclusions>
<exclusion>
<groupId>org.jboss.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.twitter4j</groupId>
<artifactId>twitter4j-stream</artifactId>
@ -132,6 +136,11 @@
<artifactId>slf4j-log4j12</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.eclipse.paho</groupId>
<artifactId>mqtt-client</artifactId>
<version>0.4.0</version>
</dependency>
</dependencies>
<build>
<outputDirectory>target/scala-${scala.version}/classes</outputDirectory>

View file

@ -256,10 +256,14 @@ class StreamingContext private (
groupId: String,
topics: Map[String, Int],
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2
): DStream[String] = {
): DStream[(String, String)] = {
val kafkaParams = Map[String, String](
"zk.connect" -> zkQuorum, "groupid" -> groupId, "zk.connectiontimeout.ms" -> "10000")
kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, storageLevel)
"zookeeper.connect" -> zkQuorum, "group.id" -> groupId,
"zookeeper.connection.timeout.ms" -> "10000")
kafkaStream[String, String, kafka.serializer.StringDecoder, kafka.serializer.StringDecoder](
kafkaParams,
topics,
storageLevel)
}
/**
@ -270,12 +274,16 @@ class StreamingContext private (
* in its own thread.
* @param storageLevel Storage level to use for storing the received objects
*/
def kafkaStream[T: ClassManifest, D <: kafka.serializer.Decoder[_]: Manifest](
def kafkaStream[
K: ClassManifest,
V: ClassManifest,
U <: kafka.serializer.Decoder[_]: Manifest,
T <: kafka.serializer.Decoder[_]: Manifest](
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
): DStream[T] = {
val inputStream = new KafkaInputDStream[T, D](this, kafkaParams, topics, storageLevel)
): DStream[(K, V)] = {
val inputStream = new KafkaInputDStream[K, V, U, T](this, kafkaParams, topics, storageLevel)
registerInputStream(inputStream)
inputStream
}
@ -454,6 +462,21 @@ class StreamingContext private (
inputStream
}
/**
* Create an input stream that receives messages pushed by a mqtt publisher.
* @param brokerUrl Url of remote mqtt publisher
* @param topic topic name to subscribe to
* @param storageLevel RDD storage level. Defaults to memory-only.
*/
def mqttStream(
brokerUrl: String,
topic: String,
storageLevel: StorageLevel = StorageLevel.MEMORY_ONLY_SER_2): DStream[String] = {
val inputStream = new MQTTInputDStream[String](this, brokerUrl, topic, storageLevel)
registerInputStream(inputStream)
inputStream
}
/**
* Create a unified DStream from multiple DStreams of the same type and same interval
*/

View file

@ -141,7 +141,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
zkQuorum: String,
groupId: String,
topics: JMap[String, JInt])
: JavaDStream[String] = {
: JavaPairDStream[String, String] = {
implicit val cmt: ClassManifest[String] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
@ -162,7 +162,7 @@ class JavaStreamingContext(val ssc: StreamingContext) {
groupId: String,
topics: JMap[String, JInt],
storageLevel: StorageLevel)
: JavaDStream[String] = {
: JavaPairDStream[String, String] = {
implicit val cmt: ClassManifest[String] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[String]]
ssc.kafkaStream(zkQuorum, groupId, Map(topics.mapValues(_.intValue()).toSeq: _*),
@ -171,25 +171,34 @@ class JavaStreamingContext(val ssc: StreamingContext) {
/**
* Create an input stream that pulls messages form a Kafka Broker.
* @param typeClass Type of RDD
* @param decoderClass Type of kafka decoder
* @param keyTypeClass Key type of RDD
* @param valueTypeClass value type of RDD
* @param keyDecoderClass Type of kafka key decoder
* @param valueDecoderClass Type of kafka value decoder
* @param kafkaParams Map of kafka configuration paramaters.
* See: http://kafka.apache.org/configuration.html
* @param topics Map of (topic_name -> numPartitions) to consume. Each partition is consumed
* in its own thread.
* @param storageLevel RDD storage level. Defaults to memory-only
*/
def kafkaStream[T, D <: kafka.serializer.Decoder[_]](
typeClass: Class[T],
decoderClass: Class[D],
def kafkaStream[K, V, U <: kafka.serializer.Decoder[_], T <: kafka.serializer.Decoder[_]](
keyTypeClass: Class[K],
valueTypeClass: Class[V],
keyDecoderClass: Class[U],
valueDecoderClass: Class[T],
kafkaParams: JMap[String, String],
topics: JMap[String, JInt],
storageLevel: StorageLevel)
: JavaDStream[T] = {
implicit val cmt: ClassManifest[T] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
implicit val cmd: Manifest[D] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[D]]
ssc.kafkaStream[T, D](
: JavaPairDStream[K, V] = {
implicit val keyCmt: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
implicit val valueCmt: ClassManifest[V] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
implicit val keyCmd: Manifest[U] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[U]]
implicit val valueCmd: Manifest[T] = implicitly[Manifest[AnyRef]].asInstanceOf[Manifest[T]]
ssc.kafkaStream[K, V, U, T](
kafkaParams.toMap,
Map(topics.mapValues(_.intValue()).toSeq: _*),
storageLevel)

View file

@ -19,22 +19,18 @@ package org.apache.spark.streaming.dstream
import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Time, DStreamCheckpointData, StreamingContext}
import org.apache.spark.streaming.StreamingContext
import java.util.Properties
import java.util.concurrent.Executors
import kafka.consumer._
import kafka.message.{Message, MessageSet, MessageAndMetadata}
import kafka.serializer.Decoder
import kafka.utils.{Utils, ZKGroupTopicDirs}
import kafka.utils.ZkUtils._
import kafka.utils.VerifiableProperties
import kafka.utils.ZKStringSerializer
import org.I0Itec.zkclient._
import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
/**
@ -46,25 +42,32 @@ import scala.collection.JavaConversions._
* @param storageLevel RDD storage level.
*/
private[streaming]
class KafkaInputDStream[T: ClassManifest, D <: Decoder[_]: Manifest](
class KafkaInputDStream[
K: ClassManifest,
V: ClassManifest,
U <: Decoder[_]: Manifest,
T <: Decoder[_]: Manifest](
@transient ssc_ : StreamingContext,
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
) extends NetworkInputDStream[T](ssc_ ) with Logging {
) extends NetworkInputDStream[(K, V)](ssc_) with Logging {
def getReceiver(): NetworkReceiver[T] = {
new KafkaReceiver[T, D](kafkaParams, topics, storageLevel)
.asInstanceOf[NetworkReceiver[T]]
def getReceiver(): NetworkReceiver[(K, V)] = {
new KafkaReceiver[K, V, U, T](kafkaParams, topics, storageLevel)
.asInstanceOf[NetworkReceiver[(K, V)]]
}
}
private[streaming]
class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
class KafkaReceiver[
K: ClassManifest,
V: ClassManifest,
U <: Decoder[_]: Manifest,
T <: Decoder[_]: Manifest](
kafkaParams: Map[String, String],
topics: Map[String, Int],
storageLevel: StorageLevel
) extends NetworkReceiver[Any] {
// Handles pushing data into the BlockManager
@ -83,27 +86,34 @@ class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
// In case we are using multiple Threads to handle Kafka Messages
val executorPool = Executors.newFixedThreadPool(topics.values.reduce(_ + _))
logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("groupid"))
logInfo("Starting Kafka Consumer Stream with group: " + kafkaParams("group.id"))
// Kafka connection properties
val props = new Properties()
kafkaParams.foreach(param => props.put(param._1, param._2))
// Create the connection to the cluster
logInfo("Connecting to Zookeper: " + kafkaParams("zk.connect"))
logInfo("Connecting to Zookeper: " + kafkaParams("zookeeper.connect"))
val consumerConfig = new ConsumerConfig(props)
consumerConnector = Consumer.create(consumerConfig)
logInfo("Connected to " + kafkaParams("zk.connect"))
logInfo("Connected to " + kafkaParams("zookeeper.connect"))
// When autooffset.reset is defined, it is our responsibility to try and whack the
// consumer group zk node.
if (kafkaParams.contains("autooffset.reset")) {
tryZookeeperConsumerGroupCleanup(kafkaParams("zk.connect"), kafkaParams("groupid"))
if (kafkaParams.contains("auto.offset.reset")) {
tryZookeeperConsumerGroupCleanup(kafkaParams("zookeeper.connect"), kafkaParams("group.id"))
}
// Create Threads for each Topic/Message Stream we are listening
val decoder = manifest[D].erasure.newInstance.asInstanceOf[Decoder[T]]
val topicMessageStreams = consumerConnector.createMessageStreams(topics, decoder)
val keyDecoder = manifest[U].erasure.getConstructor(classOf[VerifiableProperties])
.newInstance(consumerConfig.props)
.asInstanceOf[Decoder[K]]
val valueDecoder = manifest[T].erasure.getConstructor(classOf[VerifiableProperties])
.newInstance(consumerConfig.props)
.asInstanceOf[Decoder[V]]
val topicMessageStreams = consumerConnector.createMessageStreams(
topics, keyDecoder, valueDecoder)
// Start the messages handler for each partition
topicMessageStreams.values.foreach { streams =>
@ -112,11 +122,12 @@ class KafkaReceiver[T: ClassManifest, D <: Decoder[_]: Manifest](
}
// Handles Kafka Messages
private class MessageHandler[T: ClassManifest](stream: KafkaStream[T]) extends Runnable {
private class MessageHandler[K: ClassManifest, V: ClassManifest](stream: KafkaStream[K, V])
extends Runnable {
def run() {
logInfo("Starting MessageHandler.")
for (msgAndMetadata <- stream) {
blockGenerator += msgAndMetadata.message
blockGenerator += (msgAndMetadata.key, msgAndMetadata.message)
}
}
}

View file

@ -0,0 +1,109 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.dstream
import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{ Time, DStreamCheckpointData, StreamingContext }
import java.util.Properties
import java.util.concurrent.Executors
import java.io.IOException
import org.eclipse.paho.client.mqttv3.MqttCallback
import org.eclipse.paho.client.mqttv3.MqttClient
import org.eclipse.paho.client.mqttv3.MqttClientPersistence
import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken
import org.eclipse.paho.client.mqttv3.MqttException
import org.eclipse.paho.client.mqttv3.MqttMessage
import org.eclipse.paho.client.mqttv3.MqttTopic
import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.collection.JavaConversions._
/**
* Input stream that subscribe messages from a Mqtt Broker.
* Uses eclipse paho as MqttClient http://www.eclipse.org/paho/
* @param brokerUrl Url of remote mqtt publisher
* @param topic topic name to subscribe to
* @param storageLevel RDD storage level.
*/
private[streaming]
class MQTTInputDStream[T: ClassManifest](
@transient ssc_ : StreamingContext,
brokerUrl: String,
topic: String,
storageLevel: StorageLevel
) extends NetworkInputDStream[T](ssc_) with Logging {
def getReceiver(): NetworkReceiver[T] = {
new MQTTReceiver(brokerUrl, topic, storageLevel)
.asInstanceOf[NetworkReceiver[T]]
}
}
private[streaming]
class MQTTReceiver(brokerUrl: String,
topic: String,
storageLevel: StorageLevel
) extends NetworkReceiver[Any] {
lazy protected val blockGenerator = new BlockGenerator(storageLevel)
def onStop() {
blockGenerator.stop()
}
def onStart() {
blockGenerator.start()
// Set up persistence for messages
var peristance: MqttClientPersistence = new MemoryPersistence()
// Initializing Mqtt Client specifying brokerUrl, clientID and MqttClientPersistance
var client: MqttClient = new MqttClient(brokerUrl, "MQTTSub", peristance)
// Connect to MqttBroker
client.connect()
// Subscribe to Mqtt topic
client.subscribe(topic)
// Callback automatically triggers as and when new message arrives on specified topic
var callback: MqttCallback = new MqttCallback() {
// Handles Mqtt message
override def messageArrived(arg0: String, arg1: MqttMessage) {
blockGenerator += new String(arg1.getPayload())
}
override def deliveryComplete(arg0: IMqttDeliveryToken) {
}
override def connectionLost(arg0: Throwable) {
logInfo("Connection lost " + arg0)
}
}
// Set up callback for MqttClient
client.setCallback(callback)
}
}

View file

@ -1220,14 +1220,20 @@ public class JavaAPISuite implements Serializable {
@Test
public void testKafkaStream() {
HashMap<String, Integer> topics = Maps.newHashMap();
JavaDStream test1 = ssc.kafkaStream("localhost:12345", "group", topics);
JavaDStream test2 = ssc.kafkaStream("localhost:12345", "group", topics,
JavaPairDStream<String, String> test1 = ssc.kafkaStream("localhost:12345", "group", topics);
JavaPairDStream<String, String> test2 = ssc.kafkaStream("localhost:12345", "group", topics,
StorageLevel.MEMORY_AND_DISK());
HashMap<String, String> kafkaParams = Maps.newHashMap();
kafkaParams.put("zk.connect","localhost:12345");
kafkaParams.put("groupid","consumer-group");
JavaDStream test3 = ssc.kafkaStream(String.class, StringDecoder.class, kafkaParams, topics,
kafkaParams.put("zookeeper.connect","localhost:12345");
kafkaParams.put("group.id","consumer-group");
JavaPairDStream<String, String> test3 = ssc.kafkaStream(
String.class,
String.class,
StringDecoder.class,
StringDecoder.class,
kafkaParams,
topics,
StorageLevel.MEMORY_AND_DISK());
}

View file

@ -268,8 +268,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
val test2 = ssc.kafkaStream("localhost:12345", "group", topics, StorageLevel.MEMORY_AND_DISK)
// Test specifying decoder
val kafkaParams = Map("zk.connect"->"localhost:12345","groupid"->"consumer-group")
val test3 = ssc.kafkaStream[String, kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
val kafkaParams = Map("zookeeper.connect"->"localhost:12345","group.id"->"consumer-group")
val test3 = ssc.kafkaStream[
String,
String,
kafka.serializer.StringDecoder,
kafka.serializer.StringDecoder](kafkaParams, topics, StorageLevel.MEMORY_AND_DISK)
}
}

View file

@ -265,11 +265,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
val env = new HashMap[String, String]()
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
Apps.addToEnvironment(env, Environment.CLASSPATH.name,
Environment.PWD.$() + Path.SEPARATOR + "*")
Client.populateHadoopClasspath(yarnConf, env)
Client.populateClasspath(yarnConf, log4jConfLocalRes != null, env)
env("SPARK_YARN_MODE") = "true"
env("SPARK_YARN_JAR_PATH") =
localResources("spark.jar").getResource().getScheme.toString() + "://" +
@ -451,4 +447,30 @@ object Client {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, c.trim)
}
}
def populateClasspath(conf: Configuration, addLog4j: Boolean, env: HashMap[String, String]) {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
// If log4j present, ensure ours overrides all others
if (addLog4j) {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
Path.SEPARATOR + "log4j.properties")
}
// normally the users app.jar is last in case conflicts with spark jars
val userClasspathFirst = System.getProperty("spark.yarn.user.classpath.first", "false")
.toBoolean
if (userClasspathFirst) {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
Path.SEPARATOR + "app.jar")
}
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
Path.SEPARATOR + "spark.jar")
Client.populateHadoopClasspath(conf, env)
if (!userClasspathFirst) {
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
Path.SEPARATOR + "app.jar")
}
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() +
Path.SEPARATOR + "*")
}
}

View file

@ -121,7 +121,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
// TODO: If the OOM is not recoverable by rescheduling it on different node, then do 'something' to fail job ... akin to blacklisting trackers in mapred ?
" -XX:OnOutOfMemoryError='kill %p' " +
JAVA_OPTS +
" org.apache.spark.executor.StandaloneExecutorBackend " +
" org.apache.spark.executor.CoarseGrainedExecutorBackend " +
masterAddress + " " +
slaveId + " " +
hostname + " " +
@ -216,10 +216,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S
def prepareEnvironment: HashMap[String, String] = {
val env = new HashMap[String, String]()
Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$())
Apps.addToEnvironment(env, Environment.CLASSPATH.name,
Environment.PWD.$() + Path.SEPARATOR + "*")
Client.populateHadoopClasspath(yarnConf, env)
Client.populateClasspath(yarnConf, System.getenv("SPARK_YARN_LOG4J_PATH") != null, env)
// allow users to specify some environment variables
Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV"))