[SPARK-34882][SQL] Replace if with filter clause in RewriteDistinctAggregates

### What changes were proposed in this pull request?

Replaced the `agg(if (('gid = 1)) 'cat1 else null)` pattern in `RewriteDistinctAggregates` with `agg('cat1) FILTER (WHERE 'gid = 1)`

### Why are the changes needed?

For aggregate functions, that do not ignore NULL values (`First`, `Last` or `UDAF`s) the current approach can return wrong results.

In the added UT there are no nulls in the input `testData`. The query returned `Row(0, 1, 0, 51, 100)` before this PR.

### Does this PR introduce _any_ user-facing change?

Bugfix

### How was this patch tested?

UT

Closes #31983 from tanelk/SPARK-34882_distinct_agg_filter.

Lead-authored-by: Tanel Kiis <tanel.kiis@gmail.com>
Co-authored-by: tanel.kiis@gmail.com <tanel.kiis@gmail.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
Tanel Kiis 2021-04-01 07:42:53 +09:00 committed by Takeshi Yamamuro
parent 3951e3371a
commit 90f2d4d9cf
2 changed files with 49 additions and 27 deletions

View file

@ -59,9 +59,9 @@ import org.apache.spark.sql.types.IntegerType
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1)) 'cat1 else null),
* count(if (('gid = 2)) 'cat2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* functions = [count('cat1) FILTER (WHERE 'gid = 1),
* count('cat2) FILTER (WHERE 'gid = 2),
* first('total) ignore nulls FILTER (WHERE 'gid = 0)]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, 'cat1, 'cat2, 'gid]
@ -102,9 +102,9 @@ import org.apache.spark.sql.types.IntegerType
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1)) 'cat1 else null),
* count(if (('gid = 2)) 'cat2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* functions = [count('cat1) FILTER (WHERE 'gid = 1),
* count('cat2) FILTER (WHERE 'gid = 2),
* first('total) ignore nulls FILTER (WHERE 'gid = 0)]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, 'cat1, 'cat2, 'gid]
@ -145,9 +145,9 @@ import org.apache.spark.sql.types.IntegerType
* {{{
* Aggregate(
* key = ['key]
* functions = [count(if (('gid = 1) and 'max_cond1) 'cat1 else null),
* count(if (('gid = 2) and 'max_cond2) 'cat2 else null),
* first(if (('gid = 0)) 'total else null) ignore nulls]
* functions = [count('cat1) FILTER (WHERE 'gid = 1 and 'max_cond1),
* count('cat2) FILTER (WHERE 'gid = 2 and 'max_cond2),
* first('total) ignore nulls FILTER (WHERE 'gid = 0)]
* output = ['key, 'cat1_cnt, 'cat2_cnt, 'total])
* Aggregate(
* key = ['key, 'cat1, 'cat2, 'gid]
@ -242,14 +242,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
}
val groupByAttrs = groupByMap.map(_._2)
// Functions used to modify aggregate functions and their inputs.
def evalWithinGroup(id: Literal, e: Expression, condition: Option[Expression]) =
if (condition.isDefined) {
If(And(EqualTo(gid, id), condition.get), e, nullify(e))
} else {
If(EqualTo(gid, id), e, nullify(e))
}
def patchAggregateFunctionChildren(
af: AggregateFunction)(
attrs: Expression => Option[Expression]): AggregateFunction = {
@ -294,17 +286,19 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
val af = e.aggregateFunction
val condition = e.filter.flatMap(distinctAggFilterAttrLookup.get)
val naf = if (af.children.forall(_.foldable)) {
// If aggregateFunction's children are all foldable, we only put the first child in
// distinctAggGroups. So here we only need to rewrite the first child to
// `if (gid = ...) ...` or `if (gid = ... and condition) ...`.
val firstChild = evalWithinGroup(id, af.children.head, condition)
af.withNewChildren(firstChild +: af.children.drop(1)).asInstanceOf[AggregateFunction]
af
} else {
patchAggregateFunctionChildren(af) { x =>
distinctAggChildAttrLookup.get(x).map(evalWithinGroup(id, _, condition))
distinctAggChildAttrLookup.get(x)
}
}
(e, e.copy(aggregateFunction = naf, isDistinct = false, filter = None))
val newCondition = if (condition.isDefined) {
And(EqualTo(gid, id), condition.get)
} else {
EqualTo(gid, id)
}
(e, e.copy(aggregateFunction = naf, isDistinct = false, filter = Some(newCondition)))
}
(projection ++ filterProjection, operators)
@ -335,9 +329,10 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] {
// Select the result of the first aggregate in the last aggregate.
val result = AggregateExpression(
aggregate.First(evalWithinGroup(regularGroupId, operator.toAttribute, None), true),
aggregate.First(operator.toAttribute, ignoreNulls = true),
mode = Complete,
isDistinct = false)
isDistinct = false,
filter = Some(EqualTo(gid, regularGroupId)))
// Some aggregate functions (COUNT) have the special property that they can return a
// non-null result without any input. We need to make sure we return a result in this case.

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql
import java.io.{ByteArrayOutputStream, File}
import java.lang.{Long => JLong}
import java.nio.charset.StandardCharsets
import java.sql.{Date, Timestamp}
import java.util.{Locale, UUID}
@ -41,7 +42,7 @@ import org.apache.spark.sql.execution.{FilterExec, QueryExecution, WholeStageCod
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.expressions.{Aggregator, Window}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSparkSession}
@ -2834,6 +2835,32 @@ class DataFrameSuite extends QueryTest
df10.select(zip_with(col("array1"), col("array2"), (b1, b2) => reverseThenConcat2(b1, b2)))
checkAnswer(test10, Row(Array(Row("cbaihg"), Row("fedlkj"))) :: Nil)
}
test("SPARK-34882: Aggregate with multiple distinct null sensitive aggregators") {
withUserDefinedFunction(("countNulls", true)) {
spark.udf.register("countNulls", udaf(new Aggregator[JLong, JLong, JLong] {
def zero: JLong = 0L
def reduce(b: JLong, a: JLong): JLong = if (a == null) {
b + 1
} else {
b
}
def merge(b1: JLong, b2: JLong): JLong = b1 + b2
def finish(r: JLong): JLong = r
def bufferEncoder: Encoder[JLong] = Encoders.LONG
def outputEncoder: Encoder[JLong] = Encoders.LONG
}))
val result = testData.selectExpr(
"countNulls(key)",
"countNulls(DISTINCT key)",
"countNulls(key) FILTER (WHERE key > 50)",
"countNulls(DISTINCT key) FILTER (WHERE key > 50)",
"count(DISTINCT key)")
checkAnswer(result, Row(0, 0, 0, 0, 100))
}
}
}
case class GroupByKey(a: Int, b: Int)