[SPARK-24871][SQL] Refactor Concat and MapConcat to avoid creating concatenator object for each row.
## What changes were proposed in this pull request? Refactor `Concat` and `MapConcat` to: - avoid creating concatenator object for each row. - make `Concat` handle `containsNull` properly. - make `Concat` shortcut if `null` child is found. ## How was this patch tested? Added some tests and existing tests. Author: Takuya UESHIN <ueshin@databricks.com> Closes #21824 from ueshin/issues/SPARK-24871/refactor_concat_mapconcat.
This commit is contained in:
parent
0ab07b357b
commit
7b6d36bc9e
|
@ -571,16 +571,25 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|
||||||
|$mapDataClass ${ev.value} = null;
|
|$mapDataClass ${ev.value} = null;
|
||||||
""".stripMargin
|
""".stripMargin
|
||||||
|
|
||||||
val assignments = mapCodes.zipWithIndex.map { case (m, i) =>
|
val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map {
|
||||||
s"""
|
case ((m, true), i) =>
|
||||||
|if (!$hasNullName) {
|
s"""
|
||||||
| ${m.code}
|
|if (!$hasNullName) {
|
||||||
| $argsName[$i] = ${m.value};
|
| ${m.code}
|
||||||
| if (${m.isNull}) {
|
| if (!${m.isNull}) {
|
||||||
| $hasNullName = true;
|
| $argsName[$i] = ${m.value};
|
||||||
| }
|
| } else {
|
||||||
|}
|
| $hasNullName = true;
|
||||||
""".stripMargin
|
| }
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
case ((m, false), i) =>
|
||||||
|
s"""
|
||||||
|
|if (!$hasNullName) {
|
||||||
|
| ${m.code}
|
||||||
|
| $argsName[$i] = ${m.value};
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||||
|
@ -601,17 +610,21 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|
||||||
val finKeysName = ctx.freshName("finalKeys")
|
val finKeysName = ctx.freshName("finalKeys")
|
||||||
val finValsName = ctx.freshName("finalValues")
|
val finValsName = ctx.freshName("finalValues")
|
||||||
|
|
||||||
val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) {
|
val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) {
|
||||||
genCodeForPrimitiveArrays(ctx, keyType, false)
|
genCodeForPrimitiveArrays(ctx, keyType, false)
|
||||||
} else {
|
} else {
|
||||||
genCodeForNonPrimitiveArrays(ctx, keyType)
|
genCodeForNonPrimitiveArrays(ctx, keyType)
|
||||||
}
|
}
|
||||||
|
|
||||||
val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) {
|
val valueConcat =
|
||||||
genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
|
if (valueType.sameType(keyType) &&
|
||||||
} else {
|
!(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) {
|
||||||
genCodeForNonPrimitiveArrays(ctx, valueType)
|
keyConcat
|
||||||
}
|
} else if (CodeGenerator.isPrimitiveType(valueType)) {
|
||||||
|
genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull)
|
||||||
|
} else {
|
||||||
|
genCodeForNonPrimitiveArrays(ctx, valueType)
|
||||||
|
}
|
||||||
|
|
||||||
val keyArgsName = ctx.freshName("keyArgs")
|
val keyArgsName = ctx.freshName("keyArgs")
|
||||||
val valArgsName = ctx.freshName("valArgs")
|
val valArgsName = ctx.freshName("valArgs")
|
||||||
|
@ -633,9 +646,9 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|
||||||
| $numElementsName + " elements due to exceeding the map size limit " +
|
| $numElementsName + " elements due to exceeding the map size limit " +
|
||||||
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
|
||||||
| }
|
| }
|
||||||
| $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName,
|
| $arrayDataClass $finKeysName = $keyConcat($keyArgsName,
|
||||||
| (int) $numElementsName);
|
| (int) $numElementsName);
|
||||||
| $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName,
|
| $arrayDataClass $finValsName = $valueConcat($valArgsName,
|
||||||
| (int) $numElementsName);
|
| (int) $numElementsName);
|
||||||
| ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName);
|
| ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName);
|
||||||
|}
|
|}
|
||||||
|
@ -677,20 +690,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|
||||||
setterCode1
|
setterCode1
|
||||||
}
|
}
|
||||||
|
|
||||||
s"""
|
val concat = ctx.freshName("concat")
|
||||||
|new Object() {
|
val concatDef =
|
||||||
| public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {
|
s"""
|
||||||
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
|
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
|
||||||
| int $counter = 0;
|
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
|
||||||
| for (int y = 0; y < ${children.length}; y++) {
|
| int $counter = 0;
|
||||||
| for (int z = 0; z < $argsName[y].numElements(); z++) {
|
| for (int y = 0; y < ${children.length}; y++) {
|
||||||
| $setterCode
|
| for (int z = 0; z < $argsName[y].numElements(); z++) {
|
||||||
| $counter++;
|
| $setterCode
|
||||||
| }
|
| $counter++;
|
||||||
| }
|
| }
|
||||||
| return $arrayData;
|
| }
|
||||||
| }
|
| return $arrayData;
|
||||||
|}""".stripMargin.stripPrefix("\n")
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
|
||||||
|
ctx.addNewFunction(concat, concatDef)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
|
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
|
||||||
|
@ -700,20 +716,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres
|
||||||
val argsName = ctx.freshName("args")
|
val argsName = ctx.freshName("args")
|
||||||
val numElemName = ctx.freshName("numElements")
|
val numElemName = ctx.freshName("numElements")
|
||||||
|
|
||||||
s"""
|
val concat = ctx.freshName("concat")
|
||||||
|new Object() {
|
val concatDef =
|
||||||
| public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {;
|
s"""
|
||||||
| Object[] $arrayData = new Object[$numElemName];
|
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
|
||||||
| int $counter = 0;
|
| Object[] $arrayData = new Object[$numElemName];
|
||||||
| for (int y = 0; y < ${children.length}; y++) {
|
| int $counter = 0;
|
||||||
| for (int z = 0; z < $argsName[y].numElements(); z++) {
|
| for (int y = 0; y < ${children.length}; y++) {
|
||||||
| $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
|
| for (int z = 0; z < $argsName[y].numElements(); z++) {
|
||||||
| $counter++;
|
| $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")};
|
||||||
| }
|
| $counter++;
|
||||||
| }
|
| }
|
||||||
| return new $genericArrayClass($arrayData);
|
| }
|
||||||
| }
|
| return new $genericArrayClass($arrayData);
|
||||||
|}""".stripMargin.stripPrefix("\n")
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
|
||||||
|
ctx.addNewFunction(concat, concatDef)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def prettyName: String = "map_concat"
|
override def prettyName: String = "map_concat"
|
||||||
|
@ -2270,39 +2289,67 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
|
||||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
val evals = children.map(_.genCode(ctx))
|
val evals = children.map(_.genCode(ctx))
|
||||||
val args = ctx.freshName("args")
|
val args = ctx.freshName("args")
|
||||||
|
val hasNull = ctx.freshName("hasNull")
|
||||||
|
|
||||||
val inputs = evals.zipWithIndex.map { case (eval, index) =>
|
val inputs = evals.zip(children.map(_.nullable)).zipWithIndex.map {
|
||||||
s"""
|
case ((eval, true), index) =>
|
||||||
${eval.code}
|
s"""
|
||||||
if (!${eval.isNull}) {
|
|if (!$hasNull) {
|
||||||
$args[$index] = ${eval.value};
|
| ${eval.code}
|
||||||
}
|
| if (!${eval.isNull}) {
|
||||||
"""
|
| $args[$index] = ${eval.value};
|
||||||
|
| } else {
|
||||||
|
| $hasNull = true;
|
||||||
|
| }
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
case ((eval, false), index) =>
|
||||||
|
s"""
|
||||||
|
|if (!$hasNull) {
|
||||||
|
| ${eval.code}
|
||||||
|
| $args[$index] = ${eval.value};
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
}
|
}
|
||||||
|
|
||||||
val (concatenator, initCode) = dataType match {
|
|
||||||
case BinaryType =>
|
|
||||||
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
|
|
||||||
case StringType =>
|
|
||||||
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
|
|
||||||
case ArrayType(elementType, _) =>
|
|
||||||
val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
|
|
||||||
genCodeForPrimitiveArrays(ctx, elementType)
|
|
||||||
} else {
|
|
||||||
genCodeForNonPrimitiveArrays(ctx, elementType)
|
|
||||||
}
|
|
||||||
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
|
|
||||||
}
|
|
||||||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||||
expressions = inputs,
|
expressions = inputs,
|
||||||
funcName = "valueConcat",
|
funcName = "valueConcat",
|
||||||
extraArguments = (s"$javaType[]", args) :: Nil)
|
extraArguments = (s"$javaType[]", args) :: ("boolean", hasNull) :: Nil,
|
||||||
ev.copy(code"""
|
returnType = "boolean",
|
||||||
$initCode
|
makeSplitFunction = body =>
|
||||||
$codes
|
s"""
|
||||||
$javaType ${ev.value} = $concatenator.concat($args);
|
|$body
|
||||||
boolean ${ev.isNull} = ${ev.value} == null;
|
|return $hasNull;
|
||||||
""")
|
""".stripMargin,
|
||||||
|
foldFunctions = _.map(funcCall => s"$hasNull = $funcCall;").mkString("\n")
|
||||||
|
)
|
||||||
|
|
||||||
|
val (concat, initCode) = dataType match {
|
||||||
|
case BinaryType =>
|
||||||
|
(s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];")
|
||||||
|
case StringType =>
|
||||||
|
("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];")
|
||||||
|
case ArrayType(elementType, containsNull) =>
|
||||||
|
val concat = if (CodeGenerator.isPrimitiveType(elementType)) {
|
||||||
|
genCodeForPrimitiveArrays(ctx, elementType, containsNull)
|
||||||
|
} else {
|
||||||
|
genCodeForNonPrimitiveArrays(ctx, elementType)
|
||||||
|
}
|
||||||
|
(concat, s"ArrayData[] $args = new ArrayData[${evals.length}];")
|
||||||
|
}
|
||||||
|
|
||||||
|
ev.copy(code =
|
||||||
|
code"""
|
||||||
|
|boolean $hasNull = false;
|
||||||
|
|$initCode
|
||||||
|
|$codes
|
||||||
|
|$javaType ${ev.value} = null;
|
||||||
|
|if (!$hasNull) {
|
||||||
|
| ${ev.value} = $concat($args);
|
||||||
|
|}
|
||||||
|
|boolean ${ev.isNull} = ${ev.value} == null;
|
||||||
|
""".stripMargin)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
|
private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
|
||||||
|
@ -2322,19 +2369,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
|
||||||
(code, numElements)
|
(code, numElements)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def nullArgumentProtection() : String = {
|
private def genCodeForPrimitiveArrays(
|
||||||
if (nullable) {
|
ctx: CodegenContext,
|
||||||
s"""
|
elementType: DataType,
|
||||||
|for (int z = 0; z < ${children.length}; z++) {
|
checkForNull: Boolean): String = {
|
||||||
| if (args[z] == null) return null;
|
|
||||||
|}
|
|
||||||
""".stripMargin
|
|
||||||
} else {
|
|
||||||
""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
|
|
||||||
val counter = ctx.freshName("counter")
|
val counter = ctx.freshName("counter")
|
||||||
val arrayData = ctx.freshName("arrayData")
|
val arrayData = ctx.freshName("arrayData")
|
||||||
|
|
||||||
|
@ -2342,29 +2380,44 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
|
||||||
|
|
||||||
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
|
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
|
||||||
|
|
||||||
s"""
|
val setterCode =
|
||||||
|new Object() {
|
s"""
|
||||||
| public ArrayData concat($javaType[] args) {
|
|$arrayData.set$primitiveValueTypeName(
|
||||||
| ${nullArgumentProtection()}
|
| $counter,
|
||||||
| $numElemCode
|
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
|
||||||
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
|
|);
|
||||||
| int $counter = 0;
|
""".stripMargin
|
||||||
| for (int y = 0; y < ${children.length}; y++) {
|
|
||||||
| for (int z = 0; z < args[y].numElements(); z++) {
|
val nullSafeSetterCode = if (checkForNull) {
|
||||||
| if (args[y].isNullAt(z)) {
|
s"""
|
||||||
| $arrayData.setNullAt($counter);
|
|if (args[y].isNullAt(z)) {
|
||||||
| } else {
|
| $arrayData.setNullAt($counter);
|
||||||
| $arrayData.set$primitiveValueTypeName(
|
|} else {
|
||||||
| $counter,
|
| $setterCode
|
||||||
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
|
|}
|
||||||
| );
|
""".stripMargin
|
||||||
| }
|
} else {
|
||||||
| $counter++;
|
setterCode
|
||||||
| }
|
}
|
||||||
| }
|
|
||||||
| return $arrayData;
|
val concat = ctx.freshName("concat")
|
||||||
| }
|
val concatDef =
|
||||||
|}""".stripMargin.stripPrefix("\n")
|
s"""
|
||||||
|
|private ArrayData $concat(ArrayData[] args) {
|
||||||
|
| $numElemCode
|
||||||
|
| ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")}
|
||||||
|
| int $counter = 0;
|
||||||
|
| for (int y = 0; y < ${children.length}; y++) {
|
||||||
|
| for (int z = 0; z < args[y].numElements(); z++) {
|
||||||
|
| $nullSafeSetterCode
|
||||||
|
| $counter++;
|
||||||
|
| }
|
||||||
|
| }
|
||||||
|
| return $arrayData;
|
||||||
|
|}
|
||||||
|
""".stripMargin
|
||||||
|
|
||||||
|
ctx.addNewFunction(concat, concatDef)
|
||||||
}
|
}
|
||||||
|
|
||||||
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
|
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
|
||||||
|
@ -2374,22 +2427,24 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio
|
||||||
|
|
||||||
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
|
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
|
||||||
|
|
||||||
s"""
|
val concat = ctx.freshName("concat")
|
||||||
|new Object() {
|
val concatDef =
|
||||||
| public ArrayData concat($javaType[] args) {
|
s"""
|
||||||
| ${nullArgumentProtection()}
|
|private ArrayData $concat(ArrayData[] args) {
|
||||||
| $numElemCode
|
| $numElemCode
|
||||||
| Object[] $arrayData = new Object[(int)$numElemName];
|
| Object[] $arrayData = new Object[(int)$numElemName];
|
||||||
| int $counter = 0;
|
| int $counter = 0;
|
||||||
| for (int y = 0; y < ${children.length}; y++) {
|
| for (int y = 0; y < ${children.length}; y++) {
|
||||||
| for (int z = 0; z < args[y].numElements(); z++) {
|
| for (int z = 0; z < args[y].numElements(); z++) {
|
||||||
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
|
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
|
||||||
| $counter++;
|
| $counter++;
|
||||||
| }
|
| }
|
||||||
| }
|
| }
|
||||||
| return new $genericArrayClass($arrayData);
|
| return new $genericArrayClass($arrayData);
|
||||||
| }
|
|}
|
||||||
|}""".stripMargin.stripPrefix("\n")
|
""".stripMargin
|
||||||
|
|
||||||
|
ctx.addNewFunction(concat, concatDef)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def toString: String = s"concat(${children.mkString(", ")})"
|
override def toString: String = s"concat(${children.mkString(", ")})"
|
||||||
|
|
|
@ -125,6 +125,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
||||||
valueContainsNull = false))
|
valueContainsNull = false))
|
||||||
val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType,
|
val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType,
|
||||||
valueContainsNull = false))
|
valueContainsNull = false))
|
||||||
|
val m13 = Literal.create(Map(1 -> 2, 3 -> 4),
|
||||||
|
MapType(IntegerType, IntegerType, valueContainsNull = false))
|
||||||
|
val m14 = Literal.create(Map(5 -> 6),
|
||||||
|
MapType(IntegerType, IntegerType, valueContainsNull = false))
|
||||||
|
val m15 = Literal.create(Map(7 -> null),
|
||||||
|
MapType(IntegerType, IntegerType, valueContainsNull = true))
|
||||||
val mNull = Literal.create(null, MapType(StringType, StringType))
|
val mNull = Literal.create(null, MapType(StringType, StringType))
|
||||||
|
|
||||||
// overlapping maps
|
// overlapping maps
|
||||||
|
@ -188,6 +194,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// both keys and value are primitive and valueContainsNull = false
|
||||||
|
checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6))
|
||||||
|
|
||||||
|
// both keys and value are primitive and valueContainsNull = true
|
||||||
|
checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null))
|
||||||
|
|
||||||
// null map
|
// null map
|
||||||
checkEvaluation(MapConcat(Seq(m0, mNull)), null)
|
checkEvaluation(MapConcat(Seq(m0, mNull)), null)
|
||||||
checkEvaluation(MapConcat(Seq(mNull, m0)), null)
|
checkEvaluation(MapConcat(Seq(mNull, m0)), null)
|
||||||
|
@ -1121,6 +1133,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
||||||
ArrayType(ArrayType(StringType, containsNull = false), containsNull = false))
|
ArrayType(ArrayType(StringType, containsNull = false), containsNull = false))
|
||||||
assert(Concat(Seq(aa0, aa2)).dataType ===
|
assert(Concat(Seq(aa0, aa2)).dataType ===
|
||||||
ArrayType(ArrayType(StringType, containsNull = true), containsNull = true))
|
ArrayType(ArrayType(StringType, containsNull = true), containsNull = true))
|
||||||
|
|
||||||
|
// force split expressions for input in generated code
|
||||||
|
checkEvaluation(Concat(Seq.fill(100)(ai0)), Seq.fill(100)(Seq(1, 2, 3)).flatten)
|
||||||
}
|
}
|
||||||
|
|
||||||
test("Flatten") {
|
test("Flatten") {
|
||||||
|
|
Loading…
Reference in a new issue