Added task killing iterator to RDDs that take inputs.
This commit is contained in:
parent
f19984dafe
commit
70953810b4
|
@ -28,7 +28,11 @@ import org.apache.spark.util.CompletionIterator
|
|||
|
||||
private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
|
||||
|
||||
override def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics, serializer: Serializer)
|
||||
override def fetch[T](
|
||||
shuffleId: Int,
|
||||
reduceId: Int,
|
||||
context: TaskContext,
|
||||
serializer: Serializer)
|
||||
: Iterator[T] =
|
||||
{
|
||||
|
||||
|
@ -74,7 +78,7 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
|
|||
val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
|
||||
val itr = blockFetcherItr.flatMap(unpackBlock)
|
||||
|
||||
CompletionIterator[T, Iterator[T]](itr, {
|
||||
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
|
||||
val shuffleMetrics = new ShuffleReadMetrics
|
||||
shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
|
||||
shuffleMetrics.remoteFetchTime = blockFetcherItr.remoteFetchTime
|
||||
|
@ -83,7 +87,9 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
|
|||
shuffleMetrics.totalBlocksFetched = blockFetcherItr.totalBlocks
|
||||
shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
|
||||
shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
|
||||
metrics.shuffleReadMetrics = Some(shuffleMetrics)
|
||||
context.taskMetrics.shuffleReadMetrics = Some(shuffleMetrics)
|
||||
})
|
||||
|
||||
new InterruptibleIterator[T](context, completionIter)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -27,7 +27,10 @@ private[spark] abstract class ShuffleFetcher {
|
|||
* Fetch the shuffle outputs for a given ShuffleDependency.
|
||||
* @return An iterator over the elements of the fetched shuffle outputs.
|
||||
*/
|
||||
def fetch[T](shuffleId: Int, reduceId: Int, metrics: TaskMetrics,
|
||||
def fetch[T](
|
||||
shuffleId: Int,
|
||||
reduceId: Int,
|
||||
context: TaskContext,
|
||||
serializer: Serializer = SparkEnv.get.serializerManager.default): Iterator[T]
|
||||
|
||||
/** Stop the fetcher */
|
||||
|
|
|
@ -129,7 +129,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
|
|||
case ShuffleCoGroupSplitDep(shuffleId) => {
|
||||
// Read map outputs of shuffle
|
||||
val fetcher = SparkEnv.get.shuffleFetcher
|
||||
fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context.taskMetrics, ser).foreach {
|
||||
fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser).foreach {
|
||||
kv => getSeq(kv._1)(depNum) += kv._2
|
||||
}
|
||||
}
|
||||
|
|
|
@ -24,7 +24,7 @@ import org.apache.hadoop.conf.{Configurable, Configuration}
|
|||
import org.apache.hadoop.io.Writable
|
||||
import org.apache.hadoop.mapreduce._
|
||||
|
||||
import org.apache.spark.{Dependency, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
|
||||
import org.apache.spark.{InterruptibleIterator, Logging, Partition, SerializableWritable, SparkContext, TaskContext}
|
||||
|
||||
|
||||
private[spark]
|
||||
|
@ -71,49 +71,52 @@ class NewHadoopRDD[K, V](
|
|||
result
|
||||
}
|
||||
|
||||
override def compute(theSplit: Partition, context: TaskContext) = new Iterator[(K, V)] {
|
||||
val split = theSplit.asInstanceOf[NewHadoopPartition]
|
||||
logInfo("Input split: " + split.serializableHadoopSplit)
|
||||
val conf = confBroadcast.value.value
|
||||
val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
|
||||
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
|
||||
val format = inputFormatClass.newInstance
|
||||
if (format.isInstanceOf[Configurable]) {
|
||||
format.asInstanceOf[Configurable].setConf(conf)
|
||||
}
|
||||
val reader = format.createRecordReader(
|
||||
split.serializableHadoopSplit.value, hadoopAttemptContext)
|
||||
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
|
||||
|
||||
// Register an on-task-completion callback to close the input stream.
|
||||
context.addOnCompleteCallback(() => close())
|
||||
|
||||
var havePair = false
|
||||
var finished = false
|
||||
|
||||
override def hasNext: Boolean = {
|
||||
if (!finished && !havePair) {
|
||||
finished = !reader.nextKeyValue
|
||||
havePair = !finished
|
||||
override def compute(theSplit: Partition, context: TaskContext) = {
|
||||
val iter = new Iterator[(K, V)] {
|
||||
val split = theSplit.asInstanceOf[NewHadoopPartition]
|
||||
logInfo("Input split: " + split.serializableHadoopSplit)
|
||||
val conf = confBroadcast.value.value
|
||||
val attemptId = newTaskAttemptID(jobtrackerId, id, true, split.index, 0)
|
||||
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
|
||||
val format = inputFormatClass.newInstance
|
||||
if (format.isInstanceOf[Configurable]) {
|
||||
format.asInstanceOf[Configurable].setConf(conf)
|
||||
}
|
||||
!finished
|
||||
}
|
||||
val reader = format.createRecordReader(
|
||||
split.serializableHadoopSplit.value, hadoopAttemptContext)
|
||||
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
|
||||
|
||||
override def next: (K, V) = {
|
||||
if (!hasNext) {
|
||||
throw new java.util.NoSuchElementException("End of stream")
|
||||
// Register an on-task-completion callback to close the input stream.
|
||||
context.addOnCompleteCallback(() => close())
|
||||
|
||||
var havePair = false
|
||||
var finished = false
|
||||
|
||||
override def hasNext: Boolean = {
|
||||
if (!finished && !havePair) {
|
||||
finished = !reader.nextKeyValue
|
||||
havePair = !finished
|
||||
}
|
||||
!finished
|
||||
}
|
||||
havePair = false
|
||||
return (reader.getCurrentKey, reader.getCurrentValue)
|
||||
}
|
||||
|
||||
private def close() {
|
||||
try {
|
||||
reader.close()
|
||||
} catch {
|
||||
case e: Exception => logWarning("Exception in RecordReader.close()", e)
|
||||
override def next(): (K, V) = {
|
||||
if (!hasNext) {
|
||||
throw new java.util.NoSuchElementException("End of stream")
|
||||
}
|
||||
havePair = false
|
||||
(reader.getCurrentKey, reader.getCurrentValue)
|
||||
}
|
||||
|
||||
private def close() {
|
||||
try {
|
||||
reader.close()
|
||||
} catch {
|
||||
case e: Exception => logWarning("Exception in RecordReader.close()", e)
|
||||
}
|
||||
}
|
||||
}
|
||||
new InterruptibleIterator(context, iter)
|
||||
}
|
||||
|
||||
override def getPreferredLocations(split: Partition): Seq[String] = {
|
||||
|
|
|
@ -94,8 +94,9 @@ private[spark] class ParallelCollectionRDD[T: ClassManifest](
|
|||
slices.indices.map(i => new ParallelCollectionPartition(id, i, slices(i))).toArray
|
||||
}
|
||||
|
||||
override def compute(s: Partition, context: TaskContext) =
|
||||
s.asInstanceOf[ParallelCollectionPartition[T]].iterator
|
||||
override def compute(s: Partition, context: TaskContext) = {
|
||||
new InterruptibleIterator(context, s.asInstanceOf[ParallelCollectionPartition[T]].iterator)
|
||||
}
|
||||
|
||||
override def getPreferredLocations(s: Partition): Seq[String] = {
|
||||
locationPrefs.getOrElse(s.index, Nil)
|
||||
|
|
|
@ -56,7 +56,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest](
|
|||
|
||||
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
|
||||
val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
|
||||
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context.taskMetrics,
|
||||
SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context,
|
||||
SparkEnv.get.serializerManager.get(serializerClass))
|
||||
}
|
||||
|
||||
|
|
|
@ -108,7 +108,7 @@ private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassM
|
|||
}
|
||||
case ShuffleCoGroupSplitDep(shuffleId) => {
|
||||
val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
|
||||
context.taskMetrics, serializer)
|
||||
context, serializer)
|
||||
iter.foreach(op)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue