[SPARK-14793] [SQL] Code generation for large complex type exceeds JVM size limit.
## What changes were proposed in this pull request? Code generation for complex type, `CreateArray`, `CreateMap`, `CreateStruct`, `CreateNamedStruct`, exceeds JVM size limit for large elements. We should split generated code into multiple `apply` functions if the complex types have large elements, like `UnsafeProjection` or others for large expressions. ## How was this patch tested? I added some tests to check if the generated codes for the expressions exceed or not. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #12559 from ueshin/issues/SPARK-14793.
This commit is contained in:
parent
df1953f0df
commit
f1fdb23821
|
@ -68,6 +68,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
|
||||||
this.$values = new Object[${schema.length}];
|
this.$values = new Object[${schema.length}];
|
||||||
$allFields
|
$allFields
|
||||||
final InternalRow $output = new $rowClass($values);
|
final InternalRow $output = new $rowClass($values);
|
||||||
|
this.$values = null;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ExprCode(code, "false", output)
|
ExprCode(code, "false", output)
|
||||||
|
|
|
@ -51,9 +51,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
|
||||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
val arrayClass = classOf[GenericArrayData].getName
|
val arrayClass = classOf[GenericArrayData].getName
|
||||||
val values = ctx.freshName("values")
|
val values = ctx.freshName("values")
|
||||||
|
ctx.addMutableState("Object[]", values, s"this.$values = null;")
|
||||||
|
|
||||||
ev.copy(code = s"""
|
ev.copy(code = s"""
|
||||||
final boolean ${ev.isNull} = false;
|
final boolean ${ev.isNull} = false;
|
||||||
final Object[] $values = new Object[${children.size}];""" +
|
this.$values = new Object[${children.size}];""" +
|
||||||
|
ctx.splitExpressions(
|
||||||
|
ctx.INPUT_ROW,
|
||||||
children.zipWithIndex.map { case (e, i) =>
|
children.zipWithIndex.map { case (e, i) =>
|
||||||
val eval = e.genCode(ctx)
|
val eval = e.genCode(ctx)
|
||||||
eval.code + s"""
|
eval.code + s"""
|
||||||
|
@ -63,8 +67,11 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
|
||||||
$values[$i] = ${eval.value};
|
$values[$i] = ${eval.value};
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
}.mkString("\n") +
|
}) +
|
||||||
s"final ArrayData ${ev.value} = new $arrayClass($values);")
|
s"""
|
||||||
|
final ArrayData ${ev.value} = new $arrayClass($values);
|
||||||
|
this.$values = null;
|
||||||
|
""")
|
||||||
}
|
}
|
||||||
|
|
||||||
override def prettyName: String = "array"
|
override def prettyName: String = "array"
|
||||||
|
@ -119,12 +126,17 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
|
||||||
val mapClass = classOf[ArrayBasedMapData].getName
|
val mapClass = classOf[ArrayBasedMapData].getName
|
||||||
val keyArray = ctx.freshName("keyArray")
|
val keyArray = ctx.freshName("keyArray")
|
||||||
val valueArray = ctx.freshName("valueArray")
|
val valueArray = ctx.freshName("valueArray")
|
||||||
|
ctx.addMutableState("Object[]", keyArray, s"this.$keyArray = null;")
|
||||||
|
ctx.addMutableState("Object[]", valueArray, s"this.$valueArray = null;")
|
||||||
|
|
||||||
val keyData = s"new $arrayClass($keyArray)"
|
val keyData = s"new $arrayClass($keyArray)"
|
||||||
val valueData = s"new $arrayClass($valueArray)"
|
val valueData = s"new $arrayClass($valueArray)"
|
||||||
ev.copy(code = s"""
|
ev.copy(code = s"""
|
||||||
final boolean ${ev.isNull} = false;
|
final boolean ${ev.isNull} = false;
|
||||||
final Object[] $keyArray = new Object[${keys.size}];
|
$keyArray = new Object[${keys.size}];
|
||||||
final Object[] $valueArray = new Object[${values.size}];""" +
|
$valueArray = new Object[${values.size}];""" +
|
||||||
|
ctx.splitExpressions(
|
||||||
|
ctx.INPUT_ROW,
|
||||||
keys.zipWithIndex.map { case (key, i) =>
|
keys.zipWithIndex.map { case (key, i) =>
|
||||||
val eval = key.genCode(ctx)
|
val eval = key.genCode(ctx)
|
||||||
s"""
|
s"""
|
||||||
|
@ -135,8 +147,10 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
|
||||||
$keyArray[$i] = ${eval.value};
|
$keyArray[$i] = ${eval.value};
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
}.mkString("\n") + values.zipWithIndex.map {
|
}) +
|
||||||
case (value, i) =>
|
ctx.splitExpressions(
|
||||||
|
ctx.INPUT_ROW,
|
||||||
|
values.zipWithIndex.map { case (value, i) =>
|
||||||
val eval = value.genCode(ctx)
|
val eval = value.genCode(ctx)
|
||||||
s"""
|
s"""
|
||||||
${eval.code}
|
${eval.code}
|
||||||
|
@ -146,7 +160,12 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
|
||||||
$valueArray[$i] = ${eval.value};
|
$valueArray[$i] = ${eval.value};
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
}.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);")
|
}) +
|
||||||
|
s"""
|
||||||
|
final MapData ${ev.value} = new $mapClass($keyData, $valueData);
|
||||||
|
this.$keyArray = null;
|
||||||
|
this.$valueArray = null;
|
||||||
|
""")
|
||||||
}
|
}
|
||||||
|
|
||||||
override def prettyName: String = "map"
|
override def prettyName: String = "map"
|
||||||
|
@ -182,9 +201,13 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
|
||||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
val rowClass = classOf[GenericInternalRow].getName
|
val rowClass = classOf[GenericInternalRow].getName
|
||||||
val values = ctx.freshName("values")
|
val values = ctx.freshName("values")
|
||||||
|
ctx.addMutableState("Object[]", values, s"this.$values = null;")
|
||||||
|
|
||||||
ev.copy(code = s"""
|
ev.copy(code = s"""
|
||||||
boolean ${ev.isNull} = false;
|
boolean ${ev.isNull} = false;
|
||||||
final Object[] $values = new Object[${children.size}];""" +
|
this.$values = new Object[${children.size}];""" +
|
||||||
|
ctx.splitExpressions(
|
||||||
|
ctx.INPUT_ROW,
|
||||||
children.zipWithIndex.map { case (e, i) =>
|
children.zipWithIndex.map { case (e, i) =>
|
||||||
val eval = e.genCode(ctx)
|
val eval = e.genCode(ctx)
|
||||||
eval.code + s"""
|
eval.code + s"""
|
||||||
|
@ -193,8 +216,11 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
|
||||||
} else {
|
} else {
|
||||||
$values[$i] = ${eval.value};
|
$values[$i] = ${eval.value};
|
||||||
}"""
|
}"""
|
||||||
}.mkString("\n") +
|
}) +
|
||||||
s"final InternalRow ${ev.value} = new $rowClass($values);")
|
s"""
|
||||||
|
final InternalRow ${ev.value} = new $rowClass($values);
|
||||||
|
this.$values = null;
|
||||||
|
""")
|
||||||
}
|
}
|
||||||
|
|
||||||
override def prettyName: String = "struct"
|
override def prettyName: String = "struct"
|
||||||
|
@ -261,9 +287,13 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
|
||||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
val rowClass = classOf[GenericInternalRow].getName
|
val rowClass = classOf[GenericInternalRow].getName
|
||||||
val values = ctx.freshName("values")
|
val values = ctx.freshName("values")
|
||||||
|
ctx.addMutableState("Object[]", values, s"this.$values = null;")
|
||||||
|
|
||||||
ev.copy(code = s"""
|
ev.copy(code = s"""
|
||||||
boolean ${ev.isNull} = false;
|
boolean ${ev.isNull} = false;
|
||||||
final Object[] $values = new Object[${valExprs.size}];""" +
|
$values = new Object[${valExprs.size}];""" +
|
||||||
|
ctx.splitExpressions(
|
||||||
|
ctx.INPUT_ROW,
|
||||||
valExprs.zipWithIndex.map { case (e, i) =>
|
valExprs.zipWithIndex.map { case (e, i) =>
|
||||||
val eval = e.genCode(ctx)
|
val eval = e.genCode(ctx)
|
||||||
eval.code + s"""
|
eval.code + s"""
|
||||||
|
@ -272,8 +302,11 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
|
||||||
} else {
|
} else {
|
||||||
$values[$i] = ${eval.value};
|
$values[$i] = ${eval.value};
|
||||||
}"""
|
}"""
|
||||||
}.mkString("\n") +
|
}) +
|
||||||
s"final InternalRow ${ev.value} = new $rowClass($values);")
|
s"""
|
||||||
|
final InternalRow ${ev.value} = new $rowClass($values);
|
||||||
|
this.$values = null;
|
||||||
|
""")
|
||||||
}
|
}
|
||||||
|
|
||||||
override def prettyName: String = "named_struct"
|
override def prettyName: String = "named_struct"
|
||||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.Row
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
|
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.unsafe.types.UTF8String
|
import org.apache.spark.unsafe.types.UTF8String
|
||||||
import org.apache.spark.util.ThreadUtils
|
import org.apache.spark.util.ThreadUtils
|
||||||
|
@ -80,6 +81,62 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||||
assert(actual(0) == cases)
|
assert(actual(0) == cases)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-14793: split wide array creation into blocks due to JVM code size limit") {
|
||||||
|
val length = 5000
|
||||||
|
val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1)))))
|
||||||
|
val plan = GenerateMutableProjection.generate(expressions)
|
||||||
|
val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
|
||||||
|
val expected = Seq(new GenericArrayData(Seq.fill(length)(true)))
|
||||||
|
|
||||||
|
if (!checkResult(actual, expected)) {
|
||||||
|
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-14793: split wide map creation into blocks due to JVM code size limit") {
|
||||||
|
val length = 5000
|
||||||
|
val expressions = Seq(CreateMap(
|
||||||
|
List.fill(length)(EqualTo(Literal(1), Literal(1))).zipWithIndex.flatMap {
|
||||||
|
case (expr, i) => Seq(Literal(i), expr)
|
||||||
|
}))
|
||||||
|
val plan = GenerateMutableProjection.generate(expressions)
|
||||||
|
val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
|
||||||
|
val expected = Seq(new ArrayBasedMapData(
|
||||||
|
new GenericArrayData(0 until length),
|
||||||
|
new GenericArrayData(Seq.fill(length)(true))))
|
||||||
|
|
||||||
|
if (!checkResult(actual, expected)) {
|
||||||
|
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-14793: split wide struct creation into blocks due to JVM code size limit") {
|
||||||
|
val length = 5000
|
||||||
|
val expressions = Seq(CreateStruct(List.fill(length)(EqualTo(Literal(1), Literal(1)))))
|
||||||
|
val plan = GenerateMutableProjection.generate(expressions)
|
||||||
|
val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
|
||||||
|
val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
|
||||||
|
|
||||||
|
if (!checkResult(actual, expected)) {
|
||||||
|
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-14793: split wide named struct creation into blocks due to JVM code size limit") {
|
||||||
|
val length = 5000
|
||||||
|
val expressions = Seq(CreateNamedStruct(
|
||||||
|
List.fill(length)(EqualTo(Literal(1), Literal(1))).flatMap {
|
||||||
|
expr => Seq(Literal(expr.toString), expr)
|
||||||
|
}))
|
||||||
|
val plan = GenerateMutableProjection.generate(expressions)
|
||||||
|
val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
|
||||||
|
val expected = Seq(InternalRow(Seq.fill(length)(true): _*))
|
||||||
|
|
||||||
|
if (!checkResult(actual, expected)) {
|
||||||
|
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test("test generated safe and unsafe projection") {
|
test("test generated safe and unsafe projection") {
|
||||||
val schema = new StructType(Array(
|
val schema = new StructType(Array(
|
||||||
StructField("a", StringType, true),
|
StructField("a", StringType, true),
|
||||||
|
|
Loading…
Reference in a new issue