Merge remote-tracking branch 'spark-upstream/master'

Conflicts:
	project/SparkBuild.scala
This commit is contained in:
Ankur Dave 2013-10-30 15:59:09 -07:00
commit 5064f9b2d2
123 changed files with 4007 additions and 3555 deletions

View file

@ -33,12 +33,26 @@ fi
# Build up classpath
CLASSPATH="$SPARK_CLASSPATH:$FWDIR/conf"
if [ -f "$FWDIR/RELEASE" ]; then
ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
# First check if we have a dependencies jar. If so, include binary classes with the deps jar
if [ -f "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar ]; then
CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
DEPS_ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*-deps.jar`
CLASSPATH="$CLASSPATH:$DEPS_ASSEMBLY_JAR"
else
ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
# Else use spark-assembly jar from either RELEASE or assembly directory
if [ -f "$FWDIR/RELEASE" ]; then
ASSEMBLY_JAR=`ls "$FWDIR"/jars/spark-assembly*.jar`
else
ASSEMBLY_JAR=`ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar`
fi
CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
fi
CLASSPATH="$CLASSPATH:$ASSEMBLY_JAR"
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
if [[ $SPARK_TESTING == 1 ]]; then

View file

@ -25,6 +25,7 @@ import io.netty.channel.ChannelInboundMessageHandlerAdapter;
import io.netty.channel.DefaultFileRegion;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.FileSegment;
class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@ -37,40 +38,34 @@ class FileServerHandler extends ChannelInboundMessageHandlerAdapter<String> {
@Override
public void messageReceived(ChannelHandlerContext ctx, String blockIdString) {
BlockId blockId = BlockId.apply(blockIdString);
String path = pResolver.getAbsolutePath(blockId.name());
// if getFilePath returns null, close the channel
if (path == null) {
FileSegment fileSegment = pResolver.getBlockLocation(blockId);
// if getBlockLocation returns null, close the channel
if (fileSegment == null) {
//ctx.close();
return;
}
File file = new File(path);
File file = fileSegment.file();
if (file.exists()) {
if (!file.isFile()) {
//logger.info("Not a file : " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
long length = file.length();
long length = fileSegment.length();
if (length > Integer.MAX_VALUE || length <= 0) {
//logger.info("too large file : " + file.getAbsolutePath() + " of size "+ length);
ctx.write(new FileHeader(0, blockId).buffer());
ctx.flush();
return;
}
int len = new Long(length).intValue();
//logger.info("Sending block "+blockId+" filelen = "+len);
//logger.info("header = "+ (new FileHeader(len, blockId)).buffer());
ctx.write((new FileHeader(len, blockId)).buffer());
try {
ctx.sendFile(new DefaultFileRegion(new FileInputStream(file)
.getChannel(), 0, file.length()));
.getChannel(), fileSegment.offset(), fileSegment.length()));
} catch (Exception e) {
//logger.warning("Exception when sending file : " + file.getAbsolutePath());
e.printStackTrace();
}
} else {
//logger.warning("File not found: " + file.getAbsolutePath());
ctx.write(new FileHeader(0, blockId).buffer());
}
ctx.flush();

View file

@ -17,13 +17,10 @@
package org.apache.spark.network.netty;
import org.apache.spark.storage.BlockId;
import org.apache.spark.storage.FileSegment;
public interface PathResolver {
/**
* Get the absolute path of the file
*
* @param fileId
* @return the absolute path of file
*/
public String getAbsolutePath(String fileId);
/** Get the file segment in which the given block resides. */
public FileSegment getBlockLocation(BlockId blockId);
}

View file

@ -17,20 +17,29 @@
package org.apache.hadoop.mapred
private[apache]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", "org.apache.hadoop.mapred.JobContext");
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[org.apache.hadoop.mapreduce.JobID])
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
"org.apache.hadoop.mapred.JobContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf],
classOf[org.apache.hadoop.mapreduce.JobID])
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = {
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", "org.apache.hadoop.mapred.TaskAttemptContext")
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
"org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
def newTaskAttemptID(
jtIdentifier: String,
jobId: Int,
isMap: Boolean,
taskId: Int,
attemptId: Int) = {
new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId)
}

View file

@ -17,9 +17,10 @@
package org.apache.hadoop.mapreduce
import org.apache.hadoop.conf.Configuration
import java.lang.{Integer => JInteger, Boolean => JBoolean}
import org.apache.hadoop.conf.Configuration
private[apache]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
@ -37,23 +38,31 @@ trait SparkHadoopMapReduceUtil {
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
def newTaskAttemptID(jtIdentifier: String, jobId: Int, isMap: Boolean, taskId: Int, attemptId: Int) = {
val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID");
def newTaskAttemptID(
jtIdentifier: String,
jobId: Int,
isMap: Boolean,
taskId: Int,
attemptId: Int) = {
val klass = Class.forName("org.apache.hadoop.mapreduce.TaskAttemptID")
try {
// first, attempt to use the old-style constructor that takes a boolean isMap (not available in YARN)
// First, attempt to use the old-style constructor that takes a boolean isMap
// (not available in YARN)
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean],
classOf[Int], classOf[Int])
ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), new
JInteger(attemptId)).asInstanceOf[TaskAttemptID]
classOf[Int], classOf[Int])
ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId),
new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
} catch {
case exc: NoSuchMethodException => {
// failed, look for the new ctor that takes a TaskType (not available in 1.x)
val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType").asInstanceOf[Class[Enum[_]]]
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(taskTypeClass, if(isMap) "MAP" else "REDUCE")
// If that failed, look for the new constructor that takes a TaskType (not available in 1.x)
val taskTypeClass = Class.forName("org.apache.hadoop.mapreduce.TaskType")
.asInstanceOf[Class[Enum[_]]]
val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke(
taskTypeClass, if(isMap) "MAP" else "REDUCE")
val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass,
classOf[Int], classOf[Int])
ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), new
JInteger(attemptId)).asInstanceOf[TaskAttemptID]
ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId),
new JInteger(attemptId)).asInstanceOf[TaskAttemptID]
}
}
}

View file

@ -20,13 +20,11 @@ package org.apache.spark
import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import akka.actor._
import akka.dispatch._
import akka.pattern.ask
import akka.remote._
import akka.util.Duration
@ -40,11 +38,12 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int, requester: String
extends MapOutputTrackerMessage
private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage
private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Actor with Logging {
private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster)
extends Actor with Logging {
def receive = {
case GetMapOutputStatuses(shuffleId: Int, requester: String) =>
logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + requester)
sender ! tracker.getSerializedLocations(shuffleId)
sender ! tracker.getSerializedMapOutputStatuses(shuffleId)
case StopMapOutputTracker =>
logInfo("MapOutputTrackerActor stopped!")
@ -60,22 +59,19 @@ private[spark] class MapOutputTracker extends Logging {
// Set to the MapOutputTrackerActor living on the driver
var trackerActor: ActorRef = _
private var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
// Incremented every time a fetch fails so that client nodes know to clear
// their cache of map output locations if this happens.
private var epoch: Long = 0
private val epochLock = new java.lang.Object
protected var epoch: Long = 0
protected val epochLock = new java.lang.Object
// Cache a serialized version of the output statuses for each shuffle to send them out faster
var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup)
// Send a message to the trackerActor and get its result within a default timeout, or
// throw a SparkException if this fails.
def askTracker(message: Any): Any = {
private def askTracker(message: Any): Any = {
try {
val future = trackerActor.ask(message)(timeout)
return Await.result(future, timeout)
@ -86,50 +82,12 @@ private[spark] class MapOutputTracker extends Logging {
}
// Send a one-way message to the trackerActor, to which we expect it to reply with true.
def communicate(message: Any) {
private def communicate(message: Any) {
if (askTracker(message) != true) {
throw new SparkException("Error reply received from MapOutputTracker")
}
}
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
var array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
}
def registerMapOutputs(
shuffleId: Int,
statuses: Array[MapStatus],
changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
incrementEpoch()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
var arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
var array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
}
incrementEpoch()
} else {
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
// Remembers which map output locations are currently being fetched on a worker
private val fetching = new HashSet[Int]
@ -168,7 +126,7 @@ private[spark] class MapOutputTracker extends Logging {
try {
val fetchedBytes =
askTracker(GetMapOutputStatuses(shuffleId, hostPort)).asInstanceOf[Array[Byte]]
fetchedStatuses = deserializeStatuses(fetchedBytes)
fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
logInfo("Got the output locations")
mapStatuses.put(shuffleId, fetchedStatuses)
} finally {
@ -194,9 +152,8 @@ private[spark] class MapOutputTracker extends Logging {
}
}
private def cleanup(cleanupTime: Long) {
protected def cleanup(cleanupTime: Long) {
mapStatuses.clearOldValues(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
def stop() {
@ -206,15 +163,7 @@ private[spark] class MapOutputTracker extends Logging {
trackerActor = null
}
// Called on master to increment the epoch number
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
logDebug("Increasing epoch to " + epoch)
}
}
// Called on master or workers to get current epoch number
// Called to get current epoch number
def getEpoch: Long = {
epochLock.synchronized {
return epoch
@ -228,14 +177,62 @@ private[spark] class MapOutputTracker extends Logging {
epochLock.synchronized {
if (newEpoch > epoch) {
logInfo("Updating epoch to " + newEpoch + " and clearing cache")
// mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
mapStatuses.clear()
epoch = newEpoch
mapStatuses.clear()
}
}
}
}
def getSerializedLocations(shuffleId: Int): Array[Byte] = {
private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
// Cache a serialized version of the output statuses for each shuffle to send them out faster
private var cacheEpoch = epoch
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
def registerShuffle(shuffleId: Int, numMaps: Int) {
if (mapStatuses.putIfAbsent(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
}
}
def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
val array = mapStatuses(shuffleId)
array.synchronized {
array(mapId) = status
}
}
def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
mapStatuses.put(shuffleId, Array[MapStatus]() ++ statuses)
if (changeEpoch) {
incrementEpoch()
}
}
def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
val arrayOpt = mapStatuses.get(shuffleId)
if (arrayOpt.isDefined && arrayOpt.get != null) {
val array = arrayOpt.get
array.synchronized {
if (array(mapId) != null && array(mapId).location == bmAddress) {
array(mapId) = null
}
}
incrementEpoch()
} else {
throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
}
}
def incrementEpoch() {
epochLock.synchronized {
epoch += 1
logDebug("Increasing epoch to " + epoch)
}
}
def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
var statuses: Array[MapStatus] = null
var epochGotten: Long = -1
epochLock.synchronized {
@ -253,7 +250,7 @@ private[spark] class MapOutputTracker extends Logging {
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
// out a snapshot of the locations as "locs"; let's serialize and return that
val bytes = serializeStatuses(statuses)
val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
epochLock.synchronized {
@ -261,13 +258,31 @@ private[spark] class MapOutputTracker extends Logging {
cachedSerializedStatuses(shuffleId) = bytes
}
}
return bytes
bytes
}
protected override def cleanup(cleanupTime: Long) {
super.cleanup(cleanupTime)
cachedSerializedStatuses.clearOldValues(cleanupTime)
}
override def stop() {
super.stop()
cachedSerializedStatuses.clear()
}
override def updateEpoch(newEpoch: Long) {
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
}
}
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
// generally be pretty compressible because many map outputs will be on the same hostname.
private def serializeStatuses(statuses: Array[MapStatus]): Array[Byte] = {
def serializeMapStatuses(statuses: Array[MapStatus]): Array[Byte] = {
val out = new ByteArrayOutputStream
val objOut = new ObjectOutputStream(new GZIPOutputStream(out))
// Since statuses can be modified in parallel, sync on it
@ -278,18 +293,11 @@ private[spark] class MapOutputTracker extends Logging {
out.toByteArray
}
// Opposite of serializeStatuses.
def deserializeStatuses(bytes: Array[Byte]): Array[MapStatus] = {
// Opposite of serializeMapStatuses.
def deserializeMapStatuses(bytes: Array[Byte]): Array[MapStatus] = {
val objIn = new ObjectInputStream(new GZIPInputStream(new ByteArrayInputStream(bytes)))
objIn.readObject().
// // drop all null's from status - not sure why they are occuring though. Causes NPE downstream in slave if present
// comment this out - nulls could be due to missing location ?
asInstanceOf[Array[MapStatus]] // .filter( _ != null )
objIn.readObject().asInstanceOf[Array[MapStatus]]
}
}
private[spark] object MapOutputTracker {
private val LOG_BASE = 1.1
// Convert an array of MapStatuses to locations and sizes for a given reduce ID. If
// any of the statuses is null (indicating a missing location due to a failed mapper),

View file

@ -51,25 +51,19 @@ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFor
import org.apache.mesos.MesosNativeLibrary
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.LocalSparkCluster
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend,
ClusterScheduler}
import org.apache.spark.scheduler.local.LocalScheduler
import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
SparkDeploySchedulerBackend, ClusterScheduler, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.storage.{StorageUtils, BlockManagerSource}
import org.apache.spark.scheduler.local.LocalScheduler
import org.apache.spark.scheduler.StageInfo
import org.apache.spark.storage.{BlockManagerSource, RDDInfo, StorageStatus, StorageUtils}
import org.apache.spark.ui.SparkUI
import org.apache.spark.util._
import org.apache.spark.scheduler.StageInfo
import org.apache.spark.storage.RDDInfo
import org.apache.spark.storage.StorageStatus
import scala.Some
import org.apache.spark.scheduler.StageInfo
import org.apache.spark.storage.RDDInfo
import org.apache.spark.storage.StorageStatus
import org.apache.spark.util.{ClosureCleaner, MetadataCleaner, MetadataCleanerType,
TimeStampedHashMap, Utils}
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@ -125,7 +119,7 @@ class SparkContext(
private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]]
private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup)
// Initalize the Spark UI
// Initialize the Spark UI
private[spark] val ui = new SparkUI(this)
ui.bind()
@ -161,8 +155,10 @@ class SparkContext(
val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
// Regular expression for connecting to Spark deploy clusters
val SPARK_REGEX = """spark://(.*)""".r
//Regular expression for connection to Mesos cluster
val MESOS_REGEX = """(mesos://.*)""".r
// Regular expression for connection to Mesos cluster
val MESOS_REGEX = """mesos://(.*)""".r
// Regular expression for connection to Simr cluster
val SIMR_REGEX = """simr://(.*)""".r
master match {
case "local" =>
@ -181,6 +177,12 @@ class SparkContext(
scheduler.initialize(backend)
scheduler
case SIMR_REGEX(simrUrl) =>
val scheduler = new ClusterScheduler(this)
val backend = new SimrSchedulerBackend(scheduler, this, simrUrl)
scheduler.initialize(backend)
scheduler
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
// Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
val memoryPerSlaveInt = memoryPerSlave.toInt
@ -217,21 +219,20 @@ class SparkContext(
scheduler.initialize(backend)
scheduler
case _ =>
if (MESOS_REGEX.findFirstIn(master).isEmpty) {
logWarning("Master %s does not match expected format, parsing as Mesos URL".format(master))
}
case MESOS_REGEX(mesosUrl) =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
val masterWithoutProtocol = master.replaceFirst("^mesos://", "") // Strip initial mesos://
val backend = if (coarseGrained) {
new CoarseMesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
} else {
new MesosSchedulerBackend(scheduler, this, masterWithoutProtocol, appName)
new MesosSchedulerBackend(scheduler, this, mesosUrl, appName)
}
scheduler.initialize(backend)
scheduler
case _ =>
throw new SparkException("Could not parse Master URL: '" + master + "'")
}
}
taskScheduler.start()
@ -288,15 +289,46 @@ class SparkContext(
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
@deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
setJobGroup("", value)
}
/**
* Assigns a group id to all the jobs started by this thread until the group id is set to a
* different value or cleared.
*
* Often, a unit of execution in an application consists of multiple Spark actions or jobs.
* Application programmers can use this method to group all those jobs together and give a
* group description. Once set, the Spark web UI will associate such jobs with this group.
*
* The application can also use [[org.apache.spark.SparkContext.cancelJobGroup]] to cancel all
* running jobs in this group. For example,
* {{{
* // In the main thread:
* sc.setJobGroup("some_job_to_cancel", "some job description")
* sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
*
* // In a separate thread:
* sc.cancelJobGroup("some_job_to_cancel")
* }}}
*/
def setJobGroup(groupId: String, description: String) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, description)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, groupId)
}
/** Clear the job group id and its description. */
def clearJobGroup() {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, null)
setLocalProperty(SparkContext.SPARK_JOB_GROUP_ID, null)
}
// Post init
taskScheduler.postStartHook()
val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
private val dagSchedulerSource = new DAGSchedulerSource(this.dagScheduler, this)
private val blockManagerSource = new BlockManagerSource(SparkEnv.get.blockManager, this)
def initDriverMetrics() {
SparkEnv.get.metricsSystem.registerSource(dagSchedulerSource)
@ -651,12 +683,11 @@ class SparkContext(
/**
* Adds a JAR dependency for all tasks to be executed on this SparkContext in the future.
* The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
* filesystems), or an HTTP, HTTPS or FTP URI.
* filesystems), an HTTP, HTTPS or FTP URI, or local:/path for a file on every worker node.
*/
def addJar(path: String) {
if (path == null) {
logWarning("null specified as parameter to addJar",
new SparkException("null specified as parameter to addJar"))
logWarning("null specified as parameter to addJar")
} else {
var key = ""
if (path.contains("\\")) {
@ -665,6 +696,7 @@ class SparkContext(
} else {
val uri = new URI(path)
key = uri.getScheme match {
// A JAR file which exists only on the driver node
case null | "file" =>
if (env.hadoop.isYarnMode()) {
// In order for this to work on yarn the user must specify the --addjars option to
@ -682,6 +714,9 @@ class SparkContext(
} else {
env.httpFileServer.addJar(new File(uri.getPath))
}
// A JAR file which exists locally on every worker node
case "local" =>
"file:" + uri.getPath
case _ =>
path
}
@ -867,13 +902,19 @@ class SparkContext(
callSite,
allowLocal = false,
resultHandler,
null)
localProperties.get)
new SimpleFutureAction(waiter, resultFunc)
}
/**
* Cancel all jobs that have been scheduled or are running.
* Cancel active jobs for the specified group. See [[org.apache.spark.SparkContext.setJobGroup]]
* for more information.
*/
def cancelJobGroup(groupId: String) {
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
dagScheduler.cancelAllJobs()
}
@ -934,7 +975,10 @@ class SparkContext(
* various Spark features.
*/
object SparkContext {
val SPARK_JOB_DESCRIPTION = "spark.job.description"
private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2

View file

@ -187,10 +187,14 @@ object SparkEnv extends Logging {
// Have to assign trackerActor after initialization as MapOutputTrackerActor
// requires the MapOutputTracker itself
val mapOutputTracker = new MapOutputTracker()
val mapOutputTracker = if (isDriver) {
new MapOutputTrackerMaster()
} else {
new MapOutputTracker()
}
mapOutputTracker.trackerActor = registerOrLookup(
"MapOutputTracker",
new MapOutputTrackerActor(mapOutputTracker))
new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]))
val shuffleFetcher = instantiateClass[ShuffleFetcher](
"spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")

View file

@ -17,14 +17,14 @@
package org.apache.hadoop.mapred
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import java.io.IOException
import java.text.SimpleDateFormat
import java.text.NumberFormat
import java.io.IOException
import java.util.Date
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
import org.apache.spark.SerializableWritable
@ -36,6 +36,7 @@ import org.apache.spark.SerializableWritable
* Saves the RDD using a JobConf, which should contain an output key class, an output value class,
* a filename to write to, etc, exactly like in a Hadoop MapReduce job.
*/
private[apache]
class SparkHadoopWriter(@transient jobConf: JobConf)
extends Logging
with SparkHadoopMapRedUtil
@ -86,13 +87,11 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
}
getOutputCommitter().setupTask(getTaskContext())
writer = getOutputFormat().getRecordWriter(
fs, conf.value, outputName, Reporter.NULL)
writer = getOutputFormat().getRecordWriter(fs, conf.value, outputName, Reporter.NULL)
}
def write(key: AnyRef, value: AnyRef) {
if (writer!=null) {
//println (">>> Writing ("+key.toString+": " + key.getClass.toString + ", " + value.toString + ": " + value.getClass.toString + ")")
if (writer != null) {
writer.write(key, value)
} else {
throw new IOException("Writer is null, open() has not been called")
@ -182,6 +181,7 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
}
}
private[apache]
object SparkHadoopWriter {
def createJobID(time: Date, id: Int): JobID = {
val formatter = new SimpleDateFormat("yyyyMMddHHmm")

View file

@ -48,6 +48,19 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
*/
def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
* This method blocks until all blocks are deleted.
*/
def unpersist(): JavaDoubleRDD = fromRDD(srdd.unpersist())
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*
* @param blocking Whether to block until all blocks are deleted.
*/
def unpersist(blocking: Boolean): JavaDoubleRDD = fromRDD(srdd.unpersist(blocking))
// first() has to be overriden here in order for its return type to be Double instead of Object.
override def first(): Double = srdd.first()
@ -80,6 +93,17 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
def coalesce(numPartitions: Int, shuffle: Boolean): JavaDoubleRDD =
fromRDD(srdd.coalesce(numPartitions, shuffle))
/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): JavaDoubleRDD = fromRDD(srdd.repartition(numPartitions))
/**
* Return an RDD with the elements from `this` that are not in `other`.
*

View file

@ -65,6 +65,19 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd.persist(newLevel))
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
* This method blocks until all blocks are deleted.
*/
def unpersist(): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist())
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*
* @param blocking Whether to block until all blocks are deleted.
*/
def unpersist(blocking: Boolean): JavaPairRDD[K, V] = wrapRDD(rdd.unpersist(blocking))
// Transformations (return a new RDD)
/**
@ -94,6 +107,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def coalesce(numPartitions: Int, shuffle: Boolean): JavaPairRDD[K, V] =
fromRDD(rdd.coalesce(numPartitions, shuffle))
/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): JavaPairRDD[K, V] = fromRDD(rdd.repartition(numPartitions))
/**
* Return a sampled subset of this RDD.
*/
@ -598,4 +622,15 @@ object JavaPairRDD {
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
/** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
implicit val cmk: ClassManifest[K] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
implicit val cmv: ClassManifest[V] =
implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
new JavaPairRDD[K, V](rdd.rdd)
}
}

View file

@ -41,9 +41,17 @@ JavaRDDLike[T, JavaRDD[T]] {
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
* This method blocks until all blocks are deleted.
*/
def unpersist(): JavaRDD[T] = wrapRDD(rdd.unpersist())
/**
* Mark the RDD as non-persistent, and remove all blocks for it from memory and disk.
*
* @param blocking Whether to block until all blocks are deleted.
*/
def unpersist(blocking: Boolean): JavaRDD[T] = wrapRDD(rdd.unpersist(blocking))
// Transformations (return a new RDD)
/**
@ -73,6 +81,17 @@ JavaRDDLike[T, JavaRDD[T]] {
def coalesce(numPartitions: Int, shuffle: Boolean): JavaRDD[T] =
rdd.coalesce(numPartitions, shuffle)
/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): JavaRDD[T] = rdd.repartition(numPartitions)
/**
* Return a sampled subset of this RDD.
*/

View file

@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
@ -27,11 +25,7 @@ import java.io.Serializable;
*/
// DoubleFlatMapFunction does not extend FlatMapFunction because flatMap is
// overloaded for both FlatMapFunction and DoubleFlatMapFunction.
public abstract class DoubleFlatMapFunction<T> extends AbstractFunction1<T, Iterable<Double>>
public abstract class DoubleFlatMapFunction<T> extends WrappedFunction1<T, Iterable<Double>>
implements Serializable {
public abstract Iterable<Double> call(T t);
@Override
public final Iterable<Double> apply(T t) { return call(t); }
// Intentionally left blank
}

View file

@ -18,8 +18,6 @@
package org.apache.spark.api.java.function;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
/**
@ -29,6 +27,5 @@ import java.io.Serializable;
// are overloaded for both Function and DoubleFunction.
public abstract class DoubleFunction<T> extends WrappedFunction1<T, Double>
implements Serializable {
public abstract Double call(T t) throws Exception;
// Intentionally left blank
}

View file

@ -21,8 +21,5 @@ package org.apache.spark.api.java.function
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
@throws(classOf[Exception])
def call(x: T) : java.lang.Iterable[R]
def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]]
}

View file

@ -21,8 +21,5 @@ package org.apache.spark.api.java.function
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
@throws(classOf[Exception])
def call(a: A, b:B) : java.lang.Iterable[C]
def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
}

View file

@ -19,7 +19,6 @@ package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@ -30,8 +29,6 @@ import java.io.Serializable;
* when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
public abstract R call(T t) throws Exception;
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}

View file

@ -19,7 +19,6 @@ package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction2;
import java.io.Serializable;
@ -29,8 +28,6 @@ import java.io.Serializable;
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {
public abstract R call(T1 t1, T2 t2) throws Exception;
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}

View file

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction2;
import java.io.Serializable;
/**
* A three-argument function that takes arguments of type T1, T2 and T3 and returns an R.
*/
public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R>
implements Serializable {
public ClassManifest<R> returnType() {
return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
}
}

View file

@ -20,7 +20,6 @@ package org.apache.spark.api.java.function;
import scala.Tuple2;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@ -34,8 +33,6 @@ public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
public abstract Iterable<Tuple2<K, V>> call(T t) throws Exception;
public ClassManifest<K> keyType() {
return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
}

View file

@ -20,7 +20,6 @@ package org.apache.spark.api.java.function;
import scala.Tuple2;
import scala.reflect.ClassManifest;
import scala.reflect.ClassManifest$;
import scala.runtime.AbstractFunction1;
import java.io.Serializable;
@ -29,12 +28,9 @@ import java.io.Serializable;
*/
// PairFunction does not extend Function because some UDF functions, like map,
// are overloaded for both Function and PairFunction.
public abstract class PairFunction<T, K, V>
extends WrappedFunction1<T, Tuple2<K, V>>
public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
public abstract Tuple2<K, V> call(T t) throws Exception;
public ClassManifest<K> keyType() {
return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
}

View file

@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.api.java.function
import scala.runtime.AbstractFunction3
/**
* Subclass of Function3 for ease of calling from Java. The main thing it does is re-expose the
* apply() method as call() and declare that it can throw Exception (since AbstractFunction3.apply
* isn't marked to allow that).
*/
private[spark] abstract class WrappedFunction3[T1, T2, T3, R]
extends AbstractFunction3[T1, T2, T3, R] {
@throws(classOf[Exception])
def call(t1: T1, t2: T2, t3: T3): R
final def apply(t1: T1, t2: T2, t3: T3): R = call(t1, t2, t3)
}

View file

@ -308,7 +308,7 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[
* Internal class that acts as an `AccumulatorParam` for Python accumulators. Inside, it
* collects a list of pickled strings that we pass to Python through a socket.
*/
class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
extends AccumulatorParam[JList[Array[Byte]]] {
Utils.checkHost(serverHost, "Expected hostname")

View file

@ -1,410 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io._
import java.net._
import java.util.Random
import scala.collection.mutable.Map
import org.apache.spark._
import org.apache.spark.util.Utils
private object MultiTracker
extends Logging {
// Tracker Messages
val REGISTER_BROADCAST_TRACKER = 0
val UNREGISTER_BROADCAST_TRACKER = 1
val FIND_BROADCAST_TRACKER = 2
// Map to keep track of guides of ongoing broadcasts
var valueToGuideMap = Map[Long, SourceInfo]()
// Random number generator
var ranGen = new Random
private var initialized = false
private var _isDriver = false
private var stopBroadcast = false
private var trackMV: TrackMultipleValues = null
def initialize(__isDriver: Boolean) {
synchronized {
if (!initialized) {
_isDriver = __isDriver
if (isDriver) {
trackMV = new TrackMultipleValues
trackMV.setDaemon(true)
trackMV.start()
// Set DriverHostAddress to the driver's IP address for the slaves to read
System.setProperty("spark.MultiTracker.DriverHostAddress", Utils.localIpAddress)
}
initialized = true
}
}
}
def stop() {
stopBroadcast = true
}
// Load common parameters
private var DriverHostAddress_ = System.getProperty(
"spark.MultiTracker.DriverHostAddress", "")
private var DriverTrackerPort_ = System.getProperty(
"spark.broadcast.driverTrackerPort", "11111").toInt
private var BlockSize_ = System.getProperty(
"spark.broadcast.blockSize", "4096").toInt * 1024
private var MaxRetryCount_ = System.getProperty(
"spark.broadcast.maxRetryCount", "2").toInt
private var TrackerSocketTimeout_ = System.getProperty(
"spark.broadcast.trackerSocketTimeout", "50000").toInt
private var ServerSocketTimeout_ = System.getProperty(
"spark.broadcast.serverSocketTimeout", "10000").toInt
private var MinKnockInterval_ = System.getProperty(
"spark.broadcast.minKnockInterval", "500").toInt
private var MaxKnockInterval_ = System.getProperty(
"spark.broadcast.maxKnockInterval", "999").toInt
// Load TreeBroadcast config params
private var MaxDegree_ = System.getProperty(
"spark.broadcast.maxDegree", "2").toInt
// Load BitTorrentBroadcast config params
private var MaxPeersInGuideResponse_ = System.getProperty(
"spark.broadcast.maxPeersInGuideResponse", "4").toInt
private var MaxChatSlots_ = System.getProperty(
"spark.broadcast.maxChatSlots", "4").toInt
private var MaxChatTime_ = System.getProperty(
"spark.broadcast.maxChatTime", "500").toInt
private var MaxChatBlocks_ = System.getProperty(
"spark.broadcast.maxChatBlocks", "1024").toInt
private var EndGameFraction_ = System.getProperty(
"spark.broadcast.endGameFraction", "0.95").toDouble
def isDriver = _isDriver
// Common config params
def DriverHostAddress = DriverHostAddress_
def DriverTrackerPort = DriverTrackerPort_
def BlockSize = BlockSize_
def MaxRetryCount = MaxRetryCount_
def TrackerSocketTimeout = TrackerSocketTimeout_
def ServerSocketTimeout = ServerSocketTimeout_
def MinKnockInterval = MinKnockInterval_
def MaxKnockInterval = MaxKnockInterval_
// TreeBroadcast configs
def MaxDegree = MaxDegree_
// BitTorrentBroadcast configs
def MaxPeersInGuideResponse = MaxPeersInGuideResponse_
def MaxChatSlots = MaxChatSlots_
def MaxChatTime = MaxChatTime_
def MaxChatBlocks = MaxChatBlocks_
def EndGameFraction = EndGameFraction_
class TrackMultipleValues
extends Thread with Logging {
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool("Track multiple values")
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(DriverTrackerPort)
logInfo("TrackMultipleValues started at " + serverSocket)
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(TrackerSocketTimeout)
clientSocket = serverSocket.accept()
} catch {
case e: Exception => {
if (stopBroadcast) {
logInfo("Stopping TrackMultipleValues...")
}
}
}
if (clientSocket != null) {
try {
threadPool.execute(new Thread {
override def run() {
val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
val ois = new ObjectInputStream(clientSocket.getInputStream)
try {
// First, read message type
val messageType = ois.readObject.asInstanceOf[Int]
if (messageType == REGISTER_BROADCAST_TRACKER) {
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Receive hostAddress and listenPort
val gInfo = ois.readObject.asInstanceOf[SourceInfo]
// Add to the map
valueToGuideMap.synchronized {
valueToGuideMap += (id -> gInfo)
}
logInfo ("New broadcast " + id + " registered with TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == UNREGISTER_BROADCAST_TRACKER) {
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
// Remove from the map
valueToGuideMap.synchronized {
valueToGuideMap(id) = SourceInfo("", SourceInfo.TxOverGoToDefault)
}
logInfo ("Broadcast " + id + " unregistered from TrackMultipleValues. Ongoing ones: " + valueToGuideMap)
// Send dummy ACK
oos.writeObject(-1)
oos.flush()
} else if (messageType == FIND_BROADCAST_TRACKER) {
// Receive Long
val id = ois.readObject.asInstanceOf[Long]
var gInfo =
if (valueToGuideMap.contains(id)) valueToGuideMap(id)
else SourceInfo("", SourceInfo.TxNotStartedRetry)
logDebug("Got new request: " + clientSocket + " for " + id + " : " + gInfo.listenPort)
// Send reply back
oos.writeObject(gInfo)
oos.flush()
} else {
throw new SparkException("Undefined messageType at TrackMultipleValues")
}
} catch {
case e: Exception => {
logError("TrackMultipleValues had a " + e)
}
} finally {
ois.close()
oos.close()
clientSocket.close()
}
}
})
} catch {
// In failure, close socket here; else, client thread will close
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
serverSocket.close()
}
// Shutdown the thread pool
threadPool.shutdown()
}
}
def getGuideInfo(variableLong: Long): SourceInfo = {
var clientSocketToTracker: Socket = null
var oosTracker: ObjectOutputStream = null
var oisTracker: ObjectInputStream = null
var gInfo: SourceInfo = SourceInfo("", SourceInfo.TxNotStartedRetry)
var retriesLeft = MultiTracker.MaxRetryCount
do {
try {
// Connect to the tracker to find out GuideInfo
clientSocketToTracker =
new Socket(MultiTracker.DriverHostAddress, MultiTracker.DriverTrackerPort)
oosTracker =
new ObjectOutputStream(clientSocketToTracker.getOutputStream)
oosTracker.flush()
oisTracker =
new ObjectInputStream(clientSocketToTracker.getInputStream)
// Send messageType/intention
oosTracker.writeObject(MultiTracker.FIND_BROADCAST_TRACKER)
oosTracker.flush()
// Send Long and receive GuideInfo
oosTracker.writeObject(variableLong)
oosTracker.flush()
gInfo = oisTracker.readObject.asInstanceOf[SourceInfo]
} catch {
case e: Exception => logError("getGuideInfo had a " + e)
} finally {
if (oisTracker != null) {
oisTracker.close()
}
if (oosTracker != null) {
oosTracker.close()
}
if (clientSocketToTracker != null) {
clientSocketToTracker.close()
}
}
Thread.sleep(MultiTracker.ranGen.nextInt(
MultiTracker.MaxKnockInterval - MultiTracker.MinKnockInterval) +
MultiTracker.MinKnockInterval)
retriesLeft -= 1
} while (retriesLeft > 0 && gInfo.listenPort == SourceInfo.TxNotStartedRetry)
logDebug("Got this guidePort from Tracker: " + gInfo.listenPort)
return gInfo
}
def registerBroadcast(id: Long, gInfo: SourceInfo) {
val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
// Send messageType/intention
oosST.writeObject(REGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send Long of this broadcast
oosST.writeObject(id)
oosST.flush()
// Send this tracker's information
oosST.writeObject(gInfo)
oosST.flush()
// Receive ACK and throw it away
oisST.readObject.asInstanceOf[Int]
// Shut stuff down
oisST.close()
oosST.close()
socket.close()
}
def unregisterBroadcast(id: Long) {
val socket = new Socket(MultiTracker.DriverHostAddress, DriverTrackerPort)
val oosST = new ObjectOutputStream(socket.getOutputStream)
oosST.flush()
val oisST = new ObjectInputStream(socket.getInputStream)
// Send messageType/intention
oosST.writeObject(UNREGISTER_BROADCAST_TRACKER)
oosST.flush()
// Send Long of this broadcast
oosST.writeObject(id)
oosST.flush()
// Receive ACK and throw it away
oisST.readObject.asInstanceOf[Int]
// Shut stuff down
oisST.close()
oosST.close()
socket.close()
}
// Helper method to convert an object to Array[BroadcastBlock]
def blockifyObject[IN](obj: IN): VariableInfo = {
val baos = new ByteArrayOutputStream
val oos = new ObjectOutputStream(baos)
oos.writeObject(obj)
oos.close()
baos.close()
val byteArray = baos.toByteArray
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / BlockSize)
if (byteArray.length % BlockSize != 0)
blockNum += 1
var retVal = new Array[BroadcastBlock](blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, BlockSize)) {
val thisBlockSize = math.min(BlockSize, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new BroadcastBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()
var variableInfo = VariableInfo(retVal, blockNum, byteArray.length)
variableInfo.hasBlocks = blockNum
return variableInfo
}
// Helper method to convert Array[BroadcastBlock] to object
def unBlockifyObject[OUT](arrayOfBlocks: Array[BroadcastBlock],
totalBytes: Int,
totalBlocks: Int): OUT = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BlockSize, arrayOfBlocks(i).byteArray.length)
}
byteArrayToObject(retByteArray)
}
private def byteArrayToObject[OUT](bytes: Array[Byte]): OUT = {
val in = new ObjectInputStream (new ByteArrayInputStream (bytes)){
override def resolveClass(desc: ObjectStreamClass) =
Class.forName(desc.getName, false, Thread.currentThread.getContextClassLoader)
}
val retVal = in.readObject.asInstanceOf[OUT]
in.close()
return retVal
}
}
private[spark] case class BroadcastBlock(blockID: Int, byteArray: Array[Byte])
extends Serializable
private[spark] case class VariableInfo(@transient arrayOfBlocks : Array[BroadcastBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient var hasBlocks = 0
}

View file

@ -1,54 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.util.BitSet
import org.apache.spark._
/**
* Used to keep and pass around information of peers involved in a broadcast
*/
private[spark] case class SourceInfo (hostAddress: String,
listenPort: Int,
totalBlocks: Int = SourceInfo.UnusedParam,
totalBytes: Int = SourceInfo.UnusedParam)
extends Comparable[SourceInfo] with Logging {
var currentLeechers = 0
var receptionFailed = false
var hasBlocks = 0
var hasBlocksBitVector: BitSet = new BitSet (totalBlocks)
// Ascending sort based on leecher count
def compareTo (o: SourceInfo): Int = (currentLeechers - o.currentLeechers)
}
/**
* Helper Object of SourceInfo for its constants
*/
private[spark] object SourceInfo {
// Broadcast has not started yet! Should never happen.
val TxNotStartedRetry = -1
// Broadcast has already finished. Try default mechanism.
val TxOverGoToDefault = -3
// Other constants
val StopBroadcast = -2
val UnusedParam = 0
}

View file

@ -0,0 +1,247 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io._
import scala.math
import scala.util.Random
import org.apache.spark._
import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel}
import org.apache.spark.util.Utils
private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def broadcastId = BroadcastBlockId(id)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[TorrentBlock] = null
@transient var totalBlocks = -1
@transient var totalBytes = -1
@transient var hasBlocks = 0
if (!isLocal) {
sendBroadcast()
}
def sendBroadcast() {
var tInfo = TorrentBroadcast.blockifyObject(value_)
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
hasBlocks = tInfo.totalBlocks
// Store meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true)
}
// Store individual pieces
for (i <- 0 until totalBlocks) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.putSingle(
pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true)
}
}
}
// Called by JVM when deserializing an object
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(broadcastId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
val start = System.nanoTime
logInfo("Started reading broadcast variable " + id)
// Initialize @transient variables that will receive garbage values from the master.
resetWorkerVariables()
if (receiveBroadcast(id)) {
value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
// Store the merged copy in cache so that the next worker doesn't need to rebuild it.
// This creates a tradeoff between memory usage and latency.
// Storing copy doubles the memory footprint; not storing doubles deserialization cost.
SparkEnv.get.blockManager.putSingle(
broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false)
// Remove arrayOfBlocks from memory once value_ is on local cache
resetWorkerVariables()
} else {
logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
private def resetWorkerVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
}
def receiveBroadcast(variableID: Long): Boolean = {
// Receive meta-info
val metaId = BroadcastHelperBlockId(broadcastId, "meta")
var attemptId = 10
while (attemptId > 0 && totalBlocks == -1) {
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(metaId) match {
case Some(x) =>
val tInfo = x.asInstanceOf[TorrentInfo]
totalBlocks = tInfo.totalBlocks
totalBytes = tInfo.totalBytes
arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
hasBlocks = 0
case None =>
Thread.sleep(500)
}
}
attemptId -= 1
}
if (totalBlocks == -1) {
return false
}
// Receive actual blocks
val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
for (pid <- recvOrder) {
val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid)
TorrentBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(pieceId) match {
case Some(x) =>
arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
hasBlocks += 1
SparkEnv.get.blockManager.putSingle(
pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true)
case None =>
throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
}
}
}
(hasBlocks == totalBlocks)
}
}
private object TorrentBroadcast
extends Logging {
private var initialized = false
def initialize(_isDriver: Boolean) {
synchronized {
if (!initialized) {
initialized = true
}
}
}
def stop() {
initialized = false
}
val BLOCK_SIZE = System.getProperty("spark.broadcast.blockSize", "4096").toInt * 1024
def blockifyObject[T](obj: T): TorrentInfo = {
val byteArray = Utils.serialize[T](obj)
val bais = new ByteArrayInputStream(byteArray)
var blockNum = (byteArray.length / BLOCK_SIZE)
if (byteArray.length % BLOCK_SIZE != 0)
blockNum += 1
var retVal = new Array[TorrentBlock](blockNum)
var blockID = 0
for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
var tempByteArray = new Array[Byte](thisBlockSize)
val hasRead = bais.read(tempByteArray, 0, thisBlockSize)
retVal(blockID) = new TorrentBlock(blockID, tempByteArray)
blockID += 1
}
bais.close()
var tInfo = TorrentInfo(retVal, blockNum, byteArray.length)
tInfo.hasBlocks = blockNum
return tInfo
}
def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock],
totalBytes: Int,
totalBlocks: Int): T = {
var retByteArray = new Array[Byte](totalBytes)
for (i <- 0 until totalBlocks) {
System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
}
Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader)
}
}
private[spark] case class TorrentBlock(
blockID: Int,
byteArray: Array[Byte])
extends Serializable
private[spark] case class TorrentInfo(
@transient arrayOfBlocks : Array[TorrentBlock],
totalBlocks: Int,
totalBytes: Int)
extends Serializable {
@transient var hasBlocks = 0
}
private[spark] class TorrentBroadcastFactory
extends BroadcastFactory {
def initialize(isDriver: Boolean) { TorrentBroadcast.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
def stop() { TorrentBroadcast.stop() }
}

View file

@ -1,601 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.broadcast
import java.io._
import java.net._
import scala.collection.mutable.{ListBuffer, Set}
import org.apache.spark._
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
import org.apache.spark.util.Utils
private[spark] class TreeBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def value = value_
def blockId = BroadcastBlockId(id)
MultiTracker.synchronized {
SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
}
@transient var arrayOfBlocks: Array[BroadcastBlock] = null
@transient var totalBytes = -1
@transient var totalBlocks = -1
@transient var hasBlocks = 0
@transient var listenPortLock = new Object
@transient var guidePortLock = new Object
@transient var totalBlocksLock = new Object
@transient var hasBlocksLock = new Object
@transient var listOfSources = ListBuffer[SourceInfo]()
@transient var serveMR: ServeMultipleRequests = null
@transient var guideMR: GuideMultipleRequests = null
@transient var hostAddress = Utils.localIpAddress
@transient var listenPort = -1
@transient var guidePort = -1
@transient var stopBroadcast = false
// Must call this after all the variables have been created/initialized
if (!isLocal) {
sendBroadcast()
}
def sendBroadcast() {
logInfo("Local host address: " + hostAddress)
// Create a variableInfo object and store it in valueInfos
var variableInfo = MultiTracker.blockifyObject(value_)
// Prepare the value being broadcasted
arrayOfBlocks = variableInfo.arrayOfBlocks
totalBytes = variableInfo.totalBytes
totalBlocks = variableInfo.totalBlocks
hasBlocks = variableInfo.totalBlocks
guideMR = new GuideMultipleRequests
guideMR.setDaemon(true)
guideMR.start()
logInfo("GuideMultipleRequests started...")
// Must always come AFTER guideMR is created
while (guidePort == -1) {
guidePortLock.synchronized { guidePortLock.wait() }
}
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
// Must always come AFTER serveMR is created
while (listenPort == -1) {
listenPortLock.synchronized { listenPortLock.wait() }
}
// Must always come AFTER listenPort is created
val masterSource =
SourceInfo(hostAddress, listenPort, totalBlocks, totalBytes)
listOfSources += masterSource
// Register with the Tracker
MultiTracker.registerBroadcast(id,
SourceInfo(hostAddress, guidePort, totalBlocks, totalBytes))
}
private def readObject(in: ObjectInputStream) {
in.defaultReadObject()
MultiTracker.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
case Some(x) =>
value_ = x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
// Initializing everything because Driver will only send null/0 values
// Only the 1st worker in a node can be here. Others will get from cache
initializeWorkerVariables()
logInfo("Local host address: " + hostAddress)
serveMR = new ServeMultipleRequests
serveMR.setDaemon(true)
serveMR.start()
logInfo("ServeMultipleRequests started...")
val start = System.nanoTime
val receptionSucceeded = receiveBroadcast(id)
if (receptionSucceeded) {
value_ = MultiTracker.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
SparkEnv.get.blockManager.putSingle(
blockId, value_, StorageLevel.MEMORY_AND_DISK, false)
} else {
logError("Reading broadcast variable " + id + " failed")
}
val time = (System.nanoTime - start) / 1e9
logInfo("Reading broadcast variable " + id + " took " + time + " s")
}
}
}
private def initializeWorkerVariables() {
arrayOfBlocks = null
totalBytes = -1
totalBlocks = -1
hasBlocks = 0
listenPortLock = new Object
totalBlocksLock = new Object
hasBlocksLock = new Object
serveMR = null
hostAddress = Utils.localIpAddress
listenPort = -1
stopBroadcast = false
}
def receiveBroadcast(variableID: Long): Boolean = {
val gInfo = MultiTracker.getGuideInfo(variableID)
if (gInfo.listenPort == SourceInfo.TxOverGoToDefault) {
return false
}
// Wait until hostAddress and listenPort are created by the
// ServeMultipleRequests thread
while (listenPort == -1) {
listenPortLock.synchronized { listenPortLock.wait() }
}
var clientSocketToDriver: Socket = null
var oosDriver: ObjectOutputStream = null
var oisDriver: ObjectInputStream = null
// Connect and receive broadcast from the specified source, retrying the
// specified number of times in case of failures
var retriesLeft = MultiTracker.MaxRetryCount
do {
// Connect to Driver and send this worker's Information
clientSocketToDriver = new Socket(MultiTracker.DriverHostAddress, gInfo.listenPort)
oosDriver = new ObjectOutputStream(clientSocketToDriver.getOutputStream)
oosDriver.flush()
oisDriver = new ObjectInputStream(clientSocketToDriver.getInputStream)
logDebug("Connected to Driver's guiding object")
// Send local source information
oosDriver.writeObject(SourceInfo(hostAddress, listenPort))
oosDriver.flush()
// Receive source information from Driver
var sourceInfo = oisDriver.readObject.asInstanceOf[SourceInfo]
totalBlocks = sourceInfo.totalBlocks
arrayOfBlocks = new Array[BroadcastBlock](totalBlocks)
totalBlocksLock.synchronized { totalBlocksLock.notifyAll() }
totalBytes = sourceInfo.totalBytes
logDebug("Received SourceInfo from Driver:" + sourceInfo + " My Port: " + listenPort)
val start = System.nanoTime
val receptionSucceeded = receiveSingleTransmission(sourceInfo)
val time = (System.nanoTime - start) / 1e9
// Updating some statistics in sourceInfo. Driver will be using them later
if (!receptionSucceeded) {
sourceInfo.receptionFailed = true
}
// Send back statistics to the Driver
oosDriver.writeObject(sourceInfo)
if (oisDriver != null) {
oisDriver.close()
}
if (oosDriver != null) {
oosDriver.close()
}
if (clientSocketToDriver != null) {
clientSocketToDriver.close()
}
retriesLeft -= 1
} while (retriesLeft > 0 && hasBlocks < totalBlocks)
return (hasBlocks == totalBlocks)
}
/**
* Tries to receive broadcast from the source and returns Boolean status.
* This might be called multiple times to retry a defined number of times.
*/
private def receiveSingleTransmission(sourceInfo: SourceInfo): Boolean = {
var clientSocketToSource: Socket = null
var oosSource: ObjectOutputStream = null
var oisSource: ObjectInputStream = null
var receptionSucceeded = false
try {
// Connect to the source to get the object itself
clientSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
oosSource = new ObjectOutputStream(clientSocketToSource.getOutputStream)
oosSource.flush()
oisSource = new ObjectInputStream(clientSocketToSource.getInputStream)
logDebug("Inside receiveSingleTransmission")
logDebug("totalBlocks: "+ totalBlocks + " " + "hasBlocks: " + hasBlocks)
// Send the range
oosSource.writeObject((hasBlocks, totalBlocks))
oosSource.flush()
for (i <- hasBlocks until totalBlocks) {
val recvStartTime = System.currentTimeMillis
val bcBlock = oisSource.readObject.asInstanceOf[BroadcastBlock]
val receptionTime = (System.currentTimeMillis - recvStartTime)
logDebug("Received block: " + bcBlock.blockID + " from " + sourceInfo + " in " + receptionTime + " millis.")
arrayOfBlocks(hasBlocks) = bcBlock
hasBlocks += 1
// Set to true if at least one block is received
receptionSucceeded = true
hasBlocksLock.synchronized { hasBlocksLock.notifyAll() }
}
} catch {
case e: Exception => logError("receiveSingleTransmission had a " + e)
} finally {
if (oisSource != null) {
oisSource.close()
}
if (oosSource != null) {
oosSource.close()
}
if (clientSocketToSource != null) {
clientSocketToSource.close()
}
}
return receptionSucceeded
}
class GuideMultipleRequests
extends Thread with Logging {
// Keep track of sources that have completed reception
private var setOfCompletedSources = Set[SourceInfo]()
override def run() {
var threadPool = Utils.newDaemonCachedThreadPool("Tree broadcast guide multiple requests")
var serverSocket: ServerSocket = null
serverSocket = new ServerSocket(0)
guidePort = serverSocket.getLocalPort
logInfo("GuideMultipleRequests => " + serverSocket + " " + guidePort)
guidePortLock.synchronized { guidePortLock.notifyAll() }
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => {
// Stop broadcast if at least one worker has connected and
// everyone connected so far are done. Comparing with
// listOfSources.size - 1, because it includes the Guide itself
listOfSources.synchronized {
setOfCompletedSources.synchronized {
if (listOfSources.size > 1 &&
setOfCompletedSources.size == listOfSources.size - 1) {
stopBroadcast = true
logInfo("GuideMultipleRequests Timeout. stopBroadcast == true.")
}
}
}
}
}
if (clientSocket != null) {
logDebug("Guide: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new GuideSingleRequest(clientSocket))
} catch {
// In failure, close() the socket here; else, the thread will close() it
case ioe: IOException => clientSocket.close()
}
}
}
logInfo("Sending stopBroadcast notifications...")
sendStopBroadcastNotifications
MultiTracker.unregisterBroadcast(id)
} finally {
if (serverSocket != null) {
logInfo("GuideMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
private def sendStopBroadcastNotifications() {
listOfSources.synchronized {
var listIter = listOfSources.iterator
while (listIter.hasNext) {
var sourceInfo = listIter.next
var guideSocketToSource: Socket = null
var gosSource: ObjectOutputStream = null
var gisSource: ObjectInputStream = null
try {
// Connect to the source
guideSocketToSource = new Socket(sourceInfo.hostAddress, sourceInfo.listenPort)
gosSource = new ObjectOutputStream(guideSocketToSource.getOutputStream)
gosSource.flush()
gisSource = new ObjectInputStream(guideSocketToSource.getInputStream)
// Send stopBroadcast signal
gosSource.writeObject((SourceInfo.StopBroadcast, SourceInfo.StopBroadcast))
gosSource.flush()
} catch {
case e: Exception => {
logError("sendStopBroadcastNotifications had a " + e)
}
} finally {
if (gisSource != null) {
gisSource.close()
}
if (gosSource != null) {
gosSource.close()
}
if (guideSocketToSource != null) {
guideSocketToSource.close()
}
}
}
}
}
class GuideSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var selectedSourceInfo: SourceInfo = null
private var thisWorkerInfo:SourceInfo = null
override def run() {
try {
logInfo("new GuideSingleRequest is running")
// Connecting worker is sending in its hostAddress and listenPort it will
// be listening to. Other fields are invalid (SourceInfo.UnusedParam)
var sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// Select a suitable source and send it back to the worker
selectedSourceInfo = selectSuitableSource(sourceInfo)
logDebug("Sending selectedSourceInfo: " + selectedSourceInfo)
oos.writeObject(selectedSourceInfo)
oos.flush()
// Add this new (if it can finish) source to the list of sources
thisWorkerInfo = SourceInfo(sourceInfo.hostAddress,
sourceInfo.listenPort, totalBlocks, totalBytes)
logDebug("Adding possible new source to listOfSources: " + thisWorkerInfo)
listOfSources += thisWorkerInfo
}
// Wait till the whole transfer is done. Then receive and update source
// statistics in listOfSources
sourceInfo = ois.readObject.asInstanceOf[SourceInfo]
listOfSources.synchronized {
// This should work since SourceInfo is a case class
assert(listOfSources.contains(selectedSourceInfo))
// Remove first
// (Currently removing a source based on just one failure notification!)
listOfSources = listOfSources - selectedSourceInfo
// Update sourceInfo and put it back in, IF reception succeeded
if (!sourceInfo.receptionFailed) {
// Add thisWorkerInfo to sources that have completed reception
setOfCompletedSources.synchronized {
setOfCompletedSources += thisWorkerInfo
}
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
listOfSources += selectedSourceInfo
}
}
} catch {
case e: Exception => {
// Remove failed worker from listOfSources and update leecherCount of
// corresponding source worker
listOfSources.synchronized {
if (selectedSourceInfo != null) {
// Remove first
listOfSources = listOfSources - selectedSourceInfo
// Update leecher count and put it back in
selectedSourceInfo.currentLeechers -= 1
listOfSources += selectedSourceInfo
}
// Remove thisWorkerInfo
if (listOfSources != null) {
listOfSources = listOfSources - thisWorkerInfo
}
}
}
} finally {
logInfo("GuideSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
// Assuming the caller to have a synchronized block on listOfSources
// Select one with the most leechers. This will level-wise fill the tree
private def selectSuitableSource(skipSourceInfo: SourceInfo): SourceInfo = {
var maxLeechers = -1
var selectedSource: SourceInfo = null
listOfSources.foreach { source =>
if ((source.hostAddress != skipSourceInfo.hostAddress ||
source.listenPort != skipSourceInfo.listenPort) &&
source.currentLeechers < MultiTracker.MaxDegree &&
source.currentLeechers > maxLeechers) {
selectedSource = source
maxLeechers = source.currentLeechers
}
}
// Update leecher count
selectedSource.currentLeechers += 1
return selectedSource
}
}
}
class ServeMultipleRequests
extends Thread with Logging {
var threadPool = Utils.newDaemonCachedThreadPool("Tree broadcast serve multiple requests")
override def run() {
var serverSocket = new ServerSocket(0)
listenPort = serverSocket.getLocalPort
logInfo("ServeMultipleRequests started with " + serverSocket)
listenPortLock.synchronized { listenPortLock.notifyAll() }
try {
while (!stopBroadcast) {
var clientSocket: Socket = null
try {
serverSocket.setSoTimeout(MultiTracker.ServerSocketTimeout)
clientSocket = serverSocket.accept
} catch {
case e: Exception => { }
}
if (clientSocket != null) {
logDebug("Serve: Accepted new client connection: " + clientSocket)
try {
threadPool.execute(new ServeSingleRequest(clientSocket))
} catch {
// In failure, close socket here; else, the thread will close it
case ioe: IOException => clientSocket.close()
}
}
}
} finally {
if (serverSocket != null) {
logInfo("ServeMultipleRequests now stopping...")
serverSocket.close()
}
}
// Shutdown the thread pool
threadPool.shutdown()
}
class ServeSingleRequest(val clientSocket: Socket)
extends Thread with Logging {
private val oos = new ObjectOutputStream(clientSocket.getOutputStream)
oos.flush()
private val ois = new ObjectInputStream(clientSocket.getInputStream)
private var sendFrom = 0
private var sendUntil = totalBlocks
override def run() {
try {
logInfo("new ServeSingleRequest is running")
// Receive range to send
var rangeToSend = ois.readObject.asInstanceOf[(Int, Int)]
sendFrom = rangeToSend._1
sendUntil = rangeToSend._2
// If not a valid range, stop broadcast
if (sendFrom == SourceInfo.StopBroadcast && sendUntil == SourceInfo.StopBroadcast) {
stopBroadcast = true
} else {
sendObject
}
} catch {
case e: Exception => logError("ServeSingleRequest had a " + e)
} finally {
logInfo("ServeSingleRequest is closing streams and sockets")
ois.close()
oos.close()
clientSocket.close()
}
}
private def sendObject() {
// Wait till receiving the SourceInfo from Driver
while (totalBlocks == -1) {
totalBlocksLock.synchronized { totalBlocksLock.wait() }
}
for (i <- sendFrom until sendUntil) {
while (i == hasBlocks) {
hasBlocksLock.synchronized { hasBlocksLock.wait() }
}
try {
oos.writeObject(arrayOfBlocks(i))
oos.flush()
} catch {
case e: Exception => logError("sendObject had a " + e)
}
logDebug("Sent block: " + i + " to " + clientSocket)
}
}
}
}
}
private[spark] class TreeBroadcastFactory
extends BroadcastFactory {
def initialize(isDriver: Boolean) { MultiTracker.initialize(isDriver) }
def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
new TreeBroadcast[T](value_, isLocal, id)
def stop() { MultiTracker.stop() }
}

View file

@ -17,26 +17,31 @@
package org.apache.spark.deploy
import com.google.common.collect.MapMaker
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.JobConf
import com.google.common.collect.MapMaker
/**
* Contains util methods to interact with Hadoop from spark.
* Contains util methods to interact with Hadoop from Spark.
*/
private[spark]
class SparkHadoopUtil {
// A general, soft-reference map for metadata needed during HadoopRDD split computation
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
// Return an appropriate (subclass) of Configuration. Creating config can initializes some hadoop
// subsystems
/**
* Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop
* subsystems.
*/
def newConfiguration(): Configuration = new Configuration()
// Add any user credentials to the job conf which are necessary for running on a secure Hadoop
// cluster
/**
* Add any user credentials to the job conf which are necessary for running on a secure Hadoop
* cluster.
*/
def addCredentials(conf: JobConf) {}
def isYarnMode(): Boolean = { false }

View file

@ -80,6 +80,11 @@ private[spark] class CoarseGrainedExecutorBackend(
case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
logError("Driver terminated or disconnected! Shutting down.")
System.exit(1)
case StopExecutor =>
logInfo("Driver commanded a shutdown")
context.stop(self)
context.system.shutdown()
}
override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) {

View file

@ -74,30 +74,33 @@ private[spark] class Executor(
private val replClassLoader = addReplClassLoaderIfNeeded(urlClassLoader)
Thread.currentThread.setContextClassLoader(replClassLoader)
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
Thread.setDefaultUncaughtExceptionHandler(
new Thread.UncaughtExceptionHandler {
override def uncaughtException(thread: Thread, exception: Throwable) {
try {
logError("Uncaught exception in thread " + thread, exception)
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
Thread.setDefaultUncaughtExceptionHandler(
new Thread.UncaughtExceptionHandler {
override def uncaughtException(thread: Thread, exception: Throwable) {
try {
logError("Uncaught exception in thread " + thread, exception)
// We may have been called from a shutdown hook. If so, we must not call System.exit().
// (If we do, we will deadlock.)
if (!Utils.inShutdown()) {
if (exception.isInstanceOf[OutOfMemoryError]) {
System.exit(ExecutorExitCode.OOM)
} else {
System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
// We may have been called from a shutdown hook. If so, we must not call System.exit().
// (If we do, we will deadlock.)
if (!Utils.inShutdown()) {
if (exception.isInstanceOf[OutOfMemoryError]) {
System.exit(ExecutorExitCode.OOM)
} else {
System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
}
}
} catch {
case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
} catch {
case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
}
}
}
)
)
}
val executorSource = new ExecutorSource(this, executorId)

View file

@ -102,4 +102,9 @@ class ShuffleWriteMetrics extends Serializable {
* Number of bytes written for a shuffle
*/
var shuffleBytesWritten: Long = _
/**
* Time spent blocking on writes to disk or buffer cache, in nanoseconds.
*/
var shuffleWriteTime: Long = _
}

View file

@ -21,7 +21,7 @@ import java.io.File
import org.apache.spark.Logging
import org.apache.spark.util.Utils
import org.apache.spark.storage.BlockId
import org.apache.spark.storage.{BlockId, FileSegment}
private[spark] class ShuffleSender(portIn: Int, val pResolver: PathResolver) extends Logging {
@ -54,8 +54,7 @@ private[spark] object ShuffleSender {
val localDirs = args.drop(2).map(new File(_))
val pResovler = new PathResolver {
override def getAbsolutePath(blockIdString: String): String = {
val blockId = BlockId(blockIdString)
override def getBlockLocation(blockId: BlockId): FileSegment = {
if (!blockId.isShuffle) {
throw new Exception("Block " + blockId + " is not a shuffle block")
}
@ -65,7 +64,7 @@ private[spark] object ShuffleSender {
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
val subDir = new File(localDirs(dirId), "%02x".format(subDirId))
val file = new File(subDir, blockId.name)
return file.getAbsolutePath
return new FileSegment(file, 0, file.length())
}
}
val sender = new ShuffleSender(port, pResovler)

View file

@ -15,6 +15,8 @@
* limitations under the License.
*/
package org.apache
/**
* Core Spark functionality. [[org.apache.spark.SparkContext]] serves as the main entry point to
* Spark, while [[org.apache.spark.rdd.RDD]] is the data type representing a distributed collection,

View file

@ -265,6 +265,19 @@ abstract class RDD[T: ClassManifest](
def distinct(): RDD[T] = distinct(partitions.size)
/**
* Return a new RDD that has exactly numPartitions partitions.
*
* Can increase or decrease the level of parallelism in this RDD. Internally, this uses
* a shuffle to redistribute data.
*
* If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
* which can avoid performing a shuffle.
*/
def repartition(numPartitions: Int): RDD[T] = {
coalesce(numPartitions, true)
}
/**
* Return a new RDD that is reduced into `numPartitions` partitions.
*

View file

@ -52,13 +52,14 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
private[spark]
class DAGScheduler(
taskSched: TaskScheduler,
mapOutputTracker: MapOutputTracker,
mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv)
extends Logging {
def this(taskSched: TaskScheduler) {
this(taskSched, SparkEnv.get.mapOutputTracker, SparkEnv.get.blockManager.master, SparkEnv.get)
this(taskSched, SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
SparkEnv.get.blockManager.master, SparkEnv.get)
}
taskSched.setDAGScheduler(this)
@ -67,6 +68,11 @@ class DAGScheduler(
eventQueue.put(BeginEvent(task, taskInfo))
}
// Called to report that a task has completed and results are being fetched remotely.
def taskGettingResult(task: Task[_], taskInfo: TaskInfo) {
eventQueue.put(GettingResultEvent(task, taskInfo))
}
// Called by TaskScheduler to report task completions or failures.
def taskEnded(
task: Task[_],
@ -182,7 +188,7 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
val stage = newStage(shuffleDep.rdd, Some(shuffleDep), jobId)
val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
@ -195,6 +201,7 @@ class DAGScheduler(
*/
private def newStage(
rdd: RDD[_],
numTasks: Int,
shuffleDep: Option[ShuffleDependency[_,_]],
jobId: Int,
callSite: Option[String] = None)
@ -207,9 +214,10 @@ class DAGScheduler(
mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
}
val id = nextStageId.getAndIncrement()
val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
val stage =
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
stageToInfos(stage) = StageInfo(stage)
stageToInfos(stage) = new StageInfo(stage)
stage
}
@ -277,11 +285,6 @@ class DAGScheduler(
resultHandler: (Int, U) => Unit,
properties: Properties = null): JobWaiter[U] =
{
val jobId = nextJobId.getAndIncrement()
if (partitions.size == 0) {
return new JobWaiter[U](this, jobId, 0, resultHandler)
}
// Check to make sure we are not launching a task on a partition that does not exist.
val maxPartitions = rdd.partitions.length
partitions.find(p => p >= maxPartitions).foreach { p =>
@ -290,6 +293,11 @@ class DAGScheduler(
"Total number of partitions: " + maxPartitions)
}
val jobId = nextJobId.getAndIncrement()
if (partitions.size == 0) {
return new JobWaiter[U](this, jobId, 0, resultHandler)
}
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
@ -342,6 +350,11 @@ class DAGScheduler(
eventQueue.put(JobCancelled(jobId))
}
def cancelJobGroup(groupId: String) {
logInfo("Asked to cancel job group " + groupId)
eventQueue.put(JobGroupCancelled(groupId))
}
/**
* Cancel all jobs that are running or waiting in the queue.
*/
@ -356,7 +369,7 @@ class DAGScheduler(
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
val finalStage = newStage(rdd, None, jobId, Some(callSite))
val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@ -381,6 +394,17 @@ class DAGScheduler(
taskSched.cancelTasks(stage.id)
}
case JobGroupCancelled(groupId) =>
// Cancel all jobs belonging to this job group.
// First finds all active jobs with this group id, and then kill stages for them.
val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
.map(_.jobId)
if (!jobIds.isEmpty) {
running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
taskSched.cancelTasks(stage.id)
}
}
case AllJobsCancelled =>
// Cancel all running jobs.
running.foreach { stage =>
@ -396,6 +420,9 @@ class DAGScheduler(
case begin: BeginEvent =>
listenerBus.post(SparkListenerTaskStart(begin.task, begin.taskInfo))
case gettingResult: GettingResultEvent =>
listenerBus.post(SparkListenerTaskGettingResult(gettingResult.task, gettingResult.taskInfo))
case completion: CompletionEvent =>
listenerBus.post(SparkListenerTaskEnd(
completion.task, completion.reason, completion.taskInfo, completion.taskMetrics))
@ -568,7 +595,7 @@ class DAGScheduler(
// must be run listener before possible NotSerializableException
// should be "StageSubmitted" first and then "JobEnded"
listenerBus.post(SparkListenerStageSubmitted(stage, tasks.size, properties))
listenerBus.post(SparkListenerStageSubmitted(stageToInfos(stage), properties))
if (tasks.size > 0) {
// Preemptively serialize a task to make sure it can be serialized. We are catching this
@ -589,9 +616,7 @@ class DAGScheduler(
logDebug("New pending tasks: " + myPending)
taskSched.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
if (!stage.submissionTime.isDefined) {
stage.submissionTime = Some(System.currentTimeMillis())
}
stageToInfos(stage).submissionTime = Some(System.currentTimeMillis())
} else {
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@ -613,12 +638,12 @@ class DAGScheduler(
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage) = {
val serviceTime = stage.submissionTime match {
val serviceTime = stageToInfos(stage).submissionTime match {
case Some(t) => "%.03f".format((System.currentTimeMillis() - t) / 1000.0)
case _ => "Unkown"
case _ => "Unknown"
}
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
stage.completionTime = Some(System.currentTimeMillis)
stageToInfos(stage).completionTime = Some(System.currentTimeMillis())
listenerBus.post(StageCompleted(stageToInfos(stage)))
running -= stage
}
@ -788,7 +813,7 @@ class DAGScheduler(
*/
private def abortStage(failedStage: Stage, reason: String) {
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
failedStage.completionTime = Some(System.currentTimeMillis())
stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
val error = new SparkException("Job aborted: " + reason)

View file

@ -46,11 +46,16 @@ private[scheduler] case class JobSubmitted(
private[scheduler] case class JobCancelled(jobId: Int) extends DAGSchedulerEvent
private[scheduler] case class JobGroupCancelled(groupId: String) extends DAGSchedulerEvent
private[scheduler] case object AllJobsCancelled extends DAGSchedulerEvent
private[scheduler]
case class BeginEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
private[scheduler]
case class GettingResultEvent(task: Task[_], taskInfo: TaskInfo) extends DAGSchedulerEvent
private[scheduler] case class CompletionEvent(
task: Task[_],
reason: TaskEndReason,

View file

@ -24,56 +24,54 @@ import java.text.SimpleDateFormat
import java.util.{Date, Properties}
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{Map, HashMap, ListBuffer}
import scala.io.Source
import scala.collection.mutable.{HashMap, ListBuffer}
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.executor.TaskMetrics
// Used to record runtime information for each job, including RDD graph
// tasks' start/stop shuffle information and information from outside
/**
* A logger class to record runtime information for jobs in Spark. This class outputs one log file
* per Spark job with information such as RDD graph, tasks start/stop, shuffle information.
*
* @param logDirName The base directory for the log files.
*/
class JobLogger(val logDirName: String) extends SparkListener with Logging {
private val logDir =
if (System.getenv("SPARK_LOG_DIR") != null)
System.getenv("SPARK_LOG_DIR")
else
"/tmp/spark"
private val logDir = Option(System.getenv("SPARK_LOG_DIR")).getOrElse("/tmp/spark")
private val jobIDToPrintWriter = new HashMap[Int, PrintWriter]
private val stageIDToJobID = new HashMap[Int, Int]
private val jobIDToStages = new HashMap[Int, ListBuffer[Stage]]
private val DATE_FORMAT = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss")
private val eventQueue = new LinkedBlockingQueue[SparkListenerEvents]
createLogDir()
def this() = this(String.valueOf(System.currentTimeMillis()))
def getLogDir = logDir
def getJobIDtoPrintWriter = jobIDToPrintWriter
def getStageIDToJobID = stageIDToJobID
def getJobIDToStages = jobIDToStages
def getEventQueue = eventQueue
// The following 5 functions are used only in testing.
private[scheduler] def getLogDir = logDir
private[scheduler] def getJobIDtoPrintWriter = jobIDToPrintWriter
private[scheduler] def getStageIDToJobID = stageIDToJobID
private[scheduler] def getJobIDToStages = jobIDToStages
private[scheduler] def getEventQueue = eventQueue
// Create a folder for log files, the folder's name is the creation time of the jobLogger
protected def createLogDir() {
val dir = new File(logDir + "/" + logDirName + "/")
if (dir.exists()) {
return
}
if (dir.mkdirs() == false) {
logError("create log directory error:" + logDir + "/" + logDirName + "/")
if (!dir.exists() && !dir.mkdirs()) {
logError("Error creating log directory: " + logDir + "/" + logDirName + "/")
}
}
// Create a log file for one job, the file name is the jobID
protected def createLogWriter(jobID: Int) {
try{
try {
val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobID)
jobIDToPrintWriter += (jobID -> fileWriter)
} catch {
case e: FileNotFoundException => e.printStackTrace()
}
} catch {
case e: FileNotFoundException => e.printStackTrace()
}
}
// Close log file, and clean the stage relationship in stageIDToJobID
@ -118,10 +116,9 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
def getRddsInStage(rdd: RDD[_]): ListBuffer[RDD[_]] = {
var rddList = new ListBuffer[RDD[_]]
rddList += rdd
rdd.dependencies.foreach{ dep => dep match {
case shufDep: ShuffleDependency[_,_] =>
case _ => rddList ++= getRddsInStage(dep.rdd)
}
rdd.dependencies.foreach {
case shufDep: ShuffleDependency[_, _] =>
case dep: Dependency[_] => rddList ++= getRddsInStage(dep.rdd)
}
rddList
}
@ -161,29 +158,27 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
protected def recordRddInStageGraph(jobID: Int, rdd: RDD[_], indent: Int) {
val rddInfo = "RDD_ID=" + rdd.id + "(" + getRddName(rdd) + "," + rdd.generator + ")"
jobLogInfo(jobID, indentString(indent) + rddInfo, false)
rdd.dependencies.foreach{ dep => dep match {
case shufDep: ShuffleDependency[_,_] =>
val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
case _ => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
rdd.dependencies.foreach {
case shufDep: ShuffleDependency[_, _] =>
val depInfo = "SHUFFLE_ID=" + shufDep.shuffleId
jobLogInfo(jobID, indentString(indent + 1) + depInfo, false)
case dep: Dependency[_] => recordRddInStageGraph(jobID, dep.rdd, indent + 1)
}
}
protected def recordStageDepGraph(jobID: Int, stage: Stage, indent: Int = 0) {
var stageInfo: String = ""
if (stage.isShuffleMap) {
stageInfo = "STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" +
stage.shuffleDep.get.shuffleId
}else{
stageInfo = "STAGE_ID=" + stage.id + " RESULT_STAGE"
val stageInfo = if (stage.isShuffleMap) {
"STAGE_ID=" + stage.id + " MAP_STAGE SHUFFLE_ID=" + stage.shuffleDep.get.shuffleId
} else {
"STAGE_ID=" + stage.id + " RESULT_STAGE"
}
if (stage.jobId == jobID) {
jobLogInfo(jobID, indentString(indent) + stageInfo, false)
recordRddInStageGraph(jobID, stage.rdd, indent)
stage.parents.foreach(recordStageDepGraph(jobID, _, indent + 2))
} else
} else {
jobLogInfo(jobID, indentString(indent) + stageInfo + " JOB_ID=" + stage.jobId, false)
}
}
// Record task metrics into job log files
@ -193,39 +188,32 @@ class JobLogger(val logDirName: String) extends SparkListener with Logging {
" START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime +
" EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname
val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime
val readMetrics =
taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
" REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
case None => ""
}
val writeMetrics =
taskMetrics.shuffleWriteMetrics match {
case Some(metrics) =>
" SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
case None => ""
}
val readMetrics = taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" SHUFFLE_FINISH_TIME=" + metrics.shuffleFinishTime +
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
" REMOTE_FETCH_TIME=" + metrics.remoteFetchTime +
" REMOTE_BYTES_READ=" + metrics.remoteBytesRead
case None => ""
}
val writeMetrics = taskMetrics.shuffleWriteMetrics match {
case Some(metrics) => " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten
case None => ""
}
stageLogInfo(stageID, status + info + executorRunTime + readMetrics + writeMetrics)
}
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) {
stageLogInfo(
stageSubmitted.stage.id,
"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.id, stageSubmitted.taskSize))
stageLogInfo(stageSubmitted.stage.stageId,"STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format(
stageSubmitted.stage.stageId, stageSubmitted.stage.numTasks))
}
override def onStageCompleted(stageCompleted: StageCompleted) {
stageLogInfo(
stageCompleted.stageInfo.stage.id,
"STAGE_ID=%d STATUS=COMPLETED".format(stageCompleted.stageInfo.stage.id))
stageLogInfo(stageCompleted.stage.stageId, "STAGE_ID=%d STATUS=COMPLETED".format(
stageCompleted.stage.stageId))
}
override def onTaskStart(taskStart: SparkListenerTaskStart) { }

View file

@ -164,17 +164,19 @@ private[spark] class ShuffleMapTask(
// Commit the writes. Get the size of each bucket block (total block size).
var totalBytes = 0L
var totalTime = 0L
val compressedSizes: Array[Byte] = buckets.writers.map { writer: BlockObjectWriter =>
writer.commit()
writer.close()
val size = writer.size()
val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
MapOutputTracker.compressSize(size)
}
// Update shuffle metrics.
val shuffleMetrics = new ShuffleWriteMetrics
shuffleMetrics.shuffleBytesWritten = totalBytes
shuffleMetrics.shuffleWriteTime = totalTime
metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
new MapStatus(blockManager.blockManagerId, compressedSizes)
@ -188,6 +190,7 @@ private[spark] class ShuffleMapTask(
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && buckets != null) {
buckets.writers.foreach(_.close())
shuffle.releaseWriters(buckets)
}
// Execute the callbacks on task completion.

View file

@ -24,13 +24,16 @@ import org.apache.spark.executor.TaskMetrics
sealed trait SparkListenerEvents
case class SparkListenerStageSubmitted(stage: Stage, taskSize: Int, properties: Properties)
case class SparkListenerStageSubmitted(stage: StageInfo, properties: Properties)
extends SparkListenerEvents
case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
case class StageCompleted(val stage: StageInfo) extends SparkListenerEvents
case class SparkListenerTaskStart(task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
case class SparkListenerTaskGettingResult(
task: Task[_], taskInfo: TaskInfo) extends SparkListenerEvents
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
taskMetrics: TaskMetrics) extends SparkListenerEvents
@ -56,6 +59,12 @@ trait SparkListener {
*/
def onTaskStart(taskStart: SparkListenerTaskStart) { }
/**
* Called when a task begins remotely fetching its result (will not be called for tasks that do
* not need to fetch the result remotely).
*/
def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
/**
* Called when a task ends
*/
@ -80,7 +89,7 @@ class StatsReportListener extends SparkListener with Logging {
override def onStageCompleted(stageCompleted: StageCompleted) {
import org.apache.spark.scheduler.StatsReportListener._
implicit val sc = stageCompleted
this.logInfo("Finished stage: " + stageCompleted.stageInfo)
this.logInfo("Finished stage: " + stageCompleted.stage)
showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
//shuffle write
@ -93,7 +102,7 @@ class StatsReportListener extends SparkListener with Logging {
//runtime breakdown
val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
val runtimePcts = stageCompleted.stage.taskInfos.map{
case (info, metrics) => RuntimePercentage(info.duration, metrics)
}
showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
@ -111,7 +120,7 @@ object StatsReportListener extends Logging {
val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
Distribution(stage.stageInfo.taskInfos.flatMap{
Distribution(stage.stage.taskInfos.flatMap {
case ((info,metric)) => getMetric(info, metric)})
}

View file

@ -49,6 +49,8 @@ private[spark] class SparkListenerBus() extends Logging {
sparkListeners.foreach(_.onJobEnd(jobEnd))
case taskStart: SparkListenerTaskStart =>
sparkListeners.foreach(_.onTaskStart(taskStart))
case taskGettingResult: SparkListenerTaskGettingResult =>
sparkListeners.foreach(_.onTaskGettingResult(taskGettingResult))
case taskEnd: SparkListenerTaskEnd =>
sparkListeners.foreach(_.onTaskEnd(taskEnd))
case _ =>

View file

@ -39,6 +39,7 @@ import org.apache.spark.storage.BlockManagerId
private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
val numTasks: Int,
val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
@ -49,11 +50,6 @@ private[spark] class Stage(
val numPartitions = rdd.partitions.size
val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
var numAvailableOutputs = 0
/** When first task was submitted to scheduler. */
var submissionTime: Option[Long] = None
var completionTime: Option[Long] = None
private var nextAttemptId = 0
def isAvailable: Boolean = {

View file

@ -21,9 +21,16 @@ import scala.collection._
import org.apache.spark.executor.TaskMetrics
case class StageInfo(
val stage: Stage,
class StageInfo(
stage: Stage,
val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
) {
override def toString = stage.rdd.toString
val stageId = stage.id
/** When this stage was submitted from the DAGScheduler to a TaskScheduler. */
var submissionTime: Option[Long] = None
var completionTime: Option[Long] = None
val rddName = stage.rdd.name
val name = stage.name
val numPartitions = stage.numPartitions
val numTasks = stage.numTasks
}

View file

@ -45,7 +45,7 @@ import org.apache.spark.util.ByteBufferInputStream
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
def run(attemptId: Long): T = {
final def run(attemptId: Long): T = {
context = new TaskContext(stageId, partitionId, attemptId, runningLocally = false)
if (_killed) {
kill()

View file

@ -31,9 +31,25 @@ class TaskInfo(
val host: String,
val taskLocality: TaskLocality.TaskLocality) {
/**
* The time when the task started remotely getting the result. Will not be set if the
* task result was sent immediately when the task finished (as opposed to sending an
* IndirectTaskResult and later fetching the result from the block manager).
*/
var gettingResultTime: Long = 0
/**
* The time when the task has completed successfully (including the time to remotely fetch
* results, if necessary).
*/
var finishTime: Long = 0
var failed = false
def markGettingResult(time: Long = System.currentTimeMillis) {
gettingResultTime = time
}
def markSuccessful(time: Long = System.currentTimeMillis) {
finishTime = time
}
@ -43,6 +59,8 @@ class TaskInfo(
failed = true
}
def gettingResult: Boolean = gettingResultTime != 0
def finished: Boolean = finishTime != 0
def successful: Boolean = finished && !failed
@ -52,6 +70,8 @@ class TaskInfo(
def status: String = {
if (running)
"RUNNING"
else if (gettingResult)
"GET RESULT"
else if (failed)
"FAILED"
else if (successful)

View file

@ -306,6 +306,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
}
}
def handleTaskGettingResult(taskSetManager: ClusterTaskSetManager, tid: Long) {
taskSetManager.handleTaskGettingResult(tid)
}
def handleSuccessfulTask(
taskSetManager: ClusterTaskSetManager,
tid: Long,

View file

@ -418,6 +418,12 @@ private[spark] class ClusterTaskSetManager(
sched.dagScheduler.taskStarted(task, info)
}
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
sched.dagScheduler.taskGettingResult(tasks(info.index), info)
}
/**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/

View file

@ -60,6 +60,10 @@ private[spark] object CoarseGrainedClusterMessages {
case object StopDriver extends CoarseGrainedClusterMessage
case object StopExecutor extends CoarseGrainedClusterMessage
case object StopExecutors extends CoarseGrainedClusterMessage
case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
}

View file

@ -101,6 +101,13 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
sender ! true
context.stop(self)
case StopExecutors =>
logInfo("Asking each executor to shut down")
for (executor <- executorActor.values) {
executor ! StopExecutor
}
sender ! true
case RemoveExecutor(executorId, reason) =>
removeExecutor(executorId, reason)
sender ! true
@ -170,11 +177,24 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
def stopExecutors() {
try {
if (driverActor != null) {
logInfo("Shutting down all executors")
val future = driverActor.ask(StopExecutors)(timeout)
Await.ready(future, timeout)
}
} catch {
case e: Exception =>
throw new SparkException("Error asking standalone scheduler to shut down executors", e)
}
}
override def stop() {
try {
if (driverActor != null) {
val future = driverActor.ask(StopDriver)(timeout)
Await.result(future, timeout)
Await.ready(future, timeout)
}
} catch {
case e: Exception =>
@ -197,7 +217,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
def removeExecutor(executorId: String, reason: String) {
try {
val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout)
Await.result(future, timeout)
Await.ready(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error notifying standalone scheduler's driver actor", e)

View file

@ -0,0 +1,66 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.scheduler.cluster
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.{Logging, SparkContext}
private[spark] class SimrSchedulerBackend(
scheduler: ClusterScheduler,
sc: SparkContext,
driverFilePath: String)
extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
with Logging {
val tmpPath = new Path(driverFilePath + "_tmp")
val filePath = new Path(driverFilePath)
val maxCores = System.getProperty("spark.simr.executor.cores", "1").toInt
override def start() {
super.start()
val driverUrl = "akka://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
val conf = new Configuration()
val fs = FileSystem.get(conf)
logInfo("Writing to HDFS file: " + driverFilePath)
logInfo("Writing Akka address: " + driverUrl)
// Create temporary file to prevent race condition where executors get empty driverUrl file
val temp = fs.create(tmpPath, true)
temp.writeUTF(driverUrl)
temp.writeInt(maxCores)
temp.close()
// "Atomic" rename
fs.rename(tmpPath, filePath)
}
override def stop() {
val conf = new Configuration()
val fs = FileSystem.get(conf)
fs.delete(new Path(driverFilePath), false)
super.stopExecutors()
super.stop()
}
}

View file

@ -50,6 +50,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
case directResult: DirectTaskResult[_] => directResult
case IndirectTaskResult(blockId) =>
logDebug("Fetching indirect task result for TID %s".format(tid))
scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
if (!serializedTaskResult.isDefined) {
/* We won't be able to get the task result if the machine that ran the task failed

View file

@ -32,7 +32,7 @@ private[spark] sealed abstract class BlockId {
def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None
def isRDD = isInstanceOf[RDDBlockId]
def isShuffle = isInstanceOf[ShuffleBlockId]
def isBroadcast = isInstanceOf[BroadcastBlockId]
def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId]
override def toString = name
override def hashCode = name.hashCode
@ -55,6 +55,10 @@ private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId {
def name = "broadcast_" + broadcastId
}
private[spark] case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId {
def name = broadcastId.name + "_" + hType
}
private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId {
def name = "taskresult_" + taskId
}
@ -72,6 +76,7 @@ private[spark] object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
val BROADCAST = "broadcast_([0-9]+)".r
val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r
val TASKRESULT = "taskresult_([0-9]+)".r
val STREAM = "input-([0-9]+)-([0-9]+)".r
val TEST = "test_(.*)".r
@ -84,6 +89,8 @@ private[spark] object BlockId {
ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt)
case BROADCAST(broadcastId) =>
BroadcastBlockId(broadcastId.toLong)
case BROADCAST_HELPER(broadcastId, hType) =>
BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType)
case TASKRESULT(taskId) =>
TaskResultBlockId(taskId.toLong)
case STREAM(streamId, uniqueId) =>

View file

@ -0,0 +1,97 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import java.util.concurrent.ConcurrentHashMap
private[storage] trait BlockInfo {
def level: StorageLevel
def tellMaster: Boolean
// To save space, 'pending' and 'failed' are encoded as special sizes:
@volatile var size: Long = BlockInfo.BLOCK_PENDING
private def pending: Boolean = size == BlockInfo.BLOCK_PENDING
private def failed: Boolean = size == BlockInfo.BLOCK_FAILED
private def initThread: Thread = BlockInfo.blockInfoInitThreads.get(this)
setInitThread()
private def setInitThread() {
// Set current thread as init thread - waitForReady will not block this thread
// (in case there is non trivial initialization which ends up calling waitForReady as part of
// initialization itself)
BlockInfo.blockInfoInitThreads.put(this, Thread.currentThread())
}
/**
* Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
* Return true if the block is available, false otherwise.
*/
def waitForReady(): Boolean = {
if (pending && initThread != Thread.currentThread()) {
synchronized {
while (pending) this.wait()
}
}
!failed
}
/** Mark this BlockInfo as ready (i.e. block is finished writing) */
def markReady(sizeInBytes: Long) {
require (sizeInBytes >= 0, "sizeInBytes was negative: " + sizeInBytes)
assert (pending)
size = sizeInBytes
BlockInfo.blockInfoInitThreads.remove(this)
synchronized {
this.notifyAll()
}
}
/** Mark this BlockInfo as ready but failed */
def markFailure() {
assert (pending)
size = BlockInfo.BLOCK_FAILED
BlockInfo.blockInfoInitThreads.remove(this)
synchronized {
this.notifyAll()
}
}
}
private object BlockInfo {
// initThread is logically a BlockInfo field, but we store it here because
// it's only needed while this block is in the 'pending' state and we want
// to minimize BlockInfo's memory footprint.
private val blockInfoInitThreads = new ConcurrentHashMap[BlockInfo, Thread]
private val BLOCK_PENDING: Long = -1L
private val BLOCK_FAILED: Long = -2L
}
// All shuffle blocks have the same `level` and `tellMaster` properties,
// so we can save space by not storing them in each instance:
private[storage] class ShuffleBlockInfo extends BlockInfo {
// These need to be defined using 'def' instead of 'val' in order for
// the compiler to eliminate the fields:
def level: StorageLevel = StorageLevel.DISK_ONLY
def tellMaster: Boolean = false
}
private[storage] class BlockInfoImpl(val level: StorageLevel, val tellMaster: Boolean)
extends BlockInfo {
// Intentionally left blank
}

View file

@ -20,14 +20,15 @@ package org.apache.spark.storage
import java.io.{InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet}
import scala.collection.mutable.{HashMap, ArrayBuffer}
import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
import akka.dispatch.{Await, Future}
import akka.util.Duration
import akka.util.duration._
import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
@ -45,74 +46,20 @@ private[spark] class BlockManager(
maxMemory: Long)
extends Logging {
private class BlockInfo(val level: StorageLevel, val tellMaster: Boolean) {
@volatile var pending: Boolean = true
@volatile var size: Long = -1L
@volatile var initThread: Thread = null
@volatile var failed = false
setInitThread()
private def setInitThread() {
// Set current thread as init thread - waitForReady will not block this thread
// (in case there is non trivial initialization which ends up calling waitForReady as part of
// initialization itself)
this.initThread = Thread.currentThread()
}
/**
* Wait for this BlockInfo to be marked as ready (i.e. block is finished writing).
* Return true if the block is available, false otherwise.
*/
def waitForReady(): Boolean = {
if (initThread != Thread.currentThread() && pending) {
synchronized {
while (pending) this.wait()
}
}
!failed
}
/** Mark this BlockInfo as ready (i.e. block is finished writing) */
def markReady(sizeInBytes: Long) {
assert (pending)
size = sizeInBytes
initThread = null
failed = false
initThread = null
pending = false
synchronized {
this.notifyAll()
}
}
/** Mark this BlockInfo as ready but failed */
def markFailure() {
assert (pending)
size = 0
initThread = null
failed = true
initThread = null
pending = false
synchronized {
this.notifyAll()
}
}
}
val shuffleBlockManager = new ShuffleBlockManager(this)
val diskBlockManager = new DiskBlockManager(
System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory)
private[storage] val diskStore: DiskStore =
new DiskStore(this, System.getProperty("spark.local.dir", System.getProperty("java.io.tmpdir")))
private[storage] val diskStore = new DiskStore(this, diskBlockManager)
// If we use Netty for shuffle, start a new Netty-based shuffle sender service.
private val nettyPort: Int = {
val useNetty = System.getProperty("spark.shuffle.use.netty", "false").toBoolean
val nettyPortConfig = System.getProperty("spark.shuffle.sender.port", "0").toInt
if (useNetty) diskStore.startShuffleBlockSender(nettyPortConfig) else 0
if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0
}
val connectionManager = new ConnectionManager(0)
@ -269,7 +216,7 @@ private[spark] class BlockManager(
}
/**
* Actually send a UpdateBlockInfo message. Returns the mater's response,
* Actually send a UpdateBlockInfo message. Returns the master's response,
* which will be true if the block was successfully recorded and false if
* the slave needs to re-register.
*/
@ -320,89 +267,14 @@ private[spark] class BlockManager(
*/
def getLocal(blockId: BlockId): Option[Iterator[Any]] = {
logDebug("Getting local block " + blockId)
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
// In the another thread is writing the block, wait for it to become ready.
if (!info.waitForReady()) {
// If we get here, the block write failed.
logWarning("Block " + blockId + " was marked as failure.")
return None
}
val level = info.level
logDebug("Level for block " + blockId + " is " + level)
// Look for the block in memory
if (level.useMemory) {
logDebug("Getting block " + blockId + " from memory")
memoryStore.getValues(blockId) match {
case Some(iterator) =>
return Some(iterator)
case None =>
logDebug("Block " + blockId + " not found in memory")
}
}
// Look for block on disk, potentially loading it back into memory if required
if (level.useDisk) {
logDebug("Getting block " + blockId + " from disk")
if (level.useMemory && level.deserialized) {
diskStore.getValues(blockId) match {
case Some(iterator) =>
// Put the block back in memory before returning it
// TODO: Consider creating a putValues that also takes in a iterator ?
val elements = new ArrayBuffer[Any]
elements ++= iterator
memoryStore.putValues(blockId, elements, level, true).data match {
case Left(iterator2) =>
return Some(iterator2)
case _ =>
throw new Exception("Memory store did not return back an iterator")
}
case None =>
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
} else if (level.useMemory && !level.deserialized) {
// Read it as a byte buffer into memory first, then return it
diskStore.getBytes(blockId) match {
case Some(bytes) =>
// Put a copy of the block back in memory before returning it. Note that we can't
// put the ByteBuffer returned by the disk store as that's a memory-mapped file.
// The use of rewind assumes this.
assert (0 == bytes.position())
val copyForMemory = ByteBuffer.allocate(bytes.limit)
copyForMemory.put(bytes)
memoryStore.putBytes(blockId, copyForMemory, level)
bytes.rewind()
return Some(dataDeserialize(blockId, bytes))
case None =>
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
} else {
diskStore.getValues(blockId) match {
case Some(iterator) =>
return Some(iterator)
case None =>
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
}
}
}
} else {
logDebug("Block " + blockId + " not registered locally")
}
return None
doGetLocal(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
}
/**
* Get block from the local block manager as serialized bytes.
*/
def getLocalBytes(blockId: BlockId): Option[ByteBuffer] = {
// TODO: This whole thing is very similar to getLocal; we need to refactor it somehow
logDebug("Getting local block " + blockId + " as bytes")
// As an optimization for map output fetches, if the block is for a shuffle, return it
// without acquiring a lock; the disk store never deletes (recent) items so this should work
if (blockId.isShuffle) {
@ -413,12 +285,15 @@ private[spark] class BlockManager(
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
}
doGetLocal(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
}
private def doGetLocal(blockId: BlockId, asValues: Boolean): Option[Any] = {
val info = blockInfo.get(blockId).orNull
if (info != null) {
info.synchronized {
// In the another thread is writing the block, wait for it to become ready.
// If another thread is writing the block, wait for it to become ready.
if (!info.waitForReady()) {
// If we get here, the block write failed.
logWarning("Block " + blockId + " was marked as failure.")
@ -431,94 +306,111 @@ private[spark] class BlockManager(
// Look for the block in memory
if (level.useMemory) {
logDebug("Getting block " + blockId + " from memory")
memoryStore.getBytes(blockId) match {
case Some(bytes) =>
return Some(bytes)
val result = if (asValues) {
memoryStore.getValues(blockId)
} else {
memoryStore.getBytes(blockId)
}
result match {
case Some(values) =>
return Some(values)
case None =>
logDebug("Block " + blockId + " not found in memory")
}
}
// Look for block on disk
// Look for block on disk, potentially storing it back into memory if required:
if (level.useDisk) {
// Read it as a byte buffer into memory first, then return it
diskStore.getBytes(blockId) match {
case Some(bytes) =>
assert (0 == bytes.position())
if (level.useMemory) {
if (level.deserialized) {
memoryStore.putBytes(blockId, bytes, level)
} else {
// The memory store will hang onto the ByteBuffer, so give it a copy instead of
// the memory-mapped file buffer we got from the disk store
val copyForMemory = ByteBuffer.allocate(bytes.limit)
copyForMemory.put(bytes)
memoryStore.putBytes(blockId, copyForMemory, level)
}
}
bytes.rewind()
return Some(bytes)
logDebug("Getting block " + blockId + " from disk")
val bytes: ByteBuffer = diskStore.getBytes(blockId) match {
case Some(bytes) => bytes
case None =>
throw new Exception("Block " + blockId + " not found on disk, though it should be")
}
assert (0 == bytes.position())
if (!level.useMemory) {
// If the block shouldn't be stored in memory, we can just return it:
if (asValues) {
return Some(dataDeserialize(blockId, bytes))
} else {
return Some(bytes)
}
} else {
// Otherwise, we also have to store something in the memory store:
if (!level.deserialized || !asValues) {
// We'll store the bytes in memory if the block's storage level includes
// "memory serialized", or if it should be cached as objects in memory
// but we only requested its serialized bytes:
val copyForMemory = ByteBuffer.allocate(bytes.limit)
copyForMemory.put(bytes)
memoryStore.putBytes(blockId, copyForMemory, level)
bytes.rewind()
}
if (!asValues) {
return Some(bytes)
} else {
val values = dataDeserialize(blockId, bytes)
if (level.deserialized) {
// Cache the values before returning them:
// TODO: Consider creating a putValues that also takes in a iterator?
val valuesBuffer = new ArrayBuffer[Any]
valuesBuffer ++= values
memoryStore.putValues(blockId, valuesBuffer, level, true).data match {
case Left(values2) =>
return Some(values2)
case _ =>
throw new Exception("Memory store did not return back an iterator")
}
} else {
return Some(values)
}
}
}
}
}
} else {
logDebug("Block " + blockId + " not registered locally")
}
return None
None
}
/**
* Get block from remote block managers.
*/
def getRemote(blockId: BlockId): Option[Iterator[Any]] = {
if (blockId == null) {
throw new IllegalArgumentException("Block Id is null")
}
logDebug("Getting remote block " + blockId)
// Get locations of block
val locations = master.getLocations(blockId)
// Get block from remote locations
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
return Some(dataDeserialize(blockId, data))
}
logDebug("The value of block " + blockId + " is null")
}
logDebug("Block " + blockId + " not found")
return None
doGetRemote(blockId, asValues = true).asInstanceOf[Option[Iterator[Any]]]
}
/**
* Get block from remote block managers as serialized bytes.
*/
def getRemoteBytes(blockId: BlockId): Option[ByteBuffer] = {
// TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be
// refactored.
if (blockId == null) {
throw new IllegalArgumentException("Block Id is null")
}
logDebug("Getting remote block " + blockId + " as bytes")
val locations = master.getLocations(blockId)
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
return Some(data)
}
logDebug("The value of block " + blockId + " is null")
}
logDebug("Block " + blockId + " not found")
return None
logDebug("Getting remote block " + blockId + " as bytes")
doGetRemote(blockId, asValues = false).asInstanceOf[Option[ByteBuffer]]
}
private def doGetRemote(blockId: BlockId, asValues: Boolean): Option[Any] = {
require(blockId != null, "BlockId is null")
val locations = Random.shuffle(master.getLocations(blockId))
for (loc <- locations) {
logDebug("Getting remote block " + blockId + " from " + loc)
val data = BlockManagerWorker.syncGetBlock(
GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
if (data != null) {
if (asValues) {
return Some(dataDeserialize(blockId, data))
} else {
return Some(data)
}
}
logDebug("The value of block " + blockId + " is null")
}
logDebug("Block " + blockId + " not found")
return None
}
/**
* Get a block from the block manager (either local or remote).
*/
@ -566,16 +458,22 @@ private[spark] class BlockManager(
/**
* A short circuited method to get a block writer that can write data directly to disk.
* The Block will be appended to the File specified by filename.
* This is currently used for writing shuffle files out. Callers should handle error
* cases.
*/
def getDiskBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
def getDiskWriter(blockId: BlockId, filename: String, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
val writer = diskStore.getBlockWriter(blockId, serializer, bufferSize)
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val file = diskBlockManager.createBlockFile(blockId, filename, allowAppending = true)
val writer = new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream)
writer.registerCloseEventHandler(() => {
val myInfo = new BlockInfo(StorageLevel.DISK_ONLY, false)
if (shuffleBlockManager.consolidateShuffleFiles) {
diskBlockManager.mapBlockToFileSegment(blockId, writer.fileSegment())
}
val myInfo = new ShuffleBlockInfo()
blockInfo.put(blockId, myInfo)
myInfo.markReady(writer.size())
myInfo.markReady(writer.fileSegment().length)
})
writer
}
@ -584,23 +482,30 @@ private[spark] class BlockManager(
* Put a new block of values to the block manager. Returns its (estimated) size in bytes.
*/
def put(blockId: BlockId, values: ArrayBuffer[Any], level: StorageLevel,
tellMaster: Boolean = true) : Long = {
tellMaster: Boolean = true) : Long = {
require(values != null, "Values is null")
doPut(blockId, Left(values), level, tellMaster)
}
if (blockId == null) {
throw new IllegalArgumentException("Block Id is null")
}
if (values == null) {
throw new IllegalArgumentException("Values is null")
}
if (level == null || !level.isValid) {
throw new IllegalArgumentException("Storage level is null or invalid")
}
/**
* Put a new block of serialized bytes to the block manager.
*/
def putBytes(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel,
tellMaster: Boolean = true) {
require(bytes != null, "Bytes is null")
doPut(blockId, Right(bytes), level, tellMaster)
}
private def doPut(blockId: BlockId, data: Either[ArrayBuffer[Any], ByteBuffer],
level: StorageLevel, tellMaster: Boolean = true): Long = {
require(blockId != null, "BlockId is null")
require(level != null && level.isValid, "StorageLevel is null or invalid")
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
val myInfo = {
val tinfo = new BlockInfo(level, tellMaster)
val tinfo = new BlockInfoImpl(level, tellMaster)
// Do atomically !
val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
@ -610,7 +515,8 @@ private[spark] class BlockManager(
return oldBlockOpt.get.size
}
// TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
// TODO: So the block info exists - but previous attempt to load it (?) failed.
// What do we do now ? Retry on it ?
oldBlockOpt.get
} else {
tinfo
@ -619,10 +525,10 @@ private[spark] class BlockManager(
val startTimeMs = System.currentTimeMillis
// If we need to replicate the data, we'll want access to the values, but because our
// put will read the whole iterator, there will be no values left. For the case where
// the put serializes data, we'll remember the bytes, above; but for the case where it
// doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
// If we're storing values and we need to replicate the data, we'll want access to the values,
// but because our put will read the whole iterator, there will be no values left. For the
// case where the put serializes data, we'll remember the bytes, above; but for the case where
// it doesn't, such as deserialized storage, let's rely on the put returning an Iterator.
var valuesAfterPut: Iterator[Any] = null
// Ditto for the bytes after the put
@ -631,30 +537,51 @@ private[spark] class BlockManager(
// Size of the block in bytes (to return to caller)
var size = 0L
// If we're storing bytes, then initiate the replication before storing them locally.
// This is faster as data is already serialized and ready to send.
val replicationFuture = if (data.isRight && level.replication > 1) {
val bufferView = data.right.get.duplicate() // Doesn't copy the bytes, just creates a wrapper
Future {
replicate(blockId, bufferView, level)
}
} else {
null
}
myInfo.synchronized {
logTrace("Put for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
var marked = false
try {
if (level.useMemory) {
// Save it just to memory first, even if it also has useDisk set to true; we will later
// drop it to disk if the memory store can't hold it.
val res = memoryStore.putValues(blockId, values, level, true)
size = res.size
res.data match {
case Right(newBytes) => bytesAfterPut = newBytes
case Left(newIterator) => valuesAfterPut = newIterator
data match {
case Left(values) => {
if (level.useMemory) {
// Save it just to memory first, even if it also has useDisk set to true; we will
// drop it to disk later if the memory store can't hold it.
val res = memoryStore.putValues(blockId, values, level, true)
size = res.size
res.data match {
case Right(newBytes) => bytesAfterPut = newBytes
case Left(newIterator) => valuesAfterPut = newIterator
}
} else {
// Save directly to disk.
// Don't get back the bytes unless we replicate them.
val askForBytes = level.replication > 1
val res = diskStore.putValues(blockId, values, level, askForBytes)
size = res.size
res.data match {
case Right(newBytes) => bytesAfterPut = newBytes
case _ =>
}
}
}
} else {
// Save directly to disk.
// Don't get back the bytes unless we replicate them.
val askForBytes = level.replication > 1
val res = diskStore.putValues(blockId, values, level, askForBytes)
size = res.size
res.data match {
case Right(newBytes) => bytesAfterPut = newBytes
case _ =>
case Right(bytes) => {
bytes.rewind()
// Store it only in memory at first, even if useDisk is also set to true
(if (level.useMemory) memoryStore else diskStore).putBytes(blockId, bytes, level)
size = bytes.limit
}
}
@ -679,125 +606,39 @@ private[spark] class BlockManager(
}
logDebug("Put block " + blockId + " locally took " + Utils.getUsedTimeMs(startTimeMs))
// Replicate block if required
// Either we're storing bytes and we asynchronously started replication, or we're storing
// values and need to serialize and replicate them now:
if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
// Serialize the block if not already done
if (bytesAfterPut == null) {
if (valuesAfterPut == null) {
throw new SparkException(
"Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
data match {
case Right(bytes) => Await.ready(replicationFuture, Duration.Inf)
case Left(values) => {
val remoteStartTime = System.currentTimeMillis
// Serialize the block if not already done
if (bytesAfterPut == null) {
if (valuesAfterPut == null) {
throw new SparkException(
"Underlying put returned neither an Iterator nor bytes! This shouldn't happen.")
}
bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
}
replicate(blockId, bytesAfterPut, level)
logDebug("Put block " + blockId + " remotely took " +
Utils.getUsedTimeMs(remoteStartTime))
}
bytesAfterPut = dataSerialize(blockId, valuesAfterPut)
}
replicate(blockId, bytesAfterPut, level)
logDebug("Put block " + blockId + " remotely took " + Utils.getUsedTimeMs(remoteStartTime))
}
BlockManager.dispose(bytesAfterPut)
return size
}
/**
* Put a new block of serialized bytes to the block manager.
*/
def putBytes(
blockId: BlockId, bytes: ByteBuffer, level: StorageLevel, tellMaster: Boolean = true) {
if (blockId == null) {
throw new IllegalArgumentException("Block Id is null")
}
if (bytes == null) {
throw new IllegalArgumentException("Bytes is null")
}
if (level == null || !level.isValid) {
throw new IllegalArgumentException("Storage level is null or invalid")
}
// Remember the block's storage level so that we can correctly drop it to disk if it needs
// to be dropped right after it got put into memory. Note, however, that other threads will
// not be able to get() this block until we call markReady on its BlockInfo.
val myInfo = {
val tinfo = new BlockInfo(level, tellMaster)
// Do atomically !
val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo)
if (oldBlockOpt.isDefined) {
if (oldBlockOpt.get.waitForReady()) {
logWarning("Block " + blockId + " already exists on this machine; not re-adding it")
return
}
// TODO: So the block info exists - but previous attempt to load it (?) failed. What do we do now ? Retry on it ?
oldBlockOpt.get
} else {
tinfo
}
}
val startTimeMs = System.currentTimeMillis
// Initiate the replication before storing it locally. This is faster as
// data is already serialized and ready for sending
val replicationFuture = if (level.replication > 1) {
val bufferView = bytes.duplicate() // Doesn't copy the bytes, just creates a wrapper
Future {
replicate(blockId, bufferView, level)
}
} else {
null
}
myInfo.synchronized {
logDebug("PutBytes for block " + blockId + " took " + Utils.getUsedTimeMs(startTimeMs)
+ " to get into synchronized block")
var marked = false
try {
if (level.useMemory) {
// Store it only in memory at first, even if useDisk is also set to true
bytes.rewind()
memoryStore.putBytes(blockId, bytes, level)
} else {
bytes.rewind()
diskStore.putBytes(blockId, bytes, level)
}
// assert (0 == bytes.position(), "" + bytes)
// Now that the block is in either the memory or disk store, let other threads read it,
// and tell the master about it.
marked = true
myInfo.markReady(bytes.limit)
if (tellMaster) {
reportBlockStatus(blockId, myInfo)
}
} finally {
// If we failed at putting the block to memory/disk, notify other possible readers
// that it has failed, and then remove it from the block info map.
if (! marked) {
// Note that the remove must happen before markFailure otherwise another thread
// could've inserted a new BlockInfo before we remove it.
blockInfo.remove(blockId)
myInfo.markFailure()
logWarning("Putting block " + blockId + " failed")
}
}
}
// If replication had started, then wait for it to finish
if (level.replication > 1) {
Await.ready(replicationFuture, Duration.Inf)
}
if (level.replication > 1) {
logDebug("PutBytes for block " + blockId + " with replication took " +
logDebug("Put for block " + blockId + " with replication took " +
Utils.getUsedTimeMs(startTimeMs))
} else {
logDebug("PutBytes for block " + blockId + " without replication took " +
logDebug("Put for block " + blockId + " without replication took " +
Utils.getUsedTimeMs(startTimeMs))
}
size
}
/**
@ -922,34 +763,20 @@ private[spark] class BlockManager(
private def dropOldNonBroadcastBlocks(cleanupTime: Long) {
logInfo("Dropping non broadcast blocks older than " + cleanupTime)
val iterator = blockInfo.internalMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
if (time < cleanupTime && !id.isBroadcast) {
info.synchronized {
val level = info.level
if (level.useMemory) {
memoryStore.remove(id)
}
if (level.useDisk) {
diskStore.remove(id)
}
iterator.remove()
logInfo("Dropped block " + id)
}
reportBlockStatus(id, info)
}
}
dropOldBlocks(cleanupTime, !_.isBroadcast)
}
private def dropOldBroadcastBlocks(cleanupTime: Long) {
logInfo("Dropping broadcast blocks older than " + cleanupTime)
dropOldBlocks(cleanupTime, _.isBroadcast)
}
private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) {
val iterator = blockInfo.internalMap.entrySet().iterator()
while (iterator.hasNext) {
val entry = iterator.next()
val (id, info, time) = (entry.getKey, entry.getValue._1, entry.getValue._2)
if (time < cleanupTime && id.isBroadcast) {
if (time < cleanupTime && shouldDrop(id)) {
info.synchronized {
val level = info.level
if (level.useMemory) {
@ -987,13 +814,24 @@ private[spark] class BlockManager(
if (shouldCompress(blockId)) compressionCodec.compressedInputStream(s) else s
}
/** Serializes into a stream. */
def dataSerializeStream(
blockId: BlockId,
outputStream: OutputStream,
values: Iterator[Any],
serializer: Serializer = defaultSerializer) {
val byteStream = new FastBufferedOutputStream(outputStream)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
}
/** Serializes into a byte buffer. */
def dataSerialize(
blockId: BlockId,
values: Iterator[Any],
serializer: Serializer = defaultSerializer): ByteBuffer = {
val byteStream = new FastByteArrayOutputStream(4096)
val ser = serializer.newInstance()
ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
dataSerializeStream(blockId, byteStream, values, serializer)
byteStream.trim()
ByteBuffer.wrap(byteStream.array)
}

View file

@ -227,9 +227,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
}
private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
if (id.executorId == "<driver>" && !isLocal) {
// Got a register message from the master node; don't register it
} else if (!blockManagerInfo.contains(id)) {
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
case Some(manager) =>
// A block manager of the same executor already exists.

View file

@ -26,6 +26,7 @@ import org.apache.spark.storage.BlockManagerMessages._
* An actor to take commands from the master to execute options. For example,
* this is used to remove blocks from the slave's BlockManager.
*/
private[storage]
class BlockManagerSlaveActor(blockManager: BlockManager) extends Actor {
override def receive = {

View file

@ -17,6 +17,13 @@
package org.apache.spark.storage
import java.io.{FileOutputStream, File, OutputStream}
import java.nio.channels.FileChannel
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.Logging
import org.apache.spark.serializer.{SerializationStream, Serializer}
/**
* An interface for writing JVM objects to some underlying storage. This interface allows
@ -59,7 +66,129 @@ abstract class BlockObjectWriter(val blockId: BlockId) {
def write(value: Any)
/**
* Size of the valid writes, in bytes.
* Returns the file segment of committed data that this Writer has written.
*/
def size(): Long
def fileSegment(): FileSegment
/**
* Cumulative time spent performing blocking writes, in ns.
*/
def timeWriting(): Long
}
/** BlockObjectWriter which writes directly to a file on disk. Appends to the given file. */
class DiskBlockObjectWriter(
blockId: BlockId,
file: File,
serializer: Serializer,
bufferSize: Int,
compressStream: OutputStream => OutputStream)
extends BlockObjectWriter(blockId)
with Logging
{
/** Intercepts write calls and tracks total time spent writing. Not thread safe. */
private class TimeTrackingOutputStream(out: OutputStream) extends OutputStream {
def timeWriting = _timeWriting
private var _timeWriting = 0L
private def callWithTiming(f: => Unit) = {
val start = System.nanoTime()
f
_timeWriting += (System.nanoTime() - start)
}
def write(i: Int): Unit = callWithTiming(out.write(i))
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
}
private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
/** The file channel, used for repositioning / truncating the file. */
private var channel: FileChannel = null
private var bs: OutputStream = null
private var fos: FileOutputStream = null
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
private var initialPosition = 0L
private var lastValidPosition = 0L
private var initialized = false
private var _timeWriting = 0L
override def open(): BlockObjectWriter = {
fos = new FileOutputStream(file, true)
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
initialPosition = channel.position
lastValidPosition = initialPosition
bs = compressStream(new FastBufferedOutputStream(ts, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
initialized = true
this
}
override def close() {
if (initialized) {
if (syncWrites) {
// Force outstanding writes to disk and track how long it takes
objOut.flush()
val start = System.nanoTime()
fos.getFD.sync()
_timeWriting += System.nanoTime() - start
}
objOut.close()
_timeWriting += ts.timeWriting
channel = null
bs = null
fos = null
ts = null
objOut = null
}
// Invoke the close callback handler.
super.close()
}
override def isOpen: Boolean = objOut != null
override def commit(): Long = {
if (initialized) {
// NOTE: Flush the serializer first and then the compressed/buffered output stream
objOut.flush()
bs.flush()
val prevPos = lastValidPosition
lastValidPosition = channel.position()
lastValidPosition - prevPos
} else {
// lastValidPosition is zero if stream is uninitialized
lastValidPosition
}
}
override def revertPartialWrites() {
if (initialized) {
// Discard current writes. We do this by flushing the outstanding writes and
// truncate the file to the last valid position.
objOut.flush()
bs.flush()
channel.truncate(lastValidPosition)
}
}
override def write(value: Any) {
if (!initialized) {
open()
}
objOut.writeObject(value)
}
override def fileSegment(): FileSegment = {
val bytesWritten = lastValidPosition - initialPosition
new FileSegment(file, initialPosition, bytesWritten)
}
// Only valid if called after close()
override def timeWriting() = _timeWriting
}

View file

@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import java.io.File
import java.text.SimpleDateFormat
import java.util.{Date, Random}
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.Logging
import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.network.netty.{PathResolver, ShuffleSender}
import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
/**
* Creates and maintains the logical mapping between logical blocks and physical on-disk
* locations. By default, one block is mapped to one file with a name given by its BlockId.
* However, it is also possible to have a block map to only a segment of a file, by calling
* mapBlockToFileSegment().
*
* @param rootDirs The directories to use for storing block files. Data will be hashed among these.
*/
private[spark] class DiskBlockManager(rootDirs: String) extends PathResolver with Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
// having really large inodes at the top level.
private val localDirs: Array[File] = createLocalDirs()
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
private var shuffleSender : ShuffleSender = null
// Stores only Blocks which have been specifically mapped to segments of files
// (rather than the default, which maps a Block to a whole file).
// This keeps our bookkeeping down, since the file system itself tracks the standalone Blocks.
private val blockToFileSegmentMap = new TimeStampedHashMap[BlockId, FileSegment]
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DISK_BLOCK_MANAGER, this.cleanup)
addShutdownHook()
/**
* Creates a logical mapping from the given BlockId to a segment of a file.
* This will cause any accesses of the logical BlockId to be directed to the specified
* physical location.
*/
def mapBlockToFileSegment(blockId: BlockId, fileSegment: FileSegment) {
blockToFileSegmentMap.put(blockId, fileSegment)
}
/**
* Returns the phyiscal file segment in which the given BlockId is located.
* If the BlockId has been mapped to a specific FileSegment, that will be returned.
* Otherwise, we assume the Block is mapped to a whole file identified by the BlockId directly.
*/
def getBlockLocation(blockId: BlockId): FileSegment = {
if (blockToFileSegmentMap.internalMap.containsKey(blockId)) {
blockToFileSegmentMap.get(blockId).get
} else {
val file = getFile(blockId.name)
new FileSegment(file, 0, file.length())
}
}
/**
* Simply returns a File to place the given Block into. This does not physically create the file.
* If filename is given, that file will be used. Otherwise, we will use the BlockId to get
* a unique filename.
*/
def createBlockFile(blockId: BlockId, filename: String = "", allowAppending: Boolean): File = {
val actualFilename = if (filename == "") blockId.name else filename
val file = getFile(actualFilename)
if (!allowAppending && file.exists()) {
throw new IllegalStateException(
"Attempted to create file that already exists: " + actualFilename)
}
file
}
private def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
// Create the subdirectory if it doesn't already exist
var subDir = subDirs(dirId)(subDirId)
if (subDir == null) {
subDir = subDirs(dirId).synchronized {
val old = subDirs(dirId)(subDirId)
if (old != null) {
old
} else {
val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
newDir.mkdir()
subDirs(dirId)(subDirId) = newDir
newDir
}
}
}
new File(subDir, filename)
}
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
rootDirs.split(",").map { rootDir =>
var foundLocalDir = false
var localDir: File = null
var localDirId: String = null
var tries = 0
val rand = new Random()
while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
tries += 1
try {
localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
localDir = new File(rootDir, "spark-local-" + localDirId)
if (!localDir.exists) {
foundLocalDir = localDir.mkdirs()
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
}
}
if (!foundLocalDir) {
logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
" attempts to create local dir in " + rootDir)
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
}
logInfo("Created local directory at " + localDir)
localDir
}
}
private def cleanup(cleanupTime: Long) {
blockToFileSegmentMap.clearOldValues(cleanupTime)
}
private def addShutdownHook() {
localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
localDirs.foreach { localDir =>
try {
if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
} catch {
case t: Throwable =>
logError("Exception while deleting local spark dir: " + localDir, t)
}
}
if (shuffleSender != null) {
shuffleSender.stop()
}
}
})
}
private[storage] def startShuffleBlockSender(port: Int): Int = {
shuffleSender = new ShuffleSender(port, this)
logInfo("Created ShuffleSender binding to port : " + shuffleSender.port)
shuffleSender.port
}
}

View file

@ -17,120 +17,25 @@
package org.apache.spark.storage
import java.io.{File, FileOutputStream, OutputStream, RandomAccessFile}
import java.io.{FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel
import java.nio.channels.FileChannel.MapMode
import java.util.{Random, Date}
import java.text.SimpleDateFormat
import scala.collection.mutable.ArrayBuffer
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
import org.apache.spark.executor.ExecutorExitCode
import org.apache.spark.serializer.{Serializer, SerializationStream}
import org.apache.spark.Logging
import org.apache.spark.network.netty.ShuffleSender
import org.apache.spark.network.netty.PathResolver
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
/**
* Stores BlockManager blocks on disk.
*/
private class DiskStore(blockManager: BlockManager, rootDirs: String)
private class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
class DiskBlockObjectWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
extends BlockObjectWriter(blockId) {
private val f: File = createFile(blockId /*, allowAppendExisting */)
// The file channel, used for repositioning / truncating the file.
private var channel: FileChannel = null
private var bs: OutputStream = null
private var objOut: SerializationStream = null
private var lastValidPosition = 0L
private var initialized = false
override def open(): DiskBlockObjectWriter = {
val fos = new FileOutputStream(f, true)
channel = fos.getChannel()
bs = blockManager.wrapForCompression(blockId, new FastBufferedOutputStream(fos, bufferSize))
objOut = serializer.newInstance().serializeStream(bs)
initialized = true
this
}
override def close() {
if (initialized) {
objOut.close()
channel = null
bs = null
objOut = null
}
// Invoke the close callback handler.
super.close()
}
override def isOpen: Boolean = objOut != null
// Flush the partial writes, and set valid length to be the length of the entire file.
// Return the number of bytes written for this commit.
override def commit(): Long = {
if (initialized) {
// NOTE: Flush the serializer first and then the compressed/buffered output stream
objOut.flush()
bs.flush()
val prevPos = lastValidPosition
lastValidPosition = channel.position()
lastValidPosition - prevPos
} else {
// lastValidPosition is zero if stream is uninitialized
lastValidPosition
}
}
override def revertPartialWrites() {
if (initialized) {
// Discard current writes. We do this by flushing the outstanding writes and
// truncate the file to the last valid position.
objOut.flush()
bs.flush()
channel.truncate(lastValidPosition)
}
}
override def write(value: Any) {
if (!initialized) {
open()
}
objOut.writeObject(value)
}
override def size(): Long = lastValidPosition
}
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
private val subDirsPerLocalDir = System.getProperty("spark.diskStore.subDirectories", "64").toInt
private var shuffleSender : ShuffleSender = null
// Create one local directory for each path mentioned in spark.local.dir; then, inside this
// directory, create multiple subdirectories that we will hash files into, in order to avoid
// having really large inodes at the top level.
private val localDirs: Array[File] = createLocalDirs()
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
addShutdownHook()
def getBlockWriter(blockId: BlockId, serializer: Serializer, bufferSize: Int)
: BlockObjectWriter = {
new DiskBlockObjectWriter(blockId, serializer, bufferSize)
}
override def getSize(blockId: BlockId): Long = {
getFile(blockId).length()
diskManager.getBlockLocation(blockId).length
}
override def putBytes(blockId: BlockId, _bytes: ByteBuffer, level: StorageLevel) {
@ -139,27 +44,15 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
val bytes = _bytes.duplicate()
logDebug("Attempting to put block " + blockId)
val startTime = System.currentTimeMillis
val file = createFile(blockId)
val channel = new RandomAccessFile(file, "rw").getChannel()
val file = diskManager.createBlockFile(blockId, allowAppending = false)
val channel = new FileOutputStream(file).getChannel()
while (bytes.remaining > 0) {
channel.write(bytes)
}
channel.close()
val finishTime = System.currentTimeMillis
logDebug("Block %s stored as %s file on disk in %d ms".format(
blockId, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
private def getFileBytes(file: File): ByteBuffer = {
val length = file.length()
val channel = new RandomAccessFile(file, "r").getChannel()
val buffer = try {
channel.map(MapMode.READ_ONLY, 0, length)
} finally {
channel.close()
}
buffer
file.getName, Utils.bytesToString(bytes.limit), (finishTime - startTime)))
}
override def putValues(
@ -171,21 +64,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
logDebug("Attempting to write values for block " + blockId)
val startTime = System.currentTimeMillis
val file = createFile(blockId)
val fileOut = blockManager.wrapForCompression(blockId,
new FastBufferedOutputStream(new FileOutputStream(file)))
val objOut = blockManager.defaultSerializer.newInstance().serializeStream(fileOut)
objOut.writeAll(values.iterator)
objOut.close()
val length = file.length()
val file = diskManager.createBlockFile(blockId, allowAppending = false)
val outputStream = new FileOutputStream(file)
blockManager.dataSerializeStream(blockId, outputStream, values.iterator)
val length = file.length
val timeTaken = System.currentTimeMillis - startTime
logDebug("Block %s stored as %s file on disk in %d ms".format(
blockId, Utils.bytesToString(length), timeTaken))
file.getName, Utils.bytesToString(length), timeTaken))
if (returnValues) {
// Return a byte buffer for the contents of the file
val buffer = getFileBytes(file)
val buffer = getBytes(blockId).get
PutResult(length, Right(buffer))
} else {
PutResult(length, null)
@ -193,13 +83,18 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def getBytes(blockId: BlockId): Option[ByteBuffer] = {
val file = getFile(blockId)
val bytes = getFileBytes(file)
Some(bytes)
val segment = diskManager.getBlockLocation(blockId)
val channel = new RandomAccessFile(segment.file, "r").getChannel()
val buffer = try {
channel.map(MapMode.READ_ONLY, segment.offset, segment.length)
} finally {
channel.close()
}
Some(buffer)
}
override def getValues(blockId: BlockId): Option[Iterator[Any]] = {
getBytes(blockId).map(bytes => blockManager.dataDeserialize(blockId, bytes))
getBytes(blockId).map(buffer => blockManager.dataDeserialize(blockId, buffer))
}
/**
@ -211,118 +106,20 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String)
}
override def remove(blockId: BlockId): Boolean = {
val file = getFile(blockId)
if (file.exists()) {
val fileSegment = diskManager.getBlockLocation(blockId)
val file = fileSegment.file
if (file.exists() && file.length() == fileSegment.length) {
file.delete()
} else {
if (fileSegment.length < file.length()) {
logWarning("Could not delete block associated with only a part of a file: " + blockId)
}
false
}
}
override def contains(blockId: BlockId): Boolean = {
getFile(blockId).exists()
}
private def createFile(blockId: BlockId, allowAppendExisting: Boolean = false): File = {
val file = getFile(blockId)
if (!allowAppendExisting && file.exists()) {
// NOTE(shivaram): Delete the file if it exists. This might happen if a ShuffleMap task
// was rescheduled on the same machine as the old task.
logWarning("File for block " + blockId + " already exists on disk: " + file + ". Deleting")
file.delete()
}
file
}
private def getFile(blockId: BlockId): File = {
logDebug("Getting file for block " + blockId)
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(blockId)
val dirId = hash % localDirs.length
val subDirId = (hash / localDirs.length) % subDirsPerLocalDir
// Create the subdirectory if it doesn't already exist
var subDir = subDirs(dirId)(subDirId)
if (subDir == null) {
subDir = subDirs(dirId).synchronized {
val old = subDirs(dirId)(subDirId)
if (old != null) {
old
} else {
val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
newDir.mkdir()
subDirs(dirId)(subDirId) = newDir
newDir
}
}
}
new File(subDir, blockId.name)
}
private def createLocalDirs(): Array[File] = {
logDebug("Creating local directories at root dirs '" + rootDirs + "'")
val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
rootDirs.split(",").map { rootDir =>
var foundLocalDir = false
var localDir: File = null
var localDirId: String = null
var tries = 0
val rand = new Random()
while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
tries += 1
try {
localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
localDir = new File(rootDir, "spark-local-" + localDirId)
if (!localDir.exists) {
foundLocalDir = localDir.mkdirs()
}
} catch {
case e: Exception =>
logWarning("Attempt " + tries + " to create local dir " + localDir + " failed", e)
}
}
if (!foundLocalDir) {
logError("Failed " + MAX_DIR_CREATION_ATTEMPTS +
" attempts to create local dir in " + rootDir)
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
}
logInfo("Created local directory at " + localDir)
localDir
}
}
private def addShutdownHook() {
localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run() {
logDebug("Shutdown hook called")
localDirs.foreach { localDir =>
try {
if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
} catch {
case t: Throwable =>
logError("Exception while deleting local spark dir: " + localDir, t)
}
}
if (shuffleSender != null) {
shuffleSender.stop()
}
}
})
}
private[storage] def startShuffleBlockSender(port: Int): Int = {
val pResolver = new PathResolver {
override def getAbsolutePath(blockIdString: String): String = {
val blockId = BlockId(blockIdString)
if (!blockId.isShuffle) null
else DiskStore.this.getFile(blockId).getAbsolutePath
}
}
shuffleSender = new ShuffleSender(port, pResolver)
logInfo("Created ShuffleSender binding to port : "+ shuffleSender.port)
shuffleSender.port
val file = diskManager.getBlockLocation(blockId).file
file.exists()
}
}

View file

@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.storage
import java.io.File
/**
* References a particular segment of a file (potentially the entire file),
* based off an offset and a length.
*/
private[spark] class FileSegment(val file: File, val offset: Long, val length : Long) {
override def toString = "(name=%s, offset=%d, length=%d)".format(file.getName, offset, length)
}

View file

@ -17,12 +17,13 @@
package org.apache.spark.storage
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
import org.apache.spark.serializer.Serializer
private[spark]
class ShuffleWriterGroup(val id: Int, val writers: Array[BlockObjectWriter])
class ShuffleWriterGroup(val id: Int, val fileId: Int, val writers: Array[BlockObjectWriter])
private[spark]
trait ShuffleBlocks {
@ -30,24 +31,66 @@ trait ShuffleBlocks {
def releaseWriters(group: ShuffleWriterGroup)
}
/**
* Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one writer
* per reducer.
*
* As an optimization to reduce the number of physical shuffle files produced, multiple shuffle
* blocks are aggregated into the same file. There is one "combined shuffle file" per reducer
* per concurrently executing shuffle task. As soon as a task finishes writing to its shuffle files,
* it releases them for another task.
* Regarding the implementation of this feature, shuffle files are identified by a 3-tuple:
* - shuffleId: The unique id given to the entire shuffle stage.
* - bucketId: The id of the output partition (i.e., reducer id)
* - fileId: The unique id identifying a group of "combined shuffle files." Only one task at a
* time owns a particular fileId, and this id is returned to a pool when the task finishes.
*/
private[spark]
class ShuffleBlockManager(blockManager: BlockManager) {
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
// TODO: Remove this once the shuffle file consolidation feature is stable.
val consolidateShuffleFiles =
System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer): ShuffleBlocks = {
var nextFileId = new AtomicInteger(0)
val unusedFileIds = new ConcurrentLinkedQueue[java.lang.Integer]()
def forShuffle(shuffleId: Int, numBuckets: Int, serializer: Serializer) = {
new ShuffleBlocks {
// Get a group of writers for a map task.
override def acquireWriters(mapId: Int): ShuffleWriterGroup = {
val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
val fileId = getUnusedFileId()
val writers = Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
blockManager.getDiskBlockWriter(blockId, serializer, bufferSize)
if (consolidateShuffleFiles) {
val filename = physicalFileName(shuffleId, bucketId, fileId)
blockManager.getDiskWriter(blockId, filename, serializer, bufferSize)
} else {
blockManager.getDiskWriter(blockId, blockId.name, serializer, bufferSize)
}
}
new ShuffleWriterGroup(mapId, writers)
new ShuffleWriterGroup(mapId, fileId, writers)
}
override def releaseWriters(group: ShuffleWriterGroup) = {
// Nothing really to release here.
override def releaseWriters(group: ShuffleWriterGroup) {
recycleFileId(group.fileId)
}
}
}
private def getUnusedFileId(): Int = {
val fileId = unusedFileIds.poll()
if (fileId == null) nextFileId.getAndIncrement() else fileId
}
private def recycleFileId(fileId: Int) {
if (consolidateShuffleFiles) {
unusedFileIds.add(fileId)
}
}
private def physicalFileName(shuffleId: Int, bucketId: Int, fileId: Int) = {
"merged_shuffle_%d_%d_%d".format(shuffleId, bucketId, fileId)
}
}

View file

@ -0,0 +1,86 @@
package org.apache.spark.storage
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.{CountDownLatch, Executors}
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.SparkContext
import org.apache.spark.util.Utils
/**
* Utility for micro-benchmarking shuffle write performance.
*
* Writes simulated shuffle output from several threads and records the observed throughput.
*/
object StoragePerfTester {
def main(args: Array[String]) = {
/** Total amount of data to generate. Distributed evenly amongst maps and reduce splits. */
val dataSizeMb = Utils.memoryStringToMb(sys.env.getOrElse("OUTPUT_DATA", "1g"))
/** Number of map tasks. All tasks execute concurrently. */
val numMaps = sys.env.get("NUM_MAPS").map(_.toInt).getOrElse(8)
/** Number of reduce splits for each map task. */
val numOutputSplits = sys.env.get("NUM_REDUCERS").map(_.toInt).getOrElse(500)
val recordLength = 1000 // ~1KB records
val totalRecords = dataSizeMb * 1000
val recordsPerMap = totalRecords / numMaps
val writeData = "1" * recordLength
val executor = Executors.newFixedThreadPool(numMaps)
System.setProperty("spark.shuffle.compress", "false")
System.setProperty("spark.shuffle.sync", "true")
// This is only used to instantiate a BlockManager. All thread scheduling is done manually.
val sc = new SparkContext("local[4]", "Write Tester")
val blockManager = sc.env.blockManager
def writeOutputBytes(mapId: Int, total: AtomicLong) = {
val shuffle = blockManager.shuffleBlockManager.forShuffle(1, numOutputSplits,
new KryoSerializer())
val buckets = shuffle.acquireWriters(mapId)
for (i <- 1 to recordsPerMap) {
buckets.writers(i % numOutputSplits).write(writeData)
}
buckets.writers.map {w =>
w.commit()
total.addAndGet(w.fileSegment().length)
w.close()
}
shuffle.releaseWriters(buckets)
}
val start = System.currentTimeMillis()
val latch = new CountDownLatch(numMaps)
val totalBytes = new AtomicLong()
for (task <- 1 to numMaps) {
executor.submit(new Runnable() {
override def run() = {
try {
writeOutputBytes(task, totalBytes)
latch.countDown()
} catch {
case e: Exception =>
println("Exception in child thread: " + e + " " + e.getMessage)
System.exit(1)
}
}
})
}
latch.await()
val end = System.currentTimeMillis()
val time = (end - start) / 1000.0
val bytesPerSecond = totalBytes.get() / time
val bytesPerFile = (totalBytes.get() / (numOutputSplits * numMaps.toDouble)).toLong
System.err.println("files_total\t\t%s".format(numMaps * numOutputSplits))
System.err.println("bytes_per_file\t\t%s".format(Utils.bytesToString(bytesPerFile)))
System.err.println("agg_throughput\t\t%s/s".format(Utils.bytesToString(bytesPerSecond.toLong)))
executor.shutdown()
sc.stop()
}
}

View file

@ -38,7 +38,7 @@ private[spark] class IndexPage(parent: JobProgressUI) {
val now = System.currentTimeMillis()
var activeTime = 0L
for (tasks <- listener.stageToTasksActive.values; t <- tasks) {
for (tasks <- listener.stageIdToTasksActive.values; t <- tasks) {
activeTime += t.timeRunning(now)
}

View file

@ -36,52 +36,52 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val RETAINED_STAGES = System.getProperty("spark.ui.retained_stages", "1000").toInt
val DEFAULT_POOL_NAME = "default"
val stageToPool = new HashMap[Stage, String]()
val stageToDescription = new HashMap[Stage, String]()
val poolToActiveStages = new HashMap[String, HashSet[Stage]]()
val stageIdToPool = new HashMap[Int, String]()
val stageIdToDescription = new HashMap[Int, String]()
val poolToActiveStages = new HashMap[String, HashSet[StageInfo]]()
val activeStages = HashSet[Stage]()
val completedStages = ListBuffer[Stage]()
val failedStages = ListBuffer[Stage]()
val activeStages = HashSet[StageInfo]()
val completedStages = ListBuffer[StageInfo]()
val failedStages = ListBuffer[StageInfo]()
// Total metrics reflect metrics only for completed tasks
var totalTime = 0L
var totalShuffleRead = 0L
var totalShuffleWrite = 0L
val stageToTime = HashMap[Int, Long]()
val stageToShuffleRead = HashMap[Int, Long]()
val stageToShuffleWrite = HashMap[Int, Long]()
val stageToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
val stageToTasksComplete = HashMap[Int, Int]()
val stageToTasksFailed = HashMap[Int, Int]()
val stageToTaskInfos =
val stageIdToTime = HashMap[Int, Long]()
val stageIdToShuffleRead = HashMap[Int, Long]()
val stageIdToShuffleWrite = HashMap[Int, Long]()
val stageIdToTasksActive = HashMap[Int, HashSet[TaskInfo]]()
val stageIdToTasksComplete = HashMap[Int, Int]()
val stageIdToTasksFailed = HashMap[Int, Int]()
val stageIdToTaskInfos =
HashMap[Int, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]]()
override def onJobStart(jobStart: SparkListenerJobStart) {}
override def onStageCompleted(stageCompleted: StageCompleted) = synchronized {
val stage = stageCompleted.stageInfo.stage
poolToActiveStages(stageToPool(stage)) -= stage
val stage = stageCompleted.stage
poolToActiveStages(stageIdToPool(stage.stageId)) -= stage
activeStages -= stage
completedStages += stage
trimIfNecessary(completedStages)
}
/** If stages is too large, remove and garbage collect old stages */
def trimIfNecessary(stages: ListBuffer[Stage]) = synchronized {
def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
if (stages.size > RETAINED_STAGES) {
val toRemove = RETAINED_STAGES / 10
stages.takeRight(toRemove).foreach( s => {
stageToTaskInfos.remove(s.id)
stageToTime.remove(s.id)
stageToShuffleRead.remove(s.id)
stageToShuffleWrite.remove(s.id)
stageToTasksActive.remove(s.id)
stageToTasksComplete.remove(s.id)
stageToTasksFailed.remove(s.id)
stageToPool.remove(s)
if (stageToDescription.contains(s)) {stageToDescription.remove(s)}
stageIdToTaskInfos.remove(s.stageId)
stageIdToTime.remove(s.stageId)
stageIdToShuffleRead.remove(s.stageId)
stageIdToShuffleWrite.remove(s.stageId)
stageIdToTasksActive.remove(s.stageId)
stageIdToTasksComplete.remove(s.stageId)
stageIdToTasksFailed.remove(s.stageId)
stageIdToPool.remove(s.stageId)
if (stageIdToDescription.contains(s.stageId)) {stageIdToDescription.remove(s.stageId)}
})
stages.trimEnd(toRemove)
}
@ -95,63 +95,69 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
val poolName = Option(stageSubmitted.properties).map {
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
}.getOrElse(DEFAULT_POOL_NAME)
stageToPool(stage) = poolName
stageIdToPool(stage.stageId) = poolName
val description = Option(stageSubmitted.properties).flatMap {
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
}
description.map(d => stageToDescription(stage) = d)
description.map(d => stageIdToDescription(stage.stageId) = d)
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[Stage]())
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]())
stages += stage
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
val sid = taskStart.task.stageId
val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive += taskStart.taskInfo
val taskList = stageToTaskInfos.getOrElse(
val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList += ((taskStart.taskInfo, None, None))
stageToTaskInfos(sid) = taskList
stageIdToTaskInfos(sid) = taskList
}
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult)
= synchronized {
// Do nothing: because we don't do a deep copy of the TaskInfo, the TaskInfo in
// stageToTaskInfos already has the updated status.
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
val sid = taskEnd.task.stageId
val tasksActive = stageToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]())
tasksActive -= taskEnd.taskInfo
val (failureInfo, metrics): (Option[ExceptionFailure], Option[TaskMetrics]) =
taskEnd.reason match {
case e: ExceptionFailure =>
stageToTasksFailed(sid) = stageToTasksFailed.getOrElse(sid, 0) + 1
stageIdToTasksFailed(sid) = stageIdToTasksFailed.getOrElse(sid, 0) + 1
(Some(e), e.metrics)
case _ =>
stageToTasksComplete(sid) = stageToTasksComplete.getOrElse(sid, 0) + 1
stageIdToTasksComplete(sid) = stageIdToTasksComplete.getOrElse(sid, 0) + 1
(None, Option(taskEnd.taskMetrics))
}
stageToTime.getOrElseUpdate(sid, 0L)
stageIdToTime.getOrElseUpdate(sid, 0L)
val time = metrics.map(m => m.executorRunTime).getOrElse(0)
stageToTime(sid) += time
stageIdToTime(sid) += time
totalTime += time
stageToShuffleRead.getOrElseUpdate(sid, 0L)
stageIdToShuffleRead.getOrElseUpdate(sid, 0L)
val shuffleRead = metrics.flatMap(m => m.shuffleReadMetrics).map(s =>
s.remoteBytesRead).getOrElse(0L)
stageToShuffleRead(sid) += shuffleRead
stageIdToShuffleRead(sid) += shuffleRead
totalShuffleRead += shuffleRead
stageToShuffleWrite.getOrElseUpdate(sid, 0L)
stageIdToShuffleWrite.getOrElseUpdate(sid, 0L)
val shuffleWrite = metrics.flatMap(m => m.shuffleWriteMetrics).map(s =>
s.shuffleBytesWritten).getOrElse(0L)
stageToShuffleWrite(sid) += shuffleWrite
stageIdToShuffleWrite(sid) += shuffleWrite
totalShuffleWrite += shuffleWrite
val taskList = stageToTaskInfos.getOrElse(
val taskList = stageIdToTaskInfos.getOrElse(
sid, HashSet[(TaskInfo, Option[TaskMetrics], Option[ExceptionFailure])]())
taskList -= ((taskEnd.taskInfo, None, None))
taskList += ((taskEnd.taskInfo, metrics, failureInfo))
stageToTaskInfos(sid) = taskList
stageIdToTaskInfos(sid) = taskList
}
override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
@ -159,10 +165,15 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList
case end: SparkListenerJobEnd =>
end.jobResult match {
case JobFailed(ex, Some(stage)) =>
activeStages -= stage
poolToActiveStages(stageToPool(stage)) -= stage
failedStages += stage
trimIfNecessary(failedStages)
/* If two jobs share a stage we could get this failure message twice. So we first
* check whether we've already retired this stage. */
val stageInfo = activeStages.filter(s => s.stageId == stage.id).headOption
stageInfo.foreach {s =>
activeStages -= s
poolToActiveStages(stageIdToPool(stage.id)) -= s
failedStages += s
trimIfNecessary(failedStages)
}
case _ =>
}
case _ =>

View file

@ -21,13 +21,13 @@ import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.xml.Node
import org.apache.spark.scheduler.{Schedulable, Stage}
import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressListener) {
var poolToActiveStages: HashMap[String, HashSet[Stage]] = listener.poolToActiveStages
var poolToActiveStages: HashMap[String, HashSet[StageInfo]] = listener.poolToActiveStages
def toNodeSeq(): Seq[Node] = {
listener.synchronized {
@ -35,7 +35,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
}
}
private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[Stage]]) => Seq[Node],
private def poolTable(makeRow: (Schedulable, HashMap[String, HashSet[StageInfo]]) => Seq[Node],
rows: Seq[Schedulable]
): Seq[Node] = {
<table class="table table-bordered table-striped table-condensed sortable table-fixed">
@ -53,7 +53,7 @@ private[spark] class PoolTable(pools: Seq[Schedulable], listener: JobProgressLis
</table>
}
private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[Stage]])
private def poolRow(p: Schedulable, poolToActiveStages: HashMap[String, HashSet[StageInfo]])
: Seq[Node] = {
val activeStages = poolToActiveStages.get(p.name) match {
case Some(stages) => stages.size

View file

@ -40,7 +40,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val stageId = request.getParameter("id").toInt
val now = System.currentTimeMillis()
if (!listener.stageToTaskInfos.contains(stageId)) {
if (!listener.stageIdToTaskInfos.contains(stageId)) {
val content =
<div>
<h4>Summary Metrics</h4> No tasks have started yet
@ -49,23 +49,23 @@ private[spark] class StagePage(parent: JobProgressUI) {
return headerSparkPage(content, parent.sc, "Details for Stage %s".format(stageId), Stages)
}
val tasks = listener.stageToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
val tasks = listener.stageIdToTaskInfos(stageId).toSeq.sortBy(_._1.launchTime)
val numCompleted = tasks.count(_._1.finished)
val shuffleReadBytes = listener.stageToShuffleRead.getOrElse(stageId, 0L)
val shuffleReadBytes = listener.stageIdToShuffleRead.getOrElse(stageId, 0L)
val hasShuffleRead = shuffleReadBytes > 0
val shuffleWriteBytes = listener.stageToShuffleWrite.getOrElse(stageId, 0L)
val shuffleWriteBytes = listener.stageIdToShuffleWrite.getOrElse(stageId, 0L)
val hasShuffleWrite = shuffleWriteBytes > 0
var activeTime = 0L
listener.stageToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
val summary =
<div>
<ul class="unstyled">
<li>
<strong>CPU time: </strong>
{parent.formatDuration(listener.stageToTime.getOrElse(stageId, 0L) + activeTime)}
{parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
</li>
{if (hasShuffleRead)
<li>
@ -83,10 +83,10 @@ private[spark] class StagePage(parent: JobProgressUI) {
</div>
val taskHeaders: Seq[String] =
Seq("Task ID", "Status", "Locality Level", "Executor", "Launch Time", "Duration") ++
Seq("GC Time") ++
Seq("Task Index", "Task ID", "Status", "Locality Level", "Executor", "Launch Time") ++
Seq("Duration", "GC Time") ++
{if (hasShuffleRead) Seq("Shuffle Read") else Nil} ++
{if (hasShuffleWrite) Seq("Shuffle Write") else Nil} ++
{if (hasShuffleWrite) Seq("Write Time", "Shuffle Write") else Nil} ++
Seq("Errors")
val taskTable = listingTable(taskHeaders, taskRow(hasShuffleRead, hasShuffleWrite), tasks)
@ -153,6 +153,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
<tr>
<td>{info.index}</td>
<td>{info.taskId}</td>
<td>{info.status}</td>
<td>{info.taskLocality}</td>
@ -169,6 +170,8 @@ private[spark] class StagePage(parent: JobProgressUI) {
Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")}</td>
}}
{if (shuffleWrite) {
<td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}</td>
<td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")}</td>
}}

View file

@ -22,13 +22,13 @@ import java.util.Date
import scala.xml.Node
import scala.collection.mutable.HashSet
import org.apache.spark.scheduler.{SchedulingMode, Stage, TaskInfo}
import org.apache.spark.scheduler.{SchedulingMode, StageInfo, TaskInfo}
import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
/** Page showing list of all ongoing and recently finished stages */
private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressUI) {
private[spark] class StageTable(val stages: Seq[StageInfo], val parent: JobProgressUI) {
val listener = parent.listener
val dateFmt = parent.dateFmt
@ -73,40 +73,40 @@ private[spark] class StageTable(val stages: Seq[Stage], val parent: JobProgressU
}
private def stageRow(s: Stage): Seq[Node] = {
private def stageRow(s: StageInfo): Seq[Node] = {
val submissionTime = s.submissionTime match {
case Some(t) => dateFmt.format(new Date(t))
case None => "Unknown"
}
val shuffleRead = listener.stageToShuffleRead.getOrElse(s.id, 0L) match {
val shuffleRead = listener.stageIdToShuffleRead.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
val shuffleWrite = listener.stageToShuffleWrite.getOrElse(s.id, 0L) match {
val shuffleWrite = listener.stageIdToShuffleWrite.getOrElse(s.stageId, 0L) match {
case 0 => ""
case b => Utils.bytesToString(b)
}
val startedTasks = listener.stageToTasksActive.getOrElse(s.id, HashSet[TaskInfo]()).size
val completedTasks = listener.stageToTasksComplete.getOrElse(s.id, 0)
val failedTasks = listener.stageToTasksFailed.getOrElse(s.id, 0) match {
val startedTasks = listener.stageIdToTasksActive.getOrElse(s.stageId, HashSet[TaskInfo]()).size
val completedTasks = listener.stageIdToTasksComplete.getOrElse(s.stageId, 0)
val failedTasks = listener.stageIdToTasksFailed.getOrElse(s.stageId, 0) match {
case f if f > 0 => "(%s failed)".format(f)
case _ => ""
}
val totalTasks = s.numPartitions
val totalTasks = s.numTasks
val poolName = listener.stageToPool.get(s)
val poolName = listener.stageIdToPool.get(s.stageId)
val nameLink =
<a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.id)}>{s.name}</a>
val description = listener.stageToDescription.get(s)
<a href={"%s/stages/stage?id=%s".format(UIUtils.prependBaseUri(),s.stageId)}>{s.name}</a>
val description = listener.stageIdToDescription.get(s.stageId)
.map(d => <div><em>{d}</em></div><div>{nameLink}</div>).getOrElse(nameLink)
val finishTime = s.completionTime.getOrElse(System.currentTimeMillis())
val duration = s.submissionTime.map(t => finishTime - t)
<tr>
<td>{s.id}</td>
<td>{s.stageId}</td>
{if (isFairScheduler) {
<td><a href={"%s/stages/pool?poolname=%s".format(UIUtils.prependBaseUri(),poolName.get)}>
{poolName.get}</a></td>}

View file

@ -56,9 +56,10 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea
}
object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
"ShuffleMapTask", "BlockManager", "BroadcastVars") {
"ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, BROADCAST_VARS = Value
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
SHUFFLE_MAP_TASK, BLOCK_MANAGER, DISK_BLOCK_MANAGER, BROADCAST_VARS = Value
type MetadataCleanerType = Value

View file

@ -20,8 +20,14 @@ package org.apache.spark
import org.scalatest.FunSuite
class BroadcastSuite extends FunSuite with LocalSparkContext {
test("basic broadcast") {
override def afterEach() {
super.afterEach()
System.clearProperty("spark.broadcast.factory")
}
test("Using HttpBroadcast locally") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
@ -29,11 +35,51 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(results.collect.toSet === Set((1, 10), (2, 10)))
}
test("broadcast variables accessed in multiple threads") {
test("Accessing HttpBroadcast variables from multiple threads") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
}
test("Accessing HttpBroadcast variables in a local cluster") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory")
val numSlaves = 4
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}
test("Using TorrentBroadcast locally") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === Set((1, 10), (2, 10)))
}
test("Accessing TorrentBroadcast variables from multiple threads") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
sc = new SparkContext("local[10]", "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet)
}
test("Accessing TorrentBroadcast variables in a local cluster") {
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory")
val numSlaves = 4
sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test")
val list = List(1, 2, 3, 4)
val listBroadcast = sc.broadcast(list)
val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum))
assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}
}

View file

@ -120,4 +120,20 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
}.collect()
assert(result.toSet === Set((1,2), (2,7), (3,121)))
}
test ("Dynamically adding JARS on a standalone cluster using local: URL") {
sc = new SparkContext("local-cluster[1,1,512]", "test")
val sampleJarFile = getClass.getClassLoader.getResource("uncommons-maths-1.2.2.jar").getFile()
sc.addJar(sampleJarFile.replace("file", "local"))
val testData = Array((1,1), (1,1), (2,1), (3,5), (2,3), (3,0))
val result = sc.parallelize(testData).reduceByKey { (x,y) =>
val fac = Thread.currentThread.getContextClassLoader()
.loadClass("org.uncommons.maths.Maths")
.getDeclaredMethod("factorial", classOf[Int])
val a = fac.invoke(null, x.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
val b = fac.invoke(null, y.asInstanceOf[java.lang.Integer]).asInstanceOf[Long].toInt
a + b
}.collect()
assert(result.toSet === Set((1,2), (2,7), (3,121)))
}
}

View file

@ -472,6 +472,27 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
}
@Test
public void repartition() {
// Shrinking number of partitions
JavaRDD<Integer> in1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 2);
JavaRDD<Integer> repartitioned1 = in1.repartition(4);
List<List<Integer>> result1 = repartitioned1.glom().collect();
Assert.assertEquals(4, result1.size());
for (List<Integer> l: result1) {
Assert.assertTrue(l.size() > 0);
}
// Growing number of partitions
JavaRDD<Integer> in2 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8), 4);
JavaRDD<Integer> repartitioned2 = in2.repartition(2);
List<List<Integer>> result2 = repartitioned2.glom().collect();
Assert.assertEquals(2, result2.size());
for (List<Integer> l: result2) {
Assert.assertTrue(l.size() > 0);
}
}
@Test
public void persist() {
JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));

View file

@ -19,6 +19,8 @@ package org.apache.spark
import java.util.concurrent.Semaphore
import scala.concurrent.Await
import scala.concurrent.duration.Duration
import scala.concurrent.future
import scala.concurrent.ExecutionContext.Implicits.global
@ -83,6 +85,36 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
assert(sc.parallelize(1 to 10, 2).count === 10)
}
test("job group") {
sc = new SparkContext("local[2]", "test")
// Add a listener to release the semaphore once any tasks are launched.
val sem = new Semaphore(0)
sc.dagScheduler.addSparkListener(new SparkListener {
override def onTaskStart(taskStart: SparkListenerTaskStart) {
sem.release()
}
})
// jobA is the one to be cancelled.
val jobA = future {
sc.setJobGroup("jobA", "this is a job to be cancelled")
sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.count()
}
sc.clearJobGroup()
val jobB = sc.parallelize(1 to 100, 2).countAsync()
// Block until both tasks of job A have started and cancel job A.
sem.acquire(2)
sc.cancelJobGroup("jobA")
val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) }
assert(e.getMessage contains "cancel")
// Once A is cancelled, job B should finish fairly quickly.
assert(jobB.get() === 100)
}
test("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
// sem2: make sure the first stage is not finished until cancel is issued

View file

@ -48,15 +48,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
val tracker = new MapOutputTrackerMaster()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.stop()
}
test("master register and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
val tracker = new MapOutputTrackerMaster()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@ -74,19 +74,17 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTracker()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
val tracker = new MapOutputTrackerMaster()
tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
val size10000 = MapOutputTracker.decompressSize(compressedSize10000)
tracker.registerMapOutput(10, 0, new MapStatus(BlockManagerId("a", "hostA", 1000, 0),
Array(compressedSize1000, compressedSize1000, compressedSize1000)))
tracker.registerMapOutput(10, 1, new MapStatus(BlockManagerId("b", "hostB", 1000, 0),
Array(compressedSize10000, compressedSize1000, compressedSize1000)))
// As if we had two simulatenous fetch failures
// As if we had two simultaneous fetch failures
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
tracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000, 0))
@ -102,9 +100,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.driver.port", boundPort.toString) // Will be cleared by LocalSparkContext
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
val masterTracker = new MapOutputTracker()
val masterTracker = new MapOutputTrackerMaster()
masterTracker.trackerActor = actorSystem.actorOf(
Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
val slaveTracker = new MapOutputTracker()

View file

@ -139,6 +139,26 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(rdd.union(emptyKv).collect().size === 2)
}
test("repartitioned RDDs") {
val data = sc.parallelize(1 to 1000, 10)
// Coalesce partitions
val repartitioned1 = data.repartition(2)
assert(repartitioned1.partitions.size == 2)
val partitions1 = repartitioned1.glom().collect()
assert(partitions1(0).length > 0)
assert(partitions1(1).length > 0)
assert(repartitioned1.collect().toSet === (1 to 1000).toSet)
// Split partitions
val repartitioned2 = data.repartition(20)
assert(repartitioned2.partitions.size == 20)
val partitions2 = repartitioned2.glom().collect()
assert(partitions2(0).length > 0)
assert(partitions2(19).length > 0)
assert(repartitioned2.collect().toSet === (1 to 1000).toSet)
}
test("coalesced RDDs") {
val data = sc.parallelize(1 to 10, 10)

View file

@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.scalatest.BeforeAndAfter
import org.apache.spark.LocalSparkContext
import org.apache.spark.MapOutputTracker
import org.apache.spark.MapOutputTrackerMaster
import org.apache.spark.SparkContext
import org.apache.spark.Partition
import org.apache.spark.TaskContext
@ -64,7 +64,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
override def defaultParallelism() = 2
}
var mapOutputTracker: MapOutputTracker = null
var mapOutputTracker: MapOutputTrackerMaster = null
var scheduler: DAGScheduler = null
/**
@ -99,7 +99,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
taskSets.clear()
cacheLocations.clear()
results.clear()
mapOutputTracker = new MapOutputTracker()
mapOutputTracker = new MapOutputTrackerMaster()
scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
override def runLocally(job: ActiveJob) {
// don't bother with the thread while unit testing

View file

@ -58,10 +58,13 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val parentRdd = makeRdd(4, Nil)
val shuffleDep = new ShuffleDependency(parentRdd, null)
val rootRdd = makeRdd(4, List(shuffleDep))
val shuffleMapStage = new Stage(1, parentRdd, Some(shuffleDep), Nil, jobID, None)
val rootStage = new Stage(0, rootRdd, None, List(shuffleMapStage), jobID, None)
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStage, 4, null))
val shuffleMapStage =
new Stage(1, parentRdd, parentRdd.partitions.size, Some(shuffleDep), Nil, jobID, None)
val rootStage =
new Stage(0, rootRdd, rootRdd.partitions.size, None, List(shuffleMapStage), jobID, None)
val rootStageInfo = new StageInfo(rootStage)
joblogger.onStageSubmitted(SparkListenerStageSubmitted(rootStageInfo, null))
joblogger.getRddNameTest(parentRdd) should be (parentRdd.getClass.getName)
parentRdd.setName("MyRDD")
joblogger.getRddNameTest(parentRdd) should be ("MyRDD")

View file

@ -17,16 +17,62 @@
package org.apache.spark.scheduler
import org.scalatest.FunSuite
import org.apache.spark.{SparkContext, LocalSparkContext}
import scala.collection.mutable
import scala.collection.mutable.{Buffer, HashSet}
import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.scalatest.matchers.ShouldMatchers
import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
with BeforeAndAfterAll {
/** Length of time to wait while draining listener events. */
val WAIT_TIMEOUT_MILLIS = 10000
override def afterAll {
System.clearProperty("spark.akka.frameSize")
}
test("basic creation of StageInfo") {
sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
val rdd2 = rdd1.map(x => x.toString)
rdd2.setName("Target RDD")
rdd2.count
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be {1}
val first = listener.stageInfos.head
first.rddName should be {"Target RDD"}
first.numTasks should be {4}
first.numPartitions should be {4}
first.submissionTime should be ('defined)
first.completionTime should be ('defined)
first.taskInfos.length should be {4}
}
test("StageInfo with fewer tasks than partitions") {
sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
val rdd1 = sc.parallelize(1 to 100, 4)
val rdd2 = rdd1.map(x => x.toString)
sc.runJob(rdd2, (items: Iterator[String]) => items.size, Seq(0, 1), true)
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be {1}
val first = listener.stageInfos.head
first.numTasks should be {2}
first.numPartitions should be {4}
}
test("local metrics") {
sc = new SparkContext("local[4]", "test")
sc = new SparkContext("local", "DAGSchedulerSuite")
val listener = new SaveStageInfo
sc.addSparkListener(listener)
sc.addSparkListener(new StatsReportListener)
@ -39,7 +85,6 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
val d = sc.parallelize(1 to 1e4.toInt, 64).map{i => w(i)}
d.count()
val WAIT_TIMEOUT_MILLIS = 10000
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
listener.stageInfos.size should be (1)
@ -64,7 +109,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.executorDeserializeTime.toLong},
stageInfo + " executorDeserializeTime")
if (stageInfo.stage.rdd.name == d4.name) {
if (stageInfo.rddName == d4.name) {
checkNonZeroAvg(
stageInfo.taskInfos.map{_._2.shuffleReadMetrics.get.fetchWaitTime},
stageInfo + " fetchWaitTime")
@ -72,11 +117,11 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
stageInfo.taskInfos.foreach { case (taskInfo, taskMetrics) =>
taskMetrics.resultSize should be > (0l)
if (isStage(stageInfo, Set(d2.name, d3.name), Set(d4.name))) {
if (stageInfo.rddName == d2.name || stageInfo.rddName == d3.name) {
taskMetrics.shuffleWriteMetrics should be ('defined)
taskMetrics.shuffleWriteMetrics.get.shuffleBytesWritten should be > (0l)
}
if (stageInfo.stage.rdd.name == d4.name) {
if (stageInfo.rddName == d4.name) {
taskMetrics.shuffleReadMetrics should be ('defined)
val sm = taskMetrics.shuffleReadMetrics.get
sm.totalBlocksFetched should be > (0)
@ -89,20 +134,73 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
}
}
test("onTaskGettingResult() called when result fetched remotely") {
// Need to use local cluster mode here, because results are not ever returned through the
// block manager when using the LocalScheduler.
sc = new SparkContext("local-cluster[1,1,512]", "test")
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
// Make a task whose result is larger than the akka frame size
System.setProperty("spark.akka.frameSize", "1")
val akkaFrameSize =
sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
val TASK_INDEX = 0
assert(listener.startedTasks.contains(TASK_INDEX))
assert(listener.startedGettingResultTasks.contains(TASK_INDEX))
assert(listener.endedTasks.contains(TASK_INDEX))
}
test("onTaskGettingResult() not called when result sent directly") {
// Need to use local cluster mode here, because results are not ever returned through the
// block manager when using the LocalScheduler.
sc = new SparkContext("local-cluster[1,1,512]", "test")
val listener = new SaveTaskEvents
sc.addSparkListener(listener)
// Make a task whose result is larger than the akka frame size
val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
assert(result === 2)
assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
val TASK_INDEX = 0
assert(listener.startedTasks.contains(TASK_INDEX))
assert(listener.startedGettingResultTasks.isEmpty == true)
assert(listener.endedTasks.contains(TASK_INDEX))
}
def checkNonZeroAvg(m: Traversable[Long], msg: String) {
assert(m.sum / m.size.toDouble > 0.0, msg)
}
def isStage(stageInfo: StageInfo, rddNames: Set[String], excludedNames: Set[String]) = {
val names = Set(stageInfo.stage.rdd.name) ++ stageInfo.stage.rdd.dependencies.map{_.rdd.name}
!names.intersect(rddNames).isEmpty && names.intersect(excludedNames).isEmpty
}
class SaveStageInfo extends SparkListener {
val stageInfos = mutable.Buffer[StageInfo]()
val stageInfos = Buffer[StageInfo]()
override def onStageCompleted(stage: StageCompleted) {
stageInfos += stage.stageInfo
stageInfos += stage.stage
}
}
class SaveTaskEvents extends SparkListener {
val startedTasks = new HashSet[Int]()
val startedGettingResultTasks = new HashSet[Int]()
val endedTasks = new HashSet[Int]()
override def onTaskStart(taskStart: SparkListenerTaskStart) {
startedTasks += taskStart.taskInfo.index
}
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
endedTasks += taskEnd.taskInfo.index
}
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
startedGettingResultTasks += taskGettingResult.taskInfo.index
}
}
}

View file

@ -1,10 +1,11 @@
Spark Docker files usable for testing and development purposes.
These images are intended to be run like so:
docker run -v $SPARK_HOME:/opt/spark spark-test-master
docker run -v $SPARK_HOME:/opt/spark spark-test-worker <master_ip>
docker run -v $SPARK_HOME:/opt/spark spark-test-master
docker run -v $SPARK_HOME:/opt/spark spark-test-worker spark://<master_ip>:7077
Using this configuration, the containers will have their Spark directories
mounted to your actual SPARK_HOME, allowing you to modify and recompile
mounted to your actual `SPARK_HOME`, allowing you to modify and recompile
your Spark source and have them immediately usable in the docker images
(without rebuilding them).

View file

@ -149,7 +149,7 @@ Apart from these, the following properties are also available, and may be useful
<td>spark.io.compression.codec</td>
<td>org.apache.spark.io.<br />LZFCompressionCodec</td>
<td>
The compression codec class to use for various compressions. By default, Spark provides two
The codec used to compress internal data such as RDD partitions and shuffle outputs. By default, Spark provides two
codecs: <code>org.apache.spark.io.LZFCompressionCodec</code> and <code>org.apache.spark.io.SnappyCompressionCodec</code>.
</td>
</tr>
@ -319,6 +319,14 @@ Apart from these, the following properties are also available, and may be useful
Should be greater than or equal to 1. Number of allowed retries = this value - 1.
</td>
</tr>
<tr>
<td>spark.broadcast.blockSize</td>
<td>4096</td>
<td>
Size of each piece of a block in kilobytes for <code>TorrentBroadcastFactory</code>.
Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, <code>BlockManager</code> might take a performance hit.
</td>
</tr>
</table>

View file

@ -131,6 +131,17 @@ sc = SparkContext("local", "App Name", pyFiles=['MyFile.py', 'lib.zip', 'app.egg
Files listed here will be added to the `PYTHONPATH` and shipped to remote worker machines.
Code dependencies can be added to an existing SparkContext using its `addPyFile()` method.
You can set [system properties](configuration.html#system-properties)
using `SparkContext.setSystemProperty()` class method *before*
instantiating SparkContext. For example, to set the amount of memory
per executor process:
{% highlight python %}
from pyspark import SparkContext
SparkContext.setSystemProperty('spark.executor.memory', '2g')
sc = SparkContext("local", "App Name")
{% endhighlight %}
# API Docs
[API documentation](api/pyspark/index.html) for PySpark is available as Epydoc.

View file

@ -142,7 +142,7 @@ All transformations in Spark are <i>lazy</i>, in that they do not compute their
By default, each transformed RDD is recomputed each time you run an action on it. However, you may also *persist* an RDD in memory using the `persist` (or `cache`) method, in which case Spark will keep the elements around on the cluster for much faster access the next time you query it. There is also support for persisting datasets on disk, or replicated across the cluster. The next section in this document describes these options.
The following tables list the transformations and actions currently supported (see also the [RDD API doc](api/core/index.html#org.apache.spark.RDD) for details):
The following tables list the transformations and actions currently supported (see also the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD) for details):
### Transformations
@ -211,7 +211,7 @@ The following tables list the transformations and actions currently supported (s
</tr>
</table>
A complete list of transformations is available in the [RDD API doc](api/core/index.html#org.apache.spark.RDD).
A complete list of transformations is available in the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD).
### Actions
@ -259,7 +259,7 @@ A complete list of transformations is available in the [RDD API doc](api/core/in
</tr>
</table>
A complete list of actions is available in the [RDD API doc](api/core/index.html#org.apache.spark.RDD).
A complete list of actions is available in the [RDD API doc](api/core/index.html#org.apache.spark.rdd.RDD).
## RDD Persistence

View file

@ -72,6 +72,10 @@ DStreams support many of the transformations available on normal Spark RDD's:
<td> Similar to map, but runs separately on each partition (block) of the DStream, so <i>func</i> must be of type
Iterator[T] => Iterator[U] when running on an DStream of type T. </td>
</tr>
<tr>
<td> <b>repartition</b>(<i>numPartitions</i>) </td>
<td> Changes the level of parallelism in this DStream by creating more or fewer partitions. </td>
</tr>
<tr>
<td> <b>union</b>(<i>otherStream</i>) </td>
<td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td>

View file

@ -32,13 +32,20 @@
<url>http://spark.incubator.apache.org/</url>
<repositories>
<!-- A repository in the local filesystem for the Kafka JAR, which we modified for Scala 2.9 -->
<repository>
<id>lib</id>
<url>file://${project.basedir}/lib</url>
<id>apache-repo</id>
<name>Apache Repository</name>
<url>https://repository.apache.org/content/repositories/releases</url>
<releases>
<enabled>true</enabled>
</releases>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
</repositories>
<dependencies>
<dependency>
<groupId>org.apache.spark</groupId>
@ -81,9 +88,18 @@
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka</artifactId>
<version>0.7.2-spark</version> <!-- Comes from our in-project repository -->
<scope>provided</scope>
<artifactId>kafka_2.9.2</artifactId>
<version>0.8.0-beta1</version>
<exclusions>
<exclusion>
<groupId>com.sun.jmx</groupId>
<artifactId>jmxri</artifactId>
</exclusion>
<exclusion>
<groupId>com.sun.jdmk</groupId>
<artifactId>jmxtools</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.eclipse.jetty</groupId>
@ -137,6 +153,14 @@
<groupId>org.apache.cassandra.deps</groupId>
<artifactId>avro</artifactId>
</exclusion>
<exclusion>
<groupId>org.sonatype.sisu.inject</groupId>
<artifactId>*</artifactId>
</exclusion>
<exclusion>
<groupId>org.xerial.snappy</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
</dependencies>

View file

@ -0,0 +1,98 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.examples;
import java.util.Map;
import java.util.HashMap;
import com.google.common.collect.Lists;
import org.apache.spark.api.java.function.FlatMapFunction;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.Function2;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.streaming.Duration;
import org.apache.spark.streaming.api.java.JavaDStream;
import org.apache.spark.streaming.api.java.JavaPairDStream;
import org.apache.spark.streaming.api.java.JavaStreamingContext;
import scala.Tuple2;
/**
* Consumes messages from one or more topics in Kafka and does wordcount.
* Usage: JavaKafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>
* <master> is the Spark master URL. In local mode, <master> should be 'local[n]' with n > 1.
* <zkQuorum> is a list of one or more zookeeper servers that make quorum
* <group> is the name of kafka consumer group
* <topics> is a list of one or more kafka topics to consume from
* <numThreads> is the number of threads the kafka consumer should use
*
* Example:
* `./run-example org.apache.spark.streaming.examples.JavaKafkaWordCount local[2] zoo01,zoo02,
* zoo03 my-consumer-group topic1,topic2 1`
*/
public class JavaKafkaWordCount {
public static void main(String[] args) {
if (args.length < 5) {
System.err.println("Usage: KafkaWordCount <master> <zkQuorum> <group> <topics> <numThreads>");
System.exit(1);
}
// Create the context with a 1 second batch size
JavaStreamingContext ssc = new JavaStreamingContext(args[0], "NetworkWordCount",
new Duration(2000), System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
int numThreads = Integer.parseInt(args[4]);
Map<String, Integer> topicMap = new HashMap<String, Integer>();
String[] topics = args[3].split(",");
for (String topic: topics) {
topicMap.put(topic, numThreads);
}
JavaPairDStream<String, String> messages = ssc.kafkaStream(args[1], args[2], topicMap);
JavaDStream<String> lines = messages.map(new Function<Tuple2<String, String>, String>() {
@Override
public String call(Tuple2<String, String> tuple2) throws Exception {
return tuple2._2();
}
});
JavaDStream<String> words = lines.flatMap(new FlatMapFunction<String, String>() {
@Override
public Iterable<String> call(String x) {
return Lists.newArrayList(x.split(" "));
}
});
JavaPairDStream<String, Integer> wordCounts = words.map(
new PairFunction<String, String, Integer>() {
@Override
public Tuple2<String, Integer> call(String s) throws Exception {
return new Tuple2<String, Integer>(s, 1);
}
}).reduceByKey(new Function2<Integer, Integer, Integer>() {
@Override
public Integer call(Integer i1, Integer i2) throws Exception {
return i1 + i2;
}
});
wordCounts.print();
ssc.start();
}
}

View file

@ -22,12 +22,19 @@ import org.apache.spark.SparkContext
object BroadcastTest {
def main(args: Array[String]) {
if (args.length == 0) {
System.err.println("Usage: BroadcastTest <master> [<slices>] [numElem]")
System.err.println("Usage: BroadcastTest <master> [slices] [numElem] [broadcastAlgo] [blockSize]")
System.exit(1)
}
val sc = new SparkContext(args(0), "Broadcast Test",
val bcName = if (args.length > 3) args(3) else "Http"
val blockSize = if (args.length > 4) args(4) else "4096"
System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast." + bcName + "BroadcastFactory")
System.setProperty("spark.broadcast.blockSize", blockSize)
val sc = new SparkContext(args(0), "Broadcast Test 2",
System.getenv("SPARK_HOME"), Seq(System.getenv("SPARK_EXAMPLES_JAR")))
val slices = if (args.length > 1) args(1).toInt else 2
val num = if (args.length > 2) args(2).toInt else 1000000
@ -36,13 +43,15 @@ object BroadcastTest {
arr1(i) = i
}
for (i <- 0 until 2) {
for (i <- 0 until 3) {
println("Iteration " + i)
println("===========")
val startTime = System.nanoTime
val barr1 = sc.broadcast(arr1)
sc.parallelize(1 to 10, slices).foreach {
i => println(barr1.value.size)
}
println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6))
}
System.exit(0)

View file

@ -38,6 +38,6 @@ object SparkPi {
if (x*x + y*y < 1) 1 else 0
}.reduce(_ + _)
println("Pi is roughly " + 4.0 * count / n)
System.exit(0)
spark.stop()
}
}

View file

@ -18,13 +18,11 @@
package org.apache.spark.streaming.examples
import java.util.Properties
import kafka.message.Message
import kafka.producer.SyncProducerConfig
import kafka.producer._
import org.apache.spark.SparkContext
import org.apache.spark.streaming._
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.util.RawTextHelper._
/**
@ -54,9 +52,10 @@ object KafkaWordCount {
ssc.checkpoint("checkpoint")
val topicpMap = topics.split(",").map((_,numThreads.toInt)).toMap
val lines = ssc.kafkaStream(zkQuorum, group, topicpMap)
val lines = ssc.kafkaStream(zkQuorum, group, topicpMap).map(_._2)
val words = lines.flatMap(_.split(" "))
val wordCounts = words.map(x => (x, 1l)).reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
val wordCounts = words.map(x => (x, 1l))
.reduceByKeyAndWindow(add _, subtract _, Minutes(10), Seconds(2), 2)
wordCounts.print()
ssc.start()
@ -68,15 +67,16 @@ object KafkaWordCountProducer {
def main(args: Array[String]) {
if (args.length < 2) {
System.err.println("Usage: KafkaWordCountProducer <zkQuorum> <topic> <messagesPerSec> <wordsPerMessage>")
System.err.println("Usage: KafkaWordCountProducer <metadataBrokerList> <topic> " +
"<messagesPerSec> <wordsPerMessage>")
System.exit(1)
}
val Array(zkQuorum, topic, messagesPerSec, wordsPerMessage) = args
val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args
// Zookeper connection properties
val props = new Properties()
props.put("zk.connect", zkQuorum)
props.put("metadata.broker.list", brokers)
props.put("serializer.class", "kafka.serializer.StringEncoder")
val config = new ProducerConfig(props)
@ -85,11 +85,13 @@ object KafkaWordCountProducer {
// Send some messages
while(true) {
val messages = (1 to messagesPerSec.toInt).map { messageNum =>
(1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString).mkString(" ")
val str = (1 to wordsPerMessage.toInt).map(x => scala.util.Random.nextInt(10).toString)
.mkString(" ")
new KeyedMessage[String, String](topic, str)
}.toArray
println(messages.mkString(","))
val data = new ProducerData[String, String](topic, messages)
producer.send(data)
producer.send(messages: _*)
Thread.sleep(100)
}
}

View file

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.streaming.examples
import org.apache.spark.streaming.{ Seconds, StreamingContext }
import org.apache.spark.streaming.StreamingContext._
import org.apache.spark.streaming.dstream.MQTTReceiver
import org.apache.spark.storage.StorageLevel
import org.eclipse.paho.client.mqttv3.MqttClient
import org.eclipse.paho.client.mqttv3.MqttClientPersistence
import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
import org.eclipse.paho.client.mqttv3.MqttException
import org.eclipse.paho.client.mqttv3.MqttMessage
import org.eclipse.paho.client.mqttv3.MqttTopic
/**
* A simple Mqtt publisher for demonstration purposes, repeatedly publishes
* Space separated String Message "hello mqtt demo for spark streaming"
*/
object MQTTPublisher {
var client: MqttClient = _
def main(args: Array[String]) {
if (args.length < 2) {
System.err.println("Usage: MQTTPublisher <MqttBrokerUrl> <topic>")
System.exit(1)
}
val Seq(brokerUrl, topic) = args.toSeq
try {
var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp")
client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance)
} catch {
case e: MqttException => println("Exception Caught: " + e)
}
client.connect()
val msgtopic: MqttTopic = client.getTopic(topic);
val msg: String = "hello mqtt demo for spark streaming"
while (true) {
val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes())
msgtopic.publish(message);
println("Published data. topic: " + msgtopic.getName() + " Message: " + message)
}
client.disconnect()
}
}
/**
* A sample wordcount with MqttStream stream
*
* To work with Mqtt, Mqtt Message broker/server required.
* Mosquitto (http://mosquitto.org/) is an open source Mqtt Broker
* In ubuntu mosquitto can be installed using the command `$ sudo apt-get install mosquitto`
* Eclipse paho project provides Java library for Mqtt Client http://www.eclipse.org/paho/
* Example Java code for Mqtt Publisher and Subscriber can be found here https://bitbucket.org/mkjinesh/mqttclient
* Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>
* In local mode, <master> should be 'local[n]' with n > 1
* <MqttbrokerUrl> and <topic> describe where Mqtt publisher is running.
*
* To run this example locally, you may run publisher as
* `$ ./run-example org.apache.spark.streaming.examples.MQTTPublisher tcp://localhost:1883 foo`
* and run the example as
* `$ ./run-example org.apache.spark.streaming.examples.MQTTWordCount local[2] tcp://localhost:1883 foo`
*/
object MQTTWordCount {
def main(args: Array[String]) {
if (args.length < 3) {
System.err.println(
"Usage: MQTTWordCount <master> <MqttbrokerUrl> <topic>" +
" In local mode, <master> should be 'local[n]' with n > 1")
System.exit(1)
}
val Seq(master, brokerUrl, topic) = args.toSeq
val ssc = new StreamingContext(master, "MqttWordCount", Seconds(2), System.getenv("SPARK_HOME"),
Seq(System.getenv("SPARK_EXAMPLES_JAR")))
val lines = ssc.mqttStream(brokerUrl, topic, StorageLevel.MEMORY_ONLY)
val words = lines.flatMap(x => x.toString.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
wordCounts.print()
ssc.start()
}
}

85
pom.xml
View file

@ -148,6 +148,17 @@
<enabled>false</enabled>
</snapshots>
</repository>
<repository>
<id>mqtt-repo</id>
<name>MQTT Repository</name>
<url>https://repo.eclipse.org/content/repositories/paho-releases/</url>
<releases>
<enabled>true</enabled>
</releases>
<snapshots>
<enabled>false</enabled>
</snapshots>
</repository>
</repositories>
<pluginRepositories>
<pluginRepository>
@ -251,16 +262,34 @@
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-actor</artifactId>
<version>${akka.version}</version>
<exclusions>
<exclusion>
<groupId>org.jboss.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-remote</artifactId>
<version>${akka.version}</version>
<exclusions>
<exclusion>
<groupId>org.jboss.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-slf4j</artifactId>
<version>${akka.version}</version>
<exclusions>
<exclusion>
<groupId>org.jboss.netty</groupId>
<artifactId>netty</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
@ -395,19 +424,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-core-asl</artifactId>
<artifactId>*</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-mapper-asl</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-jaxrs</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-xc</artifactId>
<groupId>org.sonatype.sisu.inject</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@ -431,19 +452,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-core-asl</artifactId>
<artifactId>*</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-mapper-asl</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-jaxrs</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-xc</artifactId>
<groupId>org.sonatype.sisu.inject</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@ -462,19 +475,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-core-asl</artifactId>
<artifactId>*</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-mapper-asl</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-jaxrs</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-xc</artifactId>
<groupId>org.sonatype.sisu.inject</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>
@ -493,19 +498,11 @@
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-core-asl</artifactId>
<artifactId>*</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-mapper-asl</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-jaxrs</artifactId>
</exclusion>
<exclusion>
<groupId>org.codehaus.jackson</groupId>
<artifactId>jackson-xc</artifactId>
<groupId>org.sonatype.sisu.inject</groupId>
<artifactId>*</artifactId>
</exclusion>
</exclusions>
</dependency>

View file

@ -62,6 +62,8 @@ object SparkBuild extends Build {
lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings)
.dependsOn(core, graph, bagel, mllib, repl, streaming) dependsOn(maybeYarn: _*)
lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects")
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
@ -76,8 +78,11 @@ object SparkBuild extends Build {
// Conditionally include the yarn sub-project
lazy val maybeYarn = if(isYarnEnabled) Seq[ClasspathDependency](yarn) else Seq[ClasspathDependency]()
lazy val maybeYarnRef = if(isYarnEnabled) Seq[ProjectReference](yarn) else Seq[ProjectReference]()
lazy val allProjects = Seq[ProjectReference](
core, repl, examples, graph, bagel, streaming, mllib, tools, assemblyProj) ++ maybeYarnRef
// Everything except assembly, tools and examples belong to packageProjects
lazy val packageProjects = Seq[ProjectReference](core, repl, bagel, streaming, mllib, graph) ++ maybeYarnRef
lazy val allProjects = packageProjects ++ Seq[ProjectReference](examples, tools, assemblyProj)
def sharedSettings = Defaults.defaultSettings ++ Seq(
organization := "org.apache.spark",
@ -105,7 +110,10 @@ object SparkBuild extends Build {
// Shared between both core and streaming.
resolvers ++= Seq("Akka Repository" at "http://repo.akka.io/releases/"),
// For Sonatype publishing
// Shared between both examples and streaming.
resolvers ++= Seq("Mqtt Repository" at "https://repo.eclipse.org/content/repositories/paho-releases/"),
// For Sonatype publishing
resolvers ++= Seq("sonatype-snapshots" at "https://oss.sonatype.org/content/repositories/snapshots",
"sonatype-staging" at "https://oss.sonatype.org/service/local/staging/deploy/maven2/"),
@ -279,13 +287,18 @@ object SparkBuild extends Build {
def streamingSettings = sharedSettings ++ Seq(
name := "spark-streaming",
resolvers ++= Seq(
"Akka Repository" at "http://repo.akka.io/releases/"
"Akka Repository" at "http://repo.akka.io/releases/",
"Apache repo" at "https://repository.apache.org/content/repositories/releases"
),
libraryDependencies ++= Seq(
"org.eclipse.paho" % "mqtt-client" % "0.4.0",
"org.apache.flume" % "flume-ng-sdk" % "1.2.0" % "compile" excludeAll(excludeNetty, excludeSnappy),
"com.github.sgroschupf" % "zkclient" % "0.1" excludeAll(excludeNetty),
"org.twitter4j" % "twitter4j-stream" % "3.0.3" excludeAll(excludeNetty),
"com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty)
"com.typesafe.akka" % "akka-zeromq" % "2.0.5" excludeAll(excludeNetty),
"org.apache.kafka" % "kafka_2.9.2" % "0.8.0-beta1"
exclude("com.sun.jdmk", "jmxtools")
exclude("com.sun.jmx", "jmxri")
exclude("net.sf.jopt-simple", "jopt-simple")
)
)
@ -309,7 +322,9 @@ object SparkBuild extends Build {
def assemblyProjSettings = sharedSettings ++ Seq(
name := "spark-assembly",
jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }
assembleDeps in Compile <<= (packageProjects.map(packageBin in Compile in _) ++ Seq(packageDependency in Compile)).dependOn,
jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" },
jarName in packageDependency <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" }
) ++ assemblySettings ++ extraAssemblySettings
def extraAssemblySettings() = Seq(

View file

@ -42,6 +42,13 @@
>>> a.value
13
>>> b = sc.accumulator(0)
>>> def g(x):
... b.add(x)
>>> rdd.foreach(g)
>>> b.value
6
>>> from pyspark.accumulators import AccumulatorParam
>>> class VectorAccumulatorParam(AccumulatorParam):
... def zero(self, value):
@ -139,9 +146,13 @@ class Accumulator(object):
raise Exception("Accumulator.value cannot be accessed inside tasks")
self._value = value
def add(self, term):
"""Adds a term to this accumulator's value"""
self._value = self.accum_param.addInPlace(self._value, term)
def __iadd__(self, term):
"""The += operator; adds a term to this accumulator's value"""
self._value = self.accum_param.addInPlace(self._value, term)
self.add(term)
return self
def __str__(self):

View file

@ -49,6 +49,7 @@ class SparkContext(object):
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
environment=None, batchSize=1024):
"""
@ -66,19 +67,18 @@ class SparkContext(object):
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
>>> from pyspark.context import SparkContext
>>> sc = SparkContext('local', 'test')
>>> sc2 = SparkContext('local', 'test2') # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
"""
with SparkContext._lock:
if SparkContext._active_spark_context:
raise ValueError("Cannot run multiple SparkContexts at once")
else:
SparkContext._active_spark_context = self
if not SparkContext._gateway:
SparkContext._gateway = launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeIteratorToPickleFile = \
SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
SparkContext._takePartition = \
SparkContext._jvm.PythonRDD.takePartition
SparkContext._ensure_initialized(self)
self.master = master
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
@ -119,6 +119,32 @@ class SparkContext(object):
self._temp_dir = \
self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
@classmethod
def _ensure_initialized(cls, instance=None):
with SparkContext._lock:
if not SparkContext._gateway:
SparkContext._gateway = launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeIteratorToPickleFile = \
SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
SparkContext._takePartition = \
SparkContext._jvm.PythonRDD.takePartition
if instance:
if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
raise ValueError("Cannot run multiple SparkContexts at once")
else:
SparkContext._active_spark_context = instance
@classmethod
def setSystemProperty(cls, key, value):
"""
Set a system property, such as spark.executor.memory. This must be
invoked before instantiating SparkContext.
"""
SparkContext._ensure_initialized()
SparkContext._jvm.java.lang.System.setProperty(key, value)
@property
def defaultParallelism(self):
"""

View file

@ -633,6 +633,20 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
Result(true, shouldReplay)
}
def addAllClasspath(args: Seq[String]): Unit = {
var added = false
var totalClasspath = ""
for (arg <- args) {
val f = File(arg).normalize
if (f.exists) {
added = true
addedClasspath = ClassPath.join(addedClasspath, f.path)
totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
}
}
if (added) replay()
}
def addClasspath(arg: String): Unit = {
val f = File(arg).normalize
if (f.exists) {
@ -845,7 +859,14 @@ class SparkILoop(in0: Option[BufferedReader], val out: PrintWriter, val master:
val jars = Option(System.getenv("ADD_JARS")).map(_.split(','))
.getOrElse(new Array[String](0))
.map(new java.io.File(_).getAbsolutePath)
sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
try {
sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars)
} catch {
case e: Exception =>
e.printStackTrace()
echo("Failed to create SparkContext, exiting...")
sys.exit(1)
}
sparkContext
}

View file

@ -95,10 +95,17 @@ export JAVA_OPTS
if [ ! -f "$FWDIR/RELEASE" ]; then
# Exit if the user hasn't compiled Spark
ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null
if [[ $? != 0 ]]; then
echo "Failed to find Spark assembly in $FWDIR/assembly/target" >&2
echo "You need to build Spark with sbt/sbt assembly before running this program" >&2
num_jars=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar" | wc -l)
jars_list=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar")
if [ "$num_jars" -eq "0" ]; then
echo "Failed to find Spark assembly in $FWDIR/assembly/target/scala-$SCALA_VERSION/" >&2
echo "You need to build Spark with 'sbt/sbt assembly' before running this program." >&2
exit 1
fi
if [ "$num_jars" -gt "1" ]; then
echo "Found multiple Spark assembly jars in $FWDIR/assembly/target/scala-$SCALA_VERSION:" >&2
echo "$jars_list"
echo "Please remove all but one jar."
exit 1
fi
fi

View file

@ -1 +0,0 @@
18876b8bc2e4cef28b6d191aa49d963f

View file

@ -1 +0,0 @@
06b27270ffa52250a2c08703b397c99127b72060

View file

@ -1,9 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd" xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
<modelVersion>4.0.0</modelVersion>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka</artifactId>
<version>0.7.2-spark</version>
<description>POM was created from install:install-file</description>
</project>

View file

@ -1 +0,0 @@
7bc4322266e6032bdf9ef6eebdd8097d

View file

@ -1 +0,0 @@
d0f79e8eff0db43ca7bcf7dce2c8cd2972685c9d

Some files were not shown because too many files have changed in this diff Show more