[SPARK-25084][SQL] "distribute by" on multiple columns (wrap in brackets) may lead to codegen issue

## What changes were proposed in this pull request?

"distribute by" on multiple columns (wrap in brackets) may lead to codegen issue.

Simple way to reproduce:
```scala
  val df = spark.range(1000)
  val columns = (0 until 400).map{ i => s"id as id$i" }
  val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",")
  df.selectExpr(columns : _*).createTempView("test")
  spark.sql(s"select * from test distribute by ($distributeExprs)").count()
```

## How was this patch tested?

Add UT.

Closes #22066 from yucai/SPARK-25084.

Authored-by: yucai <yyu1@ebay.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
yucai 2018-08-11 21:38:31 +08:00 committed by Wenchen Fan
parent b73eb0efe8
commit 41a7de6002
2 changed files with 29 additions and 6 deletions

View file

@ -404,14 +404,15 @@ abstract class HashExpression[E] extends Expression {
input: String, input: String,
result: String, result: String,
fields: Array[StructField]): String = { fields: Array[StructField]): String = {
val tmpInput = ctx.freshName("input")
val fieldsHash = fields.zipWithIndex.map { case (field, index) => val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx)
} }
val hashResultType = CodeGenerator.javaType(dataType) val hashResultType = CodeGenerator.javaType(dataType)
ctx.splitExpressions( val code = ctx.splitExpressions(
expressions = fieldsHash, expressions = fieldsHash,
funcName = "computeHashForStruct", funcName = "computeHashForStruct",
arguments = Seq("InternalRow" -> input, hashResultType -> result), arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result),
returnType = hashResultType, returnType = hashResultType,
makeSplitFunction = body => makeSplitFunction = body =>
s""" s"""
@ -419,6 +420,10 @@ abstract class HashExpression[E] extends Expression {
|return $result; |return $result;
""".stripMargin, """.stripMargin,
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
s"""
|final InternalRow $tmpInput = $input;
|$code
""".stripMargin
} }
@tailrec @tailrec
@ -778,10 +783,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
input: String, input: String,
result: String, result: String,
fields: Array[StructField]): String = { fields: Array[StructField]): String = {
val tmpInput = ctx.freshName("input")
val childResult = ctx.freshName("childResult") val childResult = ctx.freshName("childResult")
val fieldsHash = fields.zipWithIndex.map { case (field, index) => val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
val computeFieldHash = nullSafeElementHash( val computeFieldHash = nullSafeElementHash(
input, index.toString, field.nullable, field.dataType, childResult, ctx) tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx)
s""" s"""
|$childResult = 0; |$childResult = 0;
|$computeFieldHash |$computeFieldHash
@ -789,10 +795,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
""".stripMargin """.stripMargin
} }
s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions( val code = ctx.splitExpressions(
expressions = fieldsHash, expressions = fieldsHash,
funcName = "computeHashForStruct", funcName = "computeHashForStruct",
arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result), arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result),
returnType = CodeGenerator.JAVA_INT, returnType = CodeGenerator.JAVA_INT,
makeSplitFunction = body => makeSplitFunction = body =>
s""" s"""
@ -801,6 +807,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|return $result; |return $result;
""".stripMargin, """.stripMargin,
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n")) foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
s"""
|final InternalRow $tmpInput = $input;
|${CodeGenerator.JAVA_INT} $childResult = 0;
|$code
""".stripMargin
} }
} }

View file

@ -2840,4 +2840,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
} }
} }
} }
test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen issue") {
withView("spark_25084") {
val count = 1000
val df = spark.range(count)
val columns = (0 until 400).map{ i => s"id as id$i" }
val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",")
df.selectExpr(columns : _*).createTempView("spark_25084")
assert(
spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count)
}
}
} }