[SPARK-20730][SQL] Add an optimizer rule to combine nested Concat
## What changes were proposed in this pull request? This pr added a new Optimizer rule to combine nested Concat. The master supports a pipeline operator '||' to concatenate strings in #17711 (This pr is follow-up). Since the parser currently generates nested Concat expressions, the optimizer needs to combine the nested expressions. ## How was this patch tested? Added tests in `CombineConcatSuite` and `SQLQueryTestSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #17970 from maropu/SPARK-20730.
This commit is contained in:
parent
8da6e8b1f3
commit
b0888d1ac3
|
@ -111,7 +111,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf)
|
|||
RemoveRedundantProject,
|
||||
SimplifyCreateStructOps,
|
||||
SimplifyCreateArrayOps,
|
||||
SimplifyCreateMapOps) ++
|
||||
SimplifyCreateMapOps,
|
||||
CombineConcats) ++
|
||||
extendedOperatorOptimizationRules: _*) ::
|
||||
Batch("Check Cartesian Products", Once,
|
||||
CheckCartesianProducts(conf)) ::
|
||||
|
|
|
@ -18,6 +18,7 @@
|
|||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import scala.collection.immutable.HashSet
|
||||
import scala.collection.mutable.{ArrayBuffer, Stack}
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
@ -543,3 +544,28 @@ object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Combine nested [[Concat]] expressions.
|
||||
*/
|
||||
object CombineConcats extends Rule[LogicalPlan] {
|
||||
|
||||
private def flattenConcats(concat: Concat): Concat = {
|
||||
val stack = Stack[Expression](concat)
|
||||
val flattened = ArrayBuffer.empty[Expression]
|
||||
while (stack.nonEmpty) {
|
||||
stack.pop() match {
|
||||
case Concat(children) =>
|
||||
stack.pushAll(children.reverse)
|
||||
case child =>
|
||||
flattened += child
|
||||
}
|
||||
}
|
||||
Concat(flattened)
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.transformExpressionsDown {
|
||||
case concat: Concat if concat.children.exists(_.isInstanceOf[Concat]) =>
|
||||
flattenConcats(concat)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.optimizer
|
||||
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||
import org.apache.spark.sql.catalyst.plans.logical._
|
||||
import org.apache.spark.sql.catalyst.rules._
|
||||
import org.apache.spark.sql.types.StringType
|
||||
|
||||
|
||||
class CombineConcatsSuite extends PlanTest {
|
||||
|
||||
object Optimize extends RuleExecutor[LogicalPlan] {
|
||||
val batches = Batch("CombineConcatsSuite", FixedPoint(50), CombineConcats) :: Nil
|
||||
}
|
||||
|
||||
protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
|
||||
val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
|
||||
val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
|
||||
comparePlans(actual, correctAnswer)
|
||||
}
|
||||
|
||||
test("combine nested Concat exprs") {
|
||||
def str(s: String): Literal = Literal(s, StringType)
|
||||
assertEquivalent(
|
||||
Concat(
|
||||
Concat(str("a") :: str("b") :: Nil) ::
|
||||
str("c") ::
|
||||
str("d") ::
|
||||
Nil),
|
||||
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
|
||||
assertEquivalent(
|
||||
Concat(
|
||||
str("a") ::
|
||||
Concat(str("b") :: str("c") :: Nil) ::
|
||||
str("d") ::
|
||||
Nil),
|
||||
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
|
||||
assertEquivalent(
|
||||
Concat(
|
||||
str("a") ::
|
||||
str("b") ::
|
||||
Concat(str("c") :: str("d") :: Nil) ::
|
||||
Nil),
|
||||
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
|
||||
assertEquivalent(
|
||||
Concat(
|
||||
Concat(
|
||||
str("a") ::
|
||||
Concat(
|
||||
str("b") ::
|
||||
Concat(str("c") :: str("d") :: Nil) ::
|
||||
Nil) ::
|
||||
Nil) ::
|
||||
Nil),
|
||||
Concat(str("a") :: str("b") :: str("c") :: str("d") :: Nil))
|
||||
}
|
||||
}
|
|
@ -4,3 +4,7 @@ select format_string();
|
|||
|
||||
-- A pipe operator for string concatenation
|
||||
select 'a' || 'b' || 'c';
|
||||
|
||||
-- Check if catalyst combine nested `Concat`s
|
||||
EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col
|
||||
FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10));
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 3
|
||||
-- Number of queries: 4
|
||||
|
||||
|
||||
-- !query 0
|
||||
|
@ -26,3 +26,29 @@ select 'a' || 'b' || 'c'
|
|||
struct<concat(concat(a, b), c):string>
|
||||
-- !query 2 output
|
||||
abc
|
||||
|
||||
|
||||
-- !query 3
|
||||
EXPLAIN EXTENDED SELECT (col1 || col2 || col3 || col4) col
|
||||
FROM (SELECT id col1, id col2, id col3, id col4 FROM range(10))
|
||||
-- !query 3 schema
|
||||
struct<plan:string>
|
||||
-- !query 3 output
|
||||
== Parsed Logical Plan ==
|
||||
'Project [concat(concat(concat('col1, 'col2), 'col3), 'col4) AS col#x]
|
||||
+- 'Project ['id AS col1#x, 'id AS col2#x, 'id AS col3#x, 'id AS col4#x]
|
||||
+- 'UnresolvedTableValuedFunction range, [10]
|
||||
|
||||
== Analyzed Logical Plan ==
|
||||
col: string
|
||||
Project [concat(concat(concat(cast(col1#xL as string), cast(col2#xL as string)), cast(col3#xL as string)), cast(col4#xL as string)) AS col#x]
|
||||
+- Project [id#xL AS col1#xL, id#xL AS col2#xL, id#xL AS col3#xL, id#xL AS col4#xL]
|
||||
+- Range (0, 10, step=1, splits=None)
|
||||
|
||||
== Optimized Logical Plan ==
|
||||
Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x]
|
||||
+- Range (0, 10, step=1, splits=None)
|
||||
|
||||
== Physical Plan ==
|
||||
*Project [concat(cast(id#xL as string), cast(id#xL as string), cast(id#xL as string), cast(id#xL as string)) AS col#x]
|
||||
+- *Range (0, 10, step=1, splits=2)
|
||||
|
|
Loading…
Reference in a new issue