diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index cec00b66f8..a754e87a17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -404,14 +404,15 @@ abstract class HashExpression[E] extends Expression { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val fieldsHash = fields.zipWithIndex.map { case (field, index) => - nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) + nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx) } val hashResultType = CodeGenerator.javaType(dataType) - ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, hashResultType -> result), + arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result), returnType = hashResultType, makeSplitFunction = body => s""" @@ -419,6 +420,10 @@ abstract class HashExpression[E] extends Expression { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |$code + """.stripMargin } @tailrec @@ -778,10 +783,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { input: String, result: String, fields: Array[StructField]): String = { + val tmpInput = ctx.freshName("input") val childResult = ctx.freshName("childResult") val fieldsHash = fields.zipWithIndex.map { case (field, index) => val computeFieldHash = nullSafeElementHash( - input, index.toString, field.nullable, field.dataType, childResult, ctx) + tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx) s""" |$childResult = 0; |$computeFieldHash @@ -789,10 +795,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { """.stripMargin } - s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( + val code = ctx.splitExpressions( expressions = fieldsHash, funcName = "computeHashForStruct", - arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), + arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result), returnType = CodeGenerator.JAVA_INT, makeSplitFunction = body => s""" @@ -801,6 +807,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { |return $result; """.stripMargin, foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) + s""" + |final InternalRow $tmpInput = $input; + |${CodeGenerator.JAVA_INT} $childResult = 0; + |$code + """.stripMargin } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c1a5f50fd8..84efd2b7a1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2840,4 +2840,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } } + + test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen issue") { + withView("spark_25084") { + val count = 1000 + val df = spark.range(count) + val columns = (0 until 400).map{ i => s"id as id$i" } + val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") + df.selectExpr(columns : _*).createTempView("spark_25084") + assert( + spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count) + } + } }