[SPARK-20730][SQL] Add an optimizer rule to combine nested Concat

## What changes were proposed in this pull request?
This pr added a new Optimizer rule to combine nested Concat. The master supports a pipeline operator '||' to concatenate strings in #17711 (This pr is follow-up). Since the parser currently generates nested Concat expressions, the optimizer needs to combine the nested expressions.

## How was this patch tested?
Added tests in `CombineConcatSuite` and `SQLQueryTestSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #17970 from maropu/SPARK-20730.
This commit is contained in:
Takeshi Yamamuro 2017-05-15 16:24:55 +08:00 committed by Wenchen Fan
parent 8da6e8b1f3
commit b0888d1ac3
5 changed files with 134 additions and 2 deletions

View file

@ -111,7 +111,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
RemoveRedundantProject,
SimplifyCreateStructOps,
SimplifyCreateArrayOps,
SimplifyCreateMapOps) ++
SimplifyCreateMapOps,
CombineConcats) ++
extendedOperatorOptimizationRules: _*) ::
Batch("Check Cartesian Products", Once,
CheckCartesianProducts(conf)) ::

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.immutable.HashSet
import scala.collection.mutable.{ArrayBuffer, Stack}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
@ -543,3 +544,28 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
}
}
}
/**
* Combine nested [[Concat]] expressions.
*/
object CombineConcats extends Rule[LogicalPlan] {
private def flattenConcats(concat: Concat): Concat = {
val stack = Stack[Expression](concat)
val flattened = ArrayBuffer.empty[Expression]
while (stack.nonEmpty) {
stack.pop() match {
case Concat(children) =>
stack.pushAll(children.reverse)
case child =>
flattened += child
}
}
Concat(flattened)
}
def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {
case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) =>
flattenConcats(concat)
}
}

View file

@ -0,0 +1,75 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.types.StringType
class CombineConcatsSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("CombineConcatsSuite", FixedPoint(50), CombineConcats) :: Nil
}
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
comparePlans(actual, correctAnswer)
}
test("combine nested Concat exprs") {
def str(s: String): Literal = Literal(s, StringType)
assertEquivalent(
Concat(
Concat(str("a") :: str("b") :: Nil) ::
str("c") ::
str("d") ::
Nil),
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
assertEquivalent(
Concat(
str("a") ::
Concat(str("b") :: str("c") :: Nil) ::
str("d") ::
Nil),
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
assertEquivalent(
Concat(
str("a") ::
str("b") ::
Concat(str("c") :: str("d") :: Nil) ::
Nil),
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
assertEquivalent(
Concat(
Concat(
str("a") ::
Concat(
str("b") ::
Concat(str("c") :: str("d") :: Nil) ::
Nil) ::
Nil) ::
Nil),
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
}
}

View file

@ -4,3 +4,7 @@ select format_string();
-- A pipe operator for string concatenation
select 'a' || 'b' || 'c';
-- Check if catalyst combine nested `Concat`s
EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col
FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10));

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 3
-- Number of queries: 4
-- !query 0
@ -26,3 +26,29 @@ select 'a' || 'b' || 'c'
struct<concat(concat(a, b), c):string>
-- !query 2 output
abc
-- !query 3
EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col
FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10))
-- !query 3 schema
struct<plan:string>
-- !query 3 output
== Parsed Logical Plan ==
'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x]
+- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x]
+- 'UnresolvedTableValuedFunction range, [10]
== Analyzed Logical Plan ==
col: string
Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x]
+- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL]
+- Range (0, 10, step=1, splits=None)
== Optimized Logical Plan ==
Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x]
+- Range (0, 10, step=1, splits=None)
== Physical Plan ==
*Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x]
+- *Range (0, 10, step=1, splits=2)