[SPARK-21228][SQL] InSet incorrect handling of structs

## What changes were proposed in this pull request?
When data type is struct, InSet now uses TypeUtils.getInterpretedOrdering (similar to EqualTo) to build a TreeSet. In other cases it will use a HashSet as before (which should be faster). Similarly, In.eval uses Ordering.equiv instead of equals.

## How was this patch tested?
New test in SQLQuerySuite.

Author: Bogdan Raducanu <bogdan@databricks.com>

Closes #18455 from bogdanrdc/SPARK-21228.
This commit is contained in:
Bogdan Raducanu 2017-07-07 01:04:57 +08:00 committed by Wenchen Fan
parent 565e7a8d4a
commit 26ac085deb
4 changed files with 78 additions and 34 deletions

View file

@ -17,10 +17,11 @@
package org.apache.spark.sql.catalyst.expressions
import scala.collection.immutable.TreeSet
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.expressions.codegen.{Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateSafeProjection, GenerateUnsafeProjection, Predicate => BasePredicate}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
@ -175,20 +176,23 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
|[${sub.output.map(_.dataType.catalogString).mkString(", ")}].
""".stripMargin)
} else {
TypeCheckResult.TypeCheckSuccess
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
case _ =>
if (list.exists(l => l.dataType != value.dataType)) {
TypeCheckResult.TypeCheckFailure("Arguments must be same type")
val mismatchOpt = list.find(l => l.dataType != value.dataType)
if (mismatchOpt.isDefined) {
TypeCheckResult.TypeCheckFailure(s"Arguments must be same type but were: " +
s"${value.dataType} != ${mismatchOpt.get.dataType}")
} else {
TypeCheckResult.TypeCheckSuccess
TypeUtils.checkForOrderingExpr(value.dataType, s"function $prettyName")
}
}
}
override def children: Seq[Expression] = value +: list
lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
private lazy val ordering = TypeUtils.getInterpretedOrdering(value.dataType)
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
@ -203,10 +207,10 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate {
var hasNull = false
list.foreach { e =>
val v = e.eval(input)
if (v == evaluatedValue) {
return true
} else if (v == null) {
if (v == null) {
hasNull = true
} else if (ordering.equiv(v, evaluatedValue)) {
return true
}
}
if (hasNull) {
@ -265,7 +269,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
override def nullable: Boolean = child.nullable || hasNull
protected override def nullSafeEval(value: Any): Any = {
if (hset.contains(value)) {
if (set.contains(value)) {
true
} else if (hasNull) {
null
@ -274,27 +278,40 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with
}
}
def getHSet(): Set[Any] = hset
@transient private[this] lazy val set = child.dataType match {
case _: AtomicType => hset
case _: NullType => hset
case _ =>
// for structs use interpreted ordering to be able to compare UnsafeRows with non-UnsafeRows
TreeSet.empty(TypeUtils.getInterpretedOrdering(child.dataType)) ++ hset
}
def getSet(): Set[Any] = set
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val setName = classOf[Set[Any]].getName
val InSetName = classOf[InSet].getName
val childGen = child.genCode(ctx)
ctx.references += this
val hsetTerm = ctx.freshName("hset")
val hasNullTerm = ctx.freshName("hasNull")
ctx.addMutableState(setName, hsetTerm,
s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();")
ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);")
val setTerm = ctx.freshName("set")
val setNull = if (hasNull) {
s"""
|if (!${ev.value}) {
| ${ev.isNull} = true;
|}
""".stripMargin
} else {
""
}
ctx.addMutableState(setName, setTerm,
s"$setTerm = (($InSetName)references[${ctx.references.size - 1}]).getSet();")
ev.copy(code = s"""
${childGen.code}
boolean ${ev.isNull} = ${childGen.isNull};
boolean ${ev.value} = false;
if (!${ev.isNull}) {
${ev.value} = $hsetTerm.contains(${childGen.value});
if (!${ev.value} && $hasNullTerm) {
${ev.isNull} = true;
}
${ev.value} = $setTerm.contains(${childGen.value});
$setNull
}
""")
}

View file

@ -35,7 +35,8 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test(s"3VL $name") {
truthTable.foreach {
case (l, r, answer) =>
val expr = op(NonFoldableLiteral(l, BooleanType), NonFoldableLiteral(r, BooleanType))
val expr = op(NonFoldableLiteral.create(l, BooleanType),
NonFoldableLiteral.create(r, BooleanType))
checkEvaluation(expr, answer)
}
}
@ -72,7 +73,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(false, true) ::
(null, null) :: Nil
notTrueTable.foreach { case (v, answer) =>
checkEvaluation(Not(NonFoldableLiteral(v, BooleanType)), answer)
checkEvaluation(Not(NonFoldableLiteral.create(v, BooleanType)), answer)
}
checkConsistencyBetweenInterpretedAndCodegen(Not, BooleanType)
}
@ -120,22 +121,26 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
(null, null, null) :: Nil)
test("IN") {
checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq(Literal(1), Literal(2))), null)
checkEvaluation(In(NonFoldableLiteral(null, IntegerType),
Seq(NonFoldableLiteral(null, IntegerType))), null)
checkEvaluation(In(NonFoldableLiteral(null, IntegerType), Seq.empty), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq(Literal(1),
Literal(2))), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType),
Seq(NonFoldableLiteral.create(null, IntegerType))), null)
checkEvaluation(In(NonFoldableLiteral.create(null, IntegerType), Seq.empty), null)
checkEvaluation(In(Literal(1), Seq.empty), false)
checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(NonFoldableLiteral.create(null, IntegerType))), null)
checkEvaluation(In(Literal(1), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
true)
checkEvaluation(In(Literal(2), Seq(Literal(1), NonFoldableLiteral.create(null, IntegerType))),
null)
checkEvaluation(In(Literal(1), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(2), Seq(Literal(1), Literal(2))), true)
checkEvaluation(In(Literal(3), Seq(Literal(1), Literal(2))), false)
checkEvaluation(
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1), Literal(2)))),
And(In(Literal(1), Seq(Literal(1), Literal(2))), In(Literal(2), Seq(Literal(1),
Literal(2)))),
true)
val ns = NonFoldableLiteral(null, StringType)
val ns = NonFoldableLiteral.create(null, StringType)
checkEvaluation(In(ns, Seq(Literal("1"), Literal("2"))), null)
checkEvaluation(In(ns, Seq(ns)), null)
checkEvaluation(In(Literal("a"), Seq(ns)), null)
@ -155,7 +160,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
case _ => value
}
}
val input = inputData.map(NonFoldableLiteral(_, t))
val input = inputData.map(NonFoldableLiteral.create(_, t))
val expected = if (inputData(0) == null) {
null
} else if (inputData.slice(1, 10).contains(inputData(0))) {
@ -279,7 +284,7 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
test("BinaryComparison: null test") {
// Use -1 (default value for codegen) which can trigger some weird bugs, e.g. SPARK-14757
val normalInt = Literal(-1)
val nullInt = NonFoldableLiteral(null, IntegerType)
val nullInt = NonFoldableLiteral.create(null, IntegerType)
def nullTest(op: (Expression, Expression) => Expression): Unit = {
checkEvaluation(op(normalInt, nullInt), null)

View file

@ -169,7 +169,7 @@ class OptimizeInSuite extends PlanTest {
val optimizedPlan = OptimizeIn(plan)
optimizedPlan match {
case Filter(cond, _)
if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getHSet().size == 3 =>
if cond.isInstanceOf[InSet] && cond.asInstanceOf[InSet].getSet().size == 3 =>
// pass
case _ => fail("Unexpected result for OptimizedIn")
}

View file

@ -2616,4 +2616,26 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
val e = intercept[AnalysisException](sql("SELECT nvl(1, 2, 3)"))
assert(e.message.contains("Invalid number of arguments"))
}
test("SPARK-21228: InSet incorrect handling of structs") {
withTempView("A") {
// reduce this from the default of 10 so the repro query text is not too long
withSQLConf((SQLConf.OPTIMIZER_INSET_CONVERSION_THRESHOLD.key -> "3")) {
// a relation that has 1 column of struct type with values (1,1), ..., (9, 9)
spark.range(1, 10).selectExpr("named_struct('a', id, 'b', id) as a")
.createOrReplaceTempView("A")
val df = sql(
"""
|SELECT * from
| (SELECT MIN(a) as minA FROM A) AA -- this Aggregate will return UnsafeRows
| -- the IN will become InSet with a Set of GenericInternalRows
| -- a GenericInternalRow is never equal to an UnsafeRow so the query would
| -- returns 0 results, which is incorrect
| WHERE minA IN (NAMED_STRUCT('a', 1L, 'b', 1L), NAMED_STRUCT('a', 2L, 'b', 2L),
| NAMED_STRUCT('a', 3L, 'b', 3L))
""".stripMargin)
checkAnswer(df, Row(Row(1, 1)))
}
}
}
}