[SPARK-25084][SQL] "distribute by" on multiple columns (wrap in brackets) may lead to codegen issue
## What changes were proposed in this pull request? "distribute by" on multiple columns (wrap in brackets) may lead to codegen issue. Simple way to reproduce: ```scala val df = spark.range(1000) val columns = (0 until 400).map{ i => s"id as id$i" } val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",") df.selectExpr(columns : _*).createTempView("test") spark.sql(s"select * from test distribute by ($distributeExprs)").count() ``` ## How was this patch tested? Add UT. Closes #22066 from yucai/SPARK-25084. Authored-by: yucai <yyu1@ebay.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
b73eb0efe8
commit
41a7de6002
|
@ -404,14 +404,15 @@ abstract class HashExpression[E] extends Expression {
|
||||||
input: String,
|
input: String,
|
||||||
result: String,
|
result: String,
|
||||||
fields: Array[StructField]): String = {
|
fields: Array[StructField]): String = {
|
||||||
|
val tmpInput = ctx.freshName("input")
|
||||||
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
|
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
|
||||||
nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
|
nullSafeElementHash(tmpInput, index.toString, field.nullable, field.dataType, result, ctx)
|
||||||
}
|
}
|
||||||
val hashResultType = CodeGenerator.javaType(dataType)
|
val hashResultType = CodeGenerator.javaType(dataType)
|
||||||
ctx.splitExpressions(
|
val code = ctx.splitExpressions(
|
||||||
expressions = fieldsHash,
|
expressions = fieldsHash,
|
||||||
funcName = "computeHashForStruct",
|
funcName = "computeHashForStruct",
|
||||||
arguments = Seq("InternalRow" -> input, hashResultType -> result),
|
arguments = Seq("InternalRow" -> tmpInput, hashResultType -> result),
|
||||||
returnType = hashResultType,
|
returnType = hashResultType,
|
||||||
makeSplitFunction = body =>
|
makeSplitFunction = body =>
|
||||||
s"""
|
s"""
|
||||||
|
@ -419,6 +420,10 @@ abstract class HashExpression[E] extends Expression {
|
||||||
|return $result;
|
|return $result;
|
||||||
""".stripMargin,
|
""".stripMargin,
|
||||||
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
|
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
|
||||||
|
s"""
|
||||||
|
|final InternalRow $tmpInput = $input;
|
||||||
|
|$code
|
||||||
|
""".stripMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
@tailrec
|
@tailrec
|
||||||
|
@ -778,10 +783,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
||||||
input: String,
|
input: String,
|
||||||
result: String,
|
result: String,
|
||||||
fields: Array[StructField]): String = {
|
fields: Array[StructField]): String = {
|
||||||
|
val tmpInput = ctx.freshName("input")
|
||||||
val childResult = ctx.freshName("childResult")
|
val childResult = ctx.freshName("childResult")
|
||||||
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
|
val fieldsHash = fields.zipWithIndex.map { case (field, index) =>
|
||||||
val computeFieldHash = nullSafeElementHash(
|
val computeFieldHash = nullSafeElementHash(
|
||||||
input, index.toString, field.nullable, field.dataType, childResult, ctx)
|
tmpInput, index.toString, field.nullable, field.dataType, childResult, ctx)
|
||||||
s"""
|
s"""
|
||||||
|$childResult = 0;
|
|$childResult = 0;
|
||||||
|$computeFieldHash
|
|$computeFieldHash
|
||||||
|
@ -789,10 +795,10 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
s"${CodeGenerator.JAVA_INT} $childResult = 0;\n" + ctx.splitExpressions(
|
val code = ctx.splitExpressions(
|
||||||
expressions = fieldsHash,
|
expressions = fieldsHash,
|
||||||
funcName = "computeHashForStruct",
|
funcName = "computeHashForStruct",
|
||||||
arguments = Seq("InternalRow" -> input, CodeGenerator.JAVA_INT -> result),
|
arguments = Seq("InternalRow" -> tmpInput, CodeGenerator.JAVA_INT -> result),
|
||||||
returnType = CodeGenerator.JAVA_INT,
|
returnType = CodeGenerator.JAVA_INT,
|
||||||
makeSplitFunction = body =>
|
makeSplitFunction = body =>
|
||||||
s"""
|
s"""
|
||||||
|
@ -801,6 +807,11 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
|
||||||
|return $result;
|
|return $result;
|
||||||
""".stripMargin,
|
""".stripMargin,
|
||||||
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
|
foldFunctions = _.map(funcCall => s"$result = $funcCall;").mkString("\n"))
|
||||||
|
s"""
|
||||||
|
|final InternalRow $tmpInput = $input;
|
||||||
|
|${CodeGenerator.JAVA_INT} $childResult = 0;
|
||||||
|
|$code
|
||||||
|
""".stripMargin
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2840,4 +2840,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-25084: 'distribute by' on multiple columns may lead to codegen issue") {
|
||||||
|
withView("spark_25084") {
|
||||||
|
val count = 1000
|
||||||
|
val df = spark.range(count)
|
||||||
|
val columns = (0 until 400).map{ i => s"id as id$i" }
|
||||||
|
val distributeExprs = (0 until 100).map(c => s"id$c").mkString(",")
|
||||||
|
df.selectExpr(columns : _*).createTempView("spark_25084")
|
||||||
|
assert(
|
||||||
|
spark.sql(s"select * from spark_25084 distribute by ($distributeExprs)").count === count)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue