diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 67298678e8..dd7193b256 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -193,27 +193,6 @@ class EquivalentExpressions { .sortBy(_.head)(new ExpressionContainmentOrdering) } - /** - * Orders `Expression` by parent/child relations. The child expression is smaller - * than parent expression. If there is child-parent relationships among the subexpressions, - * we want the child expressions come first than parent expressions, so we can replace - * child expressions in parent expressions with subexpression evaluation. Note that - * this is not for general expression ordering. For example, two irrelevant expressions - * will be considered as e1 < e2 and e2 < e1 by this ordering. But for the usage here, - * the order of irrelevant expressions does not matter. - */ - class ExpressionContainmentOrdering extends Ordering[Expression] { - override def compare(x: Expression, y: Expression): Int = { - if (x.semanticEquals(y)) { - 0 - } else if (x.find(_.semanticEquals(y)).isDefined) { - 1 - } else { - -1 - } - } - } - /** * Returns the state of the data structure as a string. If `all` is false, skips sets of * equivalent expressions with cardinality 1. @@ -229,3 +208,27 @@ class EquivalentExpressions { sb.toString() } } + +/** + * Orders `Expression` by parent/child relations. The child expression is smaller + * than parent expression. If there is child-parent relationships among the subexpressions, + * we want the child expressions come first than parent expressions, so we can replace + * child expressions in parent expressions with subexpression evaluation. Note that + * this is not for general expression ordering. For example, two irrelevant or semantically-equal + * expressions will be considered as equal by this ordering. But for the usage here, the order of + * irrelevant expressions does not matter. + */ +class ExpressionContainmentOrdering extends Ordering[Expression] { + override def compare(x: Expression, y: Expression): Int = { + if (x.find(_.semanticEquals(y)).isDefined) { + // `y` is child expression of `x`. + 1 + } else if (y.find(_.semanticEquals(x)).isDefined) { + // `x` is child expression of `y`. + -1 + } else { + // Irrelevant or semantically-equal expressions + 0 + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index 7a17a05439..11f987c960 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -370,6 +370,27 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel // `add1` is not in the elseValue, so we can't extract it from the branches assert(equivalence.getAllEquivalentExprs().count(_.size == 2) == 0) } + + test("SPARK-35439: sort exprs with ExpressionContainmentOrdering") { + val exprOrdering = new ExpressionContainmentOrdering + + val add1 = Add(Literal(1), Literal(2)) + val add2 = Add(Literal(2), Literal(3)) + + // Non parent-child expressions. Don't sort on them. + val exprs = Seq(add2, add1, add2, add1, add2, add1) + assert(exprs.sorted(exprOrdering) === exprs) + + val conditions = (GreaterThan(add1, Literal(3)), add1) :: + (GreaterThan(add2, Literal(4)), add1) :: + (GreaterThan(add2, Literal(5)), add1) :: Nil + + // `caseWhenExpr` contains add1, add2. + val caseWhenExpr = CaseWhen(conditions, None) + val exprs2 = Seq(caseWhenExpr, add2, add1, add2, add1, add2, add1) + assert(exprs2.sorted(exprOrdering) === + Seq(add2, add1, add2, add1, add2, add1, caseWhenExpr)) + } } case class CodegenFallbackExpression(child: Expression)