Made output of CoGroup and aggregations interruptible.
This commit is contained in:
parent
c5e40954eb
commit
1d87616b61
|
@ -23,7 +23,7 @@ import java.util.{HashMap => JHashMap}
|
||||||
import scala.collection.JavaConversions
|
import scala.collection.JavaConversions
|
||||||
import scala.collection.mutable.ArrayBuffer
|
import scala.collection.mutable.ArrayBuffer
|
||||||
|
|
||||||
import org.apache.spark.{Partition, Partitioner, SparkEnv, TaskContext}
|
import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv, TaskContext}
|
||||||
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
|
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
|
||||||
|
|
||||||
|
|
||||||
|
@ -134,7 +134,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
JavaConversions.mapAsScalaMap(map).iterator
|
new InterruptibleIterator(context, JavaConversions.mapAsScalaMap(map).iterator)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def clearDependencies() {
|
override def clearDependencies() {
|
||||||
|
|
|
@ -0,0 +1,36 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.rdd
|
||||||
|
|
||||||
|
import org.apache.spark.{InterruptibleIterator, Partition, TaskContext}
|
||||||
|
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Wraps around an existing RDD to make it interruptible (can be killed).
|
||||||
|
*/
|
||||||
|
private[spark]
|
||||||
|
class InterruptibleRDD[T: ClassManifest](prev: RDD[T]) extends RDD[T](prev) {
|
||||||
|
|
||||||
|
override def getPartitions: Array[Partition] = firstParent[T].partitions
|
||||||
|
|
||||||
|
override val partitioner = prev.partitioner
|
||||||
|
|
||||||
|
override def compute(split: Partition, context: TaskContext) = {
|
||||||
|
new InterruptibleIterator(context, firstParent[T].iterator(split, context))
|
||||||
|
}
|
||||||
|
}
|
|
@ -85,17 +85,20 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
|
||||||
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
|
val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
|
||||||
if (self.partitioner == Some(partitioner)) {
|
if (self.partitioner == Some(partitioner)) {
|
||||||
self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
|
self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
|
||||||
|
.interruptible()
|
||||||
} else if (mapSideCombine) {
|
} else if (mapSideCombine) {
|
||||||
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
|
val combined = self.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
|
||||||
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
|
val partitioned = new ShuffledRDD[K, C, (K, C)](combined, partitioner)
|
||||||
.setSerializer(serializerClass)
|
.setSerializer(serializerClass)
|
||||||
partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
|
partitioned.mapPartitions(aggregator.combineCombinersByKey, preservesPartitioning = true)
|
||||||
|
.interruptible()
|
||||||
} else {
|
} else {
|
||||||
// Don't apply map-side combiner.
|
// Don't apply map-side combiner.
|
||||||
// A sanity check to make sure mergeCombiners is not defined.
|
// A sanity check to make sure mergeCombiners is not defined.
|
||||||
assert(mergeCombiners == null)
|
assert(mergeCombiners == null)
|
||||||
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
|
val values = new ShuffledRDD[K, V, (K, V)](self, partitioner).setSerializer(serializerClass)
|
||||||
values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
|
values.mapPartitions(aggregator.combineValuesByKey, preservesPartitioning = true)
|
||||||
|
.interruptible()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -852,6 +852,11 @@ abstract class RDD[T: ClassManifest](
|
||||||
map(x => (f(x), x))
|
map(x => (f(x), x))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an interruptible version of this RDD.
|
||||||
|
*/
|
||||||
|
def interruptible(): RDD[T] = new InterruptibleRDD(this)
|
||||||
|
|
||||||
/** A private method for tests, to look at the contents of each partition */
|
/** A private method for tests, to look at the contents of each partition */
|
||||||
private[spark] def collectPartitions(): Array[Array[T]] = {
|
private[spark] def collectPartitions(): Array[Array[T]] = {
|
||||||
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
|
sc.runJob(this, (iter: Iterator[T]) => iter.toArray)
|
||||||
|
|
Loading…
Reference in a new issue