[SPARK-2052] [SQL] Add optimization for CaseConversionExpression's.

Add optimization for `CaseConversionExpression`'s.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #990 from ueshin/issues/SPARK-2052 and squashes the following commits:

2568666 [Takuya UESHIN] Move some rules back.
dde7ede [Takuya UESHIN] Add tests to check if ConstantFolding can handle null literals and remove the unneeded rules from NullPropagation.
c4eea67 [Takuya UESHIN] Fix toString methods.
23e2363 [Takuya UESHIN] Make CaseConversionExpressions foldable if the child is foldable.
0ff7568 [Takuya UESHIN] Add tests for collapsing case statements.
3977d80 [Takuya UESHIN] Add optimization for CaseConversionExpression's.
This commit is contained in:
Takuya UESHIN 2014-06-11 17:58:35 -07:00 committed by Michael Armbrust
parent d45e0c6b98
commit 9a2448daf9
4 changed files with 174 additions and 15 deletions

View file

@ -76,7 +76,8 @@ trait CaseConversionExpression {
type EvaluatedType = Any
def convert(v: String): String
override def foldable: Boolean = child.foldable
def nullable: Boolean = child.nullable
def dataType: DataType = StringType
@ -142,6 +143,8 @@ case class RLike(left: Expression, right: Expression)
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
override def convert(v: String): String = v.toUpperCase()
override def toString() = s"Upper($child)"
}
/**
@ -150,4 +153,6 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
override def convert(v: String): String = v.toLowerCase()
override def toString() = s"Lower($child)"
}

View file

@ -36,7 +36,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] {
ConstantFolding,
BooleanSimplification,
SimplifyFilters,
SimplifyCasts) ::
SimplifyCasts,
SimplifyCaseConversionExpressions) ::
Batch("Filter Pushdown", FixedPoint(100),
CombineFilters,
PushPredicateThroughProject,
@ -132,18 +133,6 @@ object NullPropagation extends Rule[LogicalPlan] {
case Literal(candidate, _) if candidate == v => true
case _ => false
})) => Literal(true, BooleanType)
case e: UnaryMinus => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: Cast => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
case e: Not => e.child match {
case Literal(null, _) => Literal(null, e.dataType)
case _ => e
}
// Put exceptional cases above if any
case e: BinaryArithmetic => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
@ -375,3 +364,18 @@ object CombineLimits extends Rule[LogicalPlan] {
Limit(If(LessThan(ne, le), ne, le), grandChild)
}
}
/**
* Removes the inner [[catalyst.expressions.CaseConversionExpression]] that are unnecessary because
* the inner conversion is overwritten by the outer one.
*/
object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
case Upper(Upper(child)) => Upper(child)
case Upper(Lower(child)) => Upper(child)
case Lower(Upper(child)) => Lower(child)
case Lower(Lower(child)) => Lower(child)
}
}
}

View file

@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType}
import org.apache.spark.sql.catalyst.types._
// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
@ -173,4 +173,63 @@ class ConstantFoldingSuite extends OptimizerTest {
comparePlans(optimized, correctAnswer)
}
test("Constant folding test: expressions have null literals") {
val originalQuery =
testRelation
.select(
IsNull(Literal(null)) as 'c1,
IsNotNull(Literal(null)) as 'c2,
GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3,
GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4,
GetField(
Literal(null, StructType(Seq(StructField("a", IntegerType, true)))),
"a") as 'c5,
UnaryMinus(Literal(null, IntegerType)) as 'c6,
Cast(Literal(null), IntegerType) as 'c7,
Not(Literal(null, BooleanType)) as 'c8,
Add(Literal(null, IntegerType), 1) as 'c9,
Add(1, Literal(null, IntegerType)) as 'c10,
Equals(Literal(null, IntegerType), 1) as 'c11,
Equals(1, Literal(null, IntegerType)) as 'c12,
Like(Literal(null, StringType), "abc") as 'c13,
Like("abc", Literal(null, StringType)) as 'c14,
Upper(Literal(null, StringType)) as 'c15)
val optimized = Optimize(originalQuery.analyze)
val correctAnswer =
testRelation
.select(
Literal(true) as 'c1,
Literal(false) as 'c2,
Literal(null, IntegerType) as 'c3,
Literal(null, IntegerType) as 'c4,
Literal(null, IntegerType) as 'c5,
Literal(null, IntegerType) as 'c6,
Literal(null, IntegerType) as 'c7,
Literal(null, BooleanType) as 'c8,
Literal(null, IntegerType) as 'c9,
Literal(null, IntegerType) as 'c10,
Literal(null, BooleanType) as 'c11,
Literal(null, BooleanType) as 'c12,
Literal(null, BooleanType) as 'c13,
Literal(null, BooleanType) as 'c14,
Literal(null, StringType) as 'c15)
.analyze
comparePlans(optimized, correctAnswer)
}
}

View file

@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
class SimplifyCaseConversionExpressionsSuite extends OptimizerTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Simplify CaseConversionExpressions", Once,
SimplifyCaseConversionExpressions) :: Nil
}
val testRelation = LocalRelation('a.string)
test("simplify UPPER(UPPER(str))") {
val originalQuery =
testRelation
.select(Upper(Upper('a)) as 'u)
val optimized = Optimize(originalQuery.analyze)
val correctAnswer =
testRelation
.select(Upper('a) as 'u)
.analyze
comparePlans(optimized, correctAnswer)
}
test("simplify UPPER(LOWER(str))") {
val originalQuery =
testRelation
.select(Upper(Lower('a)) as 'u)
val optimized = Optimize(originalQuery.analyze)
val correctAnswer =
testRelation
.select(Upper('a) as 'u)
.analyze
comparePlans(optimized, correctAnswer)
}
test("simplify LOWER(UPPER(str))") {
val originalQuery =
testRelation
.select(Lower(Upper('a)) as 'l)
val optimized = Optimize(originalQuery.analyze)
val correctAnswer = testRelation
.select(Lower('a) as 'l)
.analyze
comparePlans(optimized, correctAnswer)
}
test("simplify LOWER(LOWER(str))") {
val originalQuery =
testRelation
.select(Lower(Lower('a)) as 'l)
val optimized = Optimize(originalQuery.analyze)
val correctAnswer = testRelation
.select(Lower('a) as 'l)
.analyze
comparePlans(optimized, correctAnswer)
}
}