diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 92635417e9..f438748d9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -571,16 +571,25 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres |$mapDataClass ${ev.value} = null; """.stripMargin - val assignments = mapCodes.zipWithIndex.map { case (m, i) => - s""" - |if (!$hasNullName) { - | ${m.code} - | $argsName[$i] = ${m.value}; - | if (${m.isNull}) { - | $hasNullName = true; - | } - |} - """.stripMargin + val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map { + case ((m, true), i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | if (!${m.isNull}) { + | $argsName[$i] = ${m.value}; + | } else { + | $hasNullName = true; + | } + |} + """.stripMargin + case ((m, false), i) => + s""" + |if (!$hasNullName) { + | ${m.code} + | $argsName[$i] = ${m.value}; + |} + """.stripMargin } val codes = ctx.splitExpressionsWithCurrentInputs( @@ -601,17 +610,21 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val finKeysName = ctx.freshName("finalKeys") val finValsName = ctx.freshName("finalValues") - val keyConcatenator = if (CodeGenerator.isPrimitiveType(keyType)) { + val keyConcat = if (CodeGenerator.isPrimitiveType(keyType)) { genCodeForPrimitiveArrays(ctx, keyType, false) } else { genCodeForNonPrimitiveArrays(ctx, keyType) } - val valueConcatenator = if (CodeGenerator.isPrimitiveType(valueType)) { - genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) - } else { - genCodeForNonPrimitiveArrays(ctx, valueType) - } + val valueConcat = + if (valueType.sameType(keyType) && + !(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) { + keyConcat + } else if (CodeGenerator.isPrimitiveType(valueType)) { + genCodeForPrimitiveArrays(ctx, valueType, dataType.valueContainsNull) + } else { + genCodeForNonPrimitiveArrays(ctx, valueType) + } val keyArgsName = ctx.freshName("keyArgs") val valArgsName = ctx.freshName("valArgs") @@ -633,9 +646,9 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres | $numElementsName + " elements due to exceeding the map size limit " + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); | } - | $arrayDataClass $finKeysName = $keyConcatenator.concat($keyArgsName, + | $arrayDataClass $finKeysName = $keyConcat($keyArgsName, | (int) $numElementsName); - | $arrayDataClass $finValsName = $valueConcatenator.concat($valArgsName, + | $arrayDataClass $finValsName = $valueConcat($valArgsName, | (int) $numElementsName); | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); |} @@ -677,20 +690,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres setterCode1 } - s""" - |new Object() { - | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) { - | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < $argsName[y].numElements(); z++) { - | $setterCode - | $counter++; - | } - | } - | return $arrayData; - | } - |}""".stripMargin.stripPrefix("\n") + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $setterCode + | $counter++; + | } + | } + | return $arrayData; + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) } private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { @@ -700,20 +716,23 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val argsName = ctx.freshName("args") val numElemName = ctx.freshName("numElements") - s""" - |new Object() { - | public ArrayData concat(${classOf[ArrayData].getName}[] $argsName, int $numElemName) {; - | Object[] $arrayData = new Object[$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < $argsName[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - | } - |}""".stripMargin.stripPrefix("\n") + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] $argsName, int $numElemName) { + | Object[] $arrayData = new Object[$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < $argsName[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"$argsName[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) } override def prettyName: String = "map_concat" @@ -2270,39 +2289,67 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evals = children.map(_.genCode(ctx)) val args = ctx.freshName("args") + val hasNull = ctx.freshName("hasNull") - val inputs = evals.zipWithIndex.map { case (eval, index) => - s""" - ${eval.code} - if (!${eval.isNull}) { - $args[$index] = ${eval.value}; - } - """ + val inputs = evals.zip(children.map(_.nullable)).zipWithIndex.map { + case ((eval, true), index) => + s""" + |if (!$hasNull) { + | ${eval.code} + | if (!${eval.isNull}) { + | $args[$index] = ${eval.value}; + | } else { + | $hasNull = true; + | } + |} + """.stripMargin + case ((eval, false), index) => + s""" + |if (!$hasNull) { + | ${eval.code} + | $args[$index] = ${eval.value}; + |} + """.stripMargin } - val (concatenator, initCode) = dataType match { - case BinaryType => - (classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];") - case StringType => - ("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];") - case ArrayType(elementType, _) => - val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) { - genCodeForPrimitiveArrays(ctx, elementType) - } else { - genCodeForNonPrimitiveArrays(ctx, elementType) - } - (arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];") - } val codes = ctx.splitExpressionsWithCurrentInputs( expressions = inputs, funcName = "valueConcat", - extraArguments = (s"$javaType[]", args) :: Nil) - ev.copy(code""" - $initCode - $codes - $javaType ${ev.value} = $concatenator.concat($args); - boolean ${ev.isNull} = ${ev.value} == null; - """) + extraArguments = (s"$javaType[]", args) :: ("boolean", hasNull) :: Nil, + returnType = "boolean", + makeSplitFunction = body => + s""" + |$body + |return $hasNull; + """.stripMargin, + foldFunctions = _.map(funcCall => s"$hasNull = $funcCall;").mkString("\n") + ) + + val (concat, initCode) = dataType match { + case BinaryType => + (s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];") + case StringType => + ("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];") + case ArrayType(elementType, containsNull) => + val concat = if (CodeGenerator.isPrimitiveType(elementType)) { + genCodeForPrimitiveArrays(ctx, elementType, containsNull) + } else { + genCodeForNonPrimitiveArrays(ctx, elementType) + } + (concat, s"ArrayData[] $args = new ArrayData[${evals.length}];") + } + + ev.copy(code = + code""" + |boolean $hasNull = false; + |$initCode + |$codes + |$javaType ${ev.value} = null; + |if (!$hasNull) { + | ${ev.value} = $concat($args); + |} + |boolean ${ev.isNull} = ${ev.value} == null; + """.stripMargin) } private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = { @@ -2322,19 +2369,10 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio (code, numElements) } - private def nullArgumentProtection() : String = { - if (nullable) { - s""" - |for (int z = 0; z < ${children.length}; z++) { - | if (args[z] == null) return null; - |} - """.stripMargin - } else { - "" - } - } - - private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { + private def genCodeForPrimitiveArrays( + ctx: CodegenContext, + elementType: DataType, + checkForNull: Boolean): String = { val counter = ctx.freshName("counter") val arrayData = ctx.freshName("arrayData") @@ -2342,29 +2380,44 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType) - s""" - |new Object() { - | public ArrayData concat($javaType[] args) { - | ${nullArgumentProtection()} - | $numElemCode - | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | if (args[y].isNullAt(z)) { - | $arrayData.setNullAt($counter); - | } else { - | $arrayData.set$primitiveValueTypeName( - | $counter, - | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} - | ); - | } - | $counter++; - | } - | } - | return $arrayData; - | } - |}""".stripMargin.stripPrefix("\n") + val setterCode = + s""" + |$arrayData.set$primitiveValueTypeName( + | $counter, + | ${CodeGenerator.getValue(s"args[y]", elementType, "z")} + |); + """.stripMargin + + val nullSafeSetterCode = if (checkForNull) { + s""" + |if (args[y].isNullAt(z)) { + | $arrayData.setNullAt($counter); + |} else { + | $setterCode + |} + """.stripMargin + } else { + setterCode + } + + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] args) { + | $numElemCode + | ${ctx.createUnsafeArray(arrayData, numElemName, elementType, s" $prettyName failed.")} + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $nullSafeSetterCode + | $counter++; + | } + | } + | return $arrayData; + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) } private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = { @@ -2374,22 +2427,24 @@ case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpressio val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx) - s""" - |new Object() { - | public ArrayData concat($javaType[] args) { - | ${nullArgumentProtection()} - | $numElemCode - | Object[] $arrayData = new Object[(int)$numElemName]; - | int $counter = 0; - | for (int y = 0; y < ${children.length}; y++) { - | for (int z = 0; z < args[y].numElements(); z++) { - | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; - | $counter++; - | } - | } - | return new $genericArrayClass($arrayData); - | } - |}""".stripMargin.stripPrefix("\n") + val concat = ctx.freshName("concat") + val concatDef = + s""" + |private ArrayData $concat(ArrayData[] args) { + | $numElemCode + | Object[] $arrayData = new Object[(int)$numElemName]; + | int $counter = 0; + | for (int y = 0; y < ${children.length}; y++) { + | for (int z = 0; z < args[y].numElements(); z++) { + | $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")}; + | $counter++; + | } + | } + | return new $genericArrayClass($arrayData); + |} + """.stripMargin + + ctx.addNewFunction(concat, concatDef) } override def toString: String = s"concat(${children.mkString(", ")})" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index f1e3bd0915..c7f0da71e1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -125,6 +125,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper valueContainsNull = false)) val m12 = Literal.create(Map(3 -> "3", 4 -> "4"), MapType(IntegerType, StringType, valueContainsNull = false)) + val m13 = Literal.create(Map(1 -> 2, 3 -> 4), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val m14 = Literal.create(Map(5 -> 6), + MapType(IntegerType, IntegerType, valueContainsNull = false)) + val m15 = Literal.create(Map(7 -> null), + MapType(IntegerType, IntegerType, valueContainsNull = true)) val mNull = Literal.create(null, MapType(StringType, StringType)) // overlapping maps @@ -188,6 +194,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ) ) + // both keys and value are primitive and valueContainsNull = false + checkEvaluation(MapConcat(Seq(m13, m14)), Map(1 -> 2, 3 -> 4, 5 -> 6)) + + // both keys and value are primitive and valueContainsNull = true + checkEvaluation(MapConcat(Seq(m13, m15)), Map(1 -> 2, 3 -> 4, 7 -> null)) + // null map checkEvaluation(MapConcat(Seq(m0, mNull)), null) checkEvaluation(MapConcat(Seq(mNull, m0)), null) @@ -1121,6 +1133,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)) assert(Concat(Seq(aa0, aa2)).dataType === ArrayType(ArrayType(StringType, containsNull = true), containsNull = true)) + + // force split expressions for input in generated code + checkEvaluation(Concat(Seq.fill(100)(ai0)), Seq.fill(100)(Seq(1, 2, 3)).flatten) } test("Flatten") {