[SPARK-32822][SQL] Change the number of partitions to zero when a range is empty with WholeStageCodegen disabled or falled back
### What changes were proposed in this pull request? This PR changes the behavior of RangeExec with WholeStageCodegen disabled or falled back to change the number of partitions to zero when a range is empty. In the current master, if WholeStageCodegen effects, the number of partitions of an empty range will be changed to zero. ``` spark.range(1, 1, 1, 1000).rdd.getNumPartitions res0: Int = 0 ``` But it doesn't if WholeStageCodegen is disabled or falled back. ``` spark.conf.set("spark.sql.codegen.wholeStage", false) spark.range(1, 1, 1, 1000).rdd.getNumPartitions res2: Int = 1000 ``` ### Why are the changes needed? To archive better performance even though WholeStageCodegen disabled or falled back. ### Does this PR introduce _any_ user-facing change? Yes. the number of partitions gotten with `getNumPartitions` for an empty range will be changed when WholeStageCodegen is disabled. ### How was this patch tested? New test. Closes #29681 from sarutak/zero-size-range. Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
parent
a22871f50a
commit
5f468cc21e
|
@ -371,6 +371,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
|||
val step: Long = range.step
|
||||
val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism)
|
||||
val numElements: BigInt = range.numElements
|
||||
val isEmptyRange: Boolean = start == end || (start < end ^ 0 < step)
|
||||
|
||||
override val output: Seq[Attribute] = range.output
|
||||
|
||||
|
@ -396,7 +397,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
|||
}
|
||||
|
||||
override def inputRDDs(): Seq[RDD[InternalRow]] = {
|
||||
val rdd = if (start == end || (start < end ^ 0 < step)) {
|
||||
val rdd = if (isEmptyRange) {
|
||||
new EmptyRDD[InternalRow](sqlContext.sparkContext)
|
||||
} else {
|
||||
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices).map(i => InternalRow(i))
|
||||
|
@ -562,58 +563,64 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
|
|||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
val numOutputRows = longMetric("numOutputRows")
|
||||
sqlContext
|
||||
.sparkContext
|
||||
.parallelize(0 until numSlices, numSlices)
|
||||
.mapPartitionsWithIndex { (i, _) =>
|
||||
val partitionStart = (i * numElements) / numSlices * step + start
|
||||
val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
|
||||
def getSafeMargin(bi: BigInt): Long =
|
||||
if (bi.isValidLong) {
|
||||
bi.toLong
|
||||
} else if (bi > 0) {
|
||||
Long.MaxValue
|
||||
} else {
|
||||
Long.MinValue
|
||||
}
|
||||
val safePartitionStart = getSafeMargin(partitionStart)
|
||||
val safePartitionEnd = getSafeMargin(partitionEnd)
|
||||
val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
|
||||
val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
|
||||
val taskContext = TaskContext.get()
|
||||
if (isEmptyRange) {
|
||||
new EmptyRDD[InternalRow](sqlContext.sparkContext)
|
||||
} else {
|
||||
sqlContext
|
||||
.sparkContext
|
||||
.parallelize(0 until numSlices, numSlices)
|
||||
.mapPartitionsWithIndex { (i, _) =>
|
||||
val partitionStart = (i * numElements) / numSlices * step + start
|
||||
val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start
|
||||
|
||||
val iter = new Iterator[InternalRow] {
|
||||
private[this] var number: Long = safePartitionStart
|
||||
private[this] var overflow: Boolean = false
|
||||
private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics
|
||||
|
||||
override def hasNext =
|
||||
if (!overflow) {
|
||||
if (step > 0) {
|
||||
number < safePartitionEnd
|
||||
} else {
|
||||
number > safePartitionEnd
|
||||
}
|
||||
} else false
|
||||
|
||||
override def next() = {
|
||||
val ret = number
|
||||
number += step
|
||||
if (number < ret ^ step < 0) {
|
||||
// we have Long.MaxValue + Long.MaxValue < Long.MaxValue
|
||||
// and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step
|
||||
// back, we are pretty sure that we have an overflow.
|
||||
overflow = true
|
||||
def getSafeMargin(bi: BigInt): Long =
|
||||
if (bi.isValidLong) {
|
||||
bi.toLong
|
||||
} else if (bi > 0) {
|
||||
Long.MaxValue
|
||||
} else {
|
||||
Long.MinValue
|
||||
}
|
||||
|
||||
numOutputRows += 1
|
||||
inputMetrics.incRecordsRead(1)
|
||||
unsafeRow.setLong(0, ret)
|
||||
unsafeRow
|
||||
val safePartitionStart = getSafeMargin(partitionStart)
|
||||
val safePartitionEnd = getSafeMargin(partitionEnd)
|
||||
val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize
|
||||
val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1)
|
||||
val taskContext = TaskContext.get()
|
||||
|
||||
val iter = new Iterator[InternalRow] {
|
||||
private[this] var number: Long = safePartitionStart
|
||||
private[this] var overflow: Boolean = false
|
||||
private[this] val inputMetrics = taskContext.taskMetrics().inputMetrics
|
||||
|
||||
override def hasNext =
|
||||
if (!overflow) {
|
||||
if (step > 0) {
|
||||
number < safePartitionEnd
|
||||
} else {
|
||||
number > safePartitionEnd
|
||||
}
|
||||
} else false
|
||||
|
||||
override def next() = {
|
||||
val ret = number
|
||||
number += step
|
||||
if (number < ret ^ step < 0) {
|
||||
// we have Long.MaxValue + Long.MaxValue < Long.MaxValue
|
||||
// and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step
|
||||
// back, we are pretty sure that we have an overflow.
|
||||
overflow = true
|
||||
}
|
||||
|
||||
numOutputRows += 1
|
||||
inputMetrics.incRecordsRead(1)
|
||||
unsafeRow.setLong(0, ret)
|
||||
unsafeRow
|
||||
}
|
||||
}
|
||||
new InterruptibleIterator(taskContext, iter)
|
||||
}
|
||||
new InterruptibleIterator(taskContext, iter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def simpleString(maxFields: Int): String = {
|
||||
|
|
|
@ -994,6 +994,13 @@ class PlannerSuite extends SharedSparkSession with AdaptiveSparkPlanHelper {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
testWithWholeStageCodegenOnAndOff("Change the number of partitions to zero " +
|
||||
"when a range is empty") { _ =>
|
||||
val range = spark.range(1, 1, 1, 1000)
|
||||
val numPartitions = range.rdd.getNumPartitions
|
||||
assert(numPartitions == 0)
|
||||
}
|
||||
}
|
||||
|
||||
// Used for unit-testing EnsureRequirements
|
||||
|
|
Loading…
Reference in a new issue