diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala index cc576bbc4c..f98ae82574 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortExec.scala @@ -177,6 +177,8 @@ case class SortExec( """.stripMargin.trim } + protected override val shouldStopRequired = false + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { s""" |${row.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c58474eba0..c31fd92447 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -206,6 +206,21 @@ trait CodegenSupport extends SparkPlan { def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException } + + /** + * For optimization to suppress shouldStop() in a loop of WholeStageCodegen. + * Returning true means we need to insert shouldStop() into the loop producing rows, if any. + */ + def isShouldStopRequired: Boolean = { + return shouldStopRequired && (this.parent == null || this.parent.isShouldStopRequired) + } + + /** + * Set to false if this plan consumes all rows produced by children but doesn't output row + * to buffer by calling append(), so the children don't require shouldStop() + * in the loop of producing rows. + */ + protected def shouldStopRequired: Boolean = true } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 4529ed067e..68c8e6ce62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -238,6 +238,8 @@ case class HashAggregateExec( """.stripMargin } + protected override val shouldStopRequired = false + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 87e90ed685..d876688a8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -387,8 +387,8 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) // How many values should be generated in the next batch. val nextBatchTodo = ctx.freshName("nextBatchTodo") - // The default size of a batch. - val batchSize = 1000L + // The default size of a batch, which must be positive integer + val batchSize = 1000 ctx.addNewFunction("initRange", s""" @@ -434,6 +434,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) val input = ctx.freshName("input") // Right now, Range is only used when there is one upstream. ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + + val localIdx = ctx.freshName("localIdx") + val localEnd = ctx.freshName("localEnd") + val range = ctx.freshName("range") + val shouldStop = if (isShouldStopRequired) { + s"if (shouldStop()) { $number = $value + ${step}L; return; }" + } else { + "// shouldStop check is eliminated" + } s""" | // initialize Range | if (!$initTerm) { @@ -442,11 +451,15 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | } | | while (true) { - | while ($number != $batchEnd) { - | long $value = $number; - | $number += ${step}L; - | ${consume(ctx, Seq(ev))} - | if (shouldStop()) return; + | long $range = $batchEnd - $number; + | if ($range != 0L) { + | int $localEnd = (int)($range / ${step}L); + | for (int $localIdx = 0; $localIdx < $localEnd; $localIdx++) { + | long $value = ((long)$localIdx * ${step}L) + $number; + | ${consume(ctx, Seq(ev))} + | $shouldStop + | } + | $number = $batchEnd; | } | | if ($taskContext.isInterrupted()) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index acf393a9b0..5e323c02b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -89,6 +89,22 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall val n = 9L * 1000 * 1000 * 1000 * 1000 * 1000 * 1000 val res13 = spark.range(-n, n, n / 9).select("id") assert(res13.count == 18) + + // range with non aggregation operation + val res14 = spark.range(0, 100, 2).toDF.filter("50 <= id") + val len14 = res14.collect.length + assert(len14 == 25) + + val res15 = spark.range(100, -100, -2).toDF.filter("id <= 0") + val len15 = res15.collect.length + assert(len15 == 50) + + val res16 = spark.range(-1500, 1500, 3).toDF.filter("0 <= id") + val len16 = res16.collect.length + assert(len16 == 500) + + val res17 = spark.range(10, 0, -1, 1).toDF.sortWithinPartitions("id") + assert(res17.collect === (1 to 10).map(i => Row(i)).toArray) } test("Range with randomized parameters") {