[SPARK-36745][SQL] ExtractEquiJoinKeys should return the original predicates on join keys

### What changes were proposed in this pull request?

This PR updates `ExtractEquiJoinKeys` to return an extra field for the join condition with join keys.

### Why are the changes needed?

Sometimes we need to restore the original join condition. Before this PR, we need to build `EqualTo` expressions with the join keys, which is not always the original join condition. E.g. `EqualNullSafe(a, b)` will become `EqualTo(Coalesce(a, lit), Coalesce(b, lit))`. After this PR, we can simply use the new returned field.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Existing tests.

Closes #33985 from YannisSismanis/SPARK-36475-fix.

Authored-by: Yannis Sismanis <yannis.sismanis@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Yannis Sismanis 2021-09-16 13:16:16 +08:00 committed by Wenchen Fan
parent bbb33af2e4
commit afd406e4d0
11 changed files with 40 additions and 31 deletions

View file

@ -41,7 +41,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging {
*/
def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = {
plan match {
case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _) =>
case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _, _) =>
(leftKeys ++ rightKeys).exists {
case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey)
case _ => false

View file

@ -67,7 +67,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
// Only hash join and sort merge join need the normalization. Here we catch all Joins with
// join keys, assuming Joins with join keys are always planned as hash join or sort merge
// join. It's very unlikely that we will break this assumption in the near future.
case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _)
case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _, _)
// The analyzer guarantees left and right joins keys are of the same data type. Here we
// only need to check join keys of one side.
if leftKeys.exists(k => needNormalize(k)) =>

View file

@ -176,10 +176,16 @@ object ScanOperation extends OperationHelper with PredicateHelper {
* value).
*/
object ExtractEquiJoinKeys extends Logging with PredicateHelper {
/** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild, joinHint) */
/** (joinType, leftKeys, rightKeys, otherCondition, conditionOnJoinKeys, leftChild,
* rightChild, joinHint).
*/
// Note that `otherCondition` is NOT the original Join condition and it contains only
// the subset that is not handled by the 'leftKeys' to 'rightKeys' equijoin.
// 'conditionOnJoinKeys' is the subset of the original Join condition that corresponds to the
// 'leftKeys' to 'rightKeys' equijoin.
type ReturnType =
(JoinType, Seq[Expression], Seq[Expression],
Option[Expression], LogicalPlan, LogicalPlan, JoinHint)
Option[Expression], Option[Expression], LogicalPlan, LogicalPlan, JoinHint)
def unapply(join: Join): Option[ReturnType] = join match {
case Join(left, right, joinType, condition, hint) =>
@ -197,15 +203,15 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
Seq((Coalesce(Seq(l, Literal.default(l.dataType))),
Coalesce(Seq(r, Literal.default(r.dataType)))),
(IsNull(l), IsNull(r))
)
) // (coalesce(l, default) = coalesce(r, default)) and (isnull(l) = isnull(r))
case EqualNullSafe(l, r) if canEvaluate(l, right) && canEvaluate(r, left) =>
Seq((Coalesce(Seq(r, Literal.default(r.dataType))),
Coalesce(Seq(l, Literal.default(l.dataType)))),
(IsNull(r), IsNull(l))
)
case other => None
) // Same as above with left/right reversed.
case _ => None
}
val otherPredicates = predicates.filterNot {
val (predicatesOfJoinKeys, otherPredicates) = predicates.partition {
case EqualTo(l, r) if l.references.isEmpty || r.references.isEmpty => false
case Equality(l, r) =>
canEvaluate(l, left) && canEvaluate(r, right) ||
@ -216,7 +222,8 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper {
if (joinKeys.nonEmpty) {
val (leftKeys, rightKeys) = joinKeys.unzip
logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys")
Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right, hint))
Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And),
predicatesOfJoinKeys.reduceOption(And), left, right, hint))
} else {
None
}

View file

@ -56,7 +56,7 @@ case class JoinEstimation(join: Join) extends Logging {
case _ if !rowCountsExist(join.left, join.right) =>
None
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _) =>
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _, _) =>
// 1. Compute join selectivity
val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys)
val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs)

View file

@ -198,7 +198,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// 4. Pick cartesian product if join type is inner like.
// 5. Pick broadcast nested loop join as the final solution. It may OOM but we don't have
// other choice.
case j @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond, left, right, hint) =>
case j @ ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, nonEquiCond,
_, left, right, hint) =>
def createBroadcastHashJoin(onlyLookingAtHint: Boolean) = {
val buildSide = getBroadcastBuildSide(
left, right, joinType, hint, onlyLookingAtHint, conf)
@ -461,11 +462,10 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object StreamingJoinStrategy extends Strategy {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = {
plan match {
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, _)
if left.isStreaming && right.isStreaming =>
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, otherCondition, _,
left, right, _) if left.isStreaming && right.isStreaming =>
val stateVersion = conf.getConf(SQLConf.STREAMING_JOIN_STATE_FORMAT_VERSION)
new StreamingSymmetricHashJoinExec(leftKeys, rightKeys, joinType, condition,
new StreamingSymmetricHashJoinExec(leftKeys, rightKeys, joinType, otherCondition,
stateVersion, planLater(left), planLater(right)) :: Nil
case Join(left, right, _, _, _) if left.isStreaming && right.isStreaming =>

View file

@ -69,7 +69,7 @@ object DynamicJoinSelection extends Rule[LogicalPlan] {
}
def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
case j @ ExtractEquiJoinKeys(_, _, _, _, left, right, hint) =>
case j @ ExtractEquiJoinKeys(_, _, _, _, _, left, right, hint) =>
var newHint = hint
if (!hint.leftHint.exists(_.strategy.isDefined)) {
selectJoinStrategy(left).foreach { strategy =>

View file

@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BroadcastNes
* 2. Transforms [[Join]] which has one child relation already planned and executed as a
* [[BroadcastQueryStageExec]]. This is to prevent reversing a broadcast stage into a shuffle
* stage in case of the larger join child relation finishes before the smaller relation. Note
* that this rule needs to applied before regular join strategies.
* that this rule needs to be applied before regular join strategies.
*/
object LogicalQueryStageStrategy extends Strategy with PredicateHelper {
@ -43,11 +43,13 @@ object LogicalQueryStageStrategy extends Strategy with PredicateHelper {
}
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right, hint)
case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, otherCondition, _,
left, right, hint)
if isBroadcastStage(left) || isBroadcastStage(right) =>
val buildSide = if (isBroadcastStage(left)) BuildLeft else BuildRight
Seq(BroadcastHashJoinExec(
leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right)))
leftKeys, rightKeys, joinType, buildSide, otherCondition, planLater(left),
planLater(right)))
case j @ ExtractSingleColumnNullAwareAntiJoin(leftKeys, rightKeys)
if isBroadcastStage(j.right) =>

View file

@ -256,7 +256,7 @@ object PartitionPruning extends Rule[LogicalPlan] with PredicateHelper with Join
// extract the left and right keys of the join condition
val (leftKeys, rightKeys) = j match {
case ExtractEquiJoinKeys(_, lkeys, rkeys, _, _, _, _) => (lkeys, rkeys)
case ExtractEquiJoinKeys(_, lkeys, rkeys, _, _, _, _, _) => (lkeys, rkeys)
case _ => (Nil, Nil)
}

View file

@ -106,7 +106,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
}
testWithWholeStageCodegenOnAndOff(s"$testName using ShuffledHashJoin") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements.apply(
@ -125,7 +125,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
}
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements.apply(
@ -144,7 +144,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSparkSession {
}
testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements.apply(

View file

@ -132,7 +132,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession {
}
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=left)") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeBroadcastHashJoin(
@ -144,7 +144,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession {
}
testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=right)") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeBroadcastHashJoin(
@ -156,7 +156,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession {
}
testWithWholeStageCodegenOnAndOff(s"$testName using ShuffledHashJoin (build=left)") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeShuffledHashJoin(
@ -168,7 +168,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession {
}
testWithWholeStageCodegenOnAndOff(s"$testName using ShuffledHashJoin (build=right)") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeShuffledHashJoin(
@ -180,7 +180,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSparkSession {
}
testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan),

View file

@ -107,7 +107,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
}
testWithWholeStageCodegenOnAndOff(s"$testName using ShuffledHashJoin") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
@ -127,7 +127,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
case RightOuter => BuildLeft
case _ => fail(s"Unsupported join type $joinType")
}
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashJoinExec(
@ -140,7 +140,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSparkSession {
}
testWithWholeStageCodegenOnAndOff(s"$testName using SortMergeJoin") { _ =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _) =>
extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements.apply(