[SPARK-23922][SQL] Add arrays_overlap function

## What changes were proposed in this pull request?

The PR adds the function `arrays_overlap`. This function returns `true` if the input arrays contain a non-null common element; if not, it returns `null` if any of the arrays contains a `null` element, `false` otherwise.

## How was this patch tested?

added UTs

Author: Marco Gaido <marcogaido91@gmail.com>

Closes #21028 from mgaido91/SPARK-23922.
This commit is contained in:
Marco Gaido 2018-05-17 20:45:32 +08:00 committed by Wenchen Fan
parent 6ec05826d7
commit 69350aa2f0
6 changed files with 388 additions and 1 deletions

View file

@ -1855,6 +1855,21 @@ def array_contains(col, value):
return Column(sc._jvm.functions.array_contains(_to_java_column(col), value))
@since(2.4)
def arrays_overlap(a1, a2):
"""
Collection function: returns true if the arrays contain any common non-null element; if not,
returns null if both the arrays are non-empty and any of them contains a null element; returns
false otherwise.
>>> df = spark.createDataFrame([(["a", "b"], ["b", "c"]), (["a"], ["b", "c"])], ['x', 'y'])
>>> df.select(arrays_overlap(df.x, df.y).alias("overlap")).collect()
[Row(overlap=True), Row(overlap=False)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.arrays_overlap(_to_java_column(a1), _to_java_column(a2)))
@since(2.4)
def slice(x, start, length):
"""

View file

@ -410,6 +410,7 @@ object FunctionRegistry {
// collection functions
expression[CreateArray]("array"),
expression[ArrayContains]("array_contains"),
expression[ArraysOverlap]("arrays_overlap"),
expression[ArrayJoin]("array_join"),
expression[ArrayPosition]("array_position"),
expression[ArraySort]("array_sort"),

View file

@ -18,15 +18,51 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.Comparator
import scala.collection.mutable
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
* casting.
*/
trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
with ImplicitCastInputTypes {
@transient protected lazy val elementType: DataType =
inputTypes.head.asInstanceOf[ArrayType].elementType
override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) =>
TypeCoercion.findTightestCommonType(e1, e2) match {
case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2))
case _ => Seq.empty
}
case _ => Seq.empty
}
}
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) =>
TypeCheckResult.TypeCheckSuccess
case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " +
s"been two ${ArrayType.simpleString}s with same element type, but it's " +
s"[${left.dataType.simpleString}, ${right.dataType.simpleString}]")
}
}
}
/**
* Given an array or map, returns its size. Returns -1 if null.
*/
@ -529,6 +565,235 @@ case class ArrayContains(left: Expression, right: Expression)
override def prettyName: String = "array_contains"
}
/**
* Checks if the two arrays contain at least one common element.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5));
true
""", since = "2.4.0")
// scalastyle:off line.size.limit
case class ArraysOverlap(left: Expression, right: Expression)
extends BinaryArrayExpressionWithImplicitCast {
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (RowOrdering.isOrderable(elementType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.")
}
case failure => failure
}
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)
@transient private lazy val elementTypeSupportEquals = elementType match {
case BinaryType => false
case _: AtomicType => true
case _ => false
}
@transient private lazy val doEvaluation = if (elementTypeSupportEquals) {
fastEval _
} else {
bruteForceEval _
}
override def dataType: DataType = BooleanType
override def nullable: Boolean = {
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
right.dataType.asInstanceOf[ArrayType].containsNull
}
override def nullSafeEval(a1: Any, a2: Any): Any = {
doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData])
}
/**
* A fast implementation which puts all the elements from the smaller array in a set
* and then performs a lookup on it for each element of the bigger one.
* This eval mode works only for data types which implements properly the equals method.
*/
private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = {
var hasNull = false
val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) {
(arr1, arr2)
} else {
(arr2, arr1)
}
if (smaller.numElements() > 0) {
val smallestSet = new mutable.HashSet[Any]
smaller.foreach(elementType, (_, v) =>
if (v == null) {
hasNull = true
} else {
smallestSet += v
})
bigger.foreach(elementType, (_, v1) =>
if (v1 == null) {
hasNull = true
} else if (smallestSet.contains(v1)) {
return true
}
)
}
if (hasNull) {
null
} else {
false
}
}
/**
* A slower evaluation which performs a nested loop and supports all the data types.
*/
private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = {
var hasNull = false
if (arr1.numElements() > 0 && arr2.numElements() > 0) {
arr1.foreach(elementType, (_, v1) =>
if (v1 == null) {
hasNull = true
} else {
arr2.foreach(elementType, (_, v2) =>
if (v2 == null) {
hasNull = true
} else if (ordering.equiv(v1, v2)) {
return true
}
)
})
}
if (hasNull) {
null
} else {
false
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (a1, a2) => {
val smaller = ctx.freshName("smallerArray")
val bigger = ctx.freshName("biggerArray")
val comparisonCode = if (elementTypeSupportEquals) {
fastCodegen(ctx, ev, smaller, bigger)
} else {
bruteForceCodegen(ctx, ev, smaller, bigger)
}
s"""
|ArrayData $smaller;
|ArrayData $bigger;
|if ($a1.numElements() > $a2.numElements()) {
| $bigger = $a1;
| $smaller = $a2;
|} else {
| $smaller = $a1;
| $bigger = $a2;
|}
|if ($smaller.numElements() > 0) {
| $comparisonCode
|}
""".stripMargin
})
}
/**
* Code generation for a fast implementation which puts all the elements from the smaller array
* in a set and then performs a lookup on it for each element of the bigger one.
* It works only for data types which implements properly the equals method.
*/
private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
val i = ctx.freshName("i")
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i)
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
val javaElementClass = CodeGenerator.boxedType(elementType)
val javaSet = classOf[java.util.HashSet[_]].getName
val set = ctx.freshName("set")
val addToSetFromSmallerCode = nullSafeElementCodegen(
smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;")
val elementIsInSetCode = nullSafeElementCodegen(
bigger,
i,
s"""
|if ($set.contains($getFromBigger)) {
| ${ev.isNull} = false;
| ${ev.value} = true;
| break;
|}
""".stripMargin,
s"${ev.isNull} = true;")
s"""
|$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>();
|for (int $i = 0; $i < $smaller.numElements(); $i ++) {
| $addToSetFromSmallerCode
|}
|for (int $i = 0; $i < $bigger.numElements(); $i ++) {
| $elementIsInSetCode
|}
""".stripMargin
}
/**
* Code generation for a slower evaluation which performs a nested loop and supports all the data types.
*/
private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j)
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
val compareValues = nullSafeElementCodegen(
smaller,
j,
s"""
|if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) {
| ${ev.isNull} = false;
| ${ev.value} = true;
|}
""".stripMargin,
s"${ev.isNull} = true;")
val isInSmaller = nullSafeElementCodegen(
bigger,
i,
s"""
|for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) {
| $compareValues
|}
""".stripMargin,
s"${ev.isNull} = true;")
s"""
|for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) {
| $isInSmaller
|}
""".stripMargin
}
def nullSafeElementCodegen(
arrayVar: String,
index: String,
code: String,
isNullCode: String): String = {
if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) {
s"""
|if ($arrayVar.isNullAt($index)) {
| $isNullCode
|} else {
| $code
|}
""".stripMargin
} else {
code
}
}
override def prettyName: String = "arrays_overlap"
}
/**
* Slices an array according to the requested start index and length
*/

View file

@ -136,6 +136,72 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
}
test("ArraysOverlap") {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq(4, 5, 3), ArrayType(IntegerType))
val a2 = Literal.create(Seq(null, 5, 6), ArrayType(IntegerType))
val a3 = Literal.create(Seq(7, 8), ArrayType(IntegerType))
val a4 = Literal.create(Seq[String](null, ""), ArrayType(StringType))
val a5 = Literal.create(Seq[String]("", "abc"), ArrayType(StringType))
val a6 = Literal.create(Seq[String]("def", "ghi"), ArrayType(StringType))
val emptyIntArray = Literal.create(Seq.empty[Int], ArrayType(IntegerType))
checkEvaluation(ArraysOverlap(a0, a1), true)
checkEvaluation(ArraysOverlap(a0, a2), null)
checkEvaluation(ArraysOverlap(a1, a2), true)
checkEvaluation(ArraysOverlap(a1, a3), false)
checkEvaluation(ArraysOverlap(a0, emptyIntArray), false)
checkEvaluation(ArraysOverlap(a2, emptyIntArray), false)
checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)
checkEvaluation(ArraysOverlap(a4, a5), true)
checkEvaluation(ArraysOverlap(a4, a6), null)
checkEvaluation(ArraysOverlap(a5, a6), false)
// null handling
checkEvaluation(ArraysOverlap(emptyIntArray, a2), false)
checkEvaluation(ArraysOverlap(
emptyIntArray, Literal.create(Seq(null), ArrayType(IntegerType))), false)
checkEvaluation(ArraysOverlap(Literal.create(null, ArrayType(IntegerType)), a0), null)
checkEvaluation(ArraysOverlap(a0, Literal.create(null, ArrayType(IntegerType))), null)
checkEvaluation(ArraysOverlap(
Literal.create(Seq(null), ArrayType(IntegerType)),
Literal.create(Seq(null), ArrayType(IntegerType))), null)
// arrays of binaries
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](1, 2), Array[Byte](3, 4)),
ArrayType(BinaryType))
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
ArrayType(BinaryType))
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
ArrayType(BinaryType))
checkEvaluation(ArraysOverlap(b0, b1), true)
checkEvaluation(ArraysOverlap(b0, b2), false)
// arrays of complex data types
val aa0 = Literal.create(Seq[Array[String]](Array[String]("a", "b"), Array[String]("c", "d")),
ArrayType(ArrayType(StringType)))
val aa1 = Literal.create(Seq[Array[String]](Array[String]("e", "f"), Array[String]("a", "b")),
ArrayType(ArrayType(StringType)))
val aa2 = Literal.create(Seq[Array[String]](Array[String]("b", "a"), Array[String]("f", "g")),
ArrayType(ArrayType(StringType)))
checkEvaluation(ArraysOverlap(aa0, aa1), true)
checkEvaluation(ArraysOverlap(aa0, aa2), false)
// null handling with complex datatypes
val emptyBinaryArray = Literal.create(Seq.empty[Array[Byte]], ArrayType(BinaryType))
val arrayWithBinaryNull = Literal.create(Seq(null), ArrayType(BinaryType))
checkEvaluation(ArraysOverlap(emptyBinaryArray, b0), false)
checkEvaluation(ArraysOverlap(b0, emptyBinaryArray), false)
checkEvaluation(ArraysOverlap(emptyBinaryArray, arrayWithBinaryNull), false)
checkEvaluation(ArraysOverlap(arrayWithBinaryNull, emptyBinaryArray), false)
checkEvaluation(ArraysOverlap(arrayWithBinaryNull, b0), null)
checkEvaluation(ArraysOverlap(b0, arrayWithBinaryNull), null)
}
test("Slice") {
val a0 = Literal.create(Seq(1, 2, 3, 4, 5, 6), ArrayType(IntegerType))
val a1 = Literal.create(Seq[String]("a", "b", "c", "d"), ArrayType(StringType))

View file

@ -3085,6 +3085,17 @@ object functions {
ArrayContains(column.expr, Literal(value))
}
/**
* Returns `true` if `a1` and `a2` have at least one non-null element in common. If not and both
* the arrays are non-empty and any of them contains a `null`, it returns `null`. It returns
* `false` otherwise.
* @group collection_funcs
* @since 2.4.0
*/
def arrays_overlap(a1: Column, a2: Column): Column = withExpr {
ArraysOverlap(a1.expr, a2.expr)
}
/**
* Returns an array containing all the elements in `x` from index `start` (or starting from the
* end if `start` is negative) with the specified `length`.

View file

@ -442,6 +442,35 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
test("arrays_overlap function") {
val df = Seq(
(Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))),
(Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), None)),
(Seq[Option[Int]](Some(3), Some(2)), Seq[Option[Int]](Some(1), Some(2)))
).toDF("a", "b")
val answer = Seq(Row(false), Row(null), Row(true))
checkAnswer(df.select(arrays_overlap(df("a"), df("b"))), answer)
checkAnswer(df.selectExpr("arrays_overlap(a, b)"), answer)
checkAnswer(
Seq((Seq(1, 2, 3), Seq(2.0, 2.5))).toDF("a", "b").selectExpr("arrays_overlap(a, b)"),
Row(true))
intercept[AnalysisException] {
sql("select arrays_overlap(array(1, 2, 3), array('a', 'b', 'c'))")
}
intercept[AnalysisException] {
sql("select arrays_overlap(null, null)")
}
intercept[AnalysisException] {
sql("select arrays_overlap(map(1, 2), map(3, 4))")
}
}
test("slice function") {
val df = Seq(
Seq(1, 2, 3),