[SPARK-23821][SQL] Collection function: flatten
## What changes were proposed in this pull request? This PR adds a new collection function that transforms an array of arrays into a single array. The PR comprises: - An expression for flattening array structure - Flatten function - A wrapper for PySpark ## How was this patch tested? New tests added into: - CollectionExpressionsSuite - DataFrameFunctionsSuite ## Codegen examples ### Primitive type ``` val df = Seq( Seq(Seq(1, 2), Seq(4, 5)), Seq(null, Seq(1)) ).toDF("i") df.filter($"i".isNotNull || $"i".isNull).select(flatten($"i")).debugCodegen ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ boolean filter_value = true; /* 038 */ /* 039 */ if (!(!inputadapter_isNull)) { /* 040 */ filter_value = inputadapter_isNull; /* 041 */ } /* 042 */ if (!filter_value) continue; /* 043 */ /* 044 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 045 */ /* 046 */ boolean project_isNull = inputadapter_isNull; /* 047 */ ArrayData project_value = null; /* 048 */ /* 049 */ if (!inputadapter_isNull) { /* 050 */ for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) { /* 051 */ project_isNull |= inputadapter_value.isNullAt(z); /* 052 */ } /* 053 */ if (!project_isNull) { /* 054 */ long project_numElements = 0; /* 055 */ for (int z = 0; z < inputadapter_value.numElements(); z++) { /* 056 */ project_numElements += inputadapter_value.getArray(z).numElements(); /* 057 */ } /* 058 */ if (project_numElements > 2147483632) { /* 059 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 060 */ project_numElements + " elements due to exceeding the array size limit 2147483632."); /* 061 */ } /* 062 */ /* 063 */ long project_size = UnsafeArrayData.calculateSizeOfUnderlyingByteArray( /* 064 */ project_numElements, /* 065 */ 4); /* 066 */ if (project_size > 2147483632) { /* 067 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 068 */ project_size + " bytes of data due to exceeding the limit 2147483632" + /* 069 */ " bytes for UnsafeArrayData."); /* 070 */ } /* 071 */ /* 072 */ byte[] project_array = new byte[(int)project_size]; /* 073 */ UnsafeArrayData project_tempArrayData = new UnsafeArrayData(); /* 074 */ Platform.putLong(project_array, 16, project_numElements); /* 075 */ project_tempArrayData.pointTo(project_array, 16, (int)project_size); /* 076 */ int project_counter = 0; /* 077 */ for (int k = 0; k < inputadapter_value.numElements(); k++) { /* 078 */ ArrayData arr = inputadapter_value.getArray(k); /* 079 */ for (int l = 0; l < arr.numElements(); l++) { /* 080 */ if (arr.isNullAt(l)) { /* 081 */ project_tempArrayData.setNullAt(project_counter); /* 082 */ } else { /* 083 */ project_tempArrayData.setInt( /* 084 */ project_counter, /* 085 */ arr.getInt(l) /* 086 */ ); /* 087 */ } /* 088 */ project_counter++; /* 089 */ } /* 090 */ } /* 091 */ project_value = project_tempArrayData; /* 092 */ /* 093 */ } /* 094 */ /* 095 */ } ``` ### Non-primitive type ``` val df = Seq( Seq(Seq("a", "b"), Seq(null, "d")), Seq(null, Seq("a")) ).toDF("s") df.filter($"s".isNotNull || $"s".isNull).select(flatten($"s")).debugCodegen ``` Result: ``` /* 033 */ boolean inputadapter_isNull = inputadapter_row.isNullAt(0); /* 034 */ ArrayData inputadapter_value = inputadapter_isNull ? /* 035 */ null : (inputadapter_row.getArray(0)); /* 036 */ /* 037 */ boolean filter_value = true; /* 038 */ /* 039 */ if (!(!inputadapter_isNull)) { /* 040 */ filter_value = inputadapter_isNull; /* 041 */ } /* 042 */ if (!filter_value) continue; /* 043 */ /* 044 */ ((org.apache.spark.sql.execution.metric.SQLMetric) references[0] /* numOutputRows */).add(1); /* 045 */ /* 046 */ boolean project_isNull = inputadapter_isNull; /* 047 */ ArrayData project_value = null; /* 048 */ /* 049 */ if (!inputadapter_isNull) { /* 050 */ for (int z = 0; !project_isNull && z < inputadapter_value.numElements(); z++) { /* 051 */ project_isNull |= inputadapter_value.isNullAt(z); /* 052 */ } /* 053 */ if (!project_isNull) { /* 054 */ long project_numElements = 0; /* 055 */ for (int z = 0; z < inputadapter_value.numElements(); z++) { /* 056 */ project_numElements += inputadapter_value.getArray(z).numElements(); /* 057 */ } /* 058 */ if (project_numElements > 2147483632) { /* 059 */ throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " + /* 060 */ project_numElements + " elements due to exceeding the array size limit 2147483632."); /* 061 */ } /* 062 */ /* 063 */ Object[] project_arrayObject = new Object[(int)project_numElements]; /* 064 */ int project_counter = 0; /* 065 */ for (int k = 0; k < inputadapter_value.numElements(); k++) { /* 066 */ ArrayData arr = inputadapter_value.getArray(k); /* 067 */ for (int l = 0; l < arr.numElements(); l++) { /* 068 */ project_arrayObject[project_counter] = arr.getUTF8String(l); /* 069 */ project_counter++; /* 070 */ } /* 071 */ } /* 072 */ project_value = new org.apache.spark.sql.catalyst.util.GenericArrayData(project_arrayObject); /* 073 */ /* 074 */ } /* 075 */ /* 076 */ } ``` Author: mn-mikke <mrkAha12346github> Closes #20938 from mn-mikke/feature/array-api-flatten-to-master.
This commit is contained in:
parent
d6c26d1c9a
commit
5fea17b3be
|
@ -2191,6 +2191,23 @@ def reverse(col):
|
|||
return Column(sc._jvm.functions.reverse(_to_java_column(col)))
|
||||
|
||||
|
||||
@since(2.4)
|
||||
def flatten(col):
|
||||
"""
|
||||
Collection function: creates a single array from an array of arrays.
|
||||
If a structure of nested arrays is deeper than two levels,
|
||||
only one level of nesting is removed.
|
||||
|
||||
:param col: name of column or expression
|
||||
|
||||
>>> df = spark.createDataFrame([([[1, 2, 3], [4, 5], [6]],), ([None, [4, 5]],)], ['data'])
|
||||
>>> df.select(flatten(df.data).alias('r')).collect()
|
||||
[Row(r=[1, 2, 3, 4, 5, 6]), Row(r=None)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.flatten(_to_java_column(col)))
|
||||
|
||||
|
||||
@since(2.3)
|
||||
def map_keys(col):
|
||||
"""
|
||||
|
|
|
@ -413,6 +413,7 @@ object FunctionRegistry {
|
|||
expression[ArrayMax]("array_max"),
|
||||
expression[Reverse]("reverse"),
|
||||
expression[Concat]("concat"),
|
||||
expression[Flatten]("flatten"),
|
||||
CreateStruct.registryEntry,
|
||||
|
||||
// misc functions
|
||||
|
|
|
@ -883,3 +883,179 @@ case class Concat(children: Seq[Expression]) extends Expression {
|
|||
|
||||
override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
|
||||
}
|
||||
|
||||
/**
|
||||
* Transforms an array of arrays into a single array.
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_(array(array(1, 2), array(3, 4));
|
||||
[1,2,3,4]
|
||||
""",
|
||||
since = "2.4.0")
|
||||
case class Flatten(child: Expression) extends UnaryExpression {
|
||||
|
||||
private val MAX_ARRAY_LENGTH = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
|
||||
|
||||
private lazy val childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
|
||||
|
||||
override def nullable: Boolean = child.nullable || childDataType.containsNull
|
||||
|
||||
override def dataType: DataType = childDataType.elementType
|
||||
|
||||
lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
|
||||
case ArrayType(_: ArrayType, _) =>
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
case _ =>
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"The argument should be an array of arrays, " +
|
||||
s"but '${child.sql}' is of ${child.dataType.simpleString} type."
|
||||
)
|
||||
}
|
||||
|
||||
override def nullSafeEval(child: Any): Any = {
|
||||
val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType)
|
||||
|
||||
if (elements.contains(null)) {
|
||||
null
|
||||
} else {
|
||||
val arrayData = elements.map(_.asInstanceOf[ArrayData])
|
||||
val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements())
|
||||
if (numberOfElements > MAX_ARRAY_LENGTH) {
|
||||
throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
|
||||
s"$numberOfElements elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.")
|
||||
}
|
||||
val flattenedData = new Array(numberOfElements.toInt)
|
||||
var position = 0
|
||||
for (ad <- arrayData) {
|
||||
val arr = ad.toObjectArray(elementType)
|
||||
Array.copy(arr, 0, flattenedData, position, arr.length)
|
||||
position += arr.length
|
||||
}
|
||||
new GenericArrayData(flattenedData)
|
||||
}
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
nullSafeCodeGen(ctx, ev, c => {
|
||||
val code = if (CodeGenerator.isPrimitiveType(elementType)) {
|
||||
genCodeForFlattenOfPrimitiveElements(ctx, c, ev.value)
|
||||
} else {
|
||||
genCodeForFlattenOfNonPrimitiveElements(ctx, c, ev.value)
|
||||
}
|
||||
if (childDataType.containsNull) nullElementsProtection(ev, c, code) else code
|
||||
})
|
||||
}
|
||||
|
||||
private def nullElementsProtection(
|
||||
ev: ExprCode,
|
||||
childVariableName: String,
|
||||
coreLogic: String): String = {
|
||||
s"""
|
||||
|for (int z = 0; !${ev.isNull} && z < $childVariableName.numElements(); z++) {
|
||||
| ${ev.isNull} |= $childVariableName.isNullAt(z);
|
||||
|}
|
||||
|if (!${ev.isNull}) {
|
||||
| $coreLogic
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
private def genCodeForNumberOfElements(
|
||||
ctx: CodegenContext,
|
||||
childVariableName: String) : (String, String) = {
|
||||
val variableName = ctx.freshName("numElements")
|
||||
val code = s"""
|
||||
|long $variableName = 0;
|
||||
|for (int z = 0; z < $childVariableName.numElements(); z++) {
|
||||
| $variableName += $childVariableName.getArray(z).numElements();
|
||||
|}
|
||||
|if ($variableName > $MAX_ARRAY_LENGTH) {
|
||||
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
|
||||
| $variableName + " elements due to exceeding the array size limit $MAX_ARRAY_LENGTH.");
|
||||
|}
|
||||
""".stripMargin
|
||||
(code, variableName)
|
||||
}
|
||||
|
||||
private def genCodeForFlattenOfPrimitiveElements(
|
||||
ctx: CodegenContext,
|
||||
childVariableName: String,
|
||||
arrayDataName: String): String = {
|
||||
val arrayName = ctx.freshName("array")
|
||||
val arraySizeName = ctx.freshName("size")
|
||||
val counter = ctx.freshName("counter")
|
||||
val tempArrayDataName = ctx.freshName("tempArrayData")
|
||||
|
||||
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)
|
||||
|
||||
val unsafeArraySizeInBytes = s"""
|
||||
|long $arraySizeName = UnsafeArrayData.calculateSizeOfUnderlyingByteArray(
|
||||
| $numElemName,
|
||||
| ${elementType.defaultSize});
|
||||
|if ($arraySizeName > $MAX_ARRAY_LENGTH) {
|
||||
| throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
|
||||
| $arraySizeName + " bytes of data due to exceeding the limit $MAX_ARRAY_LENGTH" +
|
||||
| " bytes for UnsafeArrayData.");
|
||||
|}
|
||||
""".stripMargin
|
||||
val baseOffset = Platform.BYTE_ARRAY_OFFSET
|
||||
|
||||
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
|
||||
|
||||
s"""
|
||||
|$numElemCode
|
||||
|$unsafeArraySizeInBytes
|
||||
|byte[] $arrayName = new byte[(int)$arraySizeName];
|
||||
|UnsafeArrayData $tempArrayDataName = new UnsafeArrayData();
|
||||
|Platform.putLong($arrayName, $baseOffset, $numElemName);
|
||||
|$tempArrayDataName.pointTo($arrayName, $baseOffset, (int)$arraySizeName);
|
||||
|int $counter = 0;
|
||||
|for (int k = 0; k < $childVariableName.numElements(); k++) {
|
||||
| ArrayData arr = $childVariableName.getArray(k);
|
||||
| for (int l = 0; l < arr.numElements(); l++) {
|
||||
| if (arr.isNullAt(l)) {
|
||||
| $tempArrayDataName.setNullAt($counter);
|
||||
| } else {
|
||||
| $tempArrayDataName.set$primitiveValueTypeName(
|
||||
| $counter,
|
||||
| ${CodeGenerator.getValue("arr", elementType, "l")}
|
||||
| );
|
||||
| }
|
||||
| $counter++;
|
||||
| }
|
||||
|}
|
||||
|$arrayDataName = $tempArrayDataName;
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
private def genCodeForFlattenOfNonPrimitiveElements(
|
||||
ctx: CodegenContext,
|
||||
childVariableName: String,
|
||||
arrayDataName: String): String = {
|
||||
val genericArrayClass = classOf[GenericArrayData].getName
|
||||
val arrayName = ctx.freshName("arrayObject")
|
||||
val counter = ctx.freshName("counter")
|
||||
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)
|
||||
|
||||
s"""
|
||||
|$numElemCode
|
||||
|Object[] $arrayName = new Object[(int)$numElemName];
|
||||
|int $counter = 0;
|
||||
|for (int k = 0; k < $childVariableName.numElements(); k++) {
|
||||
| ArrayData arr = $childVariableName.getArray(k);
|
||||
| for (int l = 0; l < arr.numElements(); l++) {
|
||||
| $arrayName[$counter] = ${CodeGenerator.getValue("arr", elementType, "l")};
|
||||
| $counter++;
|
||||
| }
|
||||
|}
|
||||
|$arrayDataName = new $genericArrayClass($arrayName);
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
override def prettyName: String = "flatten"
|
||||
}
|
||||
|
|
|
@ -280,4 +280,99 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
|
||||
checkEvaluation(Concat(Seq(aa0, aa1)), Seq(Seq("a", "b"), Seq("c"), Seq("d"), Seq("e", "f")))
|
||||
}
|
||||
|
||||
test("Flatten") {
|
||||
// Primitive-type test cases
|
||||
val intArrayType = ArrayType(ArrayType(IntegerType))
|
||||
|
||||
// Main test cases (primitive type)
|
||||
val aim1 = Literal.create(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6)), intArrayType)
|
||||
val aim2 = Literal.create(Seq(Seq(1, 2, 3)), intArrayType)
|
||||
|
||||
checkEvaluation(Flatten(aim1), Seq(1, 2, 3, 4, 5, 6))
|
||||
checkEvaluation(Flatten(aim2), Seq(1, 2, 3))
|
||||
|
||||
// Test cases with an empty array (primitive type)
|
||||
val aie1 = Literal.create(Seq(Seq.empty, Seq(1, 2), Seq(3, 4)), intArrayType)
|
||||
val aie2 = Literal.create(Seq(Seq(1, 2), Seq.empty, Seq(3, 4)), intArrayType)
|
||||
val aie3 = Literal.create(Seq(Seq(1, 2), Seq(3, 4), Seq.empty), intArrayType)
|
||||
val aie4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), intArrayType)
|
||||
val aie5 = Literal.create(Seq(Seq.empty), intArrayType)
|
||||
val aie6 = Literal.create(Seq.empty, intArrayType)
|
||||
|
||||
checkEvaluation(Flatten(aie1), Seq(1, 2, 3, 4))
|
||||
checkEvaluation(Flatten(aie2), Seq(1, 2, 3, 4))
|
||||
checkEvaluation(Flatten(aie3), Seq(1, 2, 3, 4))
|
||||
checkEvaluation(Flatten(aie4), Seq.empty)
|
||||
checkEvaluation(Flatten(aie5), Seq.empty)
|
||||
checkEvaluation(Flatten(aie6), Seq.empty)
|
||||
|
||||
// Test cases with null elements (primitive type)
|
||||
val ain1 = Literal.create(Seq(Seq(null, null, null), Seq(4, null)), intArrayType)
|
||||
val ain2 = Literal.create(Seq(Seq(null, 2, null), Seq(null, null)), intArrayType)
|
||||
val ain3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), intArrayType)
|
||||
|
||||
checkEvaluation(Flatten(ain1), Seq(null, null, null, 4, null))
|
||||
checkEvaluation(Flatten(ain2), Seq(null, 2, null, null, null))
|
||||
checkEvaluation(Flatten(ain3), Seq(null, null, null, null))
|
||||
|
||||
// Test cases with a null array (primitive type)
|
||||
val aia1 = Literal.create(Seq(null, Seq(1, 2)), intArrayType)
|
||||
val aia2 = Literal.create(Seq(Seq(1, 2), null), intArrayType)
|
||||
val aia3 = Literal.create(Seq(null), intArrayType)
|
||||
val aia4 = Literal.create(null, intArrayType)
|
||||
|
||||
checkEvaluation(Flatten(aia1), null)
|
||||
checkEvaluation(Flatten(aia2), null)
|
||||
checkEvaluation(Flatten(aia3), null)
|
||||
checkEvaluation(Flatten(aia4), null)
|
||||
|
||||
// Non-primitive-type test cases
|
||||
val strArrayType = ArrayType(ArrayType(StringType))
|
||||
val arrArrayType = ArrayType(ArrayType(ArrayType(StringType)))
|
||||
|
||||
// Main test cases (non-primitive type)
|
||||
val asm1 = Literal.create(Seq(Seq("a"), Seq("b", "c"), Seq("d", "e", "f")), strArrayType)
|
||||
val asm2 = Literal.create(Seq(Seq("a", "b")), strArrayType)
|
||||
val asm3 = Literal.create(Seq(Seq(Seq("a", "b"), Seq("c")), Seq(Seq("d", "e"))), arrArrayType)
|
||||
|
||||
checkEvaluation(Flatten(asm1), Seq("a", "b", "c", "d", "e", "f"))
|
||||
checkEvaluation(Flatten(asm2), Seq("a", "b"))
|
||||
checkEvaluation(Flatten(asm3), Seq(Seq("a", "b"), Seq("c"), Seq("d", "e")))
|
||||
|
||||
// Test cases with an empty array (non-primitive type)
|
||||
val ase1 = Literal.create(Seq(Seq.empty, Seq("a", "b"), Seq("c", "d")), strArrayType)
|
||||
val ase2 = Literal.create(Seq(Seq("a", "b"), Seq.empty, Seq("c", "d")), strArrayType)
|
||||
val ase3 = Literal.create(Seq(Seq("a", "b"), Seq("c", "d"), Seq.empty), strArrayType)
|
||||
val ase4 = Literal.create(Seq(Seq.empty, Seq.empty, Seq.empty), strArrayType)
|
||||
val ase5 = Literal.create(Seq(Seq.empty), strArrayType)
|
||||
val ase6 = Literal.create(Seq.empty, strArrayType)
|
||||
|
||||
checkEvaluation(Flatten(ase1), Seq("a", "b", "c", "d"))
|
||||
checkEvaluation(Flatten(ase2), Seq("a", "b", "c", "d"))
|
||||
checkEvaluation(Flatten(ase3), Seq("a", "b", "c", "d"))
|
||||
checkEvaluation(Flatten(ase4), Seq.empty)
|
||||
checkEvaluation(Flatten(ase5), Seq.empty)
|
||||
checkEvaluation(Flatten(ase6), Seq.empty)
|
||||
|
||||
// Test cases with null elements (non-primitive type)
|
||||
val asn1 = Literal.create(Seq(Seq(null, null, "c"), Seq(null, null)), strArrayType)
|
||||
val asn2 = Literal.create(Seq(Seq(null, null, null), Seq("d", null)), strArrayType)
|
||||
val asn3 = Literal.create(Seq(Seq(null, null), Seq(null, null)), strArrayType)
|
||||
|
||||
checkEvaluation(Flatten(asn1), Seq(null, null, "c", null, null))
|
||||
checkEvaluation(Flatten(asn2), Seq(null, null, null, "d", null))
|
||||
checkEvaluation(Flatten(asn3), Seq(null, null, null, null))
|
||||
|
||||
// Test cases with a null array (non-primitive type)
|
||||
val asa1 = Literal.create(Seq(null, Seq("a", "b")), strArrayType)
|
||||
val asa2 = Literal.create(Seq(Seq("a", "b"), null), strArrayType)
|
||||
val asa3 = Literal.create(Seq(null), strArrayType)
|
||||
val asa4 = Literal.create(null, strArrayType)
|
||||
|
||||
checkEvaluation(Flatten(asa1), null)
|
||||
checkEvaluation(Flatten(asa2), null)
|
||||
checkEvaluation(Flatten(asa3), null)
|
||||
checkEvaluation(Flatten(asa4), null)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3340,6 +3340,14 @@ object functions {
|
|||
*/
|
||||
def reverse(e: Column): Column = withExpr { Reverse(e.expr) }
|
||||
|
||||
/**
|
||||
* Creates a single array from an array of arrays. If a structure of nested arrays is deeper than
|
||||
* two levels, only one level of nesting is removed.
|
||||
* @group collection_funcs
|
||||
* @since 2.4.0
|
||||
*/
|
||||
def flatten(e: Column): Column = withExpr { Flatten(e.expr) }
|
||||
|
||||
/**
|
||||
* Returns an unordered array containing the keys of the map.
|
||||
* @group collection_funcs
|
||||
|
|
|
@ -691,6 +691,85 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
}
|
||||
}
|
||||
|
||||
test("flatten function") {
|
||||
val dummyFilter = (c: Column) => c.isNull || c.isNotNull // to switch codeGen on
|
||||
val oneRowDF = Seq((1, "a", Seq(1, 2, 3))).toDF("i", "s", "arr")
|
||||
|
||||
// Test cases with a primitive type
|
||||
val intDF = Seq(
|
||||
(Seq(Seq(1, 2, 3), Seq(4, 5), Seq(6))),
|
||||
(Seq(Seq(1, 2))),
|
||||
(Seq(Seq(1), Seq.empty)),
|
||||
(Seq(Seq.empty, Seq(1))),
|
||||
(Seq(Seq.empty, Seq.empty)),
|
||||
(Seq(Seq(1), null)),
|
||||
(Seq(null, Seq(1))),
|
||||
(Seq(null, null))
|
||||
).toDF("i")
|
||||
|
||||
val intDFResult = Seq(
|
||||
Row(Seq(1, 2, 3, 4, 5, 6)),
|
||||
Row(Seq(1, 2)),
|
||||
Row(Seq(1)),
|
||||
Row(Seq(1)),
|
||||
Row(Seq.empty),
|
||||
Row(null),
|
||||
Row(null),
|
||||
Row(null))
|
||||
|
||||
checkAnswer(intDF.select(flatten($"i")), intDFResult)
|
||||
checkAnswer(intDF.filter(dummyFilter($"i"))select(flatten($"i")), intDFResult)
|
||||
checkAnswer(intDF.selectExpr("flatten(i)"), intDFResult)
|
||||
checkAnswer(
|
||||
oneRowDF.selectExpr("flatten(array(arr, array(null, 5), array(6, null)))"),
|
||||
Seq(Row(Seq(1, 2, 3, null, 5, 6, null))))
|
||||
|
||||
// Test cases with non-primitive types
|
||||
val strDF = Seq(
|
||||
(Seq(Seq("a", "b"), Seq("c"), Seq("d", "e", "f"))),
|
||||
(Seq(Seq("a", "b"))),
|
||||
(Seq(Seq("a", null), Seq(null, "b"), Seq(null, null))),
|
||||
(Seq(Seq("a"), Seq.empty)),
|
||||
(Seq(Seq.empty, Seq("a"))),
|
||||
(Seq(Seq.empty, Seq.empty)),
|
||||
(Seq(Seq("a"), null)),
|
||||
(Seq(null, Seq("a"))),
|
||||
(Seq(null, null))
|
||||
).toDF("s")
|
||||
|
||||
val strDFResult = Seq(
|
||||
Row(Seq("a", "b", "c", "d", "e", "f")),
|
||||
Row(Seq("a", "b")),
|
||||
Row(Seq("a", null, null, "b", null, null)),
|
||||
Row(Seq("a")),
|
||||
Row(Seq("a")),
|
||||
Row(Seq.empty),
|
||||
Row(null),
|
||||
Row(null),
|
||||
Row(null))
|
||||
|
||||
checkAnswer(strDF.select(flatten($"s")), strDFResult)
|
||||
checkAnswer(strDF.filter(dummyFilter($"s")).select(flatten($"s")), strDFResult)
|
||||
checkAnswer(strDF.selectExpr("flatten(s)"), strDFResult)
|
||||
checkAnswer(
|
||||
oneRowDF.selectExpr("flatten(array(array(arr, arr), array(arr)))"),
|
||||
Seq(Row(Seq(Seq(1, 2, 3), Seq(1, 2, 3), Seq(1, 2, 3)))))
|
||||
|
||||
// Error test cases
|
||||
intercept[AnalysisException] {
|
||||
oneRowDF.select(flatten($"arr"))
|
||||
}
|
||||
intercept[AnalysisException] {
|
||||
oneRowDF.select(flatten($"i"))
|
||||
}
|
||||
intercept[AnalysisException] {
|
||||
oneRowDF.select(flatten($"s"))
|
||||
}
|
||||
intercept[AnalysisException] {
|
||||
oneRowDF.selectExpr("flatten(null)")
|
||||
}
|
||||
}
|
||||
|
||||
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {
|
||||
import DataFrameFunctionsSuite.CodegenFallbackExpr
|
||||
for ((codegenFallback, wholeStage) <- Seq((true, false), (false, false), (false, true))) {
|
||||
|
|
Loading…
Reference in a new issue