[SPARK-4433] fix a racing condition in zipWithIndex
Spark hangs with the following code: ~~~ sc.parallelize(1 to 10).zipWithIndex.repartition(10).count() ~~~ This is because ZippedWithIndexRDD triggers a job in getPartitions and it causes a deadlock in DAGScheduler.getPreferredLocs (synced). The fix is to compute `startIndices` during construction. This should be applied to branch-1.0, branch-1.1, and branch-1.2. pwendell Author: Xiangrui Meng <meng@databricks.com> Closes #3291 from mengxr/SPARK-4433 and squashes the following commits: c284d9f [Xiangrui Meng] fix a racing condition in zipWithIndex
This commit is contained in:
parent
4a377aff2d
commit
bb46046154
|
@ -39,21 +39,24 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
|
|||
private[spark]
|
||||
class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) {
|
||||
|
||||
override def getPartitions: Array[Partition] = {
|
||||
/** The start index of each partition. */
|
||||
@transient private val startIndices: Array[Long] = {
|
||||
val n = prev.partitions.size
|
||||
val startIndices: Array[Long] =
|
||||
if (n == 0) {
|
||||
Array[Long]()
|
||||
} else if (n == 1) {
|
||||
Array(0L)
|
||||
} else {
|
||||
prev.context.runJob(
|
||||
prev,
|
||||
Utils.getIteratorSize _,
|
||||
0 until n - 1, // do not need to count the last partition
|
||||
false
|
||||
).scanLeft(0L)(_ + _)
|
||||
}
|
||||
if (n == 0) {
|
||||
Array[Long]()
|
||||
} else if (n == 1) {
|
||||
Array(0L)
|
||||
} else {
|
||||
prev.context.runJob(
|
||||
prev,
|
||||
Utils.getIteratorSize _,
|
||||
0 until n - 1, // do not need to count the last partition
|
||||
allowLocal = false
|
||||
).scanLeft(0L)(_ + _)
|
||||
}
|
||||
}
|
||||
|
||||
override def getPartitions: Array[Partition] = {
|
||||
firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))
|
||||
}
|
||||
|
||||
|
|
|
@ -739,6 +739,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
|
|||
}
|
||||
}
|
||||
|
||||
test("zipWithIndex chained with other RDDs (SPARK-4433)") {
|
||||
val count = sc.parallelize(0 until 10, 2).zipWithIndex().repartition(4).count()
|
||||
assert(count === 10)
|
||||
}
|
||||
|
||||
test("zipWithUniqueId") {
|
||||
val n = 10
|
||||
val data = sc.parallelize(0 until n, 3)
|
||||
|
|
Loading…
Reference in a new issue