From a5ecf2a490727fec97790b149f59bdc498b445be Mon Sep 17 00:00:00 2001 From: Angerszhuuuu Date: Mon, 9 Aug 2021 16:47:56 +0800 Subject: [PATCH] [SPARK-36352][SQL] Spark should check result plan's output schema name ### What changes were proposed in this pull request? Spark should check result plan's output schema name ### Why are the changes needed? In current code, some optimizer rule may change plan's output schema, since in the code we always use semantic equal to check output, but it may change the plan's output schema. For example, for SchemaPruning, if we have a plan ``` Project[a, B] |--Scan[A, b, c] ``` the origin output schema is `a, B`, after SchemaPruning. it become ``` Project[A, b] |--Scan[A, b] ``` It change the plan's schema. when we use CTAS, the schema is same as query plan's output. Then since we change the schema, it not consistent with origin SQL. So we need to check final result plan's schema with origin plan's schema ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? existed UT Closes #33583 from AngersZhuuuu/SPARK-36352. Authored-by: Angerszhuuuu Signed-off-by: Wenchen Fan (cherry picked from commit e051a540a10cdda42dc86a6195c0357aea8900e4) Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 6 +++-- .../sql/catalyst/optimizer/Optimizer.scala | 22 ++++++++----------- .../sql/catalyst/rules/RuleExecutor.scala | 6 ++--- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../apache/spark/sql/util/SchemaUtils.scala | 11 ++++++++++ .../catalyst/trees/RuleExecutorSuite.scala | 8 +++++-- .../sql/execution/adaptive/AQEOptimizer.scala | 12 ++++++---- .../datasources/DataSourceStrategy.scala | 2 +- .../execution/datasources/SchemaPruning.scala | 10 +++++---- .../v2/V2ScanRelationPushDown.scala | 3 ++- .../datasources/SchemaPruningSuite.scala | 12 ++++++++++ 11 files changed, 63 insertions(+), 31 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 963b42bd30..b6228d1861 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -174,8 +174,10 @@ class Analyzer(override val catalogManager: CatalogManager) private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog - override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { - !Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan) + override protected def isPlanIntegral( + previousPlan: LogicalPlan, + currentPlan: LogicalPlan): Boolean = { + !Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) } override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 369fb51657..40b4c01be7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.SchemaUtils._ import org.apache.spark.util.Utils /** @@ -46,10 +47,14 @@ abstract class Optimizer(catalogManager: CatalogManager) // - is still resolved // - only host special expressions in supported operators // - has globally-unique attribute IDs - override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { - !Utils.isTesting || (plan.resolved && - plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && - LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)) + // - optimized plan have same schema with previous plan. + override protected def isPlanIntegral( + previousPlan: LogicalPlan, + currentPlan: LogicalPlan): Boolean = { + !Utils.isTesting || (currentPlan.resolved && + currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && + LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) && + DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema)) } override protected val excludedOnceBatches: Set[String] = @@ -515,15 +520,6 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] { * Remove no-op operators from the query plan that do not make any modifications. */ object RemoveNoopOperators extends Rule[LogicalPlan] { - def restoreOriginalOutputNames( - projectList: Seq[NamedExpression], - originalNames: Seq[String]): Seq[NamedExpression] = { - projectList.zip(originalNames).map { - case (attr: Attribute, name) => attr.withName(name) - case (alias: Alias, name) => alias.withName(name) - case (other, _) => other - } - } def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning( _.containsAnyPattern(PROJECT, WINDOW), ruleId) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index 17d7794292..759eba690a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -156,7 +156,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { * `Optimizer`, so we can catch rules that return invalid plans. The check function returns * `false` if the given plan doesn't pass the structural integrity check. */ - protected def isPlanIntegral(plan: TreeType): Boolean = true + protected def isPlanIntegral(previousPlan: TreeType, currentPlan: TreeType): Boolean = true /** * Util method for checking whether a plan remains the same if re-optimized. @@ -192,7 +192,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { val beforeMetrics = RuleExecutor.getCurrentMetrics() // Run the structural integrity checker against the initial input - if (!isPlanIntegral(plan)) { + if (!isPlanIntegral(plan, plan)) { throw QueryExecutionErrors.structuralIntegrityOfInputPlanIsBrokenInClassError( this.getClass.getName.stripSuffix("$")) } @@ -224,7 +224,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective)) // Run the structural integrity checker against the plan after each rule. - if (effective && !isPlanIntegral(result)) { + if (effective && !isPlanIntegral(plan, result)) { throw QueryExecutionErrors.structuralIntegrityIsBrokenAfterApplyingRuleError( rule.ruleName, batch.name) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 585045d898..ef1aeec0d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -292,7 +292,7 @@ object DataType { /** * Compares two types, ignoring nullability of ArrayType, MapType, StructType. */ - private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { + private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = { (left, right) match { case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) => equalsIgnoreNullability(leftElementType, rightElementType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index da105af9e4..63c1f1869d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -21,6 +21,7 @@ import java.util.Locale import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression} import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType} @@ -273,6 +274,16 @@ private[spark] object SchemaUtils { field._1 } + def restoreOriginalOutputNames( + projectList: Seq[NamedExpression], + originalNames: Seq[String]): Seq[NamedExpression] = { + projectList.zip(originalNames).map { + case (attr: Attribute, name) => attr.withName(name) + case (alias: Alias, name) => alias.withName(name) + case (other, _) => other + } + } + /** * @param str The string to be escaped. * @return The escaped string. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala index 25352e2d24..b14686beff 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/RuleExecutorSuite.scala @@ -73,7 +73,9 @@ class RuleExecutorSuite extends SparkFunSuite { test("structural integrity checker - verify initial input") { object WithSIChecker extends RuleExecutor[Expression] { - override protected def isPlanIntegral(expr: Expression): Boolean = expr match { + override protected def isPlanIntegral( + previousPlan: Expression, + currentPlan: Expression): Boolean = currentPlan match { case IntegerLiteral(_) => true case _ => false } @@ -91,7 +93,9 @@ class RuleExecutorSuite extends SparkFunSuite { test("structural integrity checker - verify rule execution result") { object WithSICheckerForPositiveLiteral extends RuleExecutor[Expression] { - override protected def isPlanIntegral(expr: Expression): Boolean = expr match { + override protected def isPlanIntegral( + previousPlan: Expression, + currentPlan: Expression): Boolean = currentPlan match { case IntegerLiteral(i) if i > 0 => true case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala index 0767039d10..f8cba90d68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils /** @@ -64,9 +65,12 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] { } } - override protected def isPlanIntegral(plan: LogicalPlan): Boolean = { - !Utils.isTesting || (plan.resolved && - plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && - LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)) + override protected def isPlanIntegral( + previousPlan: LogicalPlan, + currentPlan: LogicalPlan): Boolean = { + !Utils.isTesting || (currentPlan.resolved && + currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty && + LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) && + DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 81ecb2cb27..11d23f482f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -464,7 +464,7 @@ object DataSourceStrategy */ protected[sql] def normalizeExprs( exprs: Seq[Expression], - attributes: Seq[AttributeReference]): Seq[Expression] = { + attributes: Seq[Attribute]): Seq[Expression] = { exprs.map { e => e transform { case a: AttributeReference => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala index a1974455fb..4f331c7bf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SchemaPruning.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} +import org.apache.spark.sql.util.SchemaUtils._ /** * Prunes unnecessary physical columns given a [[PhysicalOperation]] over a data source relation. @@ -82,8 +83,8 @@ object SchemaPruning extends Rule[LogicalPlan] { val prunedRelation = leafNodeBuilder(prunedDataSchema) val projectionOverSchema = ProjectionOverSchema(prunedDataSchema) - Some(buildNewProjection(normalizedProjects, normalizedFilters, prunedRelation, - projectionOverSchema)) + Some(buildNewProjection(projects, normalizedProjects, normalizedFilters, + prunedRelation, projectionOverSchema)) } else { None } @@ -125,6 +126,7 @@ object SchemaPruning extends Rule[LogicalPlan] { */ private def buildNewProjection( projects: Seq[NamedExpression], + normalizedProjects: Seq[NamedExpression], filters: Seq[Expression], leafNode: LeafNode, projectionOverSchema: ProjectionOverSchema): Project = { @@ -143,7 +145,7 @@ object SchemaPruning extends Rule[LogicalPlan] { // Construct the new projections of our Project by // rewriting the original projections - val newProjects = projects.map(_.transformDown { + val newProjects = normalizedProjects.map(_.transformDown { case projectionOverSchema(expr) => expr }).map { case expr: NamedExpression => expr } @@ -151,7 +153,7 @@ object SchemaPruning extends Rule[LogicalPlan] { logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}") } - Project(newProjects, projectionChild) + Project(restoreOriginalOutputNames(newProjects, projects.map(_.name)), projectionChild) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala index d05519b880..ab5a0feb62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ScanRelationPushDown.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownA import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.sources import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.SchemaUtils._ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { import DataSourceV2Implicits._ @@ -207,7 +208,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper { val newProjects = normalizedProjects .map(projectionFunc) .asInstanceOf[Seq[NamedExpression]] - Project(newProjects, withFilter) + Project(restoreOriginalOutputNames(newProjects, project.map(_.name)), withFilter) } else { withFilter } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index ac5c28953a..395ee6fab0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -870,4 +870,16 @@ abstract class SchemaPruningSuite checkAnswer(query, Row(1) :: Row(2) :: Nil) } } + + test("SPARK-36352: Spark should check result plan's output schema name") { + withMixedCaseData { + val query = sql("select cOL1, cOl2.B from mixedcase") + assert(query.queryExecution.executedPlan.schema.catalogString == + "struct") + checkAnswer(query.orderBy("id"), + Row("r0c1", 1) :: + Row("r1c1", 2) :: + Nil) + } + } }