[SPARK-32741][SQL] Check if the same ExprId refers to the unique attribute in logical plans

### What changes were proposed in this pull request?

Some plan transformations (e.g., `RemoveNoopOperators`) implicitly assume the same `ExprId` refers to the unique attribute. But, `RuleExecutor` does not check this integrity between logical plan transformations. So, this PR intends to add this check in `isPlanIntegral` of `Analyzer`/`Optimizer`.

This PR comes from the talk with cloud-fan viirya in https://github.com/apache/spark/pull/29485#discussion_r475346278

### Why are the changes needed?

For better logical plan integrity checking.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

Existing tests.

Closes #29585 from maropu/PlanIntegrityTest.

Authored-by: Takeshi Yamamuro <yamamuro@apache.org>
Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
Takeshi Yamamuro 2020-09-30 21:37:29 +09:00
parent cc06266ade
commit 3a299aa648
8 changed files with 181 additions and 36 deletions

View file

@ -48,6 +48,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.Utils
/**
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]].
@ -136,6 +137,10 @@ class Analyzer(
private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)
}
override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)
// Only for tests.
@ -2777,8 +2782,8 @@ class Analyzer(
// a resolved Aggregate will not have Window Functions.
case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
@ -2795,7 +2800,7 @@ class Analyzer(
// Aggregate without Having clause.
case a @ Aggregate(groupingExprs, aggregateExprs, child)
if hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
a.expressions.forall(_.resolved) =>
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
// Create an Aggregate operator to evaluate aggregation functions.
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)

View file

@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
import org.apache.spark.sql.catalyst.expressions._
@ -44,9 +43,11 @@ abstract class Optimizer(catalogManager: CatalogManager)
// Currently we check after the execution of each rule if a plan:
// - is still resolved
// - only host special expressions in supported operators
// - has globally-unique attribute IDs
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || (plan.resolved &&
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty)
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
}
override protected val excludedOnceBatches: Set[String] =
@ -1585,14 +1586,14 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
* Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator.
*/
object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Deduplicate(keys, child) if !child.isStreaming =>
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case d @ Deduplicate(keys, child) if !child.isStreaming =>
val keyExprIds = keys.map(_.exprId)
val aggCols = child.output.map { attr =>
if (keyExprIds.contains(attr.exprId)) {
attr
} else {
Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
Alias(new First(attr).toAggregateExpression(), attr.name)()
}
}
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
@ -1601,7 +1602,9 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
// we append a literal when the grouping key list is empty so that the result aggregate
// operator is properly treated as a grouping aggregation.
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
Aggregate(nonemptyKeys, aggCols, child)
val newAgg = Aggregate(nonemptyKeys, aggCols, child)
val attrMapping = d.output.zip(newAgg.output)
newAgg -> attrMapping
}
}

View file

@ -338,15 +338,20 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
/**
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
* the given collector. The expression is rewritten and returned.
* the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId`
* for the subqueries and rewrite references in the given `expression`.
* This method returns extracted subqueries and the corresponding `exprId`s and these values
* will be used later in `constructLeftJoins` for building the child plan that
* returns subquery output with the `exprId`s.
*/
private def extractCorrelatedScalarSubqueries[E <: Expression](
expression: E,
subqueries: ArrayBuffer[ScalarSubquery]): E = {
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = {
val newExpression = expression transform {
case s: ScalarSubquery if s.children.nonEmpty =>
subqueries += s
s.plan.output.head
val newExprId = NamedExpression.newExprId
subqueries += s -> newExprId
s.plan.output.head.withExprId(newExprId)
}
newExpression.asInstanceOf[E]
}
@ -510,16 +515,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
*/
private def constructLeftJoins(
child: LogicalPlan,
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = {
subqueries.foldLeft(child) {
case (currentChild, ScalarSubquery(query, conditions, _)) =>
case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) =>
val origOutput = query.output.head
val resultWithZeroTups = evalSubqueryOnZeroTups(query)
if (resultWithZeroTups.isEmpty) {
// CASE 1: Subquery guaranteed not to have the COUNT bug
Project(
currentChild.output :+ origOutput,
currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId),
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
} else {
// Subquery might have the COUNT bug. Add appropriate corrections.
@ -544,7 +549,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
Alias(
If(IsNull(alwaysTrueRef),
resultWithZeroTups.get,
aggValRef), origOutput.name)(exprId = origOutput.exprId),
aggValRef), origOutput.name)(exprId = newExprId),
Join(currentChild,
Project(query.output :+ alwaysTrueExpr, query),
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
@ -571,7 +576,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
aggValRef),
origOutput.name)(exprId = origOutput.exprId)
origOutput.name)(exprId = newExprId)
Project(
currentChild.output :+ caseExpr,
@ -588,36 +593,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
* subqueries.
*/
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a @ Aggregate(grouping, expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
// We currently only allow correlated subqueries in an aggregate if they are part of the
// grouping expressions. As a result we need to replace all the scalar subqueries in the
// grouping expressions by their result.
val newGrouping = grouping.map { e =>
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e)
}
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
val attrMapping = a.output.zip(newAgg.output)
newAgg -> attrMapping
} else {
a
a -> Nil
}
case p @ Project(expressions, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
if (subqueries.nonEmpty) {
Project(newExpressions, constructLeftJoins(child, subqueries))
val newProj = Project(newExpressions, constructLeftJoins(child, subqueries))
val attrMapping = p.output.zip(newProj.output)
newProj -> attrMapping
} else {
p
p -> Nil
}
case f @ Filter(condition, child) =>
val subqueries = ArrayBuffer.empty[ScalarSubquery]
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
if (subqueries.nonEmpty) {
Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
val attrMapping = f.output.zip(newProj.output)
newProj -> attrMapping
} else {
f
f -> Nil
}
}
}

View file

@ -203,3 +203,73 @@ abstract class BinaryNode extends LogicalPlan {
abstract class OrderPreservingUnaryNode extends UnaryNode {
override final def outputOrdering: Seq[SortOrder] = child.outputOrdering
}
object LogicalPlanIntegrity {
private def canGetOutputAttrs(p: LogicalPlan): Boolean = {
p.resolved && !p.expressions.exists { e =>
e.collectFirst {
// We cannot call `output` in plans with a `ScalarSubquery` expr having no column,
// so, we filter out them in advance.
case s: ScalarSubquery if s.plan.schema.fields.isEmpty => true
}.isDefined
}
}
/**
* Since some logical plans (e.g., `Union`) can build `AttributeReference`s in their `output`,
* this method checks if the same `ExprId` refers to attributes having the same data type
* in plan output.
*/
def hasUniqueExprIdsForOutput(plan: LogicalPlan): Boolean = {
val exprIds = plan.collect { case p if canGetOutputAttrs(p) =>
// NOTE: we still need to filter resolved expressions here because the output of
// some resolved logical plans can have unresolved references,
// e.g., outer references in `ExistenceJoin`.
p.output.filter(_.resolved).map { a => (a.exprId, a.dataType) }
}.flatten
val ignoredExprIds = plan.collect {
// NOTE: `Union` currently reuses input `ExprId`s for output references, but we cannot
// simply modify the code for assigning new `ExprId`s in `Union#output` because
// the modification will make breaking changes (See SPARK-32741(#29585)).
// So, this check just ignores the `exprId`s of `Union` output.
case u: Union if u.resolved => u.output.map(_.exprId)
}.flatten.toSet
val groupedDataTypesByExprId = exprIds.filterNot { case (exprId, _) =>
ignoredExprIds.contains(exprId)
}.groupBy(_._1).values.map(_.distinct)
groupedDataTypesByExprId.forall(_.length == 1)
}
/**
* This method checks if reference `ExprId`s are not reused when assigning a new `ExprId`.
* For example, it returns false if plan transformers create an alias having the same `ExprId`
* with one of reference attributes, e.g., `a#1 + 1 AS a#1`.
*/
def checkIfSameExprIdNotReused(plan: LogicalPlan): Boolean = {
plan.collect { case p if p.resolved =>
p.expressions.forall {
case a: Alias =>
// Even if a plan is resolved, `a.references` can return unresolved references,
// e.g., in `Grouping`/`GroupingID`, so we need to filter out them and
// check if the same `exprId` in `Alias` does not exist
// among reference `exprId`s.
!a.references.filter(_.resolved).map(_.exprId).exists(_ == a.exprId)
case _ =>
true
}
}.forall(identity)
}
/**
* This method checks if the same `ExprId` refers to an unique attribute in a plan tree.
* Some plan transformers (e.g., `RemoveNoopOperators`) rewrite logical
* plans based on this assumption.
*/
def checkIfExprIdsAreGloballyUnique(plan: LogicalPlan): Boolean = {
checkIfSameExprIdNotReused(plan) && hasUniqueExprIdsForOutput(plan)
}
}

View file

@ -156,8 +156,8 @@ class FoldablePropagationSuite extends PlanTest {
val query = expand.where(a1.isNotNull).select(a1, a2).analyze
val optimized = Optimize.execute(query)
val correctExpand = expand.copy(projections = Seq(
Seq(Literal(null), c2),
Seq(c1, Literal(null))))
Seq(Literal(null), Literal(2)),
Seq(Literal(1), Literal(null))))
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
comparePlans(optimized, correctAnswer)
}

View file

@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.types.LongType
class LogicalPlanIntegritySuite extends PlanTest {
import LogicalPlanIntegrity._
case class OutputTestPlan(child: LogicalPlan, output: Seq[Attribute]) extends UnaryNode {
override val analyzed = true
}
test("Checks if the same `ExprId` refers to a semantically-equal attribute in a plan output") {
val t = LocalRelation('a.int, 'b.int)
assert(hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output)))
assert(!hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output.zipWithIndex.map {
case (a, i) => AttributeReference(s"c$i", LongType)(a.exprId)
})))
}
test("Checks if reference ExprIds are not reused when assigning a new ExprId") {
val t = LocalRelation('a.int, 'b.int)
val Seq(a, b) = t.output
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")())))
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = a.exprId))))
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId))))
assert(checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")())))
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = a.exprId))))
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = b.exprId))))
}
}

View file

@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils
@ -54,4 +54,10 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] {
}
}
}
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || (plan.resolved &&
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
}
}

View file

@ -36,7 +36,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.plans.logical.Range
import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
import org.apache.spark.sql.execution.command.ExplainCommand
@ -47,7 +46,7 @@ import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.sources.StreamSourceProvider
import org.apache.spark.sql.streaming.util.{BlockOnStopSourceProvider, StreamManualClock}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
import org.apache.spark.util.Utils
class StreamSuite extends StreamTest {
@ -1268,7 +1267,7 @@ class StreamSuite extends StreamTest {
}
abstract class FakeSource extends StreamSourceProvider {
private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
private val fakeSchema = StructType(StructField("a", LongType) :: Nil)
override def sourceSchema(
spark: SQLContext,
@ -1290,7 +1289,7 @@ class FakeDefaultSource extends FakeSource {
new Source {
private var offset = -1L
override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)
override def schema: StructType = StructType(StructField("a", LongType) :: Nil)
override def getOffset: Option[Offset] = {
if (offset >= 10) {