[SPARK-9403] [SQL] Add codegen support in In and InSet
This continues tarekauel's work in #7778. Author: Liang-Chi Hsieh <viirya@appier.com> Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7893 from viirya/codegen_in and squashes the following commits: 81ff97b [Liang-Chi Hsieh] For comments. 47761c6 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in cf4bf41 [Liang-Chi Hsieh] For comments. f532b3c [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in 446bbcd [Liang-Chi Hsieh] Fix bug. b3d0ab4 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into codegen_in 4610eff [Liang-Chi Hsieh] Relax the types of references and update optimizer test. 224f18e [Liang-Chi Hsieh] Beef up the test cases for In and InSet to include all primitive data types. 86dc8aa [Liang-Chi Hsieh] Only convert In to InSet when the number of items in set is more than the threshold. b7ded7e [Tarek Auel] [SPARK-9403][SQL] codeGen in / inSet
This commit is contained in:
parent
1f8c364b9c
commit
e1e05873fc
|
@ -17,6 +17,9 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import scala.collection.mutable
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenFallback, GeneratedExpressionCode, CodeGenContext}
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||
|
@ -97,32 +100,80 @@ case class Not(child: Expression)
|
|||
/**
|
||||
* Evaluates to `true` if `list` contains `value`.
|
||||
*/
|
||||
case class In(value: Expression, list: Seq[Expression]) extends Predicate with CodegenFallback {
|
||||
case class In(value: Expression, list: Seq[Expression]) extends Predicate
|
||||
with ImplicitCastInputTypes {
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = value.dataType +: list.map(_.dataType)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (list.exists(l => l.dataType != value.dataType)) {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
"Arguments must be same type")
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
}
|
||||
}
|
||||
|
||||
override def children: Seq[Expression] = value +: list
|
||||
|
||||
override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
|
||||
override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
|
||||
override def toString: String = s"$value IN ${list.mkString("(", ",", ")")}"
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val evaluatedValue = value.eval(input)
|
||||
list.exists(e => e.eval(input) == evaluatedValue)
|
||||
}
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
val valueGen = value.gen(ctx)
|
||||
val listGen = list.map(_.gen(ctx))
|
||||
val listCode = listGen.map(x =>
|
||||
s"""
|
||||
if (!${ev.primitive}) {
|
||||
${x.code}
|
||||
if (${ctx.genEqual(value.dataType, valueGen.primitive, x.primitive)}) {
|
||||
${ev.primitive} = true;
|
||||
}
|
||||
}
|
||||
""").mkString("\n")
|
||||
s"""
|
||||
${valueGen.code}
|
||||
boolean ${ev.primitive} = false;
|
||||
boolean ${ev.isNull} = false;
|
||||
$listCode
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Optimized version of In clause, when all filter values of In clause are
|
||||
* static.
|
||||
*/
|
||||
case class InSet(child: Expression, hset: Set[Any])
|
||||
extends UnaryExpression with Predicate with CodegenFallback {
|
||||
case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with Predicate {
|
||||
|
||||
override def nullable: Boolean = true // TODO: Figure out correct nullability semantics of IN.
|
||||
override def nullable: Boolean = false // TODO: Figure out correct nullability semantics of IN.
|
||||
override def toString: String = s"$child INSET ${hset.mkString("(", ",", ")")}"
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
hset.contains(child.eval(input))
|
||||
}
|
||||
|
||||
def getHSet(): Set[Any] = hset
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
val setName = classOf[Set[Any]].getName
|
||||
val InSetName = classOf[InSet].getName
|
||||
val childGen = child.gen(ctx)
|
||||
ctx.references += this
|
||||
val hsetTerm = ctx.freshName("hset")
|
||||
ctx.addMutableState(setName, hsetTerm,
|
||||
s"$hsetTerm = (($InSetName)expressions[${ctx.references.size - 1}]).getHSet();")
|
||||
s"""
|
||||
${childGen.code}
|
||||
boolean ${ev.isNull} = false;
|
||||
boolean ${ev.primitive} = $hsetTerm.contains(${childGen.primitive});
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate {
|
||||
|
|
|
@ -393,7 +393,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
|
|||
object OptimizeIn extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsDown {
|
||||
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) =>
|
||||
case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) && list.size > 10 =>
|
||||
val hSet = list.map(e => e.eval(EmptyRow))
|
||||
InSet(v, HashSet() ++ hSet)
|
||||
}
|
||||
|
|
|
@ -21,7 +21,8 @@ import scala.collection.immutable.HashSet
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.types.{Decimal, DoubleType, IntegerType, BooleanType}
|
||||
import org.apache.spark.sql.RandomDataGenerator
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
||||
class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||
|
@ -118,6 +119,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("^Ba*n"))), true)
|
||||
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^Ba*n"))), true)
|
||||
checkEvaluation(In(Literal("^Ba*n"), Seq(Literal("aa"), Literal("^n"))), false)
|
||||
|
||||
val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
|
||||
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
|
||||
primitiveTypes.map { t =>
|
||||
val dataGen = RandomDataGenerator.forType(t, nullable = false).get
|
||||
val inputData = Seq.fill(10) {
|
||||
val value = dataGen.apply()
|
||||
value match {
|
||||
case d: Double if d.isNaN => 0.0d
|
||||
case f: Float if f.isNaN => 0.0f
|
||||
case _ => value
|
||||
}
|
||||
}
|
||||
val input = inputData.map(Literal(_))
|
||||
checkEvaluation(In(input(0), input.slice(1, 10)),
|
||||
inputData.slice(1, 10).contains(inputData(0)))
|
||||
}
|
||||
}
|
||||
|
||||
test("INSET") {
|
||||
|
@ -134,6 +152,23 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(InSet(three, hS), false)
|
||||
checkEvaluation(InSet(three, nS), false)
|
||||
checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
|
||||
|
||||
val primitiveTypes = Seq(IntegerType, FloatType, DoubleType, StringType, ByteType, ShortType,
|
||||
LongType, BinaryType, BooleanType, DecimalType.USER_DEFAULT, TimestampType)
|
||||
primitiveTypes.map { t =>
|
||||
val dataGen = RandomDataGenerator.forType(t, nullable = false).get
|
||||
val inputData = Seq.fill(10) {
|
||||
val value = dataGen.apply()
|
||||
value match {
|
||||
case d: Double if d.isNaN => 0.0d
|
||||
case f: Float if f.isNaN => 0.0f
|
||||
case _ => value
|
||||
}
|
||||
}
|
||||
val input = inputData.map(Literal(_))
|
||||
checkEvaluation(InSet(input(0), inputData.slice(1, 10).toSet),
|
||||
inputData.slice(1, 10).contains(inputData(0)))
|
||||
}
|
||||
}
|
||||
|
||||
private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d, false).map(Literal(_))
|
||||
|
|
|
@ -43,16 +43,26 @@ class OptimizeInSuite extends PlanTest {
|
|||
|
||||
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
|
||||
|
||||
test("OptimizedIn test: In clause optimized to InSet") {
|
||||
test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
|
||||
.analyze
|
||||
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
comparePlans(optimized, originalQuery)
|
||||
}
|
||||
|
||||
test("OptimizedIn test: In clause optimized to InSet when more than 10 items") {
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.where(In(UnresolvedAttribute("a"), (1 to 11).map(Literal(_))))
|
||||
.analyze
|
||||
|
||||
val optimized = Optimize.execute(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.where(InSet(UnresolvedAttribute("a"), HashSet[Any]() + 1 + 2))
|
||||
.where(InSet(UnresolvedAttribute("a"), (1 to 11).toSet))
|
||||
.analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
|
|
|
@ -366,6 +366,13 @@ private[sql] object DataSourceStrategy extends Strategy with Logging {
|
|||
case expressions.InSet(a: Attribute, set) =>
|
||||
Some(sources.In(a.name, set.toArray))
|
||||
|
||||
// Because we only convert In to InSet in Optimizer when there are more than certain
|
||||
// items. So it is possible we still get an In expression here that needs to be pushed
|
||||
// down.
|
||||
case expressions.In(a: Attribute, list) if !list.exists(!_.isInstanceOf[Literal]) =>
|
||||
val hSet = list.map(e => e.eval(EmptyRow))
|
||||
Some(sources.In(a.name, hSet.toArray))
|
||||
|
||||
case expressions.IsNull(a: Attribute) =>
|
||||
Some(sources.IsNull(a.name))
|
||||
case expressions.IsNotNull(a: Attribute) =>
|
||||
|
|
|
@ -357,6 +357,12 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
|
|||
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "x"))
|
||||
checkAnswer(df.filter($"b".in("z", "y")),
|
||||
df.collect().toSeq.filter(r => r.getString(1) == "z" || r.getString(1) == "y"))
|
||||
|
||||
val df2 = Seq((1, Seq(1)), (2, Seq(2)), (3, Seq(3))).toDF("a", "b")
|
||||
|
||||
intercept[AnalysisException] {
|
||||
df2.filter($"a".in($"b"))
|
||||
}
|
||||
}
|
||||
|
||||
val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(
|
||||
|
|
Loading…
Reference in a new issue