diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 0c75eda7a4..53bd591d98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} +import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -756,7 +757,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo } -object HashJoin { +object HashJoin extends CastSupport with SQLConfHelper { /** * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. * @@ -771,14 +772,14 @@ object HashJoin { } var keyExpr: Expression = if (keys.head.dataType != LongType) { - Cast(keys.head, LongType) + cast(keys.head, LongType) } else { keys.head } keys.tail.foreach { e => val bits = e.dataType.defaultSize * 8 keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)), - BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1))) + BitwiseAnd(cast(e, LongType), Literal((1L << bits) - 1))) } keyExpr :: Nil } @@ -791,13 +792,13 @@ object HashJoin { // jump over keys that have a higher index value than the required key if (keys.size == 1) { assert(index == 0) - Cast(BoundReference(0, LongType, nullable = false), keys(index).dataType) + cast(BoundReference(0, LongType, nullable = false), keys(index).dataType) } else { val shiftedBits = keys.slice(index + 1, keys.size).map(_.dataType.defaultSize * 8).sum val mask = (1L << (keys(index).dataType.defaultSize * 8)) - 1 // build the schema for unpacking the required key - Cast(BitwiseAnd( + cast(BitwiseAnd( ShiftRightUnsigned(BoundReference(0, LongType, nullable = false), Literal(shiftedBits)), Literal(mask)), keys(index).dataType) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 044e9ace62..98a1089709 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -242,33 +242,40 @@ abstract class BroadcastJoinSuiteBase extends QueryTest with SQLTestUtils assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil) assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil) - assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(i :: Nil) === + Cast(i, LongType, Some(conf.sessionLocalTimeZone)) :: Nil) assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil) assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) === - BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)), - BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil) + BitwiseOr(ShiftLeft(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal(32)), + BitwiseAnd(Cast(i, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 32) - 1))) :: + Nil) assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil) - assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil) + assert(HashJoin.rewriteKeyExpr(s :: Nil) === + Cast(s, LongType, Some(conf.sessionLocalTimeZone)) :: Nil) assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil) assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) === - BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)), + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) :: + Nil) assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) === BitwiseOr(ShiftLeft( - BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)), + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) :: + Nil) assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) === BitwiseOr(ShiftLeft( BitwiseOr(ShiftLeft( - BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + BitwiseOr(ShiftLeft(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal(16)), + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), + Literal((1L << 16) - 1))), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))), + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))), Literal(16)), - BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil) + BitwiseAnd(Cast(s, LongType, Some(conf.sessionLocalTimeZone)), Literal((1L << 16) - 1))) :: + Nil) assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) === s :: s :: s :: s :: s :: Nil)