[SPARK-21228][SQL] InSet incorrect handling of structs
## What changes were proposed in this pull request? When data type is struct, InSet now uses TypeUtils.getInterpretedOrdering (similar to EqualTo) to build a TreeSet. In other cases it will use a HashSet as before (which should be faster). Similarly, In.eval uses Ordering.equiv instead of equals. ## How was this patch tested? New test in SQLQuerySuite. Author: Bogdan Raducanu <bogdan@databricks.com> Closes #18455 from bogdanrdc/SPARK-21228.
This commit is contained in:
parent
565e7a8d4a
commit
26ac085deb
|
@ -17,10 +17,11 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import scala.collection.immutable.TreeSet
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -175,20 +176,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|
|||
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
|
||||
""".stripMargin)
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
|
||||
}
|
||||
}
|
||||
case _ =>
|
||||
if (list.exists(l => l.dataType != value.dataType)) {
|
||||
TypeCheckResult.TypeCheckFailure("Arguments must be same type")
|
||||
val mismatchOpt = list.find(l => l.dataType != value.dataType)
|
||||
if (mismatchOpt.isDefined) {
|
||||
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
|
||||
s"${value.dataType} != ${mismatchOpt.get.dataType}")
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def children: Seq[Expression] = value +: list
|
||||
lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
|
||||
private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)
|
||||
|
||||
override def nullable: Boolean = children.exists(_.nullable)
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
@ -203,10 +207,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|
|||
var hasNull = false
|
||||
list.foreach { e =>
|
||||
val v = e.eval(input)
|
||||
if (v == evaluatedValue) {
|
||||
return true
|
||||
} else if (v == null) {
|
||||
if (v == null) {
|
||||
hasNull = true
|
||||
} else if (ordering.equiv(v, evaluatedValue)) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if (hasNull) {
|
||||
|
@ -265,7 +269,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
|
|||
override def nullable: Boolean = child.nullable || hasNull
|
||||
|
||||
protected override def nullSafeEval(value: Any): Any = {
|
||||
if (hset.contains(value)) {
|
||||
if (set.contains(value)) {
|
||||
true
|
||||
} else if (hasNull) {
|
||||
null
|
||||
|
@ -274,27 +278,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
|
|||
}
|
||||
}
|
||||
|
||||
def getHSet(): Set[Any] = hset
|
||||
@transient private[this] lazy val set = child.dataType match {
|
||||
case _: AtomicType => hset
|
||||
case _: NullType => hset
|
||||
case _ =>
|
||||
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
|
||||
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
|
||||
}
|
||||
|
||||
def getSet(): Set[Any] = set
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val setName = classOf[Set[Any]].getName
|
||||
val InSetName = classOf[InSet].getName
|
||||
val childGen = child.genCode(ctx)
|
||||
ctx.references += this
|
||||
val hsetTerm = ctx.freshName("hset")
|
||||
val hasNullTerm = ctx.freshName("hasNull")
|
||||
ctx.addMutableState(setName, hsetTerm,
|
||||
s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();")
|
||||
ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
|
||||
val setTerm = ctx.freshName("set")
|
||||
val setNull = if (hasNull) {
|
||||
s"""
|
||||
|if (!${ev.value}) {
|
||||
| ${ev.isNull} = true;
|
||||
|}
|
||||
""".stripMargin
|
||||
} else {
|
||||
""
|
||||
}
|
||||
ctx.addMutableState(setName, setTerm,
|
||||
s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();")
|
||||
ev.copy(code = s"""
|
||||
${childGen.code}
|
||||
boolean ${ev.isNull} = ${childGen.isNull};
|
||||
boolean ${ev.value} = false;
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = $hsetTerm.contains(${childGen.value});
|
||||
if (!${ev.value} && $hasNullTerm) {
|
||||
${ev.isNull} = true;
|
||||
}
|
||||
${ev.value} = $setTerm.contains(${childGen.value});
|
||||
$setNull
|
||||
}
|
||||
""")
|
||||
}
|
||||
|
|
|
@ -35,7 +35,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
test(s"3VL $name") {
|
||||
truthTable.foreach {
|
||||
case (l, r, answer) =>
|
||||
val expr = op(NonFoldableLiteral(l, BooleanType), NonFoldableLiteral(r, BooleanType))
|
||||
val expr = op(NonFoldableLiteral.create(l, BooleanType),
|
||||
NonFoldableLiteral.create(r, BooleanType))
|
||||
checkEvaluation(expr, answer)
|
||||
}
|
||||
}
|
||||
|
@ -72,7 +73,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
(false, true) ::
|
||||
(null, null) :: Nil
|
||||
notTrueTable.foreach { case (v, answer) =>
|
||||
checkEvaluation(Not(NonFoldableLiteral(v, BooleanType)), answer)
|
||||
checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer)
|
||||
}
|
||||
checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType)
|
||||
}
|
||||
|
@ -120,22 +121,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
(null, null, null) :: Nil)
|
||||
|
||||
test("IN") {
|
||||
checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq(Literal(1), Literal(2))), null)
|
||||
checkEvaluation(In(NonFoldableLiteral(null, IntegerType),
|
||||
Seq(NonFoldableLiteral(null, IntegerType))), null)
|
||||
checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq.empty), null)
|
||||
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
|
||||
Literal(2))), null)
|
||||
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
|
||||
Seq(NonFoldableLiteral.create(null, IntegerType))), null)
|
||||
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
|
||||
checkEvaluation(In(Literal(1), Seq.empty), false)
|
||||
checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral(null, IntegerType))), null)
|
||||
checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), true)
|
||||
checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), null)
|
||||
checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
|
||||
checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
|
||||
true)
|
||||
checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
|
||||
null)
|
||||
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
|
||||
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
|
||||
checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
|
||||
checkEvaluation(
|
||||
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
|
||||
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1),
|
||||
Literal(2)))),
|
||||
true)
|
||||
|
||||
val ns = NonFoldableLiteral(null, StringType)
|
||||
val ns = NonFoldableLiteral.create(null, StringType)
|
||||
checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
|
||||
checkEvaluation(In(ns, Seq(ns)), null)
|
||||
checkEvaluation(In(Literal("a"), Seq(ns)), null)
|
||||
|
@ -155,7 +160,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
case _ => value
|
||||
}
|
||||
}
|
||||
val input = inputData.map(NonFoldableLiteral(_, t))
|
||||
val input = inputData.map(NonFoldableLiteral.create(_, t))
|
||||
val expected = if (inputData(0) == null) {
|
||||
null
|
||||
} else if (inputData.slice(1, 10).contains(inputData(0))) {
|
||||
|
@ -279,7 +284,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
test("BinaryComparison: null test") {
|
||||
// Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
|
||||
val normalInt = Literal(-1)
|
||||
val nullInt = NonFoldableLiteral(null, IntegerType)
|
||||
val nullInt = NonFoldableLiteral.create(null, IntegerType)
|
||||
|
||||
def nullTest(op: (Expression, Expression) => Expression): Unit = {
|
||||
checkEvaluation(op(normalInt, nullInt), null)
|
||||
|
|
|
@ -169,7 +169,7 @@ class OptimizeInSuite extends PlanTest {
|
|||
val optimizedPlan = OptimizeIn(plan)
|
||||
optimizedPlan match {
|
||||
case Filter(cond, _)
|
||||
if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 =>
|
||||
if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 =>
|
||||
// pass
|
||||
case _ => fail("Unexpected result for OptimizedIn")
|
||||
}
|
||||
|
|
|
@ -2616,4 +2616,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
|
|||
val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
|
||||
assert(e.message.contains("Invalid number of arguments"))
|
||||
}
|
||||
|
||||
test("SPARK-21228: InSet incorrect handling of structs") {
|
||||
withTempView("A") {
|
||||
// reduce this from the default of 10 so the repro query text is not too long
|
||||
withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) {
|
||||
// a relation that has 1 column of struct type with values (1,1), ..., (9, 9)
|
||||
spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a")
|
||||
.createOrReplaceTempView("A")
|
||||
val df = sql(
|
||||
"""
|
||||
|SELECT * from
|
||||
| (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows
|
||||
| -- the IN will become InSet with a Set of GenericInternalRows
|
||||
| -- a GenericInternalRow is never equal to an UnsafeRow so the query would
|
||||
| -- returns 0 results, which is incorrect
|
||||
| WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L),
|
||||
| NAMED_STRUCT('a', 3L, 'b', 3L))
|
||||
""".stripMargin)
|
||||
checkAnswer(df, Row(Row(1, 1)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue