[SPARK-9403] [SQL] Add codegen support in In and InSet

This continues tarekauel's work in #7778.

Author: Liang-Chi Hsieh <viirya@appier.com>
Author: Tarek Auel <tarek.auel@googlemail.com>

Closes #7893 from viirya/codegen_in and squashes the following commits:

81ff97b [Liang-Chi Hsieh] For comments.
47761c6 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in
cf4bf41 [Liang-Chi Hsieh] For comments.
f532b3c [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in
446bbcd [Liang-Chi Hsieh] Fix bug.
b3d0ab4 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in
4610eff [Liang-Chi Hsieh] Relax the types of references and update optimizer test.
224f18e [Liang-Chi Hsieh] Beef up the test cases for In and InSet to include all primitive data types.
86dc8aa [Liang-Chi Hsieh] Only convert In to InSet when the number of items in set is more than the threshold.
b7ded7e [Tarek Auel] [SPARK-9403][SQL] codeGen in / inSet
This commit is contained in:
Liang-Chi Hsieh 2015-08-05 11:38:56 -07:00 committed by Davies Liu
parent 1f8c364b9c
commit e1e05873fc
6 changed files with 119 additions and 10 deletions

View file

@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst.expressions
import scala.collection.mutable
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@ -97,32 +100,80 @@ case class Not(child: Expression)
/**
* Evaluates to `true` if `list` contains `value`.
*/
case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback {
case class In(value: Expression, list: Seq[Expression]) extends Predicate
with ImplicitCastInputTypes {
override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
override def checkInputDataTypes(): TypeCheckResult = {
if (list.exists(l => l.dataType != value.dataType)) {
TypeCheckResult.TypeCheckFailure(
"Arguments must be same type")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
override def children: Seq[Expression] = value +: list
override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
override def eval(input: InternalRow): Any = {
val evaluatedValue = value.eval(input)
list.exists(e => e.eval(input) == evaluatedValue)
}
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val valueGen = value.gen(ctx)
val listGen = list.map(_.gen(ctx))
val listCode = listGen.map(x =>
s"""
if (!${ev.primitive}) {
${x.code}
if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) {
${ev.primitive} = true;
}
}
""").mkString("\n")
s"""
${valueGen.code}
boolean ${ev.primitive} = false;
boolean ${ev.isNull} = false;
$listCode
"""
}
}
/**
* Optimized version of In clause, when all filter values of In clause are
* static.
*/
case class InSet(child: Expression, hset: Set[Any])
extends UnaryExpression with Predicate with CodegenFallback {
case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {
override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"
override def eval(input: InternalRow): Any = {
hset.contains(child.eval(input))
}
def getHSet(): Set[Any] = hset
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val setName = classOf[Set[Any]].getName
val InSetName = classOf[InSet].getName
val childGen = child.gen(ctx)
ctx.references += this
val hsetTerm = ctx.freshName("hset")
ctx.addMutableState(setName, hsetTerm,
s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();")
s"""
${childGen.code}
boolean ${ev.isNull} = false;
boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive});
"""
}
}
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {

View file

@ -393,7 +393,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
object OptimizeIn extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) && list.size > 10 =>
val hSet = list.map(e => e.eval(EmptyRow))
InSet(v, HashSet() ++ hSet)
}

View file

@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType}
import org.apache.spark.sql.RandomDataGenerator
import org.apache.spark.sql.types._
class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.map { t =>
val dataGen = RandomDataGenerator.forType(t, nullable = false).get
val inputData = Seq.fill(10) {
val value = dataGen.apply()
value match {
case d: Double if d.isNaN => 0.0d
case f: Float if f.isNaN => 0.0f
case _ => value
}
}
val input = inputData.map(Literal(_))
checkEvaluation(In(input(0), input.slice(1, 10)),
inputData.slice(1, 10).contains(inputData(0)))
}
}
test("INSET") {
@ -134,6 +152,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(InSet(three, hS), false)
checkEvaluation(InSet(three, nS), false)
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
primitiveTypes.map { t =>
val dataGen = RandomDataGenerator.forType(t, nullable = false).get
val inputData = Seq.fill(10) {
val value = dataGen.apply()
value match {
case d: Double if d.isNaN => 0.0d
case f: Float if f.isNaN => 0.0f
case _ => value
}
}
val input = inputData.map(Literal(_))
checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet),
inputData.slice(1, 10).contains(inputData(0)))
}
}
private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))

View file

@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
test("OptimizedIn test: In clause optimized to InSet") {
test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
.analyze
val optimized = Optimize.execute(originalQuery.analyze)
comparePlans(optimized, originalQuery)
}
test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
val originalQuery =
testRelation
.where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))
.analyze
val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer =
testRelation
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
.where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
.analyze
comparePlans(optimized, correctAnswer)

View file

@ -366,6 +366,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
case expressions.InSet(a: Attribute, set) =>
Some(sources.In(a.name, set.toArray))
// Because we only convert In to InSet in Optimizer when there are more than certain
// items. So it is possible we still get an In expression here that needs to be pushed
// down.
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
val hSet = list.map(e => e.eval(EmptyRow))
Some(sources.In(a.name, hSet.toArray))
case expressions.IsNull(a: Attribute) =>
Some(sources.IsNull(a.name))
case expressions.IsNotNull(a: Attribute) =>

View file

@ -357,6 +357,12 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
checkAnswer(df.filter($"b".in("z", "y")),
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
intercept[AnalysisException] {
df2.filter($"a".in($"b"))
}
}
val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(