[SPARK-23736][SQL] Extending the concat function to support array columns
## What changes were proposed in this pull request? The PR adds a logic for easy concatenation of multiple array columns and covers: - Concat expression has been extended to support array columns - A Python wrapper ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite - typeCoercion/native/concat.sql ## Codegen examples ### Primitive-type elements ``` val df = Seq( (Seq(1 ,2), Seq(3, 4)), (Seq(1, 2, 3), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 070 */ project_numElements, /* 071 */ 4); /* 072 */ if (project_size > 2147483632) { /* 073 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_size + /* 074 */ " bytes of data due to exceeding the limit 2147483632 bytes" + /* 075 */ " for UnsafeArrayData."); /* 076 */ } /* 077 */ /* 078 */ byte[] project_array = new byte[(int)project_size]; /* 079 */ UnsafeArrayData project_arrayData = new UnsafeArrayData(); /* 080 */ Platform.putLong(project_array, 16, project_numElements); /* 081 */ project_arrayData.pointTo(project_array, 16, (int)project_size); /* 082 */ int project_counter = 0; /* 083 */ for (int y = 0; y < 2; y++) { /* 084 */ for (int z = 0; z < args[y].numElements(); z++) { /* 085 */ if (args[y].isNullAt(z)) { /* 086 */ project_arrayData.setNullAt(project_counter); /* 087 */ } else { /* 088 */ project_arrayData.setInt( /* 089 */ project_counter, /* 090 */ args[y].getInt(z) /* 091 */ ); /* 092 */ } /* 093 */ project_counter++; /* 094 */ } /* 095 */ } /* 096 */ return project_arrayData; /* 097 */ } /* 098 */ }.concat(project_args); /* 099 */ boolean project_isNull = project_value == null; ``` ### Non-primitive-type elements ``` val df = Seq( (Seq("aa" ,"bb"), Seq("ccc", "ddd")), (Seq("x", "y"), null) ).toDF("a", "b") df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen() ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ if (!(!inputadapter_isNull)) continue; /* 038 */ /* 039 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 040 */ /* 041 */ ArrayData[] project_args = new ArrayData[2]; /* 042 */ /* 043 */ if (!false) { /* 044 */ project_args[0] = inputadapter_value; /* 045 */ } /* 046 */ /* 047 */ boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1); /* 048 */ ArrayData inputadapter_value1 = inputadapter_isNull1 ? /* 049 */ null : (inputadapter_row.getArray(1)); /* 050 */ if (!inputadapter_isNull1) { /* 051 */ project_args[1] = inputadapter_value1; /* 052 */ } /* 053 */ /* 054 */ ArrayData project_value = new Object() { /* 055 */ public ArrayData concat(ArrayData[] args) { /* 056 */ for (int z = 0; z < 2; z++) { /* 057 */ if (args[z] == null) return null; /* 058 */ } /* 059 */ /* 060 */ long project_numElements = 0L; /* 061 */ for (int z = 0; z < 2; z++) { /* 062 */ project_numElements += args[z].numElements(); /* 063 */ } /* 064 */ if (project_numElements > 2147483632) { /* 065 */ throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements + /* 066 */ " elements due to exceeding the array size limit 2147483632."); /* 067 */ } /* 068 */ /* 069 */ Object[] project_arrayObjects = new Object[(int)project_numElements]; /* 070 */ int project_counter = 0; /* 071 */ for (int y = 0; y < 2; y++) { /* 072 */ for (int z = 0; z < args[y].numElements(); z++) { /* 073 */ project_arrayObjects[project_counter] = args[y].getUTF8String(z); /* 074 */ project_counter++; /* 075 */ } /* 076 */ } /* 077 */ return new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObjects); /* 078 */ } /* 079 */ }.concat(project_args); /* 080 */ boolean project_isNull = project_value == null; ``` Author: mn-mikke <mrkAha12346github> Closes #20858 from mn-mikke/feature/array-api-concat_arrays-to-master.
This commit is contained in:
parent
b3fde5a41e
commit
e6b466084c
|
@ -33,7 +33,11 @@ public class ByteArrayMethods {
|
|||
}
|
||||
|
||||
public static int roundNumberOfBytesToNearestWord(int numBytes) {
|
||||
int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
|
||||
return (int)roundNumberOfBytesToNearestWord((long)numBytes);
|
||||
}
|
||||
|
||||
public static long roundNumberOfBytesToNearestWord(long numBytes) {
|
||||
long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
|
||||
if (remainder == 0) {
|
||||
return numBytes;
|
||||
} else {
|
||||
|
|
|
@ -1425,21 +1425,6 @@ for _name, _doc in _string_functions.items():
|
|||
del _name, _doc
|
||||
|
||||
|
||||
@since(1.5)
|
||||
@ignore_unicode_prefix
|
||||
def concat(*cols):
|
||||
"""
|
||||
Concatenates multiple input columns together into a single column.
|
||||
If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
|
||||
|
||||
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
|
||||
>>> df.select(concat(df.s, df.d).alias('s')).collect()
|
||||
[Row(s=u'abcd123')]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
|
||||
|
||||
|
||||
@since(1.5)
|
||||
@ignore_unicode_prefix
|
||||
def concat_ws(sep, *cols):
|
||||
|
@ -1845,6 +1830,25 @@ def array_contains(col, value):
|
|||
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
|
||||
|
||||
|
||||
@since(1.5)
|
||||
@ignore_unicode_prefix
|
||||
def concat(*cols):
|
||||
"""
|
||||
Concatenates multiple input columns together into a single column.
|
||||
The function works with strings, binary and compatible array columns.
|
||||
|
||||
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
|
||||
>>> df.select(concat(df.s, df.d).alias('s')).collect()
|
||||
[Row(s=u'abcd123')]
|
||||
|
||||
>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c'])
|
||||
>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
|
||||
[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
|
||||
|
||||
|
||||
@since(2.4)
|
||||
def array_position(col, value):
|
||||
"""
|
||||
|
|
|
@ -56,9 +56,19 @@ import org.apache.spark.unsafe.types.UTF8String;
|
|||
public final class UnsafeArrayData extends ArrayData {
|
||||
|
||||
public static int calculateHeaderPortionInBytes(int numFields) {
|
||||
return (int)calculateHeaderPortionInBytes((long)numFields);
|
||||
}
|
||||
|
||||
public static long calculateHeaderPortionInBytes(long numFields) {
|
||||
return 8 + ((numFields + 63)/ 64) * 8;
|
||||
}
|
||||
|
||||
public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) {
|
||||
long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) +
|
||||
ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize);
|
||||
return size;
|
||||
}
|
||||
|
||||
private Object baseObject;
|
||||
private long baseOffset;
|
||||
|
||||
|
|
|
@ -308,7 +308,6 @@ object FunctionRegistry {
|
|||
expression[BitLength]("bit_length"),
|
||||
expression[Length]("char_length"),
|
||||
expression[Length]("character_length"),
|
||||
expression[Concat]("concat"),
|
||||
expression[ConcatWs]("concat_ws"),
|
||||
expression[Decode]("decode"),
|
||||
expression[Elt]("elt"),
|
||||
|
@ -413,6 +412,7 @@ object FunctionRegistry {
|
|||
expression[ArrayMin]("array_min"),
|
||||
expression[ArrayMax]("array_max"),
|
||||
expression[Reverse]("reverse"),
|
||||
expression[Concat]("concat"),
|
||||
CreateStruct.registryEntry,
|
||||
|
||||
// misc functions
|
||||
|
|
|
@ -520,6 +520,14 @@ object TypeCoercion {
|
|||
case None => a
|
||||
}
|
||||
|
||||
case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
|
||||
!haveSameType(children) =>
|
||||
val types = children.map(_.dataType)
|
||||
findWiderCommonType(types) match {
|
||||
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
|
||||
case None => c
|
||||
}
|
||||
|
||||
case m @ CreateMap(children) if m.keys.length == m.values.length &&
|
||||
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
|
||||
val newKeys = if (haveSameType(m.keys)) {
|
||||
|
|
|
@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
|||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
import org.apache.spark.unsafe.Platform
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods
|
||||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
||||
|
||||
/**
|
||||
* Given an array or map, returns its size. Returns -1 if null.
|
||||
|
@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
|
|||
|
||||
override def prettyName: String = "element_at"
|
||||
}
|
||||
|
||||
/**
|
||||
* Concatenates multiple input columns together into a single column.
|
||||
* The function works with strings, binary and compatible array columns.
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_('Spark', 'SQL');
|
||||
SparkSQL
|
||||
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
|
||||
| [1,2,3,4,5,6]
|
||||
""")
|
||||
case class Concat(children: Seq[Expression]) extends Expression {
|
||||
|
||||
private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
|
||||
|
||||
val allowedTypes = Seq(StringType, BinaryType, ArrayType)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children.isEmpty) {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
} else {
|
||||
val childTypes = children.map(_.dataType)
|
||||
if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) {
|
||||
return TypeCheckResult.TypeCheckFailure(
|
||||
s"input to function $prettyName should have been StringType, BinaryType or ArrayType," +
|
||||
s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]"))
|
||||
}
|
||||
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
|
||||
}
|
||||
}
|
||||
|
||||
override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
|
||||
|
||||
lazy val javaType: String = CodeGenerator.javaType(dataType)
|
||||
|
||||
override def nullable: Boolean = children.exists(_.nullable)
|
||||
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
||||
override def eval(input: InternalRow): Any = dataType match {
|
||||
case BinaryType =>
|
||||
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
|
||||
ByteArray.concat(inputs: _*)
|
||||
case StringType =>
|
||||
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
|
||||
UTF8String.concat(inputs : _*)
|
||||
case ArrayType(elementType, _) =>
|
||||
val inputs = children.toStream.map(_.eval(input))
|
||||
if (inputs.contains(null)) {
|
||||
null
|
||||
} else {
|
||||
val arrayData = inputs.map(_.asInstanceOf[ArrayData])
|
||||
val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
|
||||
if (numberOfElements > MAX_ARRAY_LENGTH) {
|
||||
throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
|
||||
s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
|
||||
}
|
||||
val finalData = new Array[AnyRef](numberOfElements.toInt)
|
||||
var position = 0
|
||||
for(ad <- arrayData) {
|
||||
val arr = ad.toObjectArray(elementType)
|
||||
Array.copy(arr, 0, finalData, position, arr.length)
|
||||
position += arr.length
|
||||
}
|
||||
new GenericArrayData(finalData)
|
||||
}
|
||||
}
|
||||
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val evals = children.map(_.genCode(ctx))
|
||||
val args = ctx.freshName("args")
|
||||
|
||||
val inputs = evals.zipWithIndex.map { case (eval, index) =>
|
||||
s"""
|
||||
${eval.code}
|
||||
if (!${eval.isNull}) {
|
||||
$args[$index] = ${eval.value};
|
||||
}
|
||||
"""
|
||||
}
|
||||
|
||||
val (concatenator, initCode) = dataType match {
|
||||
case BinaryType =>
|
||||
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
|
||||
case StringType =>
|
||||
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
|
||||
case ArrayType(elementType, _) =>
|
||||
val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
|
||||
genCodeForPrimitiveArrays(ctx, elementType)
|
||||
} else {
|
||||
genCodeForNonPrimitiveArrays(ctx, elementType)
|
||||
}
|
||||
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
|
||||
}
|
||||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = inputs,
|
||||
funcName = "valueConcat",
|
||||
extraArguments = (s"$javaType[]", args) :: Nil)
|
||||
ev.copy(s"""
|
||||
$initCode
|
||||
$codes
|
||||
$javaType ${ev.value} = $concatenator.concat($args);
|
||||
boolean ${ev.isNull} = ${ev.value} == null;
|
||||
""")
|
||||
}
|
||||
|
||||
private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
|
||||
val numElements = ctx.freshName("numElements")
|
||||
val code = s"""
|
||||
|long $numElements = 0L;
|
||||
|for (int z = 0; z < ${children.length}; z++) {
|
||||
| $numElements += args[z].numElements();
|
||||
|}
|
||||
|if ($numElements > $MAX_ARRAY_LENGTH) {
|
||||
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements +
|
||||
| " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
|
||||
|}
|
||||
""".stripMargin
|
||||
|
||||
(code, numElements)
|
||||
}
|
||||
|
||||
private def nullArgumentProtection() : String = {
|
||||
if (nullable) {
|
||||
s"""
|
||||
|for (int z = 0; z < ${children.length}; z++) {
|
||||
| if (args[z] == null) return null;
|
||||
|}
|
||||
""".stripMargin
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
|
||||
private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
|
||||
val arrayName = ctx.freshName("array")
|
||||
val arraySizeName = ctx.freshName("size")
|
||||
val counter = ctx.freshName("counter")
|
||||
val arrayData = ctx.freshName("arrayData")
|
||||
|
||||
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
|
||||
|
||||
val unsafeArraySizeInBytes = s"""
|
||||
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
|
||||
| $numElemName,
|
||||
| ${elementType.defaultSize});
|
||||
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
|
||||
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName +
|
||||
| " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" +
|
||||
| " for UnsafeArrayData.");
|
||||
|}
|
||||
""".stripMargin
|
||||
val baseOffset = Platform.BYTE_ARRAY_OFFSET
|
||||
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
|
||||
|
||||
s"""
|
||||
|new Object() {
|
||||
| public ArrayData concat($javaType[] args) {
|
||||
| ${nullArgumentProtection()}
|
||||
| $numElemCode
|
||||
| $unsafeArraySizeInBytes
|
||||
| byte[] $arrayName = new byte[(int)$arraySizeName];
|
||||
| UnsafeArrayData $arrayData = new UnsafeArrayData();
|
||||
| Platform.putLong($arrayName, $baseOffset, $numElemName);
|
||||
| $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
|
||||
| int $counter = 0;
|
||||
| for (int y = 0; y < ${children.length}; y++) {
|
||||
| for (int z = 0; z < args[y].numElements(); z++) {
|
||||
| if (args[y].isNullAt(z)) {
|
||||
| $arrayData.setNullAt($counter);
|
||||
| } else {
|
||||
| $arrayData.set$primitiveValueTypeName(
|
||||
| $counter,
|
||||
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
|
||||
| );
|
||||
| }
|
||||
| $counter++;
|
||||
| }
|
||||
| }
|
||||
| return $arrayData;
|
||||
| }
|
||||
|}""".stripMargin.stripPrefix("\n")
|
||||
}
|
||||
|
||||
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
|
||||
val genericArrayClass = classOf[GenericArrayData].getName
|
||||
val arrayData = ctx.freshName("arrayObjects")
|
||||
val counter = ctx.freshName("counter")
|
||||
|
||||
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
|
||||
|
||||
s"""
|
||||
|new Object() {
|
||||
| public ArrayData concat($javaType[] args) {
|
||||
| ${nullArgumentProtection()}
|
||||
| $numElemCode
|
||||
| Object[] $arrayData = new Object[(int)$numElemName];
|
||||
| int $counter = 0;
|
||||
| for (int y = 0; y < ${children.length}; y++) {
|
||||
| for (int z = 0; z < args[y].numElements(); z++) {
|
||||
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
|
||||
| $counter++;
|
||||
| }
|
||||
| }
|
||||
| return new $genericArrayClass($arrayData);
|
||||
| }
|
||||
|}""".stripMargin.stripPrefix("\n")
|
||||
}
|
||||
|
||||
override def toString: String = s"concat(${children.mkString(", ")})"
|
||||
|
||||
override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
|
||||
}
|
||||
|
|
|
@ -36,87 +36,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
|||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
||||
/**
|
||||
* An expression that concatenates multiple inputs into a single output.
|
||||
* If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
|
||||
* If any input is null, concat returns null.
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_('Spark', 'SQL');
|
||||
SparkSQL
|
||||
""")
|
||||
case class Concat(children: Seq[Expression]) extends Expression {
|
||||
|
||||
private lazy val isBinaryMode: Boolean = dataType == BinaryType
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children.isEmpty) {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
} else {
|
||||
val childTypes = children.map(_.dataType)
|
||||
if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
|
||||
return TypeCheckResult.TypeCheckFailure(
|
||||
s"input to function $prettyName should have StringType or BinaryType, but it's " +
|
||||
childTypes.map(_.simpleString).mkString("[", ", ", "]"))
|
||||
}
|
||||
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
|
||||
}
|
||||
}
|
||||
|
||||
override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
|
||||
|
||||
override def nullable: Boolean = children.exists(_.nullable)
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
if (isBinaryMode) {
|
||||
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
|
||||
ByteArray.concat(inputs: _*)
|
||||
} else {
|
||||
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
|
||||
UTF8String.concat(inputs : _*)
|
||||
}
|
||||
}
|
||||
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val evals = children.map(_.genCode(ctx))
|
||||
val args = ctx.freshName("args")
|
||||
|
||||
val inputs = evals.zipWithIndex.map { case (eval, index) =>
|
||||
s"""
|
||||
${eval.code}
|
||||
if (!${eval.isNull}) {
|
||||
$args[$index] = ${eval.value};
|
||||
}
|
||||
"""
|
||||
}
|
||||
|
||||
val (concatenator, initCode) = if (isBinaryMode) {
|
||||
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
|
||||
} else {
|
||||
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
|
||||
}
|
||||
val codes = ctx.splitExpressionsWithCurrentInputs(
|
||||
expressions = inputs,
|
||||
funcName = "valueConcat",
|
||||
extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil)
|
||||
ev.copy(s"""
|
||||
$initCode
|
||||
$codes
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
|
||||
boolean ${ev.isNull} = ${ev.value} == null;
|
||||
""")
|
||||
}
|
||||
|
||||
override def toString: String = s"concat(${children.mkString(", ")})"
|
||||
|
||||
override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* An expression that concatenates multiple input strings or array of strings into a single string,
|
||||
* using a given separator (the first child).
|
||||
|
|
|
@ -239,4 +239,45 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
|
||||
checkEvaluation(ElementAt(m2, Literal("a")), null)
|
||||
}
|
||||
|
||||
test("Concat") {
|
||||
// Primitive-type elements
|
||||
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
|
||||
val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
|
||||
val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType))
|
||||
val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType))
|
||||
val ai4 = Literal.create(null, ArrayType(IntegerType))
|
||||
|
||||
checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3))
|
||||
checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3))
|
||||
checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3))
|
||||
checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3))
|
||||
checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5))
|
||||
checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4, null, 5))
|
||||
checkEvaluation(Concat(Seq(ai4)), null)
|
||||
checkEvaluation(Concat(Seq(ai0, ai4)), null)
|
||||
checkEvaluation(Concat(Seq(ai4, ai0)), null)
|
||||
|
||||
// Non-primitive-type elements
|
||||
val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
|
||||
val as1 = Literal.create(Seq.empty[String], ArrayType(StringType))
|
||||
val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType))
|
||||
val as3 = Literal.create(Seq(null, null), ArrayType(StringType))
|
||||
val as4 = Literal.create(null, ArrayType(StringType))
|
||||
|
||||
val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType)))
|
||||
val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType)))
|
||||
|
||||
checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c"))
|
||||
checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c"))
|
||||
checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c"))
|
||||
checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c"))
|
||||
checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e"))
|
||||
checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null, "d", null, "e"))
|
||||
checkEvaluation(Concat(Seq(as4)), null)
|
||||
checkEvaluation(Concat(Seq(as0, as4)), null)
|
||||
checkEvaluation(Concat(Seq(as4, as0)), null)
|
||||
|
||||
checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f")))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2228,16 +2228,6 @@ object functions {
|
|||
*/
|
||||
def base64(e: Column): Column = withExpr { Base64(e.expr) }
|
||||
|
||||
/**
|
||||
* Concatenates multiple input columns together into a single column.
|
||||
* If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
|
||||
*
|
||||
* @group string_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
|
||||
|
||||
/**
|
||||
* Concatenates multiple input string columns together into a single string column,
|
||||
* using the given separator.
|
||||
|
@ -3038,6 +3028,16 @@ object functions {
|
|||
ArrayContains(column.expr, Literal(value))
|
||||
}
|
||||
|
||||
/**
|
||||
* Concatenates multiple input columns together into a single column.
|
||||
* The function works with strings, binary and compatible array columns.
|
||||
*
|
||||
* @group collection_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
|
||||
|
||||
/**
|
||||
* Locates the position of the first occurrence of the value in the given array as long.
|
||||
* Returns null if either of the arguments are null.
|
||||
|
|
|
@ -91,3 +91,65 @@ FROM (
|
|||
encode(string(id + 3), 'utf-8') col4
|
||||
FROM range(10)
|
||||
);
|
||||
|
||||
CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES (
|
||||
array(true, false), array(true),
|
||||
array(2Y, 1Y), array(3Y, 4Y),
|
||||
array(2S, 1S), array(3S, 4S),
|
||||
array(2, 1), array(3, 4),
|
||||
array(2L, 1L), array(3L, 4L),
|
||||
array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809),
|
||||
array(2.0D, 1.0D), array(3.0D, 4.0D),
|
||||
array(float(2.0), float(1.0)), array(float(3.0), float(4.0)),
|
||||
array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'),
|
||||
array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
|
||||
array(timestamp '2016-11-11 20:54:00.000'),
|
||||
array('a', 'b'), array('c', 'd'),
|
||||
array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')),
|
||||
array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)),
|
||||
array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4))
|
||||
) AS various_arrays(
|
||||
boolean_array1, boolean_array2,
|
||||
tinyint_array1, tinyint_array2,
|
||||
smallint_array1, smallint_array2,
|
||||
int_array1, int_array2,
|
||||
bigint_array1, bigint_array2,
|
||||
decimal_array1, decimal_array2,
|
||||
double_array1, double_array2,
|
||||
float_array1, float_array2,
|
||||
date_array1, data_array2,
|
||||
timestamp_array1, timestamp_array2,
|
||||
string_array1, string_array2,
|
||||
array_array1, array_array2,
|
||||
struct_array1, struct_array2,
|
||||
map_array1, map_array2
|
||||
);
|
||||
|
||||
-- Concatenate arrays of the same type
|
||||
SELECT
|
||||
(boolean_array1 || boolean_array2) boolean_array,
|
||||
(tinyint_array1 || tinyint_array2) tinyint_array,
|
||||
(smallint_array1 || smallint_array2) smallint_array,
|
||||
(int_array1 || int_array2) int_array,
|
||||
(bigint_array1 || bigint_array2) bigint_array,
|
||||
(decimal_array1 || decimal_array2) decimal_array,
|
||||
(double_array1 || double_array2) double_array,
|
||||
(float_array1 || float_array2) float_array,
|
||||
(date_array1 || data_array2) data_array,
|
||||
(timestamp_array1 || timestamp_array2) timestamp_array,
|
||||
(string_array1 || string_array2) string_array,
|
||||
(array_array1 || array_array2) array_array,
|
||||
(struct_array1 || struct_array2) struct_array,
|
||||
(map_array1 || map_array2) map_array
|
||||
FROM various_arrays;
|
||||
|
||||
-- Concatenate arrays of different types
|
||||
SELECT
|
||||
(tinyint_array1 || smallint_array2) ts_array,
|
||||
(smallint_array1 || int_array2) si_array,
|
||||
(int_array1 || bigint_array2) ib_array,
|
||||
(double_array1 || float_array2) df_array,
|
||||
(string_array1 || data_array2) std_array,
|
||||
(timestamp_array1 || string_array2) tst_array,
|
||||
(string_array1 || int_array2) sti_array
|
||||
FROM various_arrays;
|
||||
|
|
|
@ -237,3 +237,81 @@ struct<col:binary>
|
|||
78910
|
||||
891011
|
||||
9101112
|
||||
|
||||
|
||||
-- !query 11
|
||||
CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES (
|
||||
array(true, false), array(true),
|
||||
array(2Y, 1Y), array(3Y, 4Y),
|
||||
array(2S, 1S), array(3S, 4S),
|
||||
array(2, 1), array(3, 4),
|
||||
array(2L, 1L), array(3L, 4L),
|
||||
array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809),
|
||||
array(2.0D, 1.0D), array(3.0D, 4.0D),
|
||||
array(float(2.0), float(1.0)), array(float(3.0), float(4.0)),
|
||||
array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'),
|
||||
array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
|
||||
array(timestamp '2016-11-11 20:54:00.000'),
|
||||
array('a', 'b'), array('c', 'd'),
|
||||
array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')),
|
||||
array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)),
|
||||
array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4))
|
||||
) AS various_arrays(
|
||||
boolean_array1, boolean_array2,
|
||||
tinyint_array1, tinyint_array2,
|
||||
smallint_array1, smallint_array2,
|
||||
int_array1, int_array2,
|
||||
bigint_array1, bigint_array2,
|
||||
decimal_array1, decimal_array2,
|
||||
double_array1, double_array2,
|
||||
float_array1, float_array2,
|
||||
date_array1, data_array2,
|
||||
timestamp_array1, timestamp_array2,
|
||||
string_array1, string_array2,
|
||||
array_array1, array_array2,
|
||||
struct_array1, struct_array2,
|
||||
map_array1, map_array2
|
||||
)
|
||||
-- !query 11 schema
|
||||
struct<>
|
||||
-- !query 11 output
|
||||
|
||||
|
||||
|
||||
-- !query 12
|
||||
SELECT
|
||||
(boolean_array1 || boolean_array2) boolean_array,
|
||||
(tinyint_array1 || tinyint_array2) tinyint_array,
|
||||
(smallint_array1 || smallint_array2) smallint_array,
|
||||
(int_array1 || int_array2) int_array,
|
||||
(bigint_array1 || bigint_array2) bigint_array,
|
||||
(decimal_array1 || decimal_array2) decimal_array,
|
||||
(double_array1 || double_array2) double_array,
|
||||
(float_array1 || float_array2) float_array,
|
||||
(date_array1 || data_array2) data_array,
|
||||
(timestamp_array1 || timestamp_array2) timestamp_array,
|
||||
(string_array1 || string_array2) string_array,
|
||||
(array_array1 || array_array2) array_array,
|
||||
(struct_array1 || struct_array2) struct_array,
|
||||
(map_array1 || map_array2) map_array
|
||||
FROM various_arrays
|
||||
-- !query 12 schema
|
||||
struct<boolean_array:array<boolean>,tinyint_array:array<tinyint>,smallint_array:array<smallint>,int_array:array<int>,bigint_array:array<bigint>,decimal_array:array<decimal(19,0)>,double_array:array<double>,float_array:array<float>,data_array:array<date>,timestamp_array:array<timestamp>,string_array:array<string>,array_array:array<array<string>>,struct_array:array<struct<col1:string,col2:int>>,map_array:array<map<string,int>>>
|
||||
-- !query 12 output
|
||||
[true,false,true] [2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,3,4] [9223372036854775809,9223372036854775808,9223372036854775808,9223372036854775809] [2.0,1.0,3.0,4.0] [2.0,1.0,3.0,4.0] [2016-03-14,2016-03-13,2016-03-12,2016-03-11] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0,2016-11-11 20:54:00.0] ["a","b","c","d"] [["a","b"],["c","d"],["e"],["f"]] [{"col1":"a","col2":1},{"col1":"b","col2":2},{"col1":"c","col2":3},{"col1":"d","col2":4}] [{"a":1},{"b":2},{"c":3},{"d":4}]
|
||||
|
||||
|
||||
-- !query 13
|
||||
SELECT
|
||||
(tinyint_array1 || smallint_array2) ts_array,
|
||||
(smallint_array1 || int_array2) si_array,
|
||||
(int_array1 || bigint_array2) ib_array,
|
||||
(double_array1 || float_array2) df_array,
|
||||
(string_array1 || data_array2) std_array,
|
||||
(timestamp_array1 || string_array2) tst_array,
|
||||
(string_array1 || int_array2) sti_array
|
||||
FROM various_arrays
|
||||
-- !query 13 schema
|
||||
struct<ts_array:array<smallint>,si_array:array<int>,ib_array:array<bigint>,df_array:array<double>,std_array:array<string>,tst_array:array<string>,sti_array:array<string>>
|
||||
-- !query 13 output
|
||||
[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"]
|
||||
|
|
|
@ -617,6 +617,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
)
|
||||
}
|
||||
|
||||
test("concat function - arrays") {
|
||||
val nseqi : Seq[Int] = null
|
||||
val nseqs : Seq[String] = null
|
||||
val df = Seq(
|
||||
|
||||
(Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs),
|
||||
(Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs)
|
||||
).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn")
|
||||
|
||||
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on
|
||||
|
||||
// Simple test cases
|
||||
checkAnswer(
|
||||
df.selectExpr("array(1, 2, 3L)"),
|
||||
Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L)))
|
||||
)
|
||||
|
||||
checkAnswer (
|
||||
df.select(concat($"i1", $"s1")),
|
||||
Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a")))
|
||||
)
|
||||
checkAnswer(
|
||||
df.select(concat($"i1", $"i2", $"i3")),
|
||||
Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
|
||||
)
|
||||
checkAnswer(
|
||||
df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")),
|
||||
Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
|
||||
)
|
||||
checkAnswer(
|
||||
df.selectExpr("concat(array(1, null), i2, i3)"),
|
||||
Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2)))
|
||||
)
|
||||
checkAnswer(
|
||||
df.select(concat($"s1", $"s2", $"s3")),
|
||||
Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
|
||||
)
|
||||
checkAnswer(
|
||||
df.selectExpr("concat(s1, s2, s3)"),
|
||||
Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
|
||||
)
|
||||
checkAnswer(
|
||||
df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")),
|
||||
Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
|
||||
)
|
||||
|
||||
// Null test cases
|
||||
checkAnswer(
|
||||
df.select(concat($"i1", $"in")),
|
||||
Seq(Row(null), Row(null))
|
||||
)
|
||||
checkAnswer(
|
||||
df.select(concat($"in", $"i1")),
|
||||
Seq(Row(null), Row(null))
|
||||
)
|
||||
checkAnswer(
|
||||
df.select(concat($"s1", $"sn")),
|
||||
Seq(Row(null), Row(null))
|
||||
)
|
||||
checkAnswer(
|
||||
df.select(concat($"sn", $"s1")),
|
||||
Seq(Row(null), Row(null))
|
||||
)
|
||||
|
||||
// Type error test cases
|
||||
intercept[AnalysisException] {
|
||||
df.selectExpr("concat(i1, i2, null)")
|
||||
}
|
||||
|
||||
intercept[AnalysisException] {
|
||||
df.selectExpr("concat(i1, array(i1, i2))")
|
||||
}
|
||||
}
|
||||
|
||||
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
|
||||
import DataFrameFunctionsSuite.CodegenFallbackExpr
|
||||
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
|
||||
|
|
|
@ -1742,8 +1742,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
|
|||
sql("DESCRIBE FUNCTION 'concat'"),
|
||||
Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") ::
|
||||
Row("Function: concat") ::
|
||||
Row("Usage: concat(str1, str2, ..., strN) - " +
|
||||
"Returns the concatenation of str1, str2, ..., strN.") :: Nil
|
||||
Row("Usage: concat(col1, col2, ..., colN) - " +
|
||||
"Returns the concatenation of col1, col2, ..., colN.") :: Nil
|
||||
)
|
||||
// extended mode
|
||||
checkAnswer(
|
||||
|
|
Loading…
Reference in a new issue