[SPARK-36352][SQL] Spark should check result plan's output schema name
### What changes were proposed in this pull request?
Spark should check result plan's output schema name
### Why are the changes needed?
In current code, some optimizer rule may change plan's output schema, since in the code we always use semantic equal to check output, but it may change the plan's output schema.
For example, for SchemaPruning, if we have a plan
```
Project[a, B]
|--Scan[A, b, c]
```
the origin output schema is `a, B`, after SchemaPruning. it become
```
Project[A, b]
|--Scan[A, b]
```
It change the plan's schema. when we use CTAS, the schema is same as query plan's output.
Then since we change the schema, it not consistent with origin SQL. So we need to check final result plan's schema with origin plan's schema
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existed UT
Closes #33583 from AngersZhuuuu/SPARK-36352.
Authored-by: Angerszhuuuu <angers.zhu@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit e051a540a1
)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
parent
94dc3c77c2
commit
a5ecf2a490
|
@ -174,8 +174,10 @@ class Analyzer(override val catalogManager: CatalogManager)
|
|||
|
||||
private val v1SessionCatalog: SessionCatalog = catalogManager.v1SessionCatalog
|
||||
|
||||
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
|
||||
!Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan)
|
||||
override protected def isPlanIntegral(
|
||||
previousPlan: LogicalPlan,
|
||||
currentPlan: LogicalPlan): Boolean = {
|
||||
!Utils.isTesting || LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan)
|
||||
}
|
||||
|
||||
override def isView(nameParts: Seq[String]): Boolean = v1SessionCatalog.isView(nameParts)
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.spark.sql.connector.catalog.CatalogManager
|
|||
import org.apache.spark.sql.errors.QueryCompilationErrors
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.sql.util.SchemaUtils._
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
|
@ -46,10 +47,14 @@ abstract class Optimizer(catalogManager: CatalogManager)
|
|||
// - is still resolved
|
||||
// - only host special expressions in supported operators
|
||||
// - has globally-unique attribute IDs
|
||||
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
|
||||
!Utils.isTesting || (plan.resolved &&
|
||||
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
|
||||
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
|
||||
// - optimized plan have same schema with previous plan.
|
||||
override protected def isPlanIntegral(
|
||||
previousPlan: LogicalPlan,
|
||||
currentPlan: LogicalPlan): Boolean = {
|
||||
!Utils.isTesting || (currentPlan.resolved &&
|
||||
currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
|
||||
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) &&
|
||||
DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema))
|
||||
}
|
||||
|
||||
override protected val excludedOnceBatches: Set[String] =
|
||||
|
@ -515,15 +520,6 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
|
|||
* Remove no-op operators from the query plan that do not make any modifications.
|
||||
*/
|
||||
object RemoveNoopOperators extends Rule[LogicalPlan] {
|
||||
def restoreOriginalOutputNames(
|
||||
projectList: Seq[NamedExpression],
|
||||
originalNames: Seq[String]): Seq[NamedExpression] = {
|
||||
projectList.zip(originalNames).map {
|
||||
case (attr: Attribute, name) => attr.withName(name)
|
||||
case (alias: Alias, name) => alias.withName(name)
|
||||
case (other, _) => other
|
||||
}
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
|
||||
_.containsAnyPattern(PROJECT, WINDOW), ruleId) {
|
||||
|
|
|
@ -156,7 +156,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
|
|||
* `Optimizer`, so we can catch rules that return invalid plans. The check function returns
|
||||
* `false` if the given plan doesn't pass the structural integrity check.
|
||||
*/
|
||||
protected def isPlanIntegral(plan: TreeType): Boolean = true
|
||||
protected def isPlanIntegral(previousPlan: TreeType, currentPlan: TreeType): Boolean = true
|
||||
|
||||
/**
|
||||
* Util method for checking whether a plan remains the same if re-optimized.
|
||||
|
@ -192,7 +192,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
|
|||
val beforeMetrics = RuleExecutor.getCurrentMetrics()
|
||||
|
||||
// Run the structural integrity checker against the initial input
|
||||
if (!isPlanIntegral(plan)) {
|
||||
if (!isPlanIntegral(plan, plan)) {
|
||||
throw QueryExecutionErrors.structuralIntegrityOfInputPlanIsBrokenInClassError(
|
||||
this.getClass.getName.stripSuffix("$"))
|
||||
}
|
||||
|
@ -224,7 +224,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging {
|
|||
tracker.foreach(_.recordRuleInvocation(rule.ruleName, runTime, effective))
|
||||
|
||||
// Run the structural integrity checker against the plan after each rule.
|
||||
if (effective && !isPlanIntegral(result)) {
|
||||
if (effective && !isPlanIntegral(plan, result)) {
|
||||
throw QueryExecutionErrors.structuralIntegrityIsBrokenAfterApplyingRuleError(
|
||||
rule.ruleName, batch.name)
|
||||
}
|
||||
|
|
|
@ -292,7 +292,7 @@ object DataType {
|
|||
/**
|
||||
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
|
||||
*/
|
||||
private[types] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
|
||||
private[sql] def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
|
||||
(left, right) match {
|
||||
case (ArrayType(leftElementType, _), ArrayType(rightElementType, _)) =>
|
||||
equalsIgnoreNullability(leftElementType, rightElementType)
|
||||
|
|
|
@ -21,6 +21,7 @@ import java.util.Locale
|
|||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.analysis._
|
||||
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression}
|
||||
import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform}
|
||||
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
|
||||
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
|
||||
|
@ -273,6 +274,16 @@ private[spark] object SchemaUtils {
|
|||
field._1
|
||||
}
|
||||
|
||||
def restoreOriginalOutputNames(
|
||||
projectList: Seq[NamedExpression],
|
||||
originalNames: Seq[String]): Seq[NamedExpression] = {
|
||||
projectList.zip(originalNames).map {
|
||||
case (attr: Attribute, name) => attr.withName(name)
|
||||
case (alias: Alias, name) => alias.withName(name)
|
||||
case (other, _) => other
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param str The string to be escaped.
|
||||
* @return The escaped string.
|
||||
|
|
|
@ -73,7 +73,9 @@ class RuleExecutorSuite extends SparkFunSuite {
|
|||
|
||||
test("structural integrity checker - verify initial input") {
|
||||
object WithSIChecker extends RuleExecutor[Expression] {
|
||||
override protected def isPlanIntegral(expr: Expression): Boolean = expr match {
|
||||
override protected def isPlanIntegral(
|
||||
previousPlan: Expression,
|
||||
currentPlan: Expression): Boolean = currentPlan match {
|
||||
case IntegerLiteral(_) => true
|
||||
case _ => false
|
||||
}
|
||||
|
@ -91,7 +93,9 @@ class RuleExecutorSuite extends SparkFunSuite {
|
|||
|
||||
test("structural integrity checker - verify rule execution result") {
|
||||
object WithSICheckerForPositiveLiteral extends RuleExecutor[Expression] {
|
||||
override protected def isPlanIntegral(expr: Expression): Boolean = expr match {
|
||||
override protected def isPlanIntegral(
|
||||
previousPlan: Expression,
|
||||
currentPlan: Expression): Boolean = currentPlan match {
|
||||
case IntegerLiteral(i) if i > 0 => true
|
||||
case _ => false
|
||||
}
|
||||
|
|
|
@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
|
|||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LogicalPlanIntegrity, PlanHelper}
|
||||
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||
import org.apache.spark.sql.internal.SQLConf
|
||||
import org.apache.spark.sql.types.DataType
|
||||
import org.apache.spark.util.Utils
|
||||
|
||||
/**
|
||||
|
@ -64,9 +65,12 @@ class AQEOptimizer(conf: SQLConf) extends RuleExecutor[LogicalPlan] {
|
|||
}
|
||||
}
|
||||
|
||||
override protected def isPlanIntegral(plan: LogicalPlan): Boolean = {
|
||||
!Utils.isTesting || (plan.resolved &&
|
||||
plan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
|
||||
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(plan))
|
||||
override protected def isPlanIntegral(
|
||||
previousPlan: LogicalPlan,
|
||||
currentPlan: LogicalPlan): Boolean = {
|
||||
!Utils.isTesting || (currentPlan.resolved &&
|
||||
currentPlan.find(PlanHelper.specialExpressionsInUnsupportedOperator(_).nonEmpty).isEmpty &&
|
||||
LogicalPlanIntegrity.checkIfExprIdsAreGloballyUnique(currentPlan) &&
|
||||
DataType.equalsIgnoreNullability(previousPlan.schema, currentPlan.schema))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -464,7 +464,7 @@ object DataSourceStrategy
|
|||
*/
|
||||
protected[sql] def normalizeExprs(
|
||||
exprs: Seq[Expression],
|
||||
attributes: Seq[AttributeReference]): Seq[Expression] = {
|
||||
attributes: Seq[Attribute]): Seq[Expression] = {
|
||||
exprs.map { e =>
|
||||
e transform {
|
||||
case a: AttributeReference =>
|
||||
|
|
|
@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
|
|||
import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat
|
||||
import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
|
||||
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
|
||||
import org.apache.spark.sql.util.SchemaUtils._
|
||||
|
||||
/**
|
||||
* Prunes unnecessary physical columns given a [[PhysicalOperation]] over a data source relation.
|
||||
|
@ -82,8 +83,8 @@ object SchemaPruning extends Rule[LogicalPlan] {
|
|||
val prunedRelation = leafNodeBuilder(prunedDataSchema)
|
||||
val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
|
||||
|
||||
Some(buildNewProjection(normalizedProjects, normalizedFilters, prunedRelation,
|
||||
projectionOverSchema))
|
||||
Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
|
||||
prunedRelation, projectionOverSchema))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
|
@ -125,6 +126,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
|
|||
*/
|
||||
private def buildNewProjection(
|
||||
projects: Seq[NamedExpression],
|
||||
normalizedProjects: Seq[NamedExpression],
|
||||
filters: Seq[Expression],
|
||||
leafNode: LeafNode,
|
||||
projectionOverSchema: ProjectionOverSchema): Project = {
|
||||
|
@ -143,7 +145,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
|
|||
|
||||
// Construct the new projections of our Project by
|
||||
// rewriting the original projections
|
||||
val newProjects = projects.map(_.transformDown {
|
||||
val newProjects = normalizedProjects.map(_.transformDown {
|
||||
case projectionOverSchema(expr) => expr
|
||||
}).map { case expr: NamedExpression => expr }
|
||||
|
||||
|
@ -151,7 +153,7 @@ object SchemaPruning extends Rule[LogicalPlan] {
|
|||
logDebug(s"New projects:\n${newProjects.map(_.treeString).mkString("\n")}")
|
||||
}
|
||||
|
||||
Project(newProjects, projectionChild)
|
||||
Project(restoreOriginalOutputNames(newProjects, projects.map(_.name)), projectionChild)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -30,6 +30,7 @@ import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownA
|
|||
import org.apache.spark.sql.execution.datasources.DataSourceStrategy
|
||||
import org.apache.spark.sql.sources
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.util.SchemaUtils._
|
||||
|
||||
object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
|
||||
import DataSourceV2Implicits._
|
||||
|
@ -207,7 +208,7 @@ object V2ScanRelationPushDown extends Rule[LogicalPlan] with PredicateHelper {
|
|||
val newProjects = normalizedProjects
|
||||
.map(projectionFunc)
|
||||
.asInstanceOf[Seq[NamedExpression]]
|
||||
Project(newProjects, withFilter)
|
||||
Project(restoreOriginalOutputNames(newProjects, project.map(_.name)), withFilter)
|
||||
} else {
|
||||
withFilter
|
||||
}
|
||||
|
|
|
@ -870,4 +870,16 @@ abstract class SchemaPruningSuite
|
|||
checkAnswer(query, Row(1) :: Row(2) :: Nil)
|
||||
}
|
||||
}
|
||||
|
||||
test("SPARK-36352: Spark should check result plan's output schema name") {
|
||||
withMixedCaseData {
|
||||
val query = sql("select cOL1, cOl2.B from mixedcase")
|
||||
assert(query.queryExecution.executedPlan.schema.catalogString ==
|
||||
"struct<cOL1:string,B:int>")
|
||||
checkAnswer(query.orderBy("id"),
|
||||
Row("r0c1", 1) ::
|
||||
Row("r1c1", 2) ::
|
||||
Nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue