diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8cf4073826..574f91b099 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -966,9 +966,9 @@ class Analyzer( case s @ Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { - case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering) => + case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => if (index > 0 && index <= child.output.size) { - SortOrder(child.output(index - 1), direction, nullOrdering) + SortOrder(child.output(index - 1), direction, nullOrdering, Set.empty) } else { s.failAnalysis( s"ORDER BY position $index is not in select list " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala index af0a565f73..38a3d3de12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/SubstituteUnresolvedOrdinals.scala @@ -36,7 +36,7 @@ class SubstituteUnresolvedOrdinals(conf: CatalystConf) extends Rule[LogicalPlan] def apply(plan: LogicalPlan): LogicalPlan = plan transform { case s: Sort if conf.orderByOrdinal && s.order.exists(o => isIntLiteral(o.child)) => val newOrders = s.order.map { - case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _) => + case order @ SortOrder(ordinal @ Literal(index: Int, IntegerType), _, _, _) => val newOrdinal = withOrigin(ordinal.origin)(UnresolvedOrdinal(index)) withOrigin(order.origin)(order.copy(child = newOrdinal)) case other => other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 35ca2a0aa5..75bf780d41 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -109,9 +109,9 @@ package object dsl { def cast(to: DataType): Expression = Cast(expr, to) def asc: SortOrder = SortOrder(expr, Ascending) - def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast) + def asc_nullsLast: SortOrder = SortOrder(expr, Ascending, NullsLast, Set.empty) def desc: SortOrder = SortOrder(expr, Descending) - def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst) + def desc_nullsFirst: SortOrder = SortOrder(expr, Descending, NullsFirst, Set.empty) def as(alias: String): NamedExpression = Alias(expr, alias)() def as(alias: Symbol): NamedExpression = Alias(expr, alias.name)() } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 3bebd552ef..abcb9a2b93 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -53,8 +53,15 @@ case object NullsLast extends NullOrdering{ /** * An expression that can be used to sort a tuple. This class extends expression primarily so that * transformations over expression will descend into its child. + * `sameOrderExpressions` is a set of expressions with the same sort order as the child. It is + * derived from equivalence relation in an operator, e.g. left/right keys of an inner sort merge + * join. */ -case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering) +case class SortOrder( + child: Expression, + direction: SortDirection, + nullOrdering: NullOrdering, + sameOrderExpressions: Set[Expression]) extends UnaryExpression with Unevaluable { /** Sort order is not foldable because we don't have an eval for it. */ @@ -75,11 +82,19 @@ case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: override def sql: String = child.sql + " " + direction.sql + " " + nullOrdering.sql def isAscending: Boolean = direction == Ascending + + def satisfies(required: SortOrder): Boolean = { + (sameOrderExpressions + child).exists(required.child.semanticEquals) && + direction == required.direction && nullOrdering == required.nullOrdering + } } object SortOrder { - def apply(child: Expression, direction: SortDirection): SortOrder = { - new SortOrder(child, direction, direction.defaultNullOrdering) + def apply( + child: Expression, + direction: SortDirection, + sameOrderExpressions: Set[Expression] = Set.empty): SortOrder = { + new SortOrder(child, direction, direction.defaultNullOrdering, sameOrderExpressions) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4c9fb2ec27..cd238e05d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1229,7 +1229,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } else { direction.defaultNullOrdering } - SortOrder(expression(ctx.expression), direction, nullOrdering) + SortOrder(expression(ctx.expression), direction, nullOrdering, Set.empty) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 38029552d1..ae0703513c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -1037,7 +1037,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst) } + def desc_nulls_first: Column = withExpr { SortOrder(expr, Descending, NullsFirst, Set.empty) } /** * Returns a descending ordering used in sorting, where null values appear after non-null values. @@ -1052,7 +1052,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast) } + def desc_nulls_last: Column = withExpr { SortOrder(expr, Descending, NullsLast, Set.empty) } /** * Returns an ascending ordering used in sorting. @@ -1082,7 +1082,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst) } + def asc_nulls_first: Column = withExpr { SortOrder(expr, Ascending, NullsFirst, Set.empty) } /** * Returns an ordering used in sorting, where null values appear after non-null values. @@ -1097,7 +1097,7 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 2.1.0 */ - def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast) } + def asc_nulls_last: Column = withExpr { SortOrder(expr, Ascending, NullsLast, Set.empty) } /** * Prints the expression to the console for debugging purpose. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index f17049949a..b91d077442 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -241,7 +241,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } else { requiredOrdering.zip(child.outputOrdering).forall { case (requiredOrder, childOutputOrder) => - requiredOrder.semanticEquals(childOutputOrder) + childOutputOrder.satisfies(requiredOrder) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 02f4f55c79..c6aae1a4db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -81,17 +81,37 @@ case class SortMergeJoinExec( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil override def outputOrdering: Seq[SortOrder] = joinType match { + // For inner join, orders of both sides keys should be kept. + case Inner => + val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) + val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) + leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => + // Also add the right key and its `sameOrderExpressions` + SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey + .sameOrderExpressions) + } // For left and right outer joins, the output is ordered by the streamed input's join keys. - case LeftOuter => requiredOrders(leftKeys) - case RightOuter => requiredOrders(rightKeys) + case LeftOuter => getKeyOrdering(leftKeys, left.outputOrdering) + case RightOuter => getKeyOrdering(rightKeys, right.outputOrdering) // There are null rows in both streams, so there is no order. case FullOuter => Nil - case _: InnerLike | LeftExistence(_) => requiredOrders(leftKeys) + case LeftExistence(_) => getKeyOrdering(leftKeys, left.outputOrdering) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") } + /** + * For SMJ, child's output must have been sorted on key or expressions with the same order as + * key, so we can get ordering for key from child's output ordering. + */ + private def getKeyOrdering(keys: Seq[Expression], childOutputOrdering: Seq[SortOrder]) + : Seq[SortOrder] = { + keys.zip(childOutputOrdering).map { case (key, childOrder) => + SortOrder(key, Ascending, childOrder.sameOrderExpressions + childOrder.child - key) + } + } + override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index f2232fc489..4d155d538d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -477,14 +477,18 @@ class PlannerSuite extends SharedSQLContext { private val exprA = Literal(1) private val exprB = Literal(2) + private val exprC = Literal(3) private val orderingA = SortOrder(exprA, Ascending) private val orderingB = SortOrder(exprB, Ascending) + private val orderingC = SortOrder(exprC, Ascending) private val planA = DummySparkPlan(outputOrdering = Seq(orderingA), outputPartitioning = HashPartitioning(exprA :: Nil, 5)) private val planB = DummySparkPlan(outputOrdering = Seq(orderingB), outputPartitioning = HashPartitioning(exprB :: Nil, 5)) + private val planC = DummySparkPlan(outputOrdering = Seq(orderingC), + outputPartitioning = HashPartitioning(exprC :: Nil, 5)) - assert(orderingA != orderingB) + assert(orderingA != orderingB && orderingA != orderingC && orderingB != orderingC) private def assertSortRequirementsAreSatisfied( childPlan: SparkPlan, @@ -508,6 +512,30 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") { + val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) + // Both left and right keys should be sorted after the SMJ. + Seq(orderingA, orderingB).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = innerSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } + } + + test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " + + "child SMJ") { + val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, Inner, None, planA, planB) + val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, Inner, None, childSmj, planC) + // After the second SMJ, exprA, exprB and exprC should all be sorted. + Seq(orderingA, orderingB, orderingC).foreach { ordering => + assertSortRequirementsAreSatisfied( + childPlan = parentSmj, + requiredOrdering = Seq(ordering), + shouldHaveSort = false) + } + } + test("EnsureRequirements for sort operator after left outer sort merge join") { // Only left key is sorted after left outer SMJ (thus doesn't need a sort). val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB)