[SPARK-24313][SQL] Fix collection operations' interpreted evaluation for complex types
## What changes were proposed in this pull request? The interpreted evaluation of several collection operations works only for simple datatypes. For complex data types, for instance, `array_contains` it returns always `false`. The list of the affected functions is `array_contains`, `array_position`, `element_at` and `GetMapValue`. The PR fixes the behavior for all the datatypes. ## How was this patch tested? added UT Author: Marco Gaido <marcogaido91@gmail.com> Closes #21361 from mgaido91/SPARK-24313.
This commit is contained in:
parent
a4470bc78c
commit
d3d1807315
|
@ -657,6 +657,9 @@ case class ArrayContains(left: Expression, right: Expression)
|
|||
|
||||
override def dataType: DataType = BooleanType
|
||||
|
||||
@transient private lazy val ordering: Ordering[Any] =
|
||||
TypeUtils.getInterpretedOrdering(right.dataType)
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = right.dataType match {
|
||||
case NullType => Seq.empty
|
||||
case _ => left.dataType match {
|
||||
|
@ -673,7 +676,7 @@ case class ArrayContains(left: Expression, right: Expression)
|
|||
TypeCheckResult.TypeCheckFailure(
|
||||
"Arguments must be an array followed by a value of same type as the array members")
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -686,7 +689,7 @@ case class ArrayContains(left: Expression, right: Expression)
|
|||
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
|
||||
if (v == null) {
|
||||
hasNull = true
|
||||
} else if (v == value) {
|
||||
} else if (ordering.equiv(v, value)) {
|
||||
return true
|
||||
}
|
||||
)
|
||||
|
@ -735,11 +738,7 @@ case class ArraysOverlap(left: Expression, right: Expression)
|
|||
|
||||
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
|
||||
case TypeCheckResult.TypeCheckSuccess =>
|
||||
if (RowOrdering.isOrderable(elementType)) {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckFailure(s"${elementType.simpleString} cannot be used in comparison.")
|
||||
}
|
||||
TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
|
||||
case failure => failure
|
||||
}
|
||||
|
||||
|
@ -1391,13 +1390,24 @@ case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCast
|
|||
case class ArrayPosition(left: Expression, right: Expression)
|
||||
extends BinaryExpression with ImplicitCastInputTypes {
|
||||
|
||||
@transient private lazy val ordering: Ordering[Any] =
|
||||
TypeUtils.getInterpretedOrdering(right.dataType)
|
||||
|
||||
override def dataType: DataType = LongType
|
||||
override def inputTypes: Seq[AbstractDataType] =
|
||||
Seq(ArrayType, left.dataType.asInstanceOf[ArrayType].elementType)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
super.checkInputDataTypes() match {
|
||||
case f: TypeCheckResult.TypeCheckFailure => f
|
||||
case TypeCheckResult.TypeCheckSuccess =>
|
||||
TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
|
||||
}
|
||||
}
|
||||
|
||||
override def nullSafeEval(arr: Any, value: Any): Any = {
|
||||
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
|
||||
if (v == value) {
|
||||
if (v != null && ordering.equiv(v, value)) {
|
||||
return (i + 1).toLong
|
||||
}
|
||||
)
|
||||
|
@ -1446,6 +1456,9 @@ case class ArrayPosition(left: Expression, right: Expression)
|
|||
since = "2.4.0")
|
||||
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
|
||||
|
||||
@transient private lazy val ordering: Ordering[Any] =
|
||||
TypeUtils.getInterpretedOrdering(left.dataType.asInstanceOf[MapType].keyType)
|
||||
|
||||
override def dataType: DataType = left.dataType match {
|
||||
case ArrayType(elementType, _) => elementType
|
||||
case MapType(_, valueType, _) => valueType
|
||||
|
@ -1460,6 +1473,16 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
|
|||
)
|
||||
}
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
super.checkInputDataTypes() match {
|
||||
case f: TypeCheckResult.TypeCheckFailure => f
|
||||
case TypeCheckResult.TypeCheckSuccess if left.dataType.isInstanceOf[MapType] =>
|
||||
TypeUtils.checkForOrderingExpr(
|
||||
left.dataType.asInstanceOf[MapType].keyType, s"function $prettyName")
|
||||
case TypeCheckResult.TypeCheckSuccess => TypeCheckResult.TypeCheckSuccess
|
||||
}
|
||||
}
|
||||
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override def nullSafeEval(value: Any, ordinal: Any): Any = {
|
||||
|
@ -1484,7 +1507,7 @@ case class ElementAt(left: Expression, right: Expression) extends GetMapValueUti
|
|||
}
|
||||
}
|
||||
case _: MapType =>
|
||||
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType)
|
||||
getValueEval(value, ordinal, left.dataType.asInstanceOf[MapType].keyType, ordering)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData}
|
||||
import org.apache.spark.sql.catalyst.util.{quoteIdentifier, ArrayData, GenericArrayData, MapData, TypeUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -273,7 +273,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
|
|||
|
||||
abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTypes {
|
||||
// todo: current search is O(n), improve it.
|
||||
def getValueEval(value: Any, ordinal: Any, keyType: DataType): Any = {
|
||||
def getValueEval(value: Any, ordinal: Any, keyType: DataType, ordering: Ordering[Any]): Any = {
|
||||
val map = value.asInstanceOf[MapData]
|
||||
val length = map.numElements()
|
||||
val keys = map.keyArray()
|
||||
|
@ -282,7 +282,7 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy
|
|||
var i = 0
|
||||
var found = false
|
||||
while (i < length && !found) {
|
||||
if (keys.get(i, keyType) == ordinal) {
|
||||
if (ordering.equiv(keys.get(i, keyType), ordinal)) {
|
||||
found = true
|
||||
} else {
|
||||
i += 1
|
||||
|
@ -345,8 +345,19 @@ abstract class GetMapValueUtil extends BinaryExpression with ImplicitCastInputTy
|
|||
case class GetMapValue(child: Expression, key: Expression)
|
||||
extends GetMapValueUtil with ExtractValue with NullIntolerant {
|
||||
|
||||
@transient private lazy val ordering: Ordering[Any] =
|
||||
TypeUtils.getInterpretedOrdering(keyType)
|
||||
|
||||
private def keyType = child.dataType.asInstanceOf[MapType].keyType
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
super.checkInputDataTypes() match {
|
||||
case f: TypeCheckResult.TypeCheckFailure => f
|
||||
case TypeCheckResult.TypeCheckSuccess =>
|
||||
TypeUtils.checkForOrderingExpr(keyType, s"function $prettyName")
|
||||
}
|
||||
}
|
||||
|
||||
// We have done type checking for child in `ExtractValue`, so only need to check the `key`.
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType)
|
||||
|
||||
|
@ -363,7 +374,7 @@ case class GetMapValue(child: Expression, key: Expression)
|
|||
|
||||
// todo: current search is O(n), improve it.
|
||||
override def nullSafeEval(value: Any, ordinal: Any): Any = {
|
||||
getValueEval(value, ordinal, keyType)
|
||||
getValueEval(value, ordinal, keyType, ordering)
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
|
|
|
@ -157,6 +157,33 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
|
||||
checkEvaluation(ArrayContains(a3, Literal("")), null)
|
||||
checkEvaluation(ArrayContains(a3, Literal.create(null, StringType)), null)
|
||||
|
||||
// binary
|
||||
val b0 = Literal.create(Seq[Array[Byte]](Array[Byte](5, 6), Array[Byte](1, 2)),
|
||||
ArrayType(BinaryType))
|
||||
val b1 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), Array[Byte](4, 3)),
|
||||
ArrayType(BinaryType))
|
||||
val b2 = Literal.create(Seq[Array[Byte]](Array[Byte](2, 1), null),
|
||||
ArrayType(BinaryType))
|
||||
val b3 = Literal.create(Seq[Array[Byte]](null, Array[Byte](1, 2)),
|
||||
ArrayType(BinaryType))
|
||||
val be = Literal.create(Array[Byte](1, 2), BinaryType)
|
||||
val nullBinary = Literal.create(null, BinaryType)
|
||||
|
||||
checkEvaluation(ArrayContains(b0, be), true)
|
||||
checkEvaluation(ArrayContains(b1, be), false)
|
||||
checkEvaluation(ArrayContains(b0, nullBinary), null)
|
||||
checkEvaluation(ArrayContains(b2, be), null)
|
||||
checkEvaluation(ArrayContains(b3, be), true)
|
||||
|
||||
// complex data types
|
||||
val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
|
||||
ArrayType(ArrayType(IntegerType)))
|
||||
val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
|
||||
ArrayType(ArrayType(IntegerType)))
|
||||
val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
|
||||
checkEvaluation(ArrayContains(aa0, aae), true)
|
||||
checkEvaluation(ArrayContains(aa1, aae), false)
|
||||
}
|
||||
|
||||
test("ArraysOverlap") {
|
||||
|
@ -372,6 +399,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
|
||||
checkEvaluation(ArrayPosition(a3, Literal("")), null)
|
||||
checkEvaluation(ArrayPosition(a3, Literal.create(null, StringType)), null)
|
||||
|
||||
val aa0 = Literal.create(Seq[Seq[Int]](Seq[Int](1, 2), Seq[Int](3, 4)),
|
||||
ArrayType(ArrayType(IntegerType)))
|
||||
val aa1 = Literal.create(Seq[Seq[Int]](Seq[Int](5, 6), Seq[Int](2, 1)),
|
||||
ArrayType(ArrayType(IntegerType)))
|
||||
val aae = Literal.create(Seq[Int](1, 2), ArrayType(IntegerType))
|
||||
checkEvaluation(ArrayPosition(aa0, aae), 1L)
|
||||
checkEvaluation(ArrayPosition(aa1, aae), 0L)
|
||||
}
|
||||
|
||||
test("elementAt") {
|
||||
|
@ -409,7 +444,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
|
||||
val m2 = Literal.create(null, MapType(StringType, StringType))
|
||||
|
||||
checkEvaluation(ElementAt(m0, Literal(1.0)), null)
|
||||
assert(ElementAt(m0, Literal(1.0)).checkInputDataTypes().isFailure)
|
||||
|
||||
checkEvaluation(ElementAt(m0, Literal("d")), null)
|
||||
|
||||
|
@ -420,6 +455,18 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
|
|||
checkEvaluation(ElementAt(m0, Literal("c")), null)
|
||||
|
||||
checkEvaluation(ElementAt(m2, Literal("a")), null)
|
||||
|
||||
// test binary type as keys
|
||||
val mb0 = Literal.create(
|
||||
Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
|
||||
MapType(BinaryType, StringType))
|
||||
val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))
|
||||
|
||||
checkEvaluation(ElementAt(mb0, Literal(Array[Byte](1, 2, 3))), null)
|
||||
|
||||
checkEvaluation(ElementAt(mb1, Literal(Array[Byte](1, 2))), null)
|
||||
checkEvaluation(ElementAt(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
|
||||
checkEvaluation(ElementAt(mb0, Literal(Array[Byte](3, 4))), null)
|
||||
}
|
||||
|
||||
test("Concat") {
|
||||
|
|
|
@ -439,4 +439,17 @@ class ComplexTypesSuite extends PlanTest with ExpressionEvalHelper {
|
|||
.select('c as 'sCol2, 'a as 'sCol1)
|
||||
checkRule(originalQuery, correctAnswer)
|
||||
}
|
||||
|
||||
test("SPARK-24313: support binary type as map keys in GetMapValue") {
|
||||
val mb0 = Literal.create(
|
||||
Map(Array[Byte](1, 2) -> "1", Array[Byte](3, 4) -> null, Array[Byte](2, 1) -> "2"),
|
||||
MapType(BinaryType, StringType))
|
||||
val mb1 = Literal.create(Map[Array[Byte], String](), MapType(BinaryType, StringType))
|
||||
|
||||
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](1, 2, 3))), null)
|
||||
|
||||
checkEvaluation(GetMapValue(mb1, Literal(Array[Byte](1, 2))), null)
|
||||
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](2, 1), BinaryType)), "2")
|
||||
checkEvaluation(GetMapValue(mb0, Literal(Array[Byte](3, 4))), null)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2265,4 +2265,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
|
|||
val df = spark.range(1).select($"id", new Column(Uuid()))
|
||||
checkAnswer(df, df.collect())
|
||||
}
|
||||
|
||||
test("SPARK-24313: access map with binary keys") {
|
||||
val mapWithBinaryKey = map(lit(Array[Byte](1.toByte)), lit(1))
|
||||
checkAnswer(spark.range(1).select(mapWithBinaryKey.getItem(Array[Byte](1.toByte))), Row(1))
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue