[SPARK-34882][SQL] Replace if with filter clause in RewriteDistinctAggregates
### What changes were proposed in this pull request? Replaced the `agg(if (('gid = 1)) 'cat1 else null)` pattern in `RewriteDistinctAggregates` with `agg('cat1) FILTER (WHERE 'gid = 1)` ### Why are the changes needed? For aggregate functions, that do not ignore NULL values (`First`, `Last` or `UDAF`s) the current approach can return wrong results. In the added UT there are no nulls in the input `testData`. The query returned `Row(0, 1, 0, 51, 100)` before this PR. ### Does this PR introduce _any_ user-facing change? Bugfix ### How was this patch tested? UT Closes #31983 from tanelk/SPARK-34882_distinct_agg_filter. Lead-authored-by: Tanel Kiis <tanel.kiis@gmail.com> Co-authored-by: tanel.kiis@gmail.com <tanel.kiis@gmail.com> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
parent
3951e3371a
commit
90f2d4d9cf
|
@ -59,9 +59,9 @@ import org.apache.spark.sql.types.IntegerType
|
|||
* {{{
|
||||
* Aggregate(
|
||||
* key = ['key]
|
||||
* functions = [count(if (('gid = 1)) 'cat1 else null),
|
||||
* count(if (('gid = 2)) 'cat2 else null),
|
||||
* first(if (('gid = 0)) 'total else null) ignore nulls]
|
||||
* functions = [count('cat1) FILTER (WHERE 'gid = 1),
|
||||
* count('cat2) FILTER (WHERE 'gid = 2),
|
||||
* first('total) ignore nulls FILTER (WHERE 'gid = 0)]
|
||||
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
|
||||
* Aggregate(
|
||||
* key = ['key, 'cat1, 'cat2, 'gid]
|
||||
|
@ -102,9 +102,9 @@ import org.apache.spark.sql.types.IntegerType
|
|||
* {{{
|
||||
* Aggregate(
|
||||
* key = ['key]
|
||||
* functions = [count(if (('gid = 1)) 'cat1 else null),
|
||||
* count(if (('gid = 2)) 'cat2 else null),
|
||||
* first(if (('gid = 0)) 'total else null) ignore nulls]
|
||||
* functions = [count('cat1) FILTER (WHERE 'gid = 1),
|
||||
* count('cat2) FILTER (WHERE 'gid = 2),
|
||||
* first('total) ignore nulls FILTER (WHERE 'gid = 0)]
|
||||
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
|
||||
* Aggregate(
|
||||
* key = ['key, 'cat1, 'cat2, 'gid]
|
||||
|
@ -145,9 +145,9 @@ import org.apache.spark.sql.types.IntegerType
|
|||
* {{{
|
||||
* Aggregate(
|
||||
* key = ['key]
|
||||
* functions = [count(if (('gid = 1) and 'max_cond1) 'cat1 else null),
|
||||
* count(if (('gid = 2) and 'max_cond2) 'cat2 else null),
|
||||
* first(if (('gid = 0)) 'total else null) ignore nulls]
|
||||
* functions = [count('cat1) FILTER (WHERE 'gid = 1 and 'max_cond1),
|
||||
* count('cat2) FILTER (WHERE 'gid = 2 and 'max_cond2),
|
||||
* first('total) ignore nulls FILTER (WHERE 'gid = 0)]
|
||||
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
|
||||
* Aggregate(
|
||||
* key = ['key, 'cat1, 'cat2, 'gid]
|
||||
|
@ -242,14 +242,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
|
|||
}
|
||||
val groupByAttrs = groupByMap.map(_._2)
|
||||
|
||||
// Functions used to modify aggregate functions and their inputs.
|
||||
def evalWithinGroup(id: Literal, e: Expression, condition: Option[Expression]) =
|
||||
if (condition.isDefined) {
|
||||
If(And(EqualTo(gid, id), condition.get), e, nullify(e))
|
||||
} else {
|
||||
If(EqualTo(gid, id), e, nullify(e))
|
||||
}
|
||||
|
||||
def patchAggregateFunctionChildren(
|
||||
af: AggregateFunction)(
|
||||
attrs: Expression => Option[Expression]): AggregateFunction = {
|
||||
|
@ -294,17 +286,19 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
|
|||
val af = e.aggregateFunction
|
||||
val condition = e.filter.flatMap(distinctAggFilterAttrLookup.get)
|
||||
val naf = if (af.children.forall(_.foldable)) {
|
||||
// If aggregateFunction's children are all foldable, we only put the first child in
|
||||
// distinctAggGroups. So here we only need to rewrite the first child to
|
||||
// `if (gid = ...) ...` or `if (gid = ... and condition) ...`.
|
||||
val firstChild = evalWithinGroup(id, af.children.head, condition)
|
||||
af.withNewChildren(firstChild +: af.children.drop(1)).asInstanceOf[AggregateFunction]
|
||||
af
|
||||
} else {
|
||||
patchAggregateFunctionChildren(af) { x =>
|
||||
distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition))
|
||||
distinctAggChildAttrLookup.get(x)
|
||||
}
|
||||
}
|
||||
(e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None))
|
||||
val newCondition = if (condition.isDefined) {
|
||||
And(EqualTo(gid, id), condition.get)
|
||||
} else {
|
||||
EqualTo(gid, id)
|
||||
}
|
||||
|
||||
(e, e.copy(aggregateFunction = naf, isDistinct = false, filter = Some(newCondition)))
|
||||
}
|
||||
|
||||
(projection ++ filterProjection, operators)
|
||||
|
@ -335,9 +329,10 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
|
|||
|
||||
// Select the result of the first aggregate in the last aggregate.
|
||||
val result = AggregateExpression(
|
||||
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute, None), true),
|
||||
aggregate.First(operator.toAttribute, ignoreNulls = true),
|
||||
mode = Complete,
|
||||
isDistinct = false)
|
||||
isDistinct = false,
|
||||
filter = Some(EqualTo(gid, regularGroupId)))
|
||||
|
||||
// Some aggregate functions (COUNT) have the special property that they can return a
|
||||
// non-null result without any input. We need to make sure we return a result in this case.
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql
|
||||
|
||||
import java.io.{ByteArrayOutputStream, File}
|
||||
import java.lang.{Long => JLong}
|
||||
import java.nio.charset.StandardCharsets
|
||||
import java.sql.{Date, Timestamp}
|
||||
import java.util.{Locale, UUID}
|
||||
|
@ -41,7 +42,7 @@ import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCod
|
|||
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
|
||||
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
|
||||
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
|
||||
import org.apache.spark.sql.expressions.Window
|
||||
import org.apache.spark.sql.expressions.{Aggregator, Window}
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
|
||||
|
@ -2834,6 +2835,32 @@ class DataFrameSuite extends QueryTest
|
|||
df10.select(zip_with(col("array1"), col("array2"), (b1, b2) => reverseThenConcat2(b1, b2)))
|
||||
checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-34882: Aggregate with multiple distinct null sensitive aggregators") {
|
||||
withUserDefinedFunction(("countNulls", true)) {
|
||||
spark.udf.register("countNulls", udaf(new Aggregator[JLong, JLong, JLong] {
|
||||
def zero: JLong = 0L
|
||||
def reduce(b: JLong, a: JLong): JLong = if (a == null) {
|
||||
b + 1
|
||||
} else {
|
||||
b
|
||||
}
|
||||
def merge(b1: JLong, b2: JLong): JLong = b1 + b2
|
||||
def finish(r: JLong): JLong = r
|
||||
def bufferEncoder: Encoder[JLong] = Encoders.LONG
|
||||
def outputEncoder: Encoder[JLong] = Encoders.LONG
|
||||
}))
|
||||
|
||||
val result = testData.selectExpr(
|
||||
"countNulls(key)",
|
||||
"countNulls(DISTINCT key)",
|
||||
"countNulls(key) FILTER (WHERE key > 50)",
|
||||
"countNulls(DISTINCT key) FILTER (WHERE key > 50)",
|
||||
"count(DISTINCT key)")
|
||||
|
||||
checkAnswer(result, Row(0, 0, 0, 0, 100))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class GroupByKey(a: Int, b: Int)
|
||||
|
|
Loading…
Reference in a new issue