[SPARK-33822][SQL] Use the CastSupport.cast method in HashJoin

### What changes were proposed in this pull request?

This PR intends to fix the bug that throws a unsupported exception when running [the TPCDS q5](https://github.com/apache/spark/blob/master/sql/core/src/test/resources/tpcds/q5.sql) with AQE enabled ([this option is enabled by default now via SPARK-33679](031c5ef280)):
```
java.lang.UnsupportedOperationException: BroadcastExchange does not support the execute() code path.
  at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecute(BroadcastExchangeExec.scala:189)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
  at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
  at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
  at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176)
  at org.apache.spark.sql.execution.exchange.ReusedExchangeExec.doExecute(Exchange.scala:60)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
  at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
  at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
  at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176)
  at org.apache.spark.sql.execution.adaptive.QueryStageExec.doExecute(QueryStageExec.scala:115)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
  at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
  at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
  at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176)
  at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:321)
  at org.apache.spark.sql.execution.SparkPlan.executeCollectIterator(SparkPlan.scala:397)
  at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.$anonfun$relationFuture$1(BroadcastExchangeExec.scala:118)
  at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$1(SQLExecution.scala:185)
  at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
  ...
```

I've checked the AQE code and I found `EnsureRequirements` wrongly puts `BroadcastExchange` on a top of `BroadcastQueryStage` in the `reOptimize` phase as follows:
```
+- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#2183]
  +- BroadcastQueryStage 2
    +- ReusedExchange [d_date_sk#1086], BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#1963]
```
A root cause is that a `Cast` class in a required child's distribution does not have a `timeZoneId` field (`timeZoneId=None`), and a `Cast` class in `child.outputPartitioning` has it. So, this difference can make the distribution requirement check fail in `EnsureRequirements`:
1e85707738/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala (L47-L50)

The `Cast` class that does not have a `timeZoneId` field is generated in the `HashJoin` object. To fix this issue, this PR proposes to use the `CastSupport.cast` method there.

### Why are the changes needed?

Bugfix.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Manually checked that q5 passed.

Closes #30818 from maropu/BugfixInAQE.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
Takeshi Yamamuro 2020-12-17 16:16:05 -08:00 committed by Dongjoon Hyun
parent 15616f499a
commit 51ef4430dc
No known key found for this signature in database
GPG key ID: EDA00CE834F0FC5C
2 changed files with 27 additions and 19 deletions

View file

@ -17,7 +17,8 @@
package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._
@ -756,7 +757,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo
}
object HashJoin {
object HashJoin extends CastSupport with SQLConfHelper {
/**
* Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
*
@ -771,14 +772,14 @@ object HashJoin {
}
var keyExpr: Expression = if (keys.head.dataType != LongType) {
Cast(keys.head, LongType)
cast(keys.head, LongType)
} else {
keys.head
}
keys.tail.foreach { e =>
val bits = e.dataType.defaultSize * 8
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
BitwiseAnd(cast(e, LongType), Literal((1L << bits) - 1)))
}
keyExpr :: Nil
}
@ -791,13 +792,13 @@ object HashJoin {
// jump over keys that have a higher index value than the required key
if (keys.size == 1) {
assert(index == 0)
Cast(BoundReference(0, LongType, nullable = false), keys(index).dataType)
cast(BoundReference(0, LongType, nullable = false), keys(index).dataType)
} else {
val shiftedBits =
keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum
val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1
// build the schema for unpacking the required key
Cast(BitwiseAnd(
cast(BitwiseAnd(
ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)),
Literal(mask)), keys(index).dataType)
}

View file

@ -242,33 +242,40 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil)
assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: Nil) ===
Cast(i, LongType, Some(conf.sessionLocalTimeZone)) :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) ===
BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)),
BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil)
BitwiseOr(ShiftLeft(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal(32)),
BitwiseAnd(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 32) - 1))) ::
Nil)
assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: Nil) ===
Cast(s, LongType, Some(conf.sessionLocalTimeZone)) :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) ===
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) ===
BitwiseOr(ShiftLeft(
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))),
Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) ===
BitwiseOr(ShiftLeft(
BitwiseOr(ShiftLeft(
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)),
Literal((1L << 16) - 1))),
Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))),
Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) ===
s :: s :: s :: s :: s :: Nil)