[SPARK-23079][SQL] Fix query constraints propagation with aliases
## What changes were proposed in this pull request? Previously, PR #19201 fix the problem of non-converging constraints. After that PR #19149 improve the loop and constraints is inferred only once. So the problem of non-converging constraints is gone. However, the case below will fail. ``` spark.range(5).write.saveAsTable("t") val t = spark.read.table("t") val left = t.withColumn("xid", $"id" + lit(1)).as("x") val right = t.withColumnRenamed("id", "xid").as("y") val df = left.join(right, "xid").filter("id = 3").toDF() checkAnswer(df, Row(4, 3)) ``` Because `aliasMap` replace all the aliased child. See the test case in PR for details. This PR is to fix this bug by removing useless code for preventing non-converging constraints. It can be also fixed with #20270, but this is much simpler and clean up the code. ## How was this patch tested? Unit test Author: Wang Gengliang <ltnwgl@gmail.com> Closes #20278 from gengliangwang/FixConstraintSimple.
This commit is contained in:
parent
0f8a28617a
commit
8598a982b4
|
@ -255,6 +255,7 @@ abstract class UnaryNode extends LogicalPlan {
|
|||
case expr: Expression if expr.semanticEquals(e) =>
|
||||
a.toAttribute
|
||||
})
|
||||
allConstraints += EqualNullSafe(e, a.toAttribute)
|
||||
case _ => // Don't change.
|
||||
}
|
||||
|
||||
|
|
|
@ -94,25 +94,16 @@ trait QueryPlanConstraints { self: LogicalPlan =>
|
|||
case _ => Seq.empty[Attribute]
|
||||
}
|
||||
|
||||
// Collect aliases from expressions of the whole tree rooted by the current QueryPlan node, so
|
||||
// we may avoid producing recursive constraints.
|
||||
private lazy val aliasMap: AttributeMap[Expression] = AttributeMap(
|
||||
expressions.collect {
|
||||
case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child)
|
||||
} ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap))
|
||||
// Note: the explicit cast is necessary, since Scala compiler fails to infer the type.
|
||||
|
||||
/**
|
||||
* Infers an additional set of constraints from a given set of equality constraints.
|
||||
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
|
||||
* additional constraint of the form `b = 5`.
|
||||
*/
|
||||
private def inferAdditionalConstraints(constraints: Set[Expression]): Set[Expression] = {
|
||||
val aliasedConstraints = eliminateAliasedExpressionInConstraints(constraints)
|
||||
var inferredConstraints = Set.empty[Expression]
|
||||
aliasedConstraints.foreach {
|
||||
constraints.foreach {
|
||||
case eq @ EqualTo(l: Attribute, r: Attribute) =>
|
||||
val candidateConstraints = aliasedConstraints - eq
|
||||
val candidateConstraints = constraints - eq
|
||||
inferredConstraints ++= replaceConstraints(candidateConstraints, l, r)
|
||||
inferredConstraints ++= replaceConstraints(candidateConstraints, r, l)
|
||||
case _ => // No inference
|
||||
|
@ -120,30 +111,6 @@ trait QueryPlanConstraints { self: LogicalPlan =>
|
|||
inferredConstraints -- constraints
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace the aliased expression in [[Alias]] with the alias name if both exist in constraints.
|
||||
* Thus non-converging inference can be prevented.
|
||||
* E.g. `Alias(b, f(a)), a = b` infers `f(a) = f(f(a))` without eliminating aliased expressions.
|
||||
* Also, the size of constraints is reduced without losing any information.
|
||||
* When the inferred filters are pushed down the operators that generate the alias,
|
||||
* the alias names used in filters are replaced by the aliased expressions.
|
||||
*/
|
||||
private def eliminateAliasedExpressionInConstraints(constraints: Set[Expression])
|
||||
: Set[Expression] = {
|
||||
val attributesInEqualTo = constraints.flatMap {
|
||||
case EqualTo(l: Attribute, r: Attribute) => l :: r :: Nil
|
||||
case _ => Nil
|
||||
}
|
||||
var aliasedConstraints = constraints
|
||||
attributesInEqualTo.foreach { a =>
|
||||
if (aliasMap.contains(a)) {
|
||||
val child = aliasMap.get(a).get
|
||||
aliasedConstraints = replaceConstraints(aliasedConstraints, child, a)
|
||||
}
|
||||
}
|
||||
aliasedConstraints
|
||||
}
|
||||
|
||||
private def replaceConstraints(
|
||||
constraints: Set[Expression],
|
||||
source: Expression,
|
||||
|
|
|
@ -34,6 +34,7 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
|
|||
PushDownPredicate,
|
||||
InferFiltersFromConstraints,
|
||||
CombineFilters,
|
||||
SimplifyBinaryComparison,
|
||||
BooleanSimplification) :: Nil
|
||||
}
|
||||
|
||||
|
@ -160,64 +161,6 @@ class InferFiltersFromConstraintsSuite extends PlanTest {
|
|||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("inner join with alias: don't generate constraints for recursive functions") {
|
||||
val t1 = testRelation.subquery('t1)
|
||||
val t2 = testRelation.subquery('t2)
|
||||
|
||||
// We should prevent `Coalese(a, b)` from recursively creating complicated constraints through
|
||||
// the constraint inference procedure.
|
||||
val originalQuery = t1.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col))
|
||||
// We hide an `Alias` inside the child's child's expressions, to cover the situation reported
|
||||
// in [SPARK-20700].
|
||||
.select('int_col, 'd, 'a).as("t")
|
||||
.join(t2, Inner,
|
||||
Some("t.a".attr === "t2.a".attr
|
||||
&& "t.d".attr === "t2.a".attr
|
||||
&& "t.int_col".attr === "t2.a".attr))
|
||||
.analyze
|
||||
val correctAnswer = t1
|
||||
.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) && IsNotNull(Coalesce(Seq('b, 'a)))
|
||||
&& IsNotNull('b) && IsNotNull(Coalesce(Seq('b, 'b))) && IsNotNull(Coalesce(Seq('a, 'b)))
|
||||
&& 'a === 'b && 'a === Coalesce(Seq('a, 'a)) && 'a === Coalesce(Seq('a, 'b))
|
||||
&& 'a === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('a, 'b))
|
||||
&& 'b === Coalesce(Seq('b, 'a)) && 'b === Coalesce(Seq('b, 'b)))
|
||||
.select('a, 'b.as('d), Coalesce(Seq('a, 'b)).as('int_col))
|
||||
.select('int_col, 'd, 'a).as("t")
|
||||
.join(
|
||||
t2.where(IsNotNull('a) && IsNotNull(Coalesce(Seq('a, 'a))) &&
|
||||
'a === Coalesce(Seq('a, 'a))),
|
||||
Inner,
|
||||
Some("t.a".attr === "t2.a".attr && "t.d".attr === "t2.a".attr
|
||||
&& "t.int_col".attr === "t2.a".attr))
|
||||
.analyze
|
||||
val optimized = Optimize.execute(originalQuery)
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("inner join with EqualTo expressions containing part of each other: don't generate " +
|
||||
"constraints for recursive functions") {
|
||||
val t1 = testRelation.subquery('t1)
|
||||
val t2 = testRelation.subquery('t2)
|
||||
|
||||
// We should prevent `c = Coalese(a, b)` and `a = Coalese(b, c)` from recursively creating
|
||||
// complicated constraints through the constraint inference procedure.
|
||||
val originalQuery = t1
|
||||
.select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
|
||||
.where('a === 'd && 'c === 'e)
|
||||
.join(t2, Inner, Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
|
||||
.analyze
|
||||
val correctAnswer = t1
|
||||
.where(IsNotNull('a) && IsNotNull('c) && 'a === Coalesce(Seq('b, 'c)) &&
|
||||
'c === Coalesce(Seq('a, 'b)))
|
||||
.select('a, 'b, 'c, Coalesce(Seq('b, 'c)).as('d), Coalesce(Seq('a, 'b)).as('e))
|
||||
.join(t2.where(IsNotNull('a) && IsNotNull('c)),
|
||||
Inner,
|
||||
Some("t1.a".attr === "t2.a".attr && "t1.c".attr === "t2.c".attr))
|
||||
.analyze
|
||||
val optimized = Optimize.execute(originalQuery)
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("generate correct filters for alias that don't produce recursive constraints") {
|
||||
val t1 = testRelation.subquery('t1)
|
||||
|
||||
|
|
|
@ -134,6 +134,8 @@ class ConstraintPropagationSuite extends SparkFunSuite with PlanTest {
|
|||
verifyConstraints(aliasedRelation.analyze.constraints,
|
||||
ExpressionSet(Seq(resolveColumn(aliasedRelation.analyze, "x") > 10,
|
||||
IsNotNull(resolveColumn(aliasedRelation.analyze, "x")),
|
||||
resolveColumn(aliasedRelation.analyze, "b") <=> resolveColumn(aliasedRelation.analyze, "y"),
|
||||
resolveColumn(aliasedRelation.analyze, "z") <=> resolveColumn(aliasedRelation.analyze, "x"),
|
||||
resolveColumn(aliasedRelation.analyze, "z") > 10,
|
||||
IsNotNull(resolveColumn(aliasedRelation.analyze, "z")))))
|
||||
|
||||
|
|
|
@ -2717,6 +2717,17 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
}
|
||||
|
||||
test("SPARK-23079: constraints should be inferred correctly with aliases") {
|
||||
withTable("t") {
|
||||
spark.range(5).write.saveAsTable("t")
|
||||
val t = spark.read.table("t")
|
||||
val left = t.withColumn("xid", $"id" + lit(1)).as("x")
|
||||
val right = t.withColumnRenamed("id", "xid").as("y")
|
||||
val df = left.join(right, "xid").filter("id = 3").toDF()
|
||||
checkAnswer(df, Row(4, 3))
|
||||
}
|
||||
}
|
||||
|
||||
test("SRARK-22266: the same aggregate function was calculated multiple times") {
|
||||
val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a"
|
||||
val df = sql(query)
|
||||
|
|
Loading…
Reference in a new issue