[SPARK-32741][SQL] Check if the same ExprId refers to the unique attribute in logical plans
### What changes were proposed in this pull request? Some plan transformations (e.g., `RemoveNoopOperators`) implicitly assume the same `ExprId` refers to the unique attribute. But, `RuleExecutor` does not check this integrity between logical plan transformations. So, this PR intends to add this check in `isPlanIntegral` of `Analyzer`/`Optimizer`. This PR comes from the talk with cloud-fan viirya in https://github.com/apache/spark/pull/29485#discussion_r475346278 ### Why are the changes needed? For better logical plan integrity checking. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Existing tests. Closes #29585 from maropu/PlanIntegrityTest. Authored-by: Takeshi Yamamuro <yamamuro@apache.org> Signed-off-by: Takeshi Yamamuro <yamamuro@apache.org>
This commit is contained in:
parent
cc06266ade
commit
3a299aa648
|
@ -48,6 +48,7 @@ import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
|
import org.apache.spark.sql.internal.SQLConf.{PartitionOverwriteMode, StoreAssignmentPolicy}
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
import org.apache.spark.sql.util.CaseInsensitiveStringMap
|
||||||
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]].
|
* A trivial [[Analyzer]] with a dummy [[SessionCatalog]] and [[EmptyFunctionRegistry]].
|
||||||
|
@ -136,6 +137,10 @@ class Analyzer(
|
||||||
|
|
||||||
private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog
|
private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog
|
||||||
|
|
||||||
|
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
|
||||||
|
!Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)
|
||||||
|
}
|
||||||
|
|
||||||
override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)
|
override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)
|
||||||
|
|
||||||
// Only for tests.
|
// Only for tests.
|
||||||
|
@ -2777,8 +2782,8 @@ class Analyzer(
|
||||||
// a resolved Aggregate will not have Window Functions.
|
// a resolved Aggregate will not have Window Functions.
|
||||||
case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
|
case f @ UnresolvedHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
|
||||||
if child.resolved &&
|
if child.resolved &&
|
||||||
hasWindowFunction(aggregateExprs) &&
|
hasWindowFunction(aggregateExprs) &&
|
||||||
a.expressions.forall(_.resolved) =>
|
a.expressions.forall(_.resolved) =>
|
||||||
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
|
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
|
||||||
// Create an Aggregate operator to evaluate aggregation functions.
|
// Create an Aggregate operator to evaluate aggregation functions.
|
||||||
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
|
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
|
||||||
|
@ -2795,7 +2800,7 @@ class Analyzer(
|
||||||
// Aggregate without Having clause.
|
// Aggregate without Having clause.
|
||||||
case a @ Aggregate(groupingExprs, aggregateExprs, child)
|
case a @ Aggregate(groupingExprs, aggregateExprs, child)
|
||||||
if hasWindowFunction(aggregateExprs) &&
|
if hasWindowFunction(aggregateExprs) &&
|
||||||
a.expressions.forall(_.resolved) =>
|
a.expressions.forall(_.resolved) =>
|
||||||
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
|
val (windowExpressions, aggregateExpressions) = extract(aggregateExprs)
|
||||||
// Create an Aggregate operator to evaluate aggregation functions.
|
// Create an Aggregate operator to evaluate aggregation functions.
|
||||||
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
|
val withAggregate = Aggregate(groupingExprs, aggregateExpressions, child)
|
||||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.optimizer
|
||||||
import scala.collection.mutable
|
import scala.collection.mutable
|
||||||
|
|
||||||
import org.apache.spark.sql.AnalysisException
|
import org.apache.spark.sql.AnalysisException
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
|
||||||
import org.apache.spark.sql.catalyst.analysis._
|
import org.apache.spark.sql.catalyst.analysis._
|
||||||
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
|
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
|
@ -44,9 +43,11 @@ abstract class Optimizer(catalogManager: CatalogManager)
|
||||||
// Currently we check after the execution of each rule if a plan:
|
// Currently we check after the execution of each rule if a plan:
|
||||||
// - is still resolved
|
// - is still resolved
|
||||||
// - only host special expressions in supported operators
|
// - only host special expressions in supported operators
|
||||||
|
// - has globally-unique attribute IDs
|
||||||
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
|
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
|
||||||
!Utils.isTesting || (plan.resolved &&
|
!Utils.isTesting || (plan.resolved &&
|
||||||
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty)
|
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
|
||||||
|
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected val excludedOnceBatches: Set[String] =
|
override protected val excludedOnceBatches: Set[String] =
|
||||||
|
@ -1585,14 +1586,14 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
|
||||||
* Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator.
|
* Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator.
|
||||||
*/
|
*/
|
||||||
object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
|
object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
|
||||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
|
||||||
case Deduplicate(keys, child) if !child.isStreaming =>
|
case d @ Deduplicate(keys, child) if !child.isStreaming =>
|
||||||
val keyExprIds = keys.map(_.exprId)
|
val keyExprIds = keys.map(_.exprId)
|
||||||
val aggCols = child.output.map { attr =>
|
val aggCols = child.output.map { attr =>
|
||||||
if (keyExprIds.contains(attr.exprId)) {
|
if (keyExprIds.contains(attr.exprId)) {
|
||||||
attr
|
attr
|
||||||
} else {
|
} else {
|
||||||
Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
|
Alias(new First(attr).toAggregateExpression(), attr.name)()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
|
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
|
||||||
|
@ -1601,7 +1602,9 @@ object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
|
||||||
// we append a literal when the grouping key list is empty so that the result aggregate
|
// we append a literal when the grouping key list is empty so that the result aggregate
|
||||||
// operator is properly treated as a grouping aggregation.
|
// operator is properly treated as a grouping aggregation.
|
||||||
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
|
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
|
||||||
Aggregate(nonemptyKeys, aggCols, child)
|
val newAgg = Aggregate(nonemptyKeys, aggCols, child)
|
||||||
|
val attrMapping = d.output.zip(newAgg.output)
|
||||||
|
newAgg -> attrMapping
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -338,15 +338,20 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
|
||||||
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
|
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
|
||||||
/**
|
/**
|
||||||
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
|
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
|
||||||
* the given collector. The expression is rewritten and returned.
|
* the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId`
|
||||||
|
* for the subqueries and rewrite references in the given `expression`.
|
||||||
|
* This method returns extracted subqueries and the corresponding `exprId`s and these values
|
||||||
|
* will be used later in `constructLeftJoins` for building the child plan that
|
||||||
|
* returns subquery output with the `exprId`s.
|
||||||
*/
|
*/
|
||||||
private def extractCorrelatedScalarSubqueries[E <: Expression](
|
private def extractCorrelatedScalarSubqueries[E <: Expression](
|
||||||
expression: E,
|
expression: E,
|
||||||
subqueries: ArrayBuffer[ScalarSubquery]): E = {
|
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = {
|
||||||
val newExpression = expression transform {
|
val newExpression = expression transform {
|
||||||
case s: ScalarSubquery if s.children.nonEmpty =>
|
case s: ScalarSubquery if s.children.nonEmpty =>
|
||||||
subqueries += s
|
val newExprId = NamedExpression.newExprId
|
||||||
s.plan.output.head
|
subqueries += s -> newExprId
|
||||||
|
s.plan.output.head.withExprId(newExprId)
|
||||||
}
|
}
|
||||||
newExpression.asInstanceOf[E]
|
newExpression.asInstanceOf[E]
|
||||||
}
|
}
|
||||||
|
@ -510,16 +515,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
|
||||||
*/
|
*/
|
||||||
private def constructLeftJoins(
|
private def constructLeftJoins(
|
||||||
child: LogicalPlan,
|
child: LogicalPlan,
|
||||||
subqueries: ArrayBuffer[ScalarSubquery]): LogicalPlan = {
|
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = {
|
||||||
subqueries.foldLeft(child) {
|
subqueries.foldLeft(child) {
|
||||||
case (currentChild, ScalarSubquery(query, conditions, _)) =>
|
case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) =>
|
||||||
val origOutput = query.output.head
|
val origOutput = query.output.head
|
||||||
|
|
||||||
val resultWithZeroTups = evalSubqueryOnZeroTups(query)
|
val resultWithZeroTups = evalSubqueryOnZeroTups(query)
|
||||||
if (resultWithZeroTups.isEmpty) {
|
if (resultWithZeroTups.isEmpty) {
|
||||||
// CASE 1: Subquery guaranteed not to have the COUNT bug
|
// CASE 1: Subquery guaranteed not to have the COUNT bug
|
||||||
Project(
|
Project(
|
||||||
currentChild.output :+ origOutput,
|
currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId),
|
||||||
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
|
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
|
||||||
} else {
|
} else {
|
||||||
// Subquery might have the COUNT bug. Add appropriate corrections.
|
// Subquery might have the COUNT bug. Add appropriate corrections.
|
||||||
|
@ -544,7 +549,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
|
||||||
Alias(
|
Alias(
|
||||||
If(IsNull(alwaysTrueRef),
|
If(IsNull(alwaysTrueRef),
|
||||||
resultWithZeroTups.get,
|
resultWithZeroTups.get,
|
||||||
aggValRef), origOutput.name)(exprId = origOutput.exprId),
|
aggValRef), origOutput.name)(exprId = newExprId),
|
||||||
Join(currentChild,
|
Join(currentChild,
|
||||||
Project(query.output :+ alwaysTrueExpr, query),
|
Project(query.output :+ alwaysTrueExpr, query),
|
||||||
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
|
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
|
||||||
|
@ -571,7 +576,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
|
||||||
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
|
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
|
||||||
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
|
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
|
||||||
aggValRef),
|
aggValRef),
|
||||||
origOutput.name)(exprId = origOutput.exprId)
|
origOutput.name)(exprId = newExprId)
|
||||||
|
|
||||||
Project(
|
Project(
|
||||||
currentChild.output :+ caseExpr,
|
currentChild.output :+ caseExpr,
|
||||||
|
@ -588,36 +593,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
|
||||||
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
|
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
|
||||||
* subqueries.
|
* subqueries.
|
||||||
*/
|
*/
|
||||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
|
||||||
case a @ Aggregate(grouping, expressions, child) =>
|
case a @ Aggregate(grouping, expressions, child) =>
|
||||||
val subqueries = ArrayBuffer.empty[ScalarSubquery]
|
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
|
||||||
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
|
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
|
||||||
if (subqueries.nonEmpty) {
|
if (subqueries.nonEmpty) {
|
||||||
// We currently only allow correlated subqueries in an aggregate if they are part of the
|
// We currently only allow correlated subqueries in an aggregate if they are part of the
|
||||||
// grouping expressions. As a result we need to replace all the scalar subqueries in the
|
// grouping expressions. As a result we need to replace all the scalar subqueries in the
|
||||||
// grouping expressions by their result.
|
// grouping expressions by their result.
|
||||||
val newGrouping = grouping.map { e =>
|
val newGrouping = grouping.map { e =>
|
||||||
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
|
subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e)
|
||||||
}
|
}
|
||||||
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
|
val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
|
||||||
|
val attrMapping = a.output.zip(newAgg.output)
|
||||||
|
newAgg -> attrMapping
|
||||||
} else {
|
} else {
|
||||||
a
|
a -> Nil
|
||||||
}
|
}
|
||||||
case p @ Project(expressions, child) =>
|
case p @ Project(expressions, child) =>
|
||||||
val subqueries = ArrayBuffer.empty[ScalarSubquery]
|
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
|
||||||
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
|
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
|
||||||
if (subqueries.nonEmpty) {
|
if (subqueries.nonEmpty) {
|
||||||
Project(newExpressions, constructLeftJoins(child, subqueries))
|
val newProj = Project(newExpressions, constructLeftJoins(child, subqueries))
|
||||||
|
val attrMapping = p.output.zip(newProj.output)
|
||||||
|
newProj -> attrMapping
|
||||||
} else {
|
} else {
|
||||||
p
|
p -> Nil
|
||||||
}
|
}
|
||||||
case f @ Filter(condition, child) =>
|
case f @ Filter(condition, child) =>
|
||||||
val subqueries = ArrayBuffer.empty[ScalarSubquery]
|
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
|
||||||
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
|
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
|
||||||
if (subqueries.nonEmpty) {
|
if (subqueries.nonEmpty) {
|
||||||
Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
|
val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
|
||||||
|
val attrMapping = f.output.zip(newProj.output)
|
||||||
|
newProj -> attrMapping
|
||||||
} else {
|
} else {
|
||||||
f
|
f -> Nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -203,3 +203,73 @@ abstract class BinaryNode extends LogicalPlan {
|
||||||
abstract class OrderPreservingUnaryNode extends UnaryNode {
|
abstract class OrderPreservingUnaryNode extends UnaryNode {
|
||||||
override final def outputOrdering: Seq[SortOrder] = child.outputOrdering
|
override final def outputOrdering: Seq[SortOrder] = child.outputOrdering
|
||||||
}
|
}
|
||||||
|
|
||||||
|
object LogicalPlanIntegrity {
|
||||||
|
|
||||||
|
private def canGetOutputAttrs(p: LogicalPlan): Boolean = {
|
||||||
|
p.resolved && !p.expressions.exists { e =>
|
||||||
|
e.collectFirst {
|
||||||
|
// We cannot call `output` in plans with a `ScalarSubquery` expr having no column,
|
||||||
|
// so, we filter out them in advance.
|
||||||
|
case s: ScalarSubquery if s.plan.schema.fields.isEmpty => true
|
||||||
|
}.isDefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Since some logical plans (e.g., `Union`) can build `AttributeReference`s in their `output`,
|
||||||
|
* this method checks if the same `ExprId` refers to attributes having the same data type
|
||||||
|
* in plan output.
|
||||||
|
*/
|
||||||
|
def hasUniqueExprIdsForOutput(plan: LogicalPlan): Boolean = {
|
||||||
|
val exprIds = plan.collect { case p if canGetOutputAttrs(p) =>
|
||||||
|
// NOTE: we still need to filter resolved expressions here because the output of
|
||||||
|
// some resolved logical plans can have unresolved references,
|
||||||
|
// e.g., outer references in `ExistenceJoin`.
|
||||||
|
p.output.filter(_.resolved).map { a => (a.exprId, a.dataType) }
|
||||||
|
}.flatten
|
||||||
|
|
||||||
|
val ignoredExprIds = plan.collect {
|
||||||
|
// NOTE: `Union` currently reuses input `ExprId`s for output references, but we cannot
|
||||||
|
// simply modify the code for assigning new `ExprId`s in `Union#output` because
|
||||||
|
// the modification will make breaking changes (See SPARK-32741(#29585)).
|
||||||
|
// So, this check just ignores the `exprId`s of `Union` output.
|
||||||
|
case u: Union if u.resolved => u.output.map(_.exprId)
|
||||||
|
}.flatten.toSet
|
||||||
|
|
||||||
|
val groupedDataTypesByExprId = exprIds.filterNot { case (exprId, _) =>
|
||||||
|
ignoredExprIds.contains(exprId)
|
||||||
|
}.groupBy(_._1).values.map(_.distinct)
|
||||||
|
|
||||||
|
groupedDataTypesByExprId.forall(_.length == 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method checks if reference `ExprId`s are not reused when assigning a new `ExprId`.
|
||||||
|
* For example, it returns false if plan transformers create an alias having the same `ExprId`
|
||||||
|
* with one of reference attributes, e.g., `a#1 + 1 AS a#1`.
|
||||||
|
*/
|
||||||
|
def checkIfSameExprIdNotReused(plan: LogicalPlan): Boolean = {
|
||||||
|
plan.collect { case p if p.resolved =>
|
||||||
|
p.expressions.forall {
|
||||||
|
case a: Alias =>
|
||||||
|
// Even if a plan is resolved, `a.references` can return unresolved references,
|
||||||
|
// e.g., in `Grouping`/`GroupingID`, so we need to filter out them and
|
||||||
|
// check if the same `exprId` in `Alias` does not exist
|
||||||
|
// among reference `exprId`s.
|
||||||
|
!a.references.filter(_.resolved).map(_.exprId).exists(_ == a.exprId)
|
||||||
|
case _ =>
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}.forall(identity)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* This method checks if the same `ExprId` refers to an unique attribute in a plan tree.
|
||||||
|
* Some plan transformers (e.g., `RemoveNoopOperators`) rewrite logical
|
||||||
|
* plans based on this assumption.
|
||||||
|
*/
|
||||||
|
def checkIfExprIdsAreGloballyUnique(plan: LogicalPlan): Boolean = {
|
||||||
|
checkIfSameExprIdNotReused(plan) && hasUniqueExprIdsForOutput(plan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -156,8 +156,8 @@ class FoldablePropagationSuite extends PlanTest {
|
||||||
val query = expand.where(a1.isNotNull).select(a1, a2).analyze
|
val query = expand.where(a1.isNotNull).select(a1, a2).analyze
|
||||||
val optimized = Optimize.execute(query)
|
val optimized = Optimize.execute(query)
|
||||||
val correctExpand = expand.copy(projections = Seq(
|
val correctExpand = expand.copy(projections = Seq(
|
||||||
Seq(Literal(null), c2),
|
Seq(Literal(null), Literal(2)),
|
||||||
Seq(c1, Literal(null))))
|
Seq(Literal(1), Literal(null))))
|
||||||
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
|
val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze
|
||||||
comparePlans(optimized, correctAnswer)
|
comparePlans(optimized, correctAnswer)
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.catalyst.plans.logical
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||||
|
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||||
|
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference}
|
||||||
|
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||||
|
import org.apache.spark.sql.types.LongType
|
||||||
|
|
||||||
|
class LogicalPlanIntegritySuite extends PlanTest {
|
||||||
|
import LogicalPlanIntegrity._
|
||||||
|
|
||||||
|
case class OutputTestPlan(child: LogicalPlan, output: Seq[Attribute]) extends UnaryNode {
|
||||||
|
override val analyzed = true
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Checks if the same `ExprId` refers to a semantically-equal attribute in a plan output") {
|
||||||
|
val t = LocalRelation('a.int, 'b.int)
|
||||||
|
assert(hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output)))
|
||||||
|
assert(!hasUniqueExprIdsForOutput(OutputTestPlan(t, t.output.zipWithIndex.map {
|
||||||
|
case (a, i) => AttributeReference(s"c$i", LongType)(a.exprId)
|
||||||
|
})))
|
||||||
|
}
|
||||||
|
|
||||||
|
test("Checks if reference ExprIds are not reused when assigning a new ExprId") {
|
||||||
|
val t = LocalRelation('a.int, 'b.int)
|
||||||
|
val Seq(a, b) = t.output
|
||||||
|
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")())))
|
||||||
|
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = a.exprId))))
|
||||||
|
assert(checkIfSameExprIdNotReused(t.select(Alias(a + 1, "a")(exprId = b.exprId))))
|
||||||
|
assert(checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")())))
|
||||||
|
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = a.exprId))))
|
||||||
|
assert(!checkIfSameExprIdNotReused(t.select(Alias(a + b, "ab")(exprId = b.exprId))))
|
||||||
|
}
|
||||||
|
}
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
package org.apache.spark.sql.execution.adaptive
|
package org.apache.spark.sql.execution.adaptive
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper}
|
||||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
@ -54,4 +54,10 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
|
||||||
|
!Utils.isTesting || (plan.resolved &&
|
||||||
|
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
|
||||||
|
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart}
|
||||||
import org.apache.spark.sql._
|
import org.apache.spark.sql._
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.Range
|
import org.apache.spark.sql.catalyst.plans.logical.Range
|
||||||
import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2}
|
import org.apache.spark.sql.catalyst.streaming.{InternalOutputModes, StreamingRelationV2}
|
||||||
import org.apache.spark.sql.catalyst.util.DateTimeConstants.MICROS_PER_MILLIS
|
|
||||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||||
import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
|
import org.apache.spark.sql.execution.{LocalLimitExec, SimpleMode, SparkPlan}
|
||||||
import org.apache.spark.sql.execution.command.ExplainCommand
|
import org.apache.spark.sql.execution.command.ExplainCommand
|
||||||
|
@ -47,7 +46,7 @@ import org.apache.spark.sql.functions._
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
import org.apache.spark.sql.sources.StreamSourceProvider
|
import org.apache.spark.sql.sources.StreamSourceProvider
|
||||||
import org.apache.spark.sql.streaming.util.{BlockOnStopSourceProvider, StreamManualClock}
|
import org.apache.spark.sql.streaming.util.{BlockOnStopSourceProvider, StreamManualClock}
|
||||||
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
|
import org.apache.spark.sql.types.{IntegerType, LongType, StructField, StructType}
|
||||||
import org.apache.spark.util.Utils
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
class StreamSuite extends StreamTest {
|
class StreamSuite extends StreamTest {
|
||||||
|
@ -1268,7 +1267,7 @@ class StreamSuite extends StreamTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract class FakeSource extends StreamSourceProvider {
|
abstract class FakeSource extends StreamSourceProvider {
|
||||||
private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil)
|
private val fakeSchema = StructType(StructField("a", LongType) :: Nil)
|
||||||
|
|
||||||
override def sourceSchema(
|
override def sourceSchema(
|
||||||
spark: SQLContext,
|
spark: SQLContext,
|
||||||
|
@ -1290,7 +1289,7 @@ class FakeDefaultSource extends FakeSource {
|
||||||
new Source {
|
new Source {
|
||||||
private var offset = -1L
|
private var offset = -1L
|
||||||
|
|
||||||
override def schema: StructType = StructType(StructField("a", IntegerType) :: Nil)
|
override def schema: StructType = StructType(StructField("a", LongType) :: Nil)
|
||||||
|
|
||||||
override def getOffset: Option[Offset] = {
|
override def getOffset: Option[Offset] = {
|
||||||
if (offset >= 10) {
|
if (offset >= 10) {
|
||||||
|
|
Loading…
Reference in a new issue