[SPARK-25084][SQL] "distribute by" on multiple columns (wrap in brackets) may lead to codegen issue
## What changes were proposed in this pull request? "distribute by" on multiple columns (wrap in brackets) may lead to codegen issue. Simple way to reproduce: ```scala val df = spark.range(1000) val columns = (0 until 400).map{ i => s"id as id$i" } val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") df.selectExpr(columns : _*).createTempView("test") spark.sql(s"select * from test distribute by ($distributeExprs)").count() ``` ## How was this patch tested? Add UT. Closes #22066 from yucai/SPARK-25084. Authored-by: yucai <yyu1@ebay.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
b73eb0efe8
commit
41a7de6002
|
@ -404,14 +404,15 @@ abstract class HashExpression[E] extends Expression {
|
|||
input: String,
|
||||
result: String,
|
||||
fields: Array[StructField]): String = {
|
||||
val tmpInput = ctx.freshName("input")
|
||||
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
|
||||
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
|
||||
nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx)
|
||||
}
|
||||
val hashResultType = CodeGenerator.javaType(dataType)
|
||||
ctx.splitExpressions(
|
||||
val code = ctx.splitExpressions(
|
||||
expressions = fieldsHash,
|
||||
funcName = "computeHashForStruct",
|
||||
arguments = Seq("InternalRow" -> input, hashResultType -> result),
|
||||
arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result),
|
||||
returnType = hashResultType,
|
||||
makeSplitFunction = body =>
|
||||
s"""
|
||||
|
@ -419,6 +420,10 @@ abstract class HashExpression[E] extends Expression {
|
|||
|return $result;
|
||||
""".stripMargin,
|
||||
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
|
||||
s"""
|
||||
|final InternalRow $tmpInput = $input;
|
||||
|$code
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
@tailrec
|
||||
|
@ -778,10 +783,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
|||
input: String,
|
||||
result: String,
|
||||
fields: Array[StructField]): String = {
|
||||
val tmpInput = ctx.freshName("input")
|
||||
val childResult = ctx.freshName("childResult")
|
||||
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
|
||||
val computeFieldHash = nullSafeElementHash(
|
||||
input, index.toString, field.nullable, field.dataType, childResult, ctx)
|
||||
tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx)
|
||||
s"""
|
||||
|$childResult = 0;
|
||||
|$computeFieldHash
|
||||
|
@ -789,10 +795,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
|||
""".stripMargin
|
||||
}
|
||||
|
||||
s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
|
||||
val code = ctx.splitExpressions(
|
||||
expressions = fieldsHash,
|
||||
funcName = "computeHashForStruct",
|
||||
arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result),
|
||||
arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result),
|
||||
returnType = CodeGenerator.JAVA_INT,
|
||||
makeSplitFunction = body =>
|
||||
s"""
|
||||
|
@ -801,6 +807,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
|||
|return $result;
|
||||
""".stripMargin,
|
||||
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
|
||||
s"""
|
||||
|final InternalRow $tmpInput = $input;
|
||||
|${CodeGenerator.JAVA_INT} $childResult = 0;
|
||||
|$code
|
||||
""".stripMargin
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -2840,4 +2840,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen issue") {
|
||||
withView("spark_25084") {
|
||||
val count = 1000
|
||||
val df = spark.range(count)
|
||||
val columns = (0 until 400).map{ i => s"id as id$i" }
|
||||
val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",")
|
||||
df.selectExpr(columns : _*).createTempView("spark_25084")
|
||||
assert(
|
||||
spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue