[SPARK-23005][CORE] Improve RDD.take on small number of partitions
## What changes were proposed in this pull request? In current implementation of RDD.take, we overestimate the number of partitions we need to try by 50%: `(1.5 * num * partsScanned / buf.size).toInt` However, when the number is small, the result of `.toInt` is not what we want. E.g, 2.9 will become 2, which should be 3. Use Math.ceil to fix the problem. Also clean up the code in RDD.scala. ## How was this patch tested? Unit test Author: Wang Gengliang <ltnwgl@gmail.com> Closes #20200 from gengliangwang/Take.
This commit is contained in:
parent
2250cb75b9
commit
96ba217a06
|
@ -150,7 +150,7 @@ abstract class RDD[T: ClassTag](
|
|||
val id: Int = sc.newRddId()
|
||||
|
||||
/** A friendly name for this RDD */
|
||||
@transient var name: String = null
|
||||
@transient var name: String = _
|
||||
|
||||
/** Assign a name to this RDD */
|
||||
def setName(_name: String): this.type = {
|
||||
|
@ -224,8 +224,8 @@ abstract class RDD[T: ClassTag](
|
|||
|
||||
// Our dependencies and partitions will be gotten by calling subclass's methods below, and will
|
||||
// be overwritten when we're checkpointed
|
||||
private var dependencies_ : Seq[Dependency[_]] = null
|
||||
@transient private var partitions_ : Array[Partition] = null
|
||||
private var dependencies_ : Seq[Dependency[_]] = _
|
||||
@transient private var partitions_ : Array[Partition] = _
|
||||
|
||||
/** An Option holding our checkpoint RDD, if we are checkpointed */
|
||||
private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD)
|
||||
|
@ -297,7 +297,7 @@ abstract class RDD[T: ClassTag](
|
|||
private[spark] def getNarrowAncestors: Seq[RDD[_]] = {
|
||||
val ancestors = new mutable.HashSet[RDD[_]]
|
||||
|
||||
def visit(rdd: RDD[_]) {
|
||||
def visit(rdd: RDD[_]): Unit = {
|
||||
val narrowDependencies = rdd.dependencies.filter(_.isInstanceOf[NarrowDependency[_]])
|
||||
val narrowParents = narrowDependencies.map(_.rdd)
|
||||
val narrowParentsNotVisited = narrowParents.filterNot(ancestors.contains)
|
||||
|
@ -449,7 +449,7 @@ abstract class RDD[T: ClassTag](
|
|||
if (shuffle) {
|
||||
/** Distributes elements evenly across output partitions, starting from a random partition. */
|
||||
val distributePartition = (index: Int, items: Iterator[T]) => {
|
||||
var position = (new Random(hashing.byteswap32(index))).nextInt(numPartitions)
|
||||
var position = new Random(hashing.byteswap32(index)).nextInt(numPartitions)
|
||||
items.map { t =>
|
||||
// Note that the hash code of the key will just be the key itself. The HashPartitioner
|
||||
// will mod it with the number of total partitions.
|
||||
|
@ -951,7 +951,7 @@ abstract class RDD[T: ClassTag](
|
|||
def collectPartition(p: Int): Array[T] = {
|
||||
sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p)).head
|
||||
}
|
||||
(0 until partitions.length).iterator.flatMap(i => collectPartition(i))
|
||||
partitions.indices.iterator.flatMap(i => collectPartition(i))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -1338,6 +1338,7 @@ abstract class RDD[T: ClassTag](
|
|||
// The number of partitions to try in this iteration. It is ok for this number to be
|
||||
// greater than totalParts because we actually cap it at totalParts in runJob.
|
||||
var numPartsToTry = 1L
|
||||
val left = num - buf.size
|
||||
if (partsScanned > 0) {
|
||||
// If we didn't find any rows after the previous iteration, quadruple and retry.
|
||||
// Otherwise, interpolate the number of partitions we need to try, but overestimate
|
||||
|
@ -1345,13 +1346,12 @@ abstract class RDD[T: ClassTag](
|
|||
if (buf.isEmpty) {
|
||||
numPartsToTry = partsScanned * scaleUpFactor
|
||||
} else {
|
||||
// the left side of max is >=1 whenever partsScanned >= 2
|
||||
numPartsToTry = Math.max((1.5 * num * partsScanned / buf.size).toInt - partsScanned, 1)
|
||||
// As left > 0, numPartsToTry is always >= 1
|
||||
numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
|
||||
numPartsToTry = Math.min(numPartsToTry, partsScanned * scaleUpFactor)
|
||||
}
|
||||
}
|
||||
|
||||
val left = num - buf.size
|
||||
val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
|
||||
val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p)
|
||||
|
||||
|
@ -1677,8 +1677,7 @@ abstract class RDD[T: ClassTag](
|
|||
// an RDD and its parent in every batch, in which case the parent may never be checkpointed
|
||||
// and its lineage never truncated, leading to OOMs in the long run (SPARK-6847).
|
||||
private val checkpointAllMarkedAncestors =
|
||||
Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS))
|
||||
.map(_.toBoolean).getOrElse(false)
|
||||
Option(sc.getLocalProperty(RDD.CHECKPOINT_ALL_MARKED_ANCESTORS)).exists(_.toBoolean)
|
||||
|
||||
/** Returns the first parent RDD */
|
||||
protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
|
||||
|
@ -1686,7 +1685,7 @@ abstract class RDD[T: ClassTag](
|
|||
}
|
||||
|
||||
/** Returns the jth parent RDD: e.g. rdd.parent[T](0) is equivalent to rdd.firstParent[T] */
|
||||
protected[spark] def parent[U: ClassTag](j: Int) = {
|
||||
protected[spark] def parent[U: ClassTag](j: Int): RDD[U] = {
|
||||
dependencies(j).rdd.asInstanceOf[RDD[U]]
|
||||
}
|
||||
|
||||
|
@ -1754,7 +1753,7 @@ abstract class RDD[T: ClassTag](
|
|||
* collected. Subclasses of RDD may override this method for implementing their own cleaning
|
||||
* logic. See [[org.apache.spark.rdd.UnionRDD]] for an example.
|
||||
*/
|
||||
protected def clearDependencies() {
|
||||
protected def clearDependencies(): Unit = {
|
||||
dependencies_ = null
|
||||
}
|
||||
|
||||
|
@ -1790,7 +1789,7 @@ abstract class RDD[T: ClassTag](
|
|||
val lastDepStrings =
|
||||
debugString(lastDep.rdd, prefix, lastDep.isInstanceOf[ShuffleDependency[_, _, _]], true)
|
||||
|
||||
(frontDepStrings ++ lastDepStrings)
|
||||
frontDepStrings ++ lastDepStrings
|
||||
}
|
||||
}
|
||||
// The first RDD in the dependency stack has no parents, so no need for a +-
|
||||
|
|
|
@ -351,8 +351,9 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
|
|||
if (buf.isEmpty) {
|
||||
numPartsToTry = partsScanned * limitScaleUpFactor
|
||||
} else {
|
||||
// the left side of max is >=1 whenever partsScanned >= 2
|
||||
numPartsToTry = Math.max((1.5 * n * partsScanned / buf.size).toInt - partsScanned, 1)
|
||||
val left = n - buf.size
|
||||
// As left > 0, numPartsToTry is always >= 1
|
||||
numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt
|
||||
numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue