[SPARK-23736][SQL] Extending the concat function to support array columns

## What changes were proposed in this pull request?
The PR adds a logic for easy concatenation of multiple array columns and covers:
- Concat expression has been extended to support array columns
- A Python wrapper

## How was this patch tested?
New tests added into:
- CollectionExpressionsSuite
- DataFrameFunctionsSuite
- typeCoercion/native/concat.sql

## Codegen examples
### Primitive-type elements
```
val df = Seq(
  (Seq(1 ,2), Seq(3, 4)),
  (Seq(1, 2, 3), null)
).toDF("a", "b")
df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen()
```
Result:
```
/* 033 */         boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 034 */         ArrayData inputadapter_value = inputadapter_isNull ?
/* 035 */         null : (inputadapter_row.getArray(0));
/* 036 */
/* 037 */         if (!(!inputadapter_isNull)) continue;
/* 038 */
/* 039 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 040 */
/* 041 */         ArrayData[] project_args = new ArrayData[2];
/* 042 */
/* 043 */         if (!false) {
/* 044 */           project_args[0] = inputadapter_value;
/* 045 */         }
/* 046 */
/* 047 */         boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 048 */         ArrayData inputadapter_value1 = inputadapter_isNull1 ?
/* 049 */         null : (inputadapter_row.getArray(1));
/* 050 */         if (!inputadapter_isNull1) {
/* 051 */           project_args[1] = inputadapter_value1;
/* 052 */         }
/* 053 */
/* 054 */         ArrayData project_value = new Object() {
/* 055 */           public ArrayData concat(ArrayData[] args) {
/* 056 */             for (int z = 0; z < 2; z++) {
/* 057 */               if (args[z] == null) return null;
/* 058 */             }
/* 059 */
/* 060 */             long project_numElements = 0L;
/* 061 */             for (int z = 0; z < 2; z++) {
/* 062 */               project_numElements += args[z].numElements();
/* 063 */             }
/* 064 */             if (project_numElements > 2147483632) {
/* 065 */               throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements +
/* 066 */                 " elements due to exceeding the array size limit 2147483632.");
/* 067 */             }
/* 068 */
/* 069 */             long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
/* 070 */               project_numElements,
/* 071 */               4);
/* 072 */             if (project_size > 2147483632) {
/* 073 */               throw new RuntimeException("Unsuccessful try to concat arrays with " + project_size +
/* 074 */                 " bytes of data due to exceeding the limit 2147483632 bytes" +
/* 075 */                 " for UnsafeArrayData.");
/* 076 */             }
/* 077 */
/* 078 */             byte[] project_array = new byte[(int)project_size];
/* 079 */             UnsafeArrayData project_arrayData = new UnsafeArrayData();
/* 080 */             Platform.putLong(project_array, 16, project_numElements);
/* 081 */             project_arrayData.pointTo(project_array, 16, (int)project_size);
/* 082 */             int project_counter = 0;
/* 083 */             for (int y = 0; y < 2; y++) {
/* 084 */               for (int z = 0; z < args[y].numElements(); z++) {
/* 085 */                 if (args[y].isNullAt(z)) {
/* 086 */                   project_arrayData.setNullAt(project_counter);
/* 087 */                 } else {
/* 088 */                   project_arrayData.setInt(
/* 089 */                     project_counter,
/* 090 */                     args[y].getInt(z)
/* 091 */                   );
/* 092 */                 }
/* 093 */                 project_counter++;
/* 094 */               }
/* 095 */             }
/* 096 */             return project_arrayData;
/* 097 */           }
/* 098 */         }.concat(project_args);
/* 099 */         boolean project_isNull = project_value == null;
```

### Non-primitive-type elements
```
val df = Seq(
  (Seq("aa" ,"bb"), Seq("ccc", "ddd")),
  (Seq("x", "y"), null)
).toDF("a", "b")
df.filter('a.isNotNull).select(concat('a, 'b)).debugCodegen()
```
Result:
```
/* 033 */         boolean inputadapter_isNull = inputadapter_row.isNullAt(0);
/* 034 */         ArrayData inputadapter_value = inputadapter_isNull ?
/* 035 */         null : (inputadapter_row.getArray(0));
/* 036 */
/* 037 */         if (!(!inputadapter_isNull)) continue;
/* 038 */
/* 039 */         ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1);
/* 040 */
/* 041 */         ArrayData[] project_args = new ArrayData[2];
/* 042 */
/* 043 */         if (!false) {
/* 044 */           project_args[0] = inputadapter_value;
/* 045 */         }
/* 046 */
/* 047 */         boolean inputadapter_isNull1 = inputadapter_row.isNullAt(1);
/* 048 */         ArrayData inputadapter_value1 = inputadapter_isNull1 ?
/* 049 */         null : (inputadapter_row.getArray(1));
/* 050 */         if (!inputadapter_isNull1) {
/* 051 */           project_args[1] = inputadapter_value1;
/* 052 */         }
/* 053 */
/* 054 */         ArrayData project_value = new Object() {
/* 055 */           public ArrayData concat(ArrayData[] args) {
/* 056 */             for (int z = 0; z < 2; z++) {
/* 057 */               if (args[z] == null) return null;
/* 058 */             }
/* 059 */
/* 060 */             long project_numElements = 0L;
/* 061 */             for (int z = 0; z < 2; z++) {
/* 062 */               project_numElements += args[z].numElements();
/* 063 */             }
/* 064 */             if (project_numElements > 2147483632) {
/* 065 */               throw new RuntimeException("Unsuccessful try to concat arrays with " + project_numElements +
/* 066 */                 " elements due to exceeding the array size limit 2147483632.");
/* 067 */             }
/* 068 */
/* 069 */             Object[] project_arrayObjects = new Object[(int)project_numElements];
/* 070 */             int project_counter = 0;
/* 071 */             for (int y = 0; y < 2; y++) {
/* 072 */               for (int z = 0; z < args[y].numElements(); z++) {
/* 073 */                 project_arrayObjects[project_counter] = args[y].getUTF8String(z);
/* 074 */                 project_counter++;
/* 075 */               }
/* 076 */             }
/* 077 */             return new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObjects);
/* 078 */           }
/* 079 */         }.concat(project_args);
/* 080 */         boolean project_isNull = project_value == null;
```

Author: mn-mikke <mrkAha12346github>

Closes #20858 from mn-mikke/feature/array-api-concat_arrays-to-master.
This commit is contained in:
mn-mikke 2018-04-20 14:58:11 +09:00 committed by Takuya UESHIN
parent b3fde5a41e
commit e6b466084c
13 changed files with 529 additions and 111 deletions

View file

@ -33,7 +33,11 @@ public class ByteArrayMethods {
} }
public static int roundNumberOfBytesToNearestWord(int numBytes) { public static int roundNumberOfBytesToNearestWord(int numBytes) {
int remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8` return (int)roundNumberOfBytesToNearestWord((long)numBytes);
}
public static long roundNumberOfBytesToNearestWord(long numBytes) {
long remainder = numBytes & 0x07; // This is equivalent to `numBytes % 8`
if (remainder == 0) { if (remainder == 0) {
return numBytes; return numBytes;
} else { } else {

View file

@ -1425,21 +1425,6 @@ for _name, _doc in _string_functions.items():
del _name, _doc del _name, _doc
@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input columns together into a single column.
If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
[Row(s=u'abcd123')]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
@since(1.5) @since(1.5)
@ignore_unicode_prefix @ignore_unicode_prefix
def concat_ws(sep, *cols): def concat_ws(sep, *cols):
@ -1845,6 +1830,25 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value)) return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
@since(1.5)
@ignore_unicode_prefix
def concat(*cols):
"""
Concatenates multiple input columns together into a single column.
The function works with strings, binary and compatible array columns.
>>> df = spark.createDataFrame([('abcd','123')], ['s', 'd'])
>>> df.select(concat(df.s, df.d).alias('s')).collect()
[Row(s=u'abcd123')]
>>> df = spark.createDataFrame([([1, 2], [3, 4], [5]), ([1, 2], None, [3])], ['a', 'b', 'c'])
>>> df.select(concat(df.a, df.b, df.c).alias("arr")).collect()
[Row(arr=[1, 2, 3, 4, 5]), Row(arr=None)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.concat(_to_seq(sc, cols, _to_java_column)))
@since(2.4) @since(2.4)
def array_position(col, value): def array_position(col, value):
""" """

View file

@ -56,9 +56,19 @@ import org.apache.spark.unsafe.types.UTF8String;
public final class UnsafeArrayData extends ArrayData { public final class UnsafeArrayData extends ArrayData {
public static int calculateHeaderPortionInBytes(int numFields) { public static int calculateHeaderPortionInBytes(int numFields) {
return (int)calculateHeaderPortionInBytes((long)numFields);
}
public static long calculateHeaderPortionInBytes(long numFields) {
return 8 + ((numFields + 63)/ 64) * 8; return 8 + ((numFields + 63)/ 64) * 8;
} }
public static long calculateSizeOfUnderlyingByteArray(long numFields, int elementSize) {
long size = UnsafeArrayData.calculateHeaderPortionInBytes(numFields) +
ByteArrayMethods.roundNumberOfBytesToNearestWord(numFields * elementSize);
return size;
}
private Object baseObject; private Object baseObject;
private long baseOffset; private long baseOffset;

View file

@ -308,7 +308,6 @@ object FunctionRegistry {
expression[BitLength]("bit_length"), expression[BitLength]("bit_length"),
expression[Length]("char_length"), expression[Length]("char_length"),
expression[Length]("character_length"), expression[Length]("character_length"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"), expression[ConcatWs]("concat_ws"),
expression[Decode]("decode"), expression[Decode]("decode"),
expression[Elt]("elt"), expression[Elt]("elt"),
@ -413,6 +412,7 @@ object FunctionRegistry {
expression[ArrayMin]("array_min"), expression[ArrayMin]("array_min"),
expression[ArrayMax]("array_max"), expression[ArrayMax]("array_max"),
expression[Reverse]("reverse"), expression[Reverse]("reverse"),
expression[Concat]("concat"),
CreateStruct.registryEntry, CreateStruct.registryEntry,
// misc functions // misc functions

View file

@ -520,6 +520,14 @@ object TypeCoercion {
case None => a case None => a
} }
case c @ Concat(children) if children.forall(c => ArrayType.acceptsType(c.dataType)) &&
!haveSameType(children) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Concat(children.map(Cast(_, finalDataType)))
case None => c
}
case m @ CreateMap(children) if m.keys.length == m.values.length && case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) => (!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) { val newKeys = if (haveSameType(m.keys)) {

View file

@ -23,7 +23,9 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
/** /**
* Given an array or map, returns its size. Returns -1 if null. * Given an array or map, returns its size. Returns -1 if null.
@ -665,3 +667,219 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
override def prettyName: String = "element_at" override def prettyName: String = "element_at"
} }
/**
* Concatenates multiple input columns together into a single column.
* The function works with strings, binary and compatible array columns.
*/
@ExpressionDescription(
usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.",
examples = """
Examples:
> SELECT _FUNC_('Spark', 'SQL');
SparkSQL
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
| [1,2,3,4,5,6]
""")
case class Concat(children: Seq[Expression]) extends Expression {
private val MAX_ARRAY_LENGTH: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
val allowedTypes = Seq(StringType, BinaryType, ArrayType)
override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
} else {
val childTypes = children.map(_.dataType)
if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) {
return TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should have been StringType, BinaryType or ArrayType," +
s" but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]"))
}
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
}
}
override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
lazy val javaType: String = CodeGenerator.javaType(dataType)
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
override def eval(input: InternalRow): Any = dataType match {
case BinaryType =>
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
ByteArray.concat(inputs: _*)
case StringType =>
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs : _*)
case ArrayType(elementType, _) =>
val inputs = children.toStream.map(_.eval(input))
if (inputs.contains(null)) {
null
} else {
val arrayData = inputs.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numberOfElements > MAX_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
s" elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
}
val finalData = new Array[AnyRef](numberOfElements.toInt)
var position = 0
for(ad <- arrayData) {
val arr = ad.toObjectArray(elementType)
Array.copy(arr, 0, finalData, position, arr.length)
position += arr.length
}
new GenericArrayData(finalData)
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val args = ctx.freshName("args")
val inputs = evals.zipWithIndex.map { case (eval, index) =>
s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
}
"""
}
val (concatenator, initCode) = dataType match {
case BinaryType =>
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, _) =>
val arrayConcatClass = if (CodeGenerator.isPrimitiveType(elementType)) {
genCodeForPrimitiveArrays(ctx, elementType)
} else {
genCodeForNonPrimitiveArrays(ctx, elementType)
}
(arrayConcatClass, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"$javaType[]", args) :: Nil)
ev.copy(s"""
$initCode
$codes
$javaType ${ev.value} = $concatenator.concat($args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
}
private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
val numElements = ctx.freshName("numElements")
val code = s"""
|long $numElements = 0L;
|for (int z = 0; z < ${children.length}; z++) {
| $numElements += args[z].numElements();
|}
|if ($numElements > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $numElements +
| " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
|}
""".stripMargin
(code, numElements)
}
private def nullArgumentProtection() : String = {
if (nullable) {
s"""
|for (int z = 0; z < ${children.length}; z++) {
| if (args[z] == null) return null;
|}
""".stripMargin
} else {
""
}
}
private def genCodeForPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
val arrayName = ctx.freshName("array")
val arraySizeName = ctx.freshName("size")
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
val unsafeArraySizeInBytes = s"""
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
| $numElemName,
| ${elementType.defaultSize});
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
| throw new RuntimeException("Unsuccessful try to concat arrays with " + $arraySizeName +
| " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH bytes" +
| " for UnsafeArrayData.");
|}
""".stripMargin
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|new Object() {
| public ArrayData concat($javaType[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| $unsafeArraySizeInBytes
| byte[] $arrayName = new byte[(int)$arraySizeName];
| UnsafeArrayData $arrayData = new UnsafeArrayData();
| Platform.putLong($arrayName, $baseOffset, $numElemName);
| $arrayData.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| if (args[y].isNullAt(z)) {
| $arrayData.setNullAt($counter);
| } else {
| $arrayData.set$primitiveValueTypeName(
| $counter,
| ${CodeGenerator.getValue(s"args[y]", elementType, "z")}
| );
| }
| $counter++;
| }
| }
| return $arrayData;
| }
|}""".stripMargin.stripPrefix("\n")
}
private def genCodeForNonPrimitiveArrays(ctx: CodegenContext, elementType: DataType): String = {
val genericArrayClass = classOf[GenericArrayData].getName
val arrayData = ctx.freshName("arrayObjects")
val counter = ctx.freshName("counter")
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
s"""
|new Object() {
| public ArrayData concat($javaType[] args) {
| ${nullArgumentProtection()}
| $numElemCode
| Object[] $arrayData = new Object[(int)$numElemName];
| int $counter = 0;
| for (int y = 0; y < ${children.length}; y++) {
| for (int z = 0; z < args[y].numElements(); z++) {
| $arrayData[$counter] = ${CodeGenerator.getValue(s"args[y]", elementType, "z")};
| $counter++;
| }
| }
| return new $genericArrayClass($arrayData);
| }
|}""".stripMargin.stripPrefix("\n")
}
override def toString: String = s"concat(${children.mkString(", ")})"
override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
}

View file

@ -36,87 +36,6 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
/**
* An expression that concatenates multiple inputs into a single output.
* If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
* If any input is null, concat returns null.
*/
@ExpressionDescription(
usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN.",
examples = """
Examples:
> SELECT _FUNC_('Spark', 'SQL');
SparkSQL
""")
case class Concat(children: Seq[Expression]) extends Expression {
private lazy val isBinaryMode: Boolean = dataType == BinaryType
override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
} else {
val childTypes = children.map(_.dataType)
if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) {
return TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should have StringType or BinaryType, but it's " +
childTypes.map(_.simpleString).mkString("[", ", ", "]"))
}
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
}
}
override def dataType: DataType = children.map(_.dataType).headOption.getOrElse(StringType)
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
override def eval(input: InternalRow): Any = {
if (isBinaryMode) {
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
ByteArray.concat(inputs: _*)
} else {
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs : _*)
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val args = ctx.freshName("args")
val inputs = evals.zipWithIndex.map { case (eval, index) =>
s"""
${eval.code}
if (!${eval.isNull}) {
$args[$index] = ${eval.value};
}
"""
}
val (concatenator, initCode) = if (isBinaryMode) {
(classOf[ByteArray].getName, s"byte[][] $args = new byte[${evals.length}][];")
} else {
("UTF8String", s"UTF8String[] $args = new UTF8String[${evals.length}];")
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"${CodeGenerator.javaType(dataType)}[]", args) :: Nil)
ev.copy(s"""
$initCode
$codes
${CodeGenerator.javaType(dataType)} ${ev.value} = $concatenator.concat($args);
boolean ${ev.isNull} = ${ev.value} == null;
""")
}
override def toString: String = s"concat(${children.mkString(", ")})"
override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
}
/** /**
* An expression that concatenates multiple input strings or array of strings into a single string, * An expression that concatenates multiple input strings or array of strings into a single string,
* using a given separator (the first child). * using a given separator (the first child).

View file

@ -239,4 +239,45 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ElementAt(m2, Literal("a")), null) checkEvaluation(ElementAt(m2, Literal("a")), null)
} }
test("Concat") {
// Primitive-type elements
val ai0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val ai1 = Literal.create(Seq.empty[Integer], ArrayType(IntegerType))
val ai2 = Literal.create(Seq(4, null, 5), ArrayType(IntegerType))
val ai3 = Literal.create(Seq(null, null), ArrayType(IntegerType))
val ai4 = Literal.create(null, ArrayType(IntegerType))
checkEvaluation(Concat(Seq(ai0)), Seq(1, 2, 3))
checkEvaluation(Concat(Seq(ai0, ai1)), Seq(1, 2, 3))
checkEvaluation(Concat(Seq(ai1, ai0)), Seq(1, 2, 3))
checkEvaluation(Concat(Seq(ai0, ai0)), Seq(1, 2, 3, 1, 2, 3))
checkEvaluation(Concat(Seq(ai0, ai2)), Seq(1, 2, 3, 4, null, 5))
checkEvaluation(Concat(Seq(ai0, ai3, ai2)), Seq(1, 2, 3, null, null, 4, null, 5))
checkEvaluation(Concat(Seq(ai4)), null)
checkEvaluation(Concat(Seq(ai0, ai4)), null)
checkEvaluation(Concat(Seq(ai4, ai0)), null)
// Non-primitive-type elements
val as0 = Literal.create(Seq("a", "b", "c"), ArrayType(StringType))
val as1 = Literal.create(Seq.empty[String], ArrayType(StringType))
val as2 = Literal.create(Seq("d", null, "e"), ArrayType(StringType))
val as3 = Literal.create(Seq(null, null), ArrayType(StringType))
val as4 = Literal.create(null, ArrayType(StringType))
val aa0 = Literal.create(Seq(Seq("a", "b"), Seq("c")), ArrayType(ArrayType(StringType)))
val aa1 = Literal.create(Seq(Seq("d"), Seq("e", "f")), ArrayType(ArrayType(StringType)))
checkEvaluation(Concat(Seq(as0)), Seq("a", "b", "c"))
checkEvaluation(Concat(Seq(as0, as1)), Seq("a", "b", "c"))
checkEvaluation(Concat(Seq(as1, as0)), Seq("a", "b", "c"))
checkEvaluation(Concat(Seq(as0, as0)), Seq("a", "b", "c", "a", "b", "c"))
checkEvaluation(Concat(Seq(as0, as2)), Seq("a", "b", "c", "d", null, "e"))
checkEvaluation(Concat(Seq(as0, as3, as2)), Seq("a", "b", "c", null, null, "d", null, "e"))
checkEvaluation(Concat(Seq(as4)), null)
checkEvaluation(Concat(Seq(as0, as4)), null)
checkEvaluation(Concat(Seq(as4, as0)), null)
checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f")))
}
} }

View file

@ -2228,16 +2228,6 @@ object functions {
*/ */
def base64(e: Column): Column = withExpr { Base64(e.expr) } def base64(e: Column): Column = withExpr { Base64(e.expr) }
/**
* Concatenates multiple input columns together into a single column.
* If all inputs are binary, concat returns an output as binary. Otherwise, it returns as string.
*
* @group string_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
/** /**
* Concatenates multiple input string columns together into a single string column, * Concatenates multiple input string columns together into a single string column,
* using the given separator. * using the given separator.
@ -3038,6 +3028,16 @@ object functions {
ArrayContains(column.expr, Literal(value)) ArrayContains(column.expr, Literal(value))
} }
/**
* Concatenates multiple input columns together into a single column.
* The function works with strings, binary and compatible array columns.
*
* @group collection_funcs
* @since 1.5.0
*/
@scala.annotation.varargs
def concat(exprs: Column*): Column = withExpr { Concat(exprs.map(_.expr)) }
/** /**
* Locates the position of the first occurrence of the value in the given array as long. * Locates the position of the first occurrence of the value in the given array as long.
* Returns null if either of the arguments are null. * Returns null if either of the arguments are null.

View file

@ -91,3 +91,65 @@ FROM (
encode(string(id + 3), 'utf-8') col4 encode(string(id + 3), 'utf-8') col4
FROM range(10) FROM range(10)
); );
CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES (
array(true, false), array(true),
array(2Y, 1Y), array(3Y, 4Y),
array(2S, 1S), array(3S, 4S),
array(2, 1), array(3, 4),
array(2L, 1L), array(3L, 4L),
array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809),
array(2.0D, 1.0D), array(3.0D, 4.0D),
array(float(2.0), float(1.0)), array(float(3.0), float(4.0)),
array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'),
array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
array(timestamp '2016-11-11 20:54:00.000'),
array('a', 'b'), array('c', 'd'),
array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')),
array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)),
array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4))
) AS various_arrays(
boolean_array1, boolean_array2,
tinyint_array1, tinyint_array2,
smallint_array1, smallint_array2,
int_array1, int_array2,
bigint_array1, bigint_array2,
decimal_array1, decimal_array2,
double_array1, double_array2,
float_array1, float_array2,
date_array1, data_array2,
timestamp_array1, timestamp_array2,
string_array1, string_array2,
array_array1, array_array2,
struct_array1, struct_array2,
map_array1, map_array2
);
-- Concatenate arrays of the same type
SELECT
(boolean_array1 || boolean_array2) boolean_array,
(tinyint_array1 || tinyint_array2) tinyint_array,
(smallint_array1 || smallint_array2) smallint_array,
(int_array1 || int_array2) int_array,
(bigint_array1 || bigint_array2) bigint_array,
(decimal_array1 || decimal_array2) decimal_array,
(double_array1 || double_array2) double_array,
(float_array1 || float_array2) float_array,
(date_array1 || data_array2) data_array,
(timestamp_array1 || timestamp_array2) timestamp_array,
(string_array1 || string_array2) string_array,
(array_array1 || array_array2) array_array,
(struct_array1 || struct_array2) struct_array,
(map_array1 || map_array2) map_array
FROM various_arrays;
-- Concatenate arrays of different types
SELECT
(tinyint_array1 || smallint_array2) ts_array,
(smallint_array1 || int_array2) si_array,
(int_array1 || bigint_array2) ib_array,
(double_array1 || float_array2) df_array,
(string_array1 || data_array2) std_array,
(timestamp_array1 || string_array2) tst_array,
(string_array1 || int_array2) sti_array
FROM various_arrays;

View file

@ -237,3 +237,81 @@ struct<col:binary>
78910 78910
891011 891011
9101112 9101112
-- !query 11
CREATE TEMPORARY VIEW various_arrays AS SELECT * FROM VALUES (
array(true, false), array(true),
array(2Y, 1Y), array(3Y, 4Y),
array(2S, 1S), array(3S, 4S),
array(2, 1), array(3, 4),
array(2L, 1L), array(3L, 4L),
array(9223372036854775809, 9223372036854775808), array(9223372036854775808, 9223372036854775809),
array(2.0D, 1.0D), array(3.0D, 4.0D),
array(float(2.0), float(1.0)), array(float(3.0), float(4.0)),
array(date '2016-03-14', date '2016-03-13'), array(date '2016-03-12', date '2016-03-11'),
array(timestamp '2016-11-15 20:54:00.000', timestamp '2016-11-12 20:54:00.000'),
array(timestamp '2016-11-11 20:54:00.000'),
array('a', 'b'), array('c', 'd'),
array(array('a', 'b'), array('c', 'd')), array(array('e'), array('f')),
array(struct('a', 1), struct('b', 2)), array(struct('c', 3), struct('d', 4)),
array(map('a', 1), map('b', 2)), array(map('c', 3), map('d', 4))
) AS various_arrays(
boolean_array1, boolean_array2,
tinyint_array1, tinyint_array2,
smallint_array1, smallint_array2,
int_array1, int_array2,
bigint_array1, bigint_array2,
decimal_array1, decimal_array2,
double_array1, double_array2,
float_array1, float_array2,
date_array1, data_array2,
timestamp_array1, timestamp_array2,
string_array1, string_array2,
array_array1, array_array2,
struct_array1, struct_array2,
map_array1, map_array2
)
-- !query 11 schema
struct<>
-- !query 11 output
-- !query 12
SELECT
(boolean_array1 || boolean_array2) boolean_array,
(tinyint_array1 || tinyint_array2) tinyint_array,
(smallint_array1 || smallint_array2) smallint_array,
(int_array1 || int_array2) int_array,
(bigint_array1 || bigint_array2) bigint_array,
(decimal_array1 || decimal_array2) decimal_array,
(double_array1 || double_array2) double_array,
(float_array1 || float_array2) float_array,
(date_array1 || data_array2) data_array,
(timestamp_array1 || timestamp_array2) timestamp_array,
(string_array1 || string_array2) string_array,
(array_array1 || array_array2) array_array,
(struct_array1 || struct_array2) struct_array,
(map_array1 || map_array2) map_array
FROM various_arrays
-- !query 12 schema
struct<boolean_array:array<boolean>,tinyint_array:array<tinyint>,smallint_array:array<smallint>,int_array:array<int>,bigint_array:array<bigint>,decimal_array:array<decimal(19,0)>,double_array:array<double>,float_array:array<float>,data_array:array<date>,timestamp_array:array<timestamp>,string_array:array<string>,array_array:array<array<string>>,struct_array:array<struct<col1:string,col2:int>>,map_array:array<map<string,int>>>
-- !query 12 output
[true,false,true] [2,1,3,4] [2,1,3,4] [2,1,3,4] [2,1,3,4] [9223372036854775809,9223372036854775808,9223372036854775808,9223372036854775809] [2.0,1.0,3.0,4.0] [2.0,1.0,3.0,4.0] [2016-03-14,2016-03-13,2016-03-12,2016-03-11] [2016-11-15 20:54:00.0,2016-11-12 20:54:00.0,2016-11-11 20:54:00.0] ["a","b","c","d"] [["a","b"],["c","d"],["e"],["f"]] [{"col1":"a","col2":1},{"col1":"b","col2":2},{"col1":"c","col2":3},{"col1":"d","col2":4}] [{"a":1},{"b":2},{"c":3},{"d":4}]
-- !query 13
SELECT
(tinyint_array1 || smallint_array2) ts_array,
(smallint_array1 || int_array2) si_array,
(int_array1 || bigint_array2) ib_array,
(double_array1 || float_array2) df_array,
(string_array1 || data_array2) std_array,
(timestamp_array1 || string_array2) tst_array,
(string_array1 || int_array2) sti_array
FROM various_arrays
-- !query 13 schema
struct<ts_array:array<smallint>,si_array:array<int>,ib_array:array<bigint>,df_array:array<double>,std_array:array<string>,tst_array:array<string>,sti_array:array<string>>
-- !query 13 output
[2,1,3,4] [2,1,3,4] [2,1,3,4] [2.0,1.0,3.0,4.0] ["a","b","2016-03-12","2016-03-11"] ["2016-11-15 20:54:00","2016-11-12 20:54:00","c","d"] ["a","b","3","4"]

View file

@ -617,6 +617,80 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
) )
} }
test("concat function - arrays") {
val nseqi : Seq[Int] = null
val nseqs : Seq[String] = null
val df = Seq(
(Seq(1), Seq(2, 3), Seq(5L, 6L), nseqi, Seq("a", "b", "c"), Seq("d", "e"), Seq("f"), nseqs),
(Seq(1, 0), Seq.empty[Int], Seq(2L), nseqi, Seq("a"), Seq.empty[String], Seq(null), nseqs)
).toDF("i1", "i2", "i3", "in", "s1", "s2", "s3", "sn")
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // switch codeGen on
// Simple test cases
checkAnswer(
df.selectExpr("array(1, 2, 3L)"),
Seq(Row(Seq(1L, 2L, 3L)), Row(Seq(1L, 2L, 3L)))
)
checkAnswer (
df.select(concat($"i1", $"s1")),
Seq(Row(Seq("1", "a", "b", "c")), Row(Seq("1", "0", "a")))
)
checkAnswer(
df.select(concat($"i1", $"i2", $"i3")),
Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
)
checkAnswer(
df.filter(dummyFilter($"i1")).select(concat($"i1", $"i2", $"i3")),
Seq(Row(Seq(1, 2, 3, 5, 6)), Row(Seq(1, 0, 2)))
)
checkAnswer(
df.selectExpr("concat(array(1, null), i2, i3)"),
Seq(Row(Seq(1, null, 2, 3, 5, 6)), Row(Seq(1, null, 2)))
)
checkAnswer(
df.select(concat($"s1", $"s2", $"s3")),
Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
)
checkAnswer(
df.selectExpr("concat(s1, s2, s3)"),
Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
)
checkAnswer(
df.filter(dummyFilter($"s1"))select(concat($"s1", $"s2", $"s3")),
Seq(Row(Seq("a", "b", "c", "d", "e", "f")), Row(Seq("a", null)))
)
// Null test cases
checkAnswer(
df.select(concat($"i1", $"in")),
Seq(Row(null), Row(null))
)
checkAnswer(
df.select(concat($"in", $"i1")),
Seq(Row(null), Row(null))
)
checkAnswer(
df.select(concat($"s1", $"sn")),
Seq(Row(null), Row(null))
)
checkAnswer(
df.select(concat($"sn", $"s1")),
Seq(Row(null), Row(null))
)
// Type error test cases
intercept[AnalysisException] {
df.selectExpr("concat(i1, i2, null)")
}
intercept[AnalysisException] {
df.selectExpr("concat(i1, array(i1, i2))")
}
}
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = { private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
import DataFrameFunctionsSuite.CodegenFallbackExpr import DataFrameFunctionsSuite.CodegenFallbackExpr
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) { for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {

View file

@ -1742,8 +1742,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils {
sql("DESCRIBE FUNCTION 'concat'"), sql("DESCRIBE FUNCTION 'concat'"),
Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") :: Row("Class: org.apache.spark.sql.catalyst.expressions.Concat") ::
Row("Function: concat") :: Row("Function: concat") ::
Row("Usage: concat(str1, str2, ..., strN) - " + Row("Usage: concat(col1, col2, ..., colN) - " +
"Returns the concatenation of str1, str2, ..., strN.") :: Nil "Returns the concatenation of col1, col2, ..., colN.") :: Nil
) )
// extended mode // extended mode
checkAnswer( checkAnswer(