[SPARK-2052] [SQL] Add optimization for CaseConversionExpression's.
Add optimization for `CaseConversionExpression`'s. Author: Takuya UESHIN <ueshin@happy-camper.st> Closes #990 from ueshin/issues/SPARK-2052 and squashes the following commits: 2568666 [Takuya UESHIN] Move some rules back. dde7ede [Takuya UESHIN] Add tests to check if ConstantFolding can handle null literals and remove the unneeded rules from NullPropagation. c4eea67 [Takuya UESHIN] Fix toString methods. 23e2363 [Takuya UESHIN] Make CaseConversionExpressions foldable if the child is foldable. 0ff7568 [Takuya UESHIN] Add tests for collapsing case statements. 3977d80 [Takuya UESHIN] Add optimization for CaseConversionExpression's.
This commit is contained in:
parent
d45e0c6b98
commit
9a2448daf9
|
@ -76,7 +76,8 @@ trait CaseConversionExpression {
|
|||
type EvaluatedType = Any
|
||||
|
||||
def convert(v: String): String
|
||||
|
||||
|
||||
override def foldable: Boolean = child.foldable
|
||||
def nullable: Boolean = child.nullable
|
||||
def dataType: DataType = StringType
|
||||
|
||||
|
@ -142,6 +143,8 @@ case class RLike(left: Expression, right: Expression)
|
|||
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
|
||||
|
||||
override def convert(v: String): String = v.toUpperCase()
|
||||
|
||||
override def toString() = s"Upper($child)"
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -150,4 +153,6 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
|
|||
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
|
||||
|
||||
override def convert(v: String): String = v.toLowerCase()
|
||||
|
||||
override def toString() = s"Lower($child)"
|
||||
}
|
||||
|
|
|
@ -36,7 +36,8 @@ object Optimizer extends RuleExecutor[LogicalPlan] {
|
|||
ConstantFolding,
|
||||
BooleanSimplification,
|
||||
SimplifyFilters,
|
||||
SimplifyCasts) ::
|
||||
SimplifyCasts,
|
||||
SimplifyCaseConversionExpressions) ::
|
||||
Batch("Filter Pushdown", FixedPoint(100),
|
||||
CombineFilters,
|
||||
PushPredicateThroughProject,
|
||||
|
@ -132,18 +133,6 @@ object NullPropagation extends Rule[LogicalPlan] {
|
|||
case Literal(candidate, _) if candidate == v => true
|
||||
case _ => false
|
||||
})) => Literal(true, BooleanType)
|
||||
case e: UnaryMinus => e.child match {
|
||||
case Literal(null, _) => Literal(null, e.dataType)
|
||||
case _ => e
|
||||
}
|
||||
case e: Cast => e.child match {
|
||||
case Literal(null, _) => Literal(null, e.dataType)
|
||||
case _ => e
|
||||
}
|
||||
case e: Not => e.child match {
|
||||
case Literal(null, _) => Literal(null, e.dataType)
|
||||
case _ => e
|
||||
}
|
||||
// Put exceptional cases above if any
|
||||
case e: BinaryArithmetic => e.children match {
|
||||
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
|
||||
|
@ -375,3 +364,18 @@ object CombineLimits extends Rule[LogicalPlan] {
|
|||
Limit(If(LessThan(ne, le), ne, le), grandChild)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes the inner [[catalyst.expressions.CaseConversionExpression]] that are unnecessary because
|
||||
* the inner conversion is overwritten by the outer one.
|
||||
*/
|
||||
object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case Upper(Upper(child)) => Upper(child)
|
||||
case Upper(Lower(child)) => Upper(child)
|
||||
case Lower(Upper(child)) => Lower(child)
|
||||
case Lower(Lower(child)) => Lower(child)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
|
|||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
|
||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||
import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType}
|
||||
import org.apache.spark.sql.catalyst.types._
|
||||
|
||||
// For implicit conversions
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
|
@ -173,4 +173,63 @@ class ConstantFoldingSuite extends OptimizerTest {
|
|||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("Constant folding test: expressions have null literals") {
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.select(
|
||||
IsNull(Literal(null)) as 'c1,
|
||||
IsNotNull(Literal(null)) as 'c2,
|
||||
|
||||
GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3,
|
||||
GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4,
|
||||
GetField(
|
||||
Literal(null, StructType(Seq(StructField("a", IntegerType, true)))),
|
||||
"a") as 'c5,
|
||||
|
||||
UnaryMinus(Literal(null, IntegerType)) as 'c6,
|
||||
Cast(Literal(null), IntegerType) as 'c7,
|
||||
Not(Literal(null, BooleanType)) as 'c8,
|
||||
|
||||
Add(Literal(null, IntegerType), 1) as 'c9,
|
||||
Add(1, Literal(null, IntegerType)) as 'c10,
|
||||
|
||||
Equals(Literal(null, IntegerType), 1) as 'c11,
|
||||
Equals(1, Literal(null, IntegerType)) as 'c12,
|
||||
|
||||
Like(Literal(null, StringType), "abc") as 'c13,
|
||||
Like("abc", Literal(null, StringType)) as 'c14,
|
||||
|
||||
Upper(Literal(null, StringType)) as 'c15)
|
||||
|
||||
val optimized = Optimize(originalQuery.analyze)
|
||||
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select(
|
||||
Literal(true) as 'c1,
|
||||
Literal(false) as 'c2,
|
||||
|
||||
Literal(null, IntegerType) as 'c3,
|
||||
Literal(null, IntegerType) as 'c4,
|
||||
Literal(null, IntegerType) as 'c5,
|
||||
|
||||
Literal(null, IntegerType) as 'c6,
|
||||
Literal(null, IntegerType) as 'c7,
|
||||
Literal(null, BooleanType) as 'c8,
|
||||
|
||||
Literal(null, IntegerType) as 'c9,
|
||||
Literal(null, IntegerType) as 'c10,
|
||||
|
||||
Literal(null, BooleanType) as 'c11,
|
||||
Literal(null, BooleanType) as 'c12,
|
||||
|
||||
Literal(null, BooleanType) as 'c13,
|
||||
Literal(null, BooleanType) as 'c14,
|
||||
|
||||
Literal(null, StringType) as 'c15)
|
||||
.analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules._
|
||||
|
||||
/* Implicit conversions */
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
|
||||
class SimplifyCaseConversionExpressionsSuite extends OptimizerTest {
|
||||
|
||||
object Optimize extends RuleExecutor[LogicalPlan] {
|
||||
val batches =
|
||||
Batch("Simplify CaseConversionExpressions", Once,
|
||||
SimplifyCaseConversionExpressions) :: Nil
|
||||
}
|
||||
|
||||
val testRelation = LocalRelation('a.string)
|
||||
|
||||
test("simplify UPPER(UPPER(str))") {
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.select(Upper(Upper('a)) as 'u)
|
||||
|
||||
val optimized = Optimize(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select(Upper('a) as 'u)
|
||||
.analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("simplify UPPER(LOWER(str))") {
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.select(Upper(Lower('a)) as 'u)
|
||||
|
||||
val optimized = Optimize(originalQuery.analyze)
|
||||
val correctAnswer =
|
||||
testRelation
|
||||
.select(Upper('a) as 'u)
|
||||
.analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("simplify LOWER(UPPER(str))") {
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.select(Lower(Upper('a)) as 'l)
|
||||
|
||||
val optimized = Optimize(originalQuery.analyze)
|
||||
val correctAnswer = testRelation
|
||||
.select(Lower('a) as 'l)
|
||||
.analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
|
||||
test("simplify LOWER(LOWER(str))") {
|
||||
val originalQuery =
|
||||
testRelation
|
||||
.select(Lower(Lower('a)) as 'l)
|
||||
|
||||
val optimized = Optimize(originalQuery.analyze)
|
||||
val correctAnswer = testRelation
|
||||
.select(Lower('a) as 'l)
|
||||
.analyze
|
||||
|
||||
comparePlans(optimized, correctAnswer)
|
||||
}
|
||||
}
|
Loading…
Reference in a new issue