[SPARK-33822][SQL] Use the CastSupport.cast method in HashJoin

### What changes were proposed in this pull request?

This PR intends to fix the bug that throws a unsupported exception when running [the TPCDS q5](https://github.com/apache/spark/blob/master/sql/core/src/test/resources/tpcds/q5.sql) with AQE enabled ([this option is enabled by default now via SPARK-33679](031c5ef280)):
```
java.lang.UnsupportedOperationException: BroadcastExchange does not support the execute() code path.
  at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecute(BroadcastExchangeExec.scala:189)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
  at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
  at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
  at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176)
  at org.apache.spark.sql.execution.exchange.ReusedExchangeExec.doExecute(Exchange.scala:60)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
  at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
  at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
  at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176)
  at org.apache.spark.sql.execution.adaptive.QueryStageExec.doExecute(QueryStageExec.scala:115)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180)
  at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218)
  at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
  at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215)
  at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176)
  at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:321)
  at org.apache.spark.sql.execution.SparkPlan.executeCollectIterator(SparkPlan.scala:397)
  at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.$anonfun$relationFuture$1(BroadcastExchangeExec.scala:118)
  at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$1(SQLExecution.scala:185)
  at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264)
  ...
```

I've checked the AQE code and I found `EnsureRequirements` wrongly puts `BroadcastExchange` on a top of `BroadcastQueryStage` in the `reOptimize` phase as follows:
```
+- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#2183]
  +- BroadcastQueryStage 2
    +- ReusedExchange [d_date_sk#1086], BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#1963]
```
A root cause is that a `Cast` class in a required child's distribution does not have a `timeZoneId` field (`timeZoneId=None`), and a `Cast` class in `child.outputPartitioning` has it. So, this difference can make the distribution requirement check fail in `EnsureRequirements`:
1e85707738/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala (L47-L50)

The `Cast` class that does not have a `timeZoneId` field is generated in the `HashJoin` object. To fix this issue, this PR proposes to use the `CastSupport.cast` method there.

### Why are the changes needed?

Bugfix.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Manually checked that q5 passed.

Closes #30818 from maropu/BugfixInAQE.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
Takeshi Yamamuro 2020-12-17 16:16:05 -08:00 committed by Dongjoon Hyun
parent 15616f499a
commit 51ef4430dc
No known key found for this signature in database
GPG key ID: EDA00CE834F0FC5C
2 changed files with 27 additions and 19 deletions

View file

@ -17,7 +17,8 @@
package org.apache.spark.sql.execution.joins package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
import org.apache.spark.sql.catalyst.analysis.CastSupport
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen._
@ -756,7 +757,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo
} }
object HashJoin { object HashJoin extends CastSupport with SQLConfHelper {
/** /**
* Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
* *
@ -771,14 +772,14 @@ object HashJoin {
} }
var keyExpr: Expression = if (keys.head.dataType != LongType) { var keyExpr: Expression = if (keys.head.dataType != LongType) {
Cast(keys.head, LongType) cast(keys.head, LongType)
} else { } else {
keys.head keys.head
} }
keys.tail.foreach { e => keys.tail.foreach { e =>
val bits = e.dataType.defaultSize * 8 val bits = e.dataType.defaultSize * 8
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) BitwiseAnd(cast(e, LongType), Literal((1L << bits) - 1)))
} }
keyExpr :: Nil keyExpr :: Nil
} }
@ -791,13 +792,13 @@ object HashJoin {
// jump over keys that have a higher index value than the required key // jump over keys that have a higher index value than the required key
if (keys.size == 1) { if (keys.size == 1) {
assert(index == 0) assert(index == 0)
Cast(BoundReference(0, LongType, nullable = false), keys(index).dataType) cast(BoundReference(0, LongType, nullable = false), keys(index).dataType)
} else { } else {
val shiftedBits = val shiftedBits =
keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum
val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1 val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1
// build the schema for unpacking the required key // build the schema for unpacking the required key
Cast(BitwiseAnd( cast(BitwiseAnd(
ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)), ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)),
Literal(mask)), keys(index).dataType) Literal(mask)), keys(index).dataType)
} }

View file

@ -242,33 +242,40 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil) assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil)
assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil) assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil) assert(HashJoin.rewriteKeyExpr(i :: Nil) ===
Cast(i, LongType, Some(conf.sessionLocalTimeZone)) :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil) assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil)
assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) === assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) ===
BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)), BitwiseOr(ShiftLeft(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal(32)),
BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil) BitwiseAnd(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 32) - 1))) ::
Nil)
assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil) assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil) assert(HashJoin.rewriteKeyExpr(s :: Nil) ===
Cast(s, LongType, Some(conf.sessionLocalTimeZone)) :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil) assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) === assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) ===
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) === assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) ===
BitwiseOr(ShiftLeft( BitwiseOr(ShiftLeft(
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))),
Literal(16)), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) === assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) ===
BitwiseOr(ShiftLeft( BitwiseOr(ShiftLeft(
BitwiseOr(ShiftLeft( BitwiseOr(ShiftLeft(
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)),
Literal((1L << 16) - 1))),
Literal(16)), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))),
Literal(16)), Literal(16)),
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
Nil)
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) === assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) ===
s :: s :: s :: s :: s :: Nil) s :: s :: s :: s :: s :: Nil)