[SPARK-23922][SQL] Add arrays_overlap function
## What changes were proposed in this pull request? The PR adds the function `arrays_overlap`. This function returns `true` if the input arrays contain a non-null common element; if not, it returns `null` if any of the arrays contains a `null` element, `false` otherwise. ## How was this patch tested? added UTs Author: Marco Gaido <marcogaido91@gmail.com> Closes #21028 from mgaido91/SPARK-23922.
This commit is contained in:
parent
6ec05826d7
commit
69350aa2f0
|
@ -1855,6 +1855,21 @@ def array_contains(col, value):
|
|||
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
|
||||
|
||||
|
||||
@since(2.4)
|
||||
def arrays_overlap(a1, a2):
|
||||
"""
|
||||
Collection function: returns true if the arrays contain any common non-null element; if not,
|
||||
returns null if both the arrays are non-empty and any of them contains a null element; returns
|
||||
false otherwise.
|
||||
|
||||
>>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y'])
|
||||
>>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
|
||||
[Row(overlap=True), Row(overlap=False)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2)))
|
||||
|
||||
|
||||
@since(2.4)
|
||||
def slice(x, start, length):
|
||||
"""
|
||||
|
|
|
@ -410,6 +410,7 @@ object FunctionRegistry {
|
|||
// collection functions
|
||||
expression[CreateArray]("array"),
|
||||
expression[ArrayContains]("array_contains"),
|
||||
expression[ArraysOverlap]("arrays_overlap"),
|
||||
expression[ArrayJoin]("array_join"),
|
||||
expression[ArrayPosition]("array_position"),
|
||||
expression[ArraySort]("array_sort"),
|
||||
|
|
|
@ -18,15 +18,51 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import java.util.Comparator
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
|
||||
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.array.ByteArrayMethods
|
||||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
||||
|
||||
/**
|
||||
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
|
||||
* casting.
|
||||
*/
|
||||
trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
|
||||
with ImplicitCastInputTypes {
|
||||
|
||||
@transient protected lazy val elementType: DataType =
|
||||
inputTypes.head.asInstanceOf[ArrayType].elementType
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = {
|
||||
(left.dataType, right.dataType) match {
|
||||
case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) =>
|
||||
TypeCoercion.findTightestCommonType(e1, e2) match {
|
||||
case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2))
|
||||
case _ => Seq.empty
|
||||
}
|
||||
case _ => Seq.empty
|
||||
}
|
||||
}
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
(left.dataType, right.dataType) match {
|
||||
case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) =>
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " +
|
||||
s"been two ${ArrayType.simpleString}s with same element type, but it's " +
|
||||
s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Given an array or map, returns its size. Returns -1 if null.
|
||||
*/
|
||||
|
@ -529,6 +565,235 @@ case class ArrayContains(left: Expression, right: Expression)
|
|||
override def prettyName: String = "array_contains"
|
||||
}
|
||||
|
||||
/**
|
||||
* Checks if the two arrays contain at least one common element.
|
||||
*/
|
||||
// scalastyle:off line.size.limit
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5));
|
||||
true
|
||||
""", since = "2.4.0")
|
||||
// scalastyle:off line.size.limit
|
||||
case class ArraysOverlap(left: Expression, right: Expression)
|
||||
extends BinaryArrayExpressionWithImplicitCast {
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
|
||||
case TypeCheckResult.TypeCheckSuccess =>
|
||||
if (RowOrdering.isOrderable(elementType)) {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.")
|
||||
}
|
||||
case failure => failure
|
||||
}
|
||||
|
||||
@transient private lazy val ordering: Ordering[Any] =
|
||||
TypeUtils.getInterpretedOrdering(elementType)
|
||||
|
||||
@transient private lazy val elementTypeSupportEquals = elementType match {
|
||||
case BinaryType => false
|
||||
case _: AtomicType => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
@transient private lazy val doEvaluation = if (elementTypeSupportEquals) {
|
||||
fastEval _
|
||||
} else {
|
||||
bruteForceEval _
|
||||
}
|
||||
|
||||
override def dataType: DataType = BooleanType
|
||||
|
||||
override def nullable: Boolean = {
|
||||
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
|
||||
right.dataType.asInstanceOf[ArrayType].containsNull
|
||||
}
|
||||
|
||||
override def nullSafeEval(a1: Any, a2: Any): Any = {
|
||||
doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData])
|
||||
}
|
||||
|
||||
/**
|
||||
* A fast implementation which puts all the elements from the smaller array in a set
|
||||
* and then performs a lookup on it for each element of the bigger one.
|
||||
* This eval mode works only for data types which implements properly the equals method.
|
||||
*/
|
||||
private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = {
|
||||
var hasNull = false
|
||||
val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) {
|
||||
(arr1, arr2)
|
||||
} else {
|
||||
(arr2, arr1)
|
||||
}
|
||||
if (smaller.numElements() > 0) {
|
||||
val smallestSet = new mutable.HashSet[Any]
|
||||
smaller.foreach(elementType, (_, v) =>
|
||||
if (v == null) {
|
||||
hasNull = true
|
||||
} else {
|
||||
smallestSet += v
|
||||
})
|
||||
bigger.foreach(elementType, (_, v1) =>
|
||||
if (v1 == null) {
|
||||
hasNull = true
|
||||
} else if (smallestSet.contains(v1)) {
|
||||
return true
|
||||
}
|
||||
)
|
||||
}
|
||||
if (hasNull) {
|
||||
null
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A slower evaluation which performs a nested loop and supports all the data types.
|
||||
*/
|
||||
private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = {
|
||||
var hasNull = false
|
||||
if (arr1.numElements() > 0 && arr2.numElements() > 0) {
|
||||
arr1.foreach(elementType, (_, v1) =>
|
||||
if (v1 == null) {
|
||||
hasNull = true
|
||||
} else {
|
||||
arr2.foreach(elementType, (_, v2) =>
|
||||
if (v2 == null) {
|
||||
hasNull = true
|
||||
} else if (ordering.equiv(v1, v2)) {
|
||||
return true
|
||||
}
|
||||
)
|
||||
})
|
||||
}
|
||||
if (hasNull) {
|
||||
null
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
nullSafeCodeGen(ctx, ev, (a1, a2) => {
|
||||
val smaller = ctx.freshName("smallerArray")
|
||||
val bigger = ctx.freshName("biggerArray")
|
||||
val comparisonCode = if (elementTypeSupportEquals) {
|
||||
fastCodegen(ctx, ev, smaller, bigger)
|
||||
} else {
|
||||
bruteForceCodegen(ctx, ev, smaller, bigger)
|
||||
}
|
||||
s"""
|
||||
|ArrayData $smaller;
|
||||
|ArrayData $bigger;
|
||||
|if ($a1.numElements() > $a2.numElements()) {
|
||||
| $bigger = $a1;
|
||||
| $smaller = $a2;
|
||||
|} else {
|
||||
| $smaller = $a1;
|
||||
| $bigger = $a2;
|
||||
|}
|
||||
|if ($smaller.numElements() > 0) {
|
||||
| $comparisonCode
|
||||
|}
|
||||
""".stripMargin
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Code generation for a fast implementation which puts all the elements from the smaller array
|
||||
* in a set and then performs a lookup on it for each element of the bigger one.
|
||||
* It works only for data types which implements properly the equals method.
|
||||
*/
|
||||
private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
|
||||
val i = ctx.freshName("i")
|
||||
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i)
|
||||
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
|
||||
val javaElementClass = CodeGenerator.boxedType(elementType)
|
||||
val javaSet = classOf[java.util.HashSet[_]].getName
|
||||
val set = ctx.freshName("set")
|
||||
val addToSetFromSmallerCode = nullSafeElementCodegen(
|
||||
smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;")
|
||||
val elementIsInSetCode = nullSafeElementCodegen(
|
||||
bigger,
|
||||
i,
|
||||
s"""
|
||||
|if ($set.contains($getFromBigger)) {
|
||||
| ${ev.isNull} = false;
|
||||
| ${ev.value} = true;
|
||||
| break;
|
||||
|}
|
||||
""".stripMargin,
|
||||
s"${ev.isNull} = true;")
|
||||
s"""
|
||||
|$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>();
|
||||
|for (int $i = 0; $i < $smaller.numElements(); $i ++) {
|
||||
| $addToSetFromSmallerCode
|
||||
|}
|
||||
|for (int $i = 0; $i < $bigger.numElements(); $i ++) {
|
||||
| $elementIsInSetCode
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
/**
|
||||
* Code generation for a slower evaluation which performs a nested loop and supports all the data types.
|
||||
*/
|
||||
private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
|
||||
val i = ctx.freshName("i")
|
||||
val j = ctx.freshName("j")
|
||||
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j)
|
||||
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
|
||||
val compareValues = nullSafeElementCodegen(
|
||||
smaller,
|
||||
j,
|
||||
s"""
|
||||
|if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) {
|
||||
| ${ev.isNull} = false;
|
||||
| ${ev.value} = true;
|
||||
|}
|
||||
""".stripMargin,
|
||||
s"${ev.isNull} = true;")
|
||||
val isInSmaller = nullSafeElementCodegen(
|
||||
bigger,
|
||||
i,
|
||||
s"""
|
||||
|for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) {
|
||||
| $compareValues
|
||||
|}
|
||||
""".stripMargin,
|
||||
s"${ev.isNull} = true;")
|
||||
s"""
|
||||
|for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) {
|
||||
| $isInSmaller
|
||||
|}
|
||||
""".stripMargin
|
||||
}
|
||||
|
||||
def nullSafeElementCodegen(
|
||||
arrayVar: String,
|
||||
index: String,
|
||||
code: String,
|
||||
isNullCode: String): String = {
|
||||
if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) {
|
||||
s"""
|
||||
|if ($arrayVar.isNullAt($index)) {
|
||||
| $isNullCode
|
||||
|} else {
|
||||
| $code
|
||||
|}
|
||||
""".stripMargin
|
||||
} else {
|
||||
code
|
||||
}
|
||||
}
|
||||
|
||||
override def prettyName: String = "arrays_overlap"
|
||||
}
|
||||
|
||||
/**
|
||||
* Slices an array according to the requested start index and length
|
||||
*/
|
||||
|
|
|
@ -136,6 +136,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
|
||||
}
|
||||
|
||||
test("ArraysOverlap") {
|
||||
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
|
||||
val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType))
|
||||
val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType))
|
||||
val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType))
|
||||
val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
|
||||
val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType))
|
||||
val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType))
|
||||
|
||||
val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType))
|
||||
|
||||
checkEvaluation(ArraysOverlap(a0, a1), true)
|
||||
checkEvaluation(ArraysOverlap(a0, a2), null)
|
||||
checkEvaluation(ArraysOverlap(a1, a2), true)
|
||||
checkEvaluation(ArraysOverlap(a1, a3), false)
|
||||
checkEvaluation(ArraysOverlap(a0, emptyIntArray), false)
|
||||
checkEvaluation(ArraysOverlap(a2, emptyIntArray), false)
|
||||
checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)
|
||||
|
||||
checkEvaluation(ArraysOverlap(a4, a5), true)
|
||||
checkEvaluation(ArraysOverlap(a4, a6), null)
|
||||
checkEvaluation(ArraysOverlap(a5, a6), false)
|
||||
|
||||
// null handling
|
||||
checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)
|
||||
checkEvaluation(ArraysOverlap(
|
||||
emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false)
|
||||
checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null)
|
||||
checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null)
|
||||
checkEvaluation(ArraysOverlap(
|
||||
Literal.create(Seq(null), ArrayType(IntegerType)),
|
||||
Literal.create(Seq(null), ArrayType(IntegerType))), null)
|
||||
|
||||
// arrays of binaries
|
||||
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)),
|
||||
ArrayType(BinaryType))
|
||||
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
|
||||
ArrayType(BinaryType))
|
||||
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
|
||||
ArrayType(BinaryType))
|
||||
|
||||
checkEvaluation(ArraysOverlap(b0, b1), true)
|
||||
checkEvaluation(ArraysOverlap(b0, b2), false)
|
||||
|
||||
// arrays of complex data types
|
||||
val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")),
|
||||
ArrayType(ArrayType(StringType)))
|
||||
val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")),
|
||||
ArrayType(ArrayType(StringType)))
|
||||
val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")),
|
||||
ArrayType(ArrayType(StringType)))
|
||||
|
||||
checkEvaluation(ArraysOverlap(aa0, aa1), true)
|
||||
checkEvaluation(ArraysOverlap(aa0, aa2), false)
|
||||
|
||||
// null handling with complex datatypes
|
||||
val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType))
|
||||
val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType))
|
||||
checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false)
|
||||
checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false)
|
||||
checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false)
|
||||
checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false)
|
||||
checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null)
|
||||
checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null)
|
||||
}
|
||||
|
||||
test("Slice") {
|
||||
val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
|
||||
val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType))
|
||||
|
|
|
@ -3085,6 +3085,17 @@ object functions {
|
|||
ArrayContains(column.expr, Literal(value))
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both
|
||||
* the arrays are non-empty and any of them contains a `null`, it returns `null`. It returns
|
||||
* `false` otherwise.
|
||||
* @group collection_funcs
|
||||
* @since 2.4.0
|
||||
*/
|
||||
def arrays_overlap(a1: Column, a2: Column): Column = withExpr {
|
||||
ArraysOverlap(a1.expr, a2.expr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array containing all the elements in `x` from index `start` (or starting from the
|
||||
* end if `start` is negative) with the specified `length`.
|
||||
|
|
|
@ -442,6 +442,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
)
|
||||
}
|
||||
|
||||
test("arrays_overlap function") {
|
||||
val df = Seq(
|
||||
(Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))),
|
||||
(Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), None)),
|
||||
(Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2)))
|
||||
).toDF("a", "b")
|
||||
|
||||
val answer = Seq(Row(false), Row(null), Row(true))
|
||||
|
||||
checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer)
|
||||
checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer)
|
||||
|
||||
checkAnswer(
|
||||
Seq((Seq(1, 2, 3), Seq(2.0, 2.5))).toDF("a", "b").selectExpr("arrays_overlap(a, b)"),
|
||||
Row(true))
|
||||
|
||||
intercept[AnalysisException] {
|
||||
sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))")
|
||||
}
|
||||
|
||||
intercept[AnalysisException] {
|
||||
sql("select arrays_overlap(null, null)")
|
||||
}
|
||||
|
||||
intercept[AnalysisException] {
|
||||
sql("select arrays_overlap(map(1, 2), map(3, 4))")
|
||||
}
|
||||
}
|
||||
|
||||
test("slice function") {
|
||||
val df = Seq(
|
||||
Seq(1, 2, 3),
|
||||
|
|
Loading…
Reference in a new issue