[SPARK-24871][SQL] Refactor Concat and MapConcat to avoid creating concatenator object for each row.

## What changes were proposed in this pull request?

Refactor `Concat` and `MapConcat` to:

- avoid creating concatenator object for each row.
- make `Concat` handle `containsNull` properly.
- make `Concat` shortcut if `null` child is found.

## How was this patch tested?

Added some tests and existing tests.

Author: Takuya UESHIN <ueshin@databricks.com>

Closes #21824 from ueshin/issues/SPARK-24871/refactor_concat_mapconcat.
This commit is contained in:
Takuya UESHIN 2018-07-20 20:08:42 +08:00 committed by Wenchen Fan
parent 0ab07b357b
commit 7b6d36bc9e
2 changed files with 195 additions and 125 deletions

View file

@ -571,16 +571,25 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|$mapDataClass ${ev.value} = null; |$mapDataClass ${ev.value} = null;
""".stripMargin """.stripMargin
val assignments = mapCodes.zipWithIndex.map { case (m, i) => val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map {
s""" case ((m, true), i) =>
|if (!$hasNullName) { s"""
| ${m.code} |if (!$hasNullName) {
| $argsName[$i] = ${m.value}; | ${m.code}
| if (${m.isNull}) { | if (!${m.isNull}) {
| $hasNullName = true; | $argsName[$i] = ${m.value};
| } | } else {
|} | $hasNullName = true;
""".stripMargin | }
|}
""".stripMargin
case ((m, false), i) =>
s"""
|if (!$hasNullName) {
| ${m.code}
| $argsName[$i] = ${m.value};
|}
""".stripMargin
} }
val codes = ctx.splitExpressionsWithCurrentInputs( val codes = ctx.splitExpressionsWithCurrentInputs(
@ -601,17 +610,21 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
val finKeysName = ctx.freshName("finalKeys") val finKeysName = ctx.freshName("finalKeys")
val finValsName = ctx.freshName("finalValues") val finValsName = ctx.freshName("finalValues")
val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) { val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) {
genCodeForPrimitiveArrays(ctx, keyType, false) genCodeForPrimitiveArrays(ctx, keyType, false)
} else { } else {
genCodeForNonPrimitiveArrays(ctx, keyType) genCodeForNonPrimitiveArrays(ctx, keyType)
} }
val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) { val valueConcat =
genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) if (valueType.sameType(keyType) &&
} else { !(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) {
genCodeForNonPrimitiveArrays(ctx, valueType) keyConcat
} } else if (CodeGenerator.isPrimitiveType(valueType)) {
genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
} else {
genCodeForNonPrimitiveArrays(ctx, valueType)
}
val keyArgsName = ctx.freshName("keyArgs") val keyArgsName = ctx.freshName("keyArgs")
val valArgsName = ctx.freshName("valArgs") val valArgsName = ctx.freshName("valArgs")
@ -633,9 +646,9 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
| $numElementsName + " elements due to exceeding the map size limit " + | $numElementsName + " elements due to exceeding the map size limit " +
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
| } | }
| $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName, | $arrayDataClass $finKeysName = $keyConcat($keyArgsName,
| (int) $numElementsName); | (int) $numElementsName);
| $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName, | $arrayDataClass $finValsName = $valueConcat($valArgsName,
| (int) $numElementsName); | (int) $numElementsName);
| ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName);
|} |}
@ -677,20 +690,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
setterCode1 setterCode1
} }
s""" val concat = ctx.freshName("concat")
|new Object() { val concatDef =
| public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) { s"""
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
| int $counter = 0; | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| for (int y = 0; y < ${children.length}; y++) { | int $counter = 0;
| for (int z = 0; z < $argsName[y].numElements(); z++) { | for (int y = 0; y < ${children.length}; y++) {
| $setterCode | for (int z = 0; z < $argsName[y].numElements(); z++) {
| $counter++; | $setterCode
| } | $counter++;
| } | }
| return $arrayData; | }
| } | return $arrayData;
|}""".stripMargin.stripPrefix("\n") |}
""".stripMargin
ctx.addNewFunction(concat, concatDef)
} }
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
@ -700,20 +716,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
val argsName = ctx.freshName("args") val argsName = ctx.freshName("args")
val numElemName = ctx.freshName("numElements") val numElemName = ctx.freshName("numElements")
s""" val concat = ctx.freshName("concat")
|new Object() { val concatDef =
| public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {; s"""
| Object[] $arrayData = new Object[$numElemName]; |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
| int $counter = 0; | Object[] $arrayData = new Object[$numElemName];
| for (int y = 0; y < ${children.length}; y++) { | int $counter = 0;
| for (int z = 0; z < $argsName[y].numElements(); z++) { | for (int y = 0; y < ${children.length}; y++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; | for (int z = 0; z < $argsName[y].numElements(); z++) {
| $counter++; | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
| } | $counter++;
| } | }
| return new $genericArrayClass($arrayData); | }
| } | return new $genericArrayClass($arrayData);
|}""".stripMargin.stripPrefix("\n") |}
""".stripMargin
ctx.addNewFunction(concat, concatDef)
} }
override def prettyName: String = "map_concat" override def prettyName: String = "map_concat"
@ -2270,39 +2289,67 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx)) val evals = children.map(_.genCode(ctx))
val args = ctx.freshName("args") val args = ctx.freshName("args")
val hasNull = ctx.freshName("hasNull")
val inputs = evals.zipWithIndex.map { case (eval, index) => val inputs = evals.zip(children.map(_.nullable)).zipWithIndex.map {
s""" case ((eval, true), index) =>
${eval.code} s"""
if (!${eval.isNull}) { |if (!$hasNull) {
$args[$index] = ${eval.value}; | ${eval.code}
} | if (!${eval.isNull}) {
""" | $args[$index] = ${eval.value};
| } else {
| $hasNull = true;
| }
|}
""".stripMargin
case ((eval, false), index) =>
s"""
|if (!$hasNull) {
| ${eval.code}
| $args[$index] = ${eval.value};
|}
""".stripMargin
} }
val (concatenator, initCode) = dataType match {
case BinaryType =>
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, _) =>
val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrays(ctx, elementType)
} else {
genCodeForNonPrimitiveArrays(ctx, elementType)
}
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
val codes = ctx.splitExpressionsWithCurrentInputs( val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs, expressions = inputs,
funcName = "valueConcat", funcName = "valueConcat",
extraArguments = (s"$javaType[]", args) :: Nil) extraArguments = (s"$javaType[]", args) :: ("boolean", hasNull) :: Nil,
ev.copy(code""" returnType = "boolean",
$initCode makeSplitFunction = body =>
$codes s"""
$javaType ${ev.value} = $concatenator.concat($args); |$body
boolean ${ev.isNull} = ${ev.value} == null; |return $hasNull;
""") """.stripMargin,
foldFunctions = _.map(funcCall => s"$hasNull = $funcCall;").mkString("\n")
)
val (concat, initCode) = dataType match {
case BinaryType =>
(s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, containsNull) =>
val concat = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrays(ctx, elementType, containsNull)
} else {
genCodeForNonPrimitiveArrays(ctx, elementType)
}
(concat, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
ev.copy(code =
code"""
|boolean $hasNull = false;
|$initCode
|$codes
|$javaType ${ev.value} = null;
|if (!$hasNull) {
| ${ev.value} = $concat($args);
|}
|boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)
} }
private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
@ -2322,19 +2369,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
(code, numElements) (code, numElements)
} }
private def nullArgumentProtection() : String = { private def genCodeForPrimitiveArrays(
if (nullable) { ctx: CodegenContext,
s""" elementType: DataType,
|for (int z = 0; z < ${children.length}; z++) { checkForNull: Boolean): String = {
| if (args[z] == null) return null;
|}
""".stripMargin
} else {
""
}
}
private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
val counter = ctx.freshName("counter") val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData") val arrayData = ctx.freshName("arrayData")
@ -2342,29 +2380,44 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
s""" val setterCode =
|new Object() { s"""
| public ArrayData concat($javaType[] args) { |$arrayData.set$primitiveValueTypeName(
| ${nullArgumentProtection()} | $counter,
| $numElemCode | ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} |);
| int $counter = 0; """.stripMargin
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) { val nullSafeSetterCode = if (checkForNull) {
| if (args[y].isNullAt(z)) { s"""
| $arrayData.setNullAt($counter); |if (args[y].isNullAt(z)) {
| } else { | $arrayData.setNullAt($counter);
| $arrayData.set$primitiveValueTypeName( |} else {
| $counter, | $setterCode
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")} |}
| ); """.stripMargin
| } } else {
| $counter++; setterCode
| } }
| }
| return $arrayData; val concat = ctx.freshName("concat")
| } val concatDef =
|}""".stripMargin.stripPrefix("\n") s"""
|private ArrayData $concat(ArrayData[] args) {
| $numElemCode
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| $nullSafeSetterCode
| $counter++;
| }
| }
| return $arrayData;
|}
""".stripMargin
ctx.addNewFunction(concat, concatDef)
} }
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
@ -2374,22 +2427,24 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
s""" val concat = ctx.freshName("concat")
|new Object() { val concatDef =
| public ArrayData concat($javaType[] args) { s"""
| ${nullArgumentProtection()} |private ArrayData $concat(ArrayData[] args) {
| $numElemCode | $numElemCode
| Object[] $arrayData = new Object[(int)$numElemName]; | Object[] $arrayData = new Object[(int)$numElemName];
| int $counter = 0; | int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) { | for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) { | for (int z = 0; z < args[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
| $counter++; | $counter++;
| } | }
| } | }
| return new $genericArrayClass($arrayData); | return new $genericArrayClass($arrayData);
| } |}
|}""".stripMargin.stripPrefix("\n") """.stripMargin
ctx.addNewFunction(concat, concatDef)
} }
override def toString: String = s"concat(${children.mkString(", ")})" override def toString: String = s"concat(${children.mkString(", ")})"

View file

@ -125,6 +125,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
valueContainsNull = false)) valueContainsNull = false))
val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType,
valueContainsNull = false)) valueContainsNull = false))
val m13 = Literal.create(Map(1 -> 2, 3 -> 4),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val m14 = Literal.create(Map(5 -> 6),
MapType(IntegerType, IntegerType, valueContainsNull = false))
val m15 = Literal.create(Map(7 -> null),
MapType(IntegerType, IntegerType, valueContainsNull = true))
val mNull = Literal.create(null, MapType(StringType, StringType)) val mNull = Literal.create(null, MapType(StringType, StringType))
// overlapping maps // overlapping maps
@ -188,6 +194,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
) )
) )
// both keys and value are primitive and valueContainsNull = false
checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6))
// both keys and value are primitive and valueContainsNull = true
checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null))
// null map // null map
checkEvaluation(MapConcat(Seq(m0, mNull)), null) checkEvaluation(MapConcat(Seq(m0, mNull)), null)
checkEvaluation(MapConcat(Seq(mNull, m0)), null) checkEvaluation(MapConcat(Seq(mNull, m0)), null)
@ -1121,6 +1133,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) ArrayType(ArrayType(StringType, containsNull = false), containsNull = false))
assert(Concat(Seq(aa0, aa2)).dataType === assert(Concat(Seq(aa0, aa2)).dataType ===
ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) ArrayType(ArrayType(StringType, containsNull = true), containsNull = true))
// force split expressions for input in generated code
checkEvaluation(Concat(Seq.fill(100)(ai0)), Seq.fill(100)(Seq(1, 2, 3)).flatten)
} }
test("Flatten") { test("Flatten") {