[SPARK-33822][SQL] Use the CastSupport.cast
method in HashJoin
### What changes were proposed in this pull request? This PR intends to fix the bug that throws a unsupported exception when running [the TPCDS q5](https://github.com/apache/spark/blob/master/sql/core/src/test/resources/tpcds/q5.sql) with AQE enabled ([this option is enabled by default now via SPARK-33679](031c5ef280
)): ``` java.lang.UnsupportedOperationException: BroadcastExchange does not support the execute() code path. at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.doExecute(BroadcastExchangeExec.scala:189) at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176) at org.apache.spark.sql.execution.exchange.ReusedExchangeExec.doExecute(Exchange.scala:60) at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176) at org.apache.spark.sql.execution.adaptive.QueryStageExec.doExecute(QueryStageExec.scala:115) at org.apache.spark.sql.execution.SparkPlan.$anonfun$execute$1(SparkPlan.scala:180) at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeQuery$1(SparkPlan.scala:218) at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151) at org.apache.spark.sql.execution.SparkPlan.executeQuery(SparkPlan.scala:215) at org.apache.spark.sql.execution.SparkPlan.execute(SparkPlan.scala:176) at org.apache.spark.sql.execution.SparkPlan.getByteArrayRdd(SparkPlan.scala:321) at org.apache.spark.sql.execution.SparkPlan.executeCollectIterator(SparkPlan.scala:397) at org.apache.spark.sql.execution.exchange.BroadcastExchangeExec.$anonfun$relationFuture$1(BroadcastExchangeExec.scala:118) at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withThreadLocalCaptured$1(SQLExecution.scala:185) at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264) ... ``` I've checked the AQE code and I found `EnsureRequirements` wrongly puts `BroadcastExchange` on a top of `BroadcastQueryStage` in the `reOptimize` phase as follows: ``` +- BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#2183] +- BroadcastQueryStage 2 +- ReusedExchange [d_date_sk#1086], BroadcastExchange HashedRelationBroadcastMode(List(cast(input[0, int, true] as bigint)),false), [id=#1963] ``` A root cause is that a `Cast` class in a required child's distribution does not have a `timeZoneId` field (`timeZoneId=None`), and a `Cast` class in `child.outputPartitioning` has it. So, this difference can make the distribution requirement check fail in `EnsureRequirements`:1e85707738/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala (L47-L50)
The `Cast` class that does not have a `timeZoneId` field is generated in the `HashJoin` object. To fix this issue, this PR proposes to use the `CastSupport.cast` method there. ### Why are the changes needed? Bugfix. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Manually checked that q5 passed. Closes #30818 from maropu/BugfixInAQE. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
This commit is contained in:
parent
15616f499a
commit
51ef4430dc
|
@ -17,7 +17,8 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.execution.joins
|
package org.apache.spark.sql.execution.joins
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper}
|
||||||
|
import org.apache.spark.sql.catalyst.analysis.CastSupport
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
|
import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
|
@ -756,7 +757,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
|
||||||
protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo
|
protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
object HashJoin {
|
object HashJoin extends CastSupport with SQLConfHelper {
|
||||||
/**
|
/**
|
||||||
* Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
|
* Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
|
||||||
*
|
*
|
||||||
|
@ -771,14 +772,14 @@ object HashJoin {
|
||||||
}
|
}
|
||||||
|
|
||||||
var keyExpr: Expression = if (keys.head.dataType != LongType) {
|
var keyExpr: Expression = if (keys.head.dataType != LongType) {
|
||||||
Cast(keys.head, LongType)
|
cast(keys.head, LongType)
|
||||||
} else {
|
} else {
|
||||||
keys.head
|
keys.head
|
||||||
}
|
}
|
||||||
keys.tail.foreach { e =>
|
keys.tail.foreach { e =>
|
||||||
val bits = e.dataType.defaultSize * 8
|
val bits = e.dataType.defaultSize * 8
|
||||||
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
|
keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
|
||||||
BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
|
BitwiseAnd(cast(e, LongType), Literal((1L << bits) - 1)))
|
||||||
}
|
}
|
||||||
keyExpr :: Nil
|
keyExpr :: Nil
|
||||||
}
|
}
|
||||||
|
@ -791,13 +792,13 @@ object HashJoin {
|
||||||
// jump over keys that have a higher index value than the required key
|
// jump over keys that have a higher index value than the required key
|
||||||
if (keys.size == 1) {
|
if (keys.size == 1) {
|
||||||
assert(index == 0)
|
assert(index == 0)
|
||||||
Cast(BoundReference(0, LongType, nullable = false), keys(index).dataType)
|
cast(BoundReference(0, LongType, nullable = false), keys(index).dataType)
|
||||||
} else {
|
} else {
|
||||||
val shiftedBits =
|
val shiftedBits =
|
||||||
keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum
|
keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum
|
||||||
val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1
|
val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1
|
||||||
// build the schema for unpacking the required key
|
// build the schema for unpacking the required key
|
||||||
Cast(BitwiseAnd(
|
cast(BitwiseAnd(
|
||||||
ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)),
|
ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)),
|
||||||
Literal(mask)), keys(index).dataType)
|
Literal(mask)), keys(index).dataType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -242,33 +242,40 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils
|
||||||
assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil)
|
assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil)
|
||||||
assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil)
|
assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil)
|
||||||
|
|
||||||
assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil)
|
assert(HashJoin.rewriteKeyExpr(i :: Nil) ===
|
||||||
|
Cast(i, LongType, Some(conf.sessionLocalTimeZone)) :: Nil)
|
||||||
assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil)
|
assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil)
|
||||||
assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) ===
|
assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) ===
|
||||||
BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)),
|
BitwiseOr(ShiftLeft(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal(32)),
|
||||||
BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil)
|
BitwiseAnd(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 32) - 1))) ::
|
||||||
|
Nil)
|
||||||
assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil)
|
assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil)
|
||||||
|
|
||||||
assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil)
|
assert(HashJoin.rewriteKeyExpr(s :: Nil) ===
|
||||||
|
Cast(s, LongType, Some(conf.sessionLocalTimeZone)) :: Nil)
|
||||||
assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil)
|
assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil)
|
||||||
assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) ===
|
assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) ===
|
||||||
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
|
BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
|
||||||
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
|
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
|
||||||
|
Nil)
|
||||||
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) ===
|
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) ===
|
||||||
BitwiseOr(ShiftLeft(
|
BitwiseOr(ShiftLeft(
|
||||||
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
|
BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
|
||||||
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
|
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))),
|
||||||
Literal(16)),
|
Literal(16)),
|
||||||
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
|
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
|
||||||
|
Nil)
|
||||||
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) ===
|
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) ===
|
||||||
BitwiseOr(ShiftLeft(
|
BitwiseOr(ShiftLeft(
|
||||||
BitwiseOr(ShiftLeft(
|
BitwiseOr(ShiftLeft(
|
||||||
BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
|
BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)),
|
||||||
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
|
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)),
|
||||||
|
Literal((1L << 16) - 1))),
|
||||||
Literal(16)),
|
Literal(16)),
|
||||||
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
|
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))),
|
||||||
Literal(16)),
|
Literal(16)),
|
||||||
BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
|
BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) ::
|
||||||
|
Nil)
|
||||||
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) ===
|
assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) ===
|
||||||
s :: s :: s :: s :: s :: Nil)
|
s :: s :: s :: s :: s :: Nil)
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue