[SPARK-35439][SQL][FOLLOWUP] ExpressionContainmentOrdering should not sort unrelated expressions

### What changes were proposed in this pull request?

This is a followup of #32586. We introduced `ExpressionContainmentOrdering` to sort common expressions according to their parent-child relations. For unrelated expressions, previously the ordering returns -1 which is not correct and can possibly lead to transitivity issue.

### Why are the changes needed?

To fix the possible transitivity issue of `ExpressionContainmentOrdering`.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit test.

Closes #32870 from viirya/SPARK-35439-followup.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
Liang-Chi Hsieh 2021-06-11 16:13:46 +09:00 committed by Takeshi Yamamuro
parent e9af4576d5
commit c463472e85
2 changed files with 45 additions and 21 deletions

View file

@ -193,27 +193,6 @@ class EquivalentExpressions {
.sortBy(_.head)(new ExpressionContainmentOrdering) .sortBy(_.head)(new ExpressionContainmentOrdering)
} }
/**
* Orders `Expression` by parent/child relations. The child expression is smaller
* than parent expression. If there is child-parent relationships among the subexpressions,
* we want the child expressions come first than parent expressions, so we can replace
* child expressions in parent expressions with subexpression evaluation. Note that
* this is not for general expression ordering. For example, two irrelevant expressions
* will be considered as e1 < e2 and e2 < e1 by this ordering. But for the usage here,
* the order of irrelevant expressions does not matter.
*/
class ExpressionContainmentOrdering extends Ordering[Expression] {
override def compare(x: Expression, y: Expression): Int = {
if (x.semanticEquals(y)) {
0
} else if (x.find(_.semanticEquals(y)).isDefined) {
1
} else {
-1
}
}
}
/** /**
* Returns the state of the data structure as a string. If `all` is false, skips sets of * Returns the state of the data structure as a string. If `all` is false, skips sets of
* equivalent expressions with cardinality 1. * equivalent expressions with cardinality 1.
@ -229,3 +208,27 @@ class EquivalentExpressions {
sb.toString() sb.toString()
} }
} }
/**
* Orders `Expression` by parent/child relations. The child expression is smaller
* than parent expression. If there is child-parent relationships among the subexpressions,
* we want the child expressions come first than parent expressions, so we can replace
* child expressions in parent expressions with subexpression evaluation. Note that
* this is not for general expression ordering. For example, two irrelevant or semantically-equal
* expressions will be considered as equal by this ordering. But for the usage here, the order of
* irrelevant expressions does not matter.
*/
class ExpressionContainmentOrdering extends Ordering[Expression] {
override def compare(x: Expression, y: Expression): Int = {
if (x.find(_.semanticEquals(y)).isDefined) {
// `y` is child expression of `x`.
1
} else if (y.find(_.semanticEquals(x)).isDefined) {
// `x` is child expression of `y`.
-1
} else {
// Irrelevant or semantically-equal expressions
0
}
}
}

View file

@ -370,6 +370,27 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
// `add1` is not in the elseValue, so we can't extract it from the branches // `add1` is not in the elseValue, so we can't extract it from the branches
assert(equivalence.getAllEquivalentExprs().count(_.size == 2) == 0) assert(equivalence.getAllEquivalentExprs().count(_.size == 2) == 0)
} }
test("SPARK-35439: sort exprs with ExpressionContainmentOrdering") {
val exprOrdering = new ExpressionContainmentOrdering
val add1 = Add(Literal(1), Literal(2))
val add2 = Add(Literal(2), Literal(3))
// Non parent-child expressions. Don't sort on them.
val exprs = Seq(add2, add1, add2, add1, add2, add1)
assert(exprs.sorted(exprOrdering) === exprs)
val conditions = (GreaterThan(add1, Literal(3)), add1) ::
(GreaterThan(add2, Literal(4)), add1) ::
(GreaterThan(add2, Literal(5)), add1) :: Nil
// `caseWhenExpr` contains add1, add2.
val caseWhenExpr = CaseWhen(conditions, None)
val exprs2 = Seq(caseWhenExpr, add2, add1, add2, add1, add2, add1)
assert(exprs2.sorted(exprOrdering) ===
Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr))
}
} }
case class CodegenFallbackExpression(child: Expression) case class CodegenFallbackExpression(child: Expression)