[SPARK-24871][SQL] Refactor Concat and MapConcat to avoid creating concatenator object for each row.

## What changes were proposed in this pull request?

Refactor `Concat` and `MapConcat` to:

- avoid creating concatenator object for each row.
- make `Concat` handle `containsNull` properly.
- make `Concat` shortcut if `null` child is found.

## How was this patch tested?

Added some tests and existing tests.

Author: Takuya UESHIN <ueshin@databricks.com>

Closes #21824 from ueshin/issues/SPARK-24871/refactor_concat_mapconcat.
This commit is contained in:
Takuya UESHIN 2018-07-20 20:08:42 +08:00 committed by Wenchen Fan
parent 0ab07b357b
commit 7b6d36bc9e
2 changed files with 195 additions and 125 deletions

View file

@ -571,16 +571,25 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|$mapDataClass ${ev.value} = null;
""".stripMargin
val assignments = mapCodes.zipWithIndex.map { case (m, i) =>
s"""
|if (!$hasNullName) {
| ${m.code}
| $argsName[$i] = ${m.value};
| if (${m.isNull}) {
| $hasNullName = true;
| }
|}
""".stripMargin
val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map {
case ((m, true), i) =>
s"""
|if (!$hasNullName) {
| ${m.code}
| if (!${m.isNull}) {
| $argsName[$i] = ${m.value};
| } else {
| $hasNullName = true;
| }
|}
""".stripMargin
case ((m, false), i) =>
s"""
|if (!$hasNullName) {
| ${m.code}
| $argsName[$i] = ${m.value};
|}
""".stripMargin
}
val codes = ctx.splitExpressionsWithCurrentInputs(
@ -601,17 +610,21 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
val finKeysName = ctx.freshName("finalKeys")
val finValsName = ctx.freshName("finalValues")
val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) {
val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) {
genCodeForPrimitiveArrays(ctx, keyType, false)
} else {
genCodeForNonPrimitiveArrays(ctx, keyType)
}
val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) {
genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
} else {
genCodeForNonPrimitiveArrays(ctx, valueType)
}
val valueConcat =
if (valueType.sameType(keyType) &&
!(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) {
keyConcat
} else if (CodeGenerator.isPrimitiveType(valueType)) {
genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
} else {
genCodeForNonPrimitiveArrays(ctx, valueType)
}
val keyArgsName = ctx.freshName("keyArgs")
val valArgsName = ctx.freshName("valArgs")
@ -633,9 +646,9 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
| $numElementsName + " elements due to exceeding the map size limit " +
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
| }
| $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName,
| $arrayDataClass $finKeysName = $keyConcat($keyArgsName,
| (int) $numElementsName);
| $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName,
| $arrayDataClass $finValsName = $valueConcat($valArgsName,
| (int) $numElementsName);
| ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName);
|}
@ -677,20 +690,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
setterCode1
}
s"""
|new Object() {
| public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < $argsName[y].numElements(); z++) {
| $setterCode
| $counter++;
| }
| }
| return $arrayData;
| }
|}""".stripMargin.stripPrefix("\n")
val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < $argsName[y].numElements(); z++) {
| $setterCode
| $counter++;
| }
| }
| return $arrayData;
|}
""".stripMargin
ctx.addNewFunction(concat, concatDef)
}
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
@ -700,20 +716,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
val argsName = ctx.freshName("args")
val numElemName = ctx.freshName("numElements")
s"""
|new Object() {
| public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {;
| Object[] $arrayData = new Object[$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < $argsName[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
| }
|}""".stripMargin.stripPrefix("\n")
val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
| Object[] $arrayData = new Object[$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < $argsName[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
|}
""".stripMargin
ctx.addNewFunction(concat, concatDef)
}
override def prettyName: String = "map_concat"
@ -2270,39 +2289,67 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val args = ctx.freshName("args")
val hasNull = ctx.freshName("hasNull")
val inputs = evals.zipWithIndex.map { case (eval, index) =>
s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
}
"""
val inputs = evals.zip(children.map(_.nullable)).zipWithIndex.map {
case ((eval, true), index) =>
s"""
|if (!$hasNull) {
| ${eval.code}
| if (!${eval.isNull}) {
| $args[$index] = ${eval.value};
| } else {
| $hasNull = true;
| }
|}
""".stripMargin
case ((eval, false), index) =>
s"""
|if (!$hasNull) {
| ${eval.code}
| $args[$index] = ${eval.value};
|}
""".stripMargin
}
val (concatenator, initCode) = dataType match {
case BinaryType =>
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, _) =>
val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrays(ctx, elementType)
} else {
genCodeForNonPrimitiveArrays(ctx, elementType)
}
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"$javaType[]", args) :: Nil)
ev.copy(code"""
$initCode
$codes
$javaType ${ev.value} = $concatenator.concat($args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
extraArguments = (s"$javaType[]", args) :: ("boolean", hasNull) :: Nil,
returnType = "boolean",
makeSplitFunction = body =>
s"""
|$body
|return $hasNull;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$hasNull = $funcCall;").mkString("\n")
)
val (concat, initCode) = dataType match {
case BinaryType =>
(s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, containsNull) =>
val concat = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrays(ctx, elementType, containsNull)
} else {
genCodeForNonPrimitiveArrays(ctx, elementType)
}
(concat, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
ev.copy(code =
code"""
|boolean $hasNull = false;
|$initCode
|$codes
|$javaType ${ev.value} = null;
|if (!$hasNull) {
| ${ev.value} = $concat($args);
|}
|boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)
}
private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
@ -2322,19 +2369,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
(code, numElements)
}
private def nullArgumentProtection() : String = {
if (nullable) {
s"""
|for (int z = 0; z < ${children.length}; z++) {
| if (args[z] == null) return null;
|}
""".stripMargin
} else {
""
}
}
private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
private def genCodeForPrimitiveArrays(
ctx: CodegenContext,
elementType: DataType,
checkForNull: Boolean): String = {
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")
@ -2342,29 +2380,44 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|new Object() {
| public ArrayData concat($javaType[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| if (args[y].isNullAt(z)) {
| $arrayData.setNullAt($counter);
| } else {
| $arrayData.set$primitiveValueTypeName(
| $counter,
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
| );
| }
| $counter++;
| }
| }
| return $arrayData;
| }
|}""".stripMargin.stripPrefix("\n")
val setterCode =
s"""
|$arrayData.set$primitiveValueTypeName(
| $counter,
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
|);
""".stripMargin
val nullSafeSetterCode = if (checkForNull) {
s"""
|if (args[y].isNullAt(z)) {
| $arrayData.setNullAt($counter);
|} else {
| $setterCode
|}
""".stripMargin
} else {
setterCode
}
val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] args) {
| $numElemCode
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| $nullSafeSetterCode
| $counter++;
| }
| }
| return $arrayData;
|}
""".stripMargin
ctx.addNewFunction(concat, concatDef)
}
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
@ -2374,22 +2427,24 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
s"""
|new Object() {
| public ArrayData concat($javaType[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| Object[] $arrayData = new Object[(int)$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
| }
|}""".stripMargin.stripPrefix("\n")
val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] args) {
| $numElemCode
| Object[] $arrayData = new Object[(int)$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
|}
""".stripMargin
ctx.addNewFunction(concat, concatDef)
}
override def toString: String = s"concat(${children.mkString(", ")})"

View file

@ -125,6 +125,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
valueContainsNull = false))
val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType,
valueContainsNull = false))
val m13 = Literal.create(Map(1 -> 2, 3 -> 4),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val m14 = Literal.create(Map(5 -> 6),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val m15 = Literal.create(Map(7 -> null),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val mNull = Literal.create(null, MapType(StringType, StringType))
// overlapping maps
@ -188,6 +194,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
)
)
// both keys and value are primitive and valueContainsNull = false
checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6))
// both keys and value are primitive and valueContainsNull = true
checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null))
// null map
checkEvaluation(MapConcat(Seq(m0, mNull)), null)
checkEvaluation(MapConcat(Seq(mNull, m0)), null)
@ -1121,6 +1133,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayType(ArrayType(StringType, containsNull = false), containsNull = false))
assert(Concat(Seq(aa0, aa2)).dataType ===
ArrayType(ArrayType(StringType, containsNull = true), containsNull = true))
// force split expressions for input in generated code
checkEvaluation(Concat(Seq.fill(100)(ai0)), Seq.fill(100)(Seq(1, 2, 3)).flatten)
}
test("Flatten") {