[SPARK-36352][SQL] Spark should check result plan's output schema name

### What changes were proposed in this pull request?
Spark should check result plan's output schema name

### Why are the changes needed?
In current code, some optimizer rule may change plan's output schema, since in the code we always use semantic equal to check output, but it may change the plan's output schema.
For example, for SchemaPruning, if we have a plan
```
Project[a, B]
|--Scan[A, b, c]
```
the origin output schema is `a, B`, after SchemaPruning. it become
```
Project[A, b]
|--Scan[A, b]
```
It change the plan's schema. when we use CTAS, the schema is same as query plan's output.
Then since we change the schema, it not consistent with origin SQL. So we need to check final result plan's schema with origin plan's schema

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
existed UT

Closes #33583 from AngersZhuuuu/SPARK-36352.

Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit e051a540a1)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Angerszhuuuu 2021-08-09 16:47:56 +08:00 committed by Wenchen Fan
parent 94dc3c77c2
commit a5ecf2a490
11 changed files with 63 additions and 31 deletions

View file

@ -174,8 +174,10 @@ class Analyzer(override val catalogManager: CatalogManager)
private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)
override protected def isPlanIntegral(
previousPlan: LogicalPlan,
currentPlan: LogicalPlan): Boolean = {
!Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan)
}
override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)

View file

@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils._
import org.apache.spark.util.Utils
/**
@ -46,10 +47,14 @@ abstract class Optimizer(catalogManager: CatalogManager)
// - is still resolved
// - only host special expressions in supported operators
// - has globally-unique attribute IDs
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || (plan.resolved &&
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
// - optimized plan have same schema with previous plan.
override protected def isPlanIntegral(
previousPlan: LogicalPlan,
currentPlan: LogicalPlan): Boolean = {
!Utils.isTesting || (currentPlan.resolved &&
currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) &&
DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema))
}
override protected val excludedOnceBatches: Set[String] =
@ -515,15 +520,6 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
* Remove no-op operators from the query plan that do not make any modifications.
*/
object RemoveNoopOperators extends Rule[LogicalPlan] {
def restoreOriginalOutputNames(
projectList: Seq[NamedExpression],
originalNames: Seq[String]): Seq[NamedExpression] = {
projectList.zip(originalNames).map {
case (attr: Attribute, name) => attr.withName(name)
case (alias: Alias, name) => alias.withName(name)
case (other, _) => other
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
_.containsAnyPattern(PROJECT, WINDOW), ruleId) {

View file

@ -156,7 +156,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
* `Optimizer`, so we can catch rules that return invalid plans. The check function returns
* `false` if the given plan doesn't pass the structural integrity check.
*/
protected def isPlanIntegral(plan: TreeType): Boolean = true
protected def isPlanIntegral(previousPlan: TreeType, currentPlan: TreeType): Boolean = true
/**
* Util method for checking whether a plan remains the same if re-optimized.
@ -192,7 +192,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
val beforeMetrics = RuleExecutor.getCurrentMetrics()
// Run the structural integrity checker against the initial input
if (!isPlanIntegral(plan)) {
if (!isPlanIntegral(plan, plan)) {
throw QueryExecutionErrors.structuralIntegrityOfInputPlanIsBrokenInClassError(
this.getClass.getName.stripSuffix("$"))
}
@ -224,7 +224,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective))
// Run the structural integrity checker against the plan after each rule.
if (effective && !isPlanIntegral(result)) {
if (effective && !isPlanIntegral(plan, result)) {
throw QueryExecutionErrors.structuralIntegrityIsBrokenAfterApplyingRuleError(
rule.ruleName, batch.name)
}

View file

@ -292,7 +292,7 @@ object DataType {
/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
*/
private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
(left, right) match {
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
equalsIgnoreNullability(leftElementType, rightElementType)

View file

@ -21,6 +21,7 @@ import java.util.Locale
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
@ -273,6 +274,16 @@ private[spark] object SchemaUtils {
field._1
}
def restoreOriginalOutputNames(
projectList: Seq[NamedExpression],
originalNames: Seq[String]): Seq[NamedExpression] = {
projectList.zip(originalNames).map {
case (attr: Attribute, name) => attr.withName(name)
case (alias: Alias, name) => alias.withName(name)
case (other, _) => other
}
}
/**
* @param str The string to be escaped.
* @return The escaped string.

View file

@ -73,7 +73,9 @@ class RuleExecutorSuite extends SparkFunSuite {
test("structural integrity checker - verify initial input") {
object WithSIChecker extends RuleExecutor[Expression] {
override protected def isPlanIntegral(expr: Expression): Boolean = expr match {
override protected def isPlanIntegral(
previousPlan: Expression,
currentPlan: Expression): Boolean = currentPlan match {
case IntegerLiteral(_) => true
case _ => false
}
@ -91,7 +93,9 @@ class RuleExecutorSuite extends SparkFunSuite {
test("structural integrity checker - verify rule execution result") {
object WithSICheckerForPositiveLiteral extends RuleExecutor[Expression] {
override protected def isPlanIntegral(expr: Expression): Boolean = expr match {
override protected def isPlanIntegral(
previousPlan: Expression,
currentPlan: Expression): Boolean = currentPlan match {
case IntegerLiteral(i) if i > 0 => true
case _ => false
}

View file

@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils
/**
@ -64,9 +65,12 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] {
}
}
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
!Utils.isTesting || (plan.resolved &&
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
override protected def isPlanIntegral(
previousPlan: LogicalPlan,
currentPlan: LogicalPlan): Boolean = {
!Utils.isTesting || (currentPlan.resolved &&
currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) &&
DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema))
}
}

View file

@ -464,7 +464,7 @@ object DataSourceStrategy
*/
protected[sql] def normalizeExprs(
exprs: Seq[Expression],
attributes: Seq[AttributeReference]): Seq[Expression] = {
attributes: Seq[Attribute]): Seq[Expression] = {
exprs.map { e =>
e transform {
case a: AttributeReference =>

View file

@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.util.SchemaUtils._
/**
* Prunes unnecessary physical columns given a [[PhysicalOperation]] over a data source relation.
@ -82,8 +83,8 @@ object SchemaPruning extends Rule[LogicalPlan] {
val prunedRelation = leafNodeBuilder(prunedDataSchema)
val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
Some(buildNewProjection(normalizedProjects, normalizedFilters, prunedRelation,
projectionOverSchema))
Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
prunedRelation, projectionOverSchema))
} else {
None
}
@ -125,6 +126,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
*/
private def buildNewProjection(
projects: Seq[NamedExpression],
normalizedProjects: Seq[NamedExpression],
filters: Seq[Expression],
leafNode: LeafNode,
projectionOverSchema: ProjectionOverSchema): Project = {
@ -143,7 +145,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
// Construct the new projections of our Project by
// rewriting the original projections
val newProjects = projects.map(_.transformDown {
val newProjects = normalizedProjects.map(_.transformDown {
case projectionOverSchema(expr) => expr
}).map { case expr: NamedExpression => expr }
@ -151,7 +153,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}")
}
Project(newProjects, projectionChild)
Project(restoreOriginalOutputNames(newProjects, projects.map(_.name)), projectionChild)
}
/**

View file

@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownA
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
import org.apache.spark.sql.sources
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.SchemaUtils._
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
import DataSourceV2Implicits._
@ -207,7 +208,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
val newProjects = normalizedProjects
.map(projectionFunc)
.asInstanceOf[Seq[NamedExpression]]
Project(newProjects, withFilter)
Project(restoreOriginalOutputNames(newProjects, project.map(_.name)), withFilter)
} else {
withFilter
}

View file

@ -870,4 +870,16 @@ abstract class SchemaPruningSuite
checkAnswer(query, Row(1) :: Row(2) :: Nil)
}
}
test("SPARK-36352: Spark should check result plan's output schema name") {
withMixedCaseData {
val query = sql("select cOL1, cOl2.B from mixedcase")
assert(query.queryExecution.executedPlan.schema.catalogString ==
"struct<cOL1:string,B:int>")
checkAnswer(query.orderBy("id"),
Row("r0c1", 1) ::
Row("r1c1", 2) ::
Nil)
}
}
}