Added task killing iterator to RDDs that take inputs.

This commit is contained in:
Reynold Xin 2013-09-19 18:33:16 -07:00
parent f19984dafe
commit 70953810b4
7 changed files with 60 additions and 47 deletions

View file

@ -28,7 +28,11 @@ import org.apache.spark.util.CompletionIterator
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
override def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
{
@ -74,7 +78,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
val itr = blockFetcherItr.flatMap(unpackBlock)
CompletionIterator[T, Iterator[T]](itr, {
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
val shuffleMetrics = new ShuffleReadMetrics
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
@ -83,7 +87,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
metrics.shuffleReadMetrics = Some(shuffleMetrics)
context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
})
new InterruptibleIterator[T](context, completionIter)
}
}

View file

@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher {
* Fetch the shuffle outputs for a given ShuffleDependency.
* @return An iterator over the elements of the fetched shuffle outputs.
*/
def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
/** Stop the fetcher */

View file

@ -129,7 +129,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
case ShuffleCoGroupSplitDep(shuffleId) => {
// Read map outputs of shuffle
val fetcher = SparkEnv.get.shuffleFetcher
fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
kv => getSeq(kv._1)(depNum) += kv._2
}
}

View file

@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
private[spark]
@ -71,49 +71,52 @@ class NewHadoopRDD[K, V](
result
}
override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
if (format.isInstanceOf[Configurable]) {
format.asInstanceOf[Configurable].setConf(conf)
}
val reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback(() => close())
var havePair = false
var finished = false
override def hasNext: Boolean = {
if (!finished && !havePair) {
finished = !reader.nextKeyValue
havePair = !finished
override def compute(theSplit: Partition, context: TaskContext) = {
val iter = new Iterator[(K, V)] {
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
if (format.isInstanceOf[Configurable]) {
format.asInstanceOf[Configurable].setConf(conf)
}
!finished
}
val reader = format.createRecordReader(
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
override def next: (K, V) = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
// Register an on-task-completion callback to close the input stream.
context.addOnCompleteCallback(() => close())
var havePair = false
var finished = false
override def hasNext: Boolean = {
if (!finished && !havePair) {
finished = !reader.nextKeyValue
havePair = !finished
}
!finished
}
havePair = false
return (reader.getCurrentKey, reader.getCurrentValue)
}
private def close() {
try {
reader.close()
} catch {
case e: Exception => logWarning("Exception in RecordReader.close()", e)
override def next(): (K, V) = {
if (!hasNext) {
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
(reader.getCurrentKey, reader.getCurrentValue)
}
private def close() {
try {
reader.close()
} catch {
case e: Exception => logWarning("Exception in RecordReader.close()", e)
}
}
}
new InterruptibleIterator(context, iter)
}
override def getPreferredLocations(split: Partition): Seq[String] = {

View file

@ -94,8 +94,9 @@ private[spark] class ParallelCollectionRDD[T: ClassManifest](
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
}
override def compute(s: Partition, context: TaskContext) =
s.asInstanceOf[ParallelCollectionPartition[T]].iterator
override def compute(s: Partition, context: TaskContext) = {
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
}
override def getPreferredLocations(s: Partition): Seq[String] = {
locationPrefs.getOrElse(s.index, Nil)

View file

@ -56,7 +56,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest](
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics,
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
SparkEnv.get.serializerManager.get(serializerClass))
}

View file

@ -108,7 +108,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
}
case ShuffleCoGroupSplitDep(shuffleId) => {
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
context.taskMetrics, serializer)
context, serializer)
iter.foreach(op)
}
}