Made output of CoGroup and aggregations interruptible.

This commit is contained in:
Reynold Xin 2013-09-19 23:31:36 -07:00
parent c5e40954eb
commit 1d87616b61
4 changed files with 46 additions and 2 deletions

View file

@ -23,7 +23,7 @@ import java.util.{HashMap => JHashMap}
import scala.collection.JavaConversions import scala.collection.JavaConversions
import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext} import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency} import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
@ -134,7 +134,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
} }
} }
} }
JavaConversions.mapAsScalaMap(map).iterator new InterruptibleIterator(context, JavaConversions.mapAsScalaMap(map).iterator)
} }
override def clearDependencies() { override def clearDependencies() {

View file

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rdd
import org.apache.spark.{InterruptibleIterator, Partition, TaskContext}
/**
* Wraps around an existing RDD to make it interruptible (can be killed).
*/
private[spark]
class InterruptibleRDD[T: ClassManifest](prev: RDD[T]) extends RDD[T](prev) {
override def getPartitions: Array[Partition] = firstParent[T].partitions
override val partitioner = prev.partitioner
override def compute(split: Partition, context: TaskContext) = {
new InterruptibleIterator(context, firstParent[T].iterator(split, context))
}
}

View file

@ -85,17 +85,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners) val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
if (self.partitioner == Some(partitioner)) { if (self.partitioner == Some(partitioner)) {
self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
.interruptible()
} else if (mapSideCombine) { } else if (mapSideCombine) {
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner) val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
.setSerializer(serializerClass) .setSerializer(serializerClass)
partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true) partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
.interruptible()
} else { } else {
// Don't apply map-side combiner. // Don't apply map-side combiner.
// A sanity check to make sure mergeCombiners is not defined. // A sanity check to make sure mergeCombiners is not defined.
assert(mergeCombiners == null) assert(mergeCombiners == null)
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass) val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true) values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
.interruptible()
} }
} }

View file

@ -852,6 +852,11 @@ abstract class RDD[T: ClassManifest](
map(x => (f(x), x)) map(x => (f(x), x))
} }
/**
* Creates an interruptible version of this RDD.
*/
def interruptible(): RDD[T] = new InterruptibleRDD(this)
/** A private method for tests, to look at the contents of each partition */ /** A private method for tests, to look at the contents of each partition */
private[spark] def collectPartitions(): Array[Array[T]] = { private[spark] def collectPartitions(): Array[Array[T]] = {
sc.runJob(this, (iter: Iterator[T]) => iter.toArray) sc.runJob(this, (iter: Iterator[T]) => iter.toArray)