[SPARK-18300][SQL] Do not apply foldable propagation with expand as a child.
## What changes were proposed in this pull request? The `FoldablePropagation` optimizer rule, pulls foldable values out from under an `Expand`. This breaks the `Expand` in two ways: - It rewrites the output attributes of the `Expand`. We explicitly define output attributes for `Expand`, these are (unfortunately) considered as part of the expressions of the `Expand` and can be rewritten. - Expand can actually change the column (it will typically re-use the attributes or the underlying plan). This means that we cannot safely propagate the expressions from under an `Expand`. This PR fixes this and (hopefully) other issues by explicitly whitelisting allowed operators. ## How was this patch tested? Added tests to `FoldablePropagationSuite` and to `SQLQueryTestSuite`. Author: Herman van Hovell <hvanhovell@databricks.com> Closes #15857 from hvanhovell/SPARK-18300.
This commit is contained in:
parent
33be4da539
commit
f14ae4900a
|
@ -428,43 +428,49 @@ object FoldablePropagation extends Rule[LogicalPlan] {
|
|||
}
|
||||
case _ => Nil
|
||||
})
|
||||
val replaceFoldable: PartialFunction[Expression, Expression] = {
|
||||
case a: AttributeReference if foldableMap.contains(a) => foldableMap(a)
|
||||
}
|
||||
|
||||
if (foldableMap.isEmpty) {
|
||||
plan
|
||||
} else {
|
||||
var stop = false
|
||||
CleanupAliases(plan.transformUp {
|
||||
case u: Union =>
|
||||
stop = true
|
||||
u
|
||||
case c: Command =>
|
||||
stop = true
|
||||
c
|
||||
// For outer join, although its output attributes are derived from its children, they are
|
||||
// actually different attributes: the output of outer join is not always picked from its
|
||||
// children, but can also be null.
|
||||
// A leaf node should not stop the folding process (note that we are traversing up the
|
||||
// tree, starting at the leaf nodes); so we are allowing it.
|
||||
case l: LeafNode =>
|
||||
l
|
||||
|
||||
// Whitelist of all nodes we are allowed to apply this rule to.
|
||||
case p @ (_: Project | _: Filter | _: SubqueryAlias | _: Aggregate | _: Window |
|
||||
_: Sample | _: GlobalLimit | _: LocalLimit | _: Generate | _: Distinct |
|
||||
_: AppendColumns | _: AppendColumnsWithObject | _: BroadcastHint |
|
||||
_: RedistributeData | _: Repartition | _: Sort | _: TypedFilter) if !stop =>
|
||||
p.transformExpressions(replaceFoldable)
|
||||
|
||||
// Allow inner joins. We do not allow outer join, although its output attributes are
|
||||
// derived from its children, they are actually different attributes: the output of outer
|
||||
// join is not always picked from its children, but can also be null.
|
||||
// TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes
|
||||
// of outer join.
|
||||
case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) =>
|
||||
stop = true
|
||||
j
|
||||
case j @ Join(_, _, Inner, _) =>
|
||||
j.transformExpressions(replaceFoldable)
|
||||
|
||||
// These 3 operators take attributes as constructor parameters, and these attributes
|
||||
// can't be replaced by alias.
|
||||
case m: MapGroups =>
|
||||
// We can fold the projections an expand holds. However expand changes the output columns
|
||||
// and often reuses the underlying attributes; so we cannot assume that a column is still
|
||||
// foldable after the expand has been applied.
|
||||
// TODO(hvanhovell): Expand should use new attributes as the output attributes.
|
||||
case expand: Expand if !stop =>
|
||||
val newExpand = expand.copy(projections = expand.projections.map { projection =>
|
||||
projection.map(_.transform(replaceFoldable))
|
||||
})
|
||||
stop = true
|
||||
m
|
||||
case f: FlatMapGroupsInR =>
|
||||
stop = true
|
||||
f
|
||||
case c: CoGroup =>
|
||||
stop = true
|
||||
c
|
||||
newExpand
|
||||
|
||||
case p: LogicalPlan if !stop => p.transformExpressions {
|
||||
case a: AttributeReference if foldableMap.contains(a) =>
|
||||
foldableMap(a)
|
||||
}
|
||||
case other =>
|
||||
stop = true
|
||||
other
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -116,16 +116,35 @@ class FoldablePropagationSuite extends PlanTest {
|
|||
test("Propagate in subqueries of Union queries") {
|
||||
val query = Union(
|
||||
Seq(
|
||||
testRelation.select(Literal(1).as('x), 'a).select('x + 'a),
|
||||
testRelation.select(Literal(2).as('x), 'a).select('x + 'a)))
|
||||
testRelation.select(Literal(1).as('x), 'a).select('x, 'x + 'a),
|
||||
testRelation.select(Literal(2).as('x), 'a).select('x, 'x + 'a)))
|
||||
.select('x)
|
||||
val optimized = Optimize.execute(query.analyze)
|
||||
val correctAnswer = Union(
|
||||
Seq(
|
||||
testRelation.select(Literal(1).as('x), 'a).select((Literal(1).as('x) + 'a).as("(x + a)")),
|
||||
testRelation.select(Literal(2).as('x), 'a).select((Literal(2).as('x) + 'a).as("(x + a)"))))
|
||||
testRelation.select(Literal(1).as('x), 'a)
|
||||
.select(Literal(1).as('x), (Literal(1).as('x) + 'a).as("(x + a)")),
|
||||
testRelation.select(Literal(2).as('x), 'a)
|
||||
.select(Literal(2).as('x), (Literal(2).as('x) + 'a).as("(x + a)"))))
|
||||
.select('x).analyze
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("Propagate in expand") {
|
||||
val c1 = Literal(1).as('a)
|
||||
val c2 = Literal(2).as('b)
|
||||
val a1 = c1.toAttribute.withNullability(true)
|
||||
val a2 = c2.toAttribute.withNullability(true)
|
||||
val expand = Expand(
|
||||
Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))),
|
||||
Seq(a1, a2),
|
||||
OneRowRelation.select(c1, c2))
|
||||
val query = expand.where(a1.isNotNull).select(a1, a2).analyze
|
||||
val optimized = Optimize.execute(query)
|
||||
val correctExpand = expand.copy(projections = Seq(
|
||||
Seq(Literal(null), c2),
|
||||
Seq(c1, Literal(null))))
|
||||
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -32,3 +32,6 @@ SELECT a + 1 + 1, COUNT(b) FROM testData GROUP BY a + 1;
|
|||
-- Aggregate with nulls.
|
||||
SELECT SKEWNESS(a), KURTOSIS(a), MIN(a), MAX(a), AVG(a), VARIANCE(a), STDDEV(a), SUM(a), COUNT(a)
|
||||
FROM testData;
|
||||
|
||||
-- Aggregate with foldable input and multiple distinct groups.
|
||||
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 14
|
||||
-- Number of queries: 15
|
||||
|
||||
|
||||
-- !query 0
|
||||
|
@ -131,3 +131,11 @@ FROM testData
|
|||
struct<skewness(CAST(a AS DOUBLE)):double,kurtosis(CAST(a AS DOUBLE)):double,min(a):int,max(a):int,avg(a):double,var_samp(CAST(a AS DOUBLE)):double,stddev_samp(CAST(a AS DOUBLE)):double,sum(a):bigint,count(a):bigint>
|
||||
-- !query 13 output
|
||||
-0.2723801058145729 -1.5069204152249134 1 3 2.142857142857143 0.8095238095238094 0.8997354108424372 15 7
|
||||
|
||||
|
||||
-- !query 14
|
||||
SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a
|
||||
-- !query 14 schema
|
||||
struct<count(DISTINCT b):bigint,count(DISTINCT b, c):bigint>
|
||||
-- !query 14 output
|
||||
1 1
|
||||
|
|
Loading…
Reference in a new issue