[SPARK-35287][SQL] Allow RemoveRedundantProjects to preserve ProjectExec which generates UnsafeRow for DataSourceV2ScanRelation

### What changes were proposed in this pull request?

This PR fixes an issue that `RemoveRedundantProjects` removes `ProjectExec` which is for generating `UnsafeRow`.
In `DataSourceV2Strategy`, `ProjectExec` will be inserted to ensure internal rows are `UnsafeRow`.

```
  private def withProjectAndFilter(
      project: Seq[NamedExpression],
      filters: Seq[Expression],
      scan: LeafExecNode,
      needsUnsafeConversion: Boolean): SparkPlan = {
    val filterCondition = filters.reduceLeftOption(And)
    val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan)

    if (withFilter.output != project || needsUnsafeConversion) {
      ProjectExec(project, withFilter)
    } else {
      withFilter
    }
  }
...
    case PhysicalOperation(project, filters, relation: DataSourceV2ScanRelation) =>
      // projection and filters were already pushed down in the optimizer.
      // this uses PhysicalOperation to get the projection and ensure that if the batch scan does
      // not support columnar, a projection is added to convert the rows to UnsafeRow.
      val batchExec = BatchScanExec(relation.output, relation.scan)
      withProjectAndFilter(project, filters, batchExec, !batchExec.supportsColumnar) :: Nil
```
So, the hierarchy of the partial tree should be like `ProjectExec(FilterExec(BatchScan))`.
But `RemoveRedundantProjects` doesn't consider this type of hierarchy, leading `ClassCastException`.

A concreate example to reproduce this issue is reported:
```
import scala.collection.JavaConverters._

import org.apache.iceberg.{PartitionSpec, TableProperties}
import org.apache.iceberg.hadoop.HadoopTables
import org.apache.iceberg.spark.SparkSchemaUtil
import org.apache.spark.sql.{DataFrame, QueryTest, SparkSession}
import org.apache.spark.sql.internal.SQLConf

class RemoveRedundantProjectsTest extends QueryTest {
  override val spark: SparkSession = SparkSession
    .builder()
    .master("local[4]")
    .config("spark.driver.bindAddress", "127.0.0.1")
    .appName(suiteName)
    .getOrCreate()
  test("RemoveRedundantProjects removes non-redundant projects") {
    withSQLConf(
      SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1",
      SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false",
      SQLConf.REMOVE_REDUNDANT_PROJECTS_ENABLED.key -> "true") {
      withTempDir { dir =>
        val path = dir.getCanonicalPath
        val data = spark.range(3).toDF
        val table = new HadoopTables().create(
          SparkSchemaUtil.convert(data.schema),
          PartitionSpec.unpartitioned(),
          Map(TableProperties.WRITE_NEW_DATA_LOCATION -> path).asJava,
          path)
        data.write.format("iceberg").mode("overwrite").save(path)
        table.refresh()

        val df = spark.read.format("iceberg").load(path)
        val dfX = df.as("x")
        val dfY = df.as("y")
        val join = dfX.filter(dfX("id") > 0).join(dfY, "id")
        join.explain("extended")
        assert(join.count() == 2)
      }
    }
  }
}
```
```
[info] - RemoveRedundantProjects removes non-redundant projects *** FAILED ***
[info]   org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 1.0 failed 1 times, most recent failure: Lost task 0.0 in stage 1.0 (TID 4) (xeroxms100.northamerica.corp.microsoft.com executor driver): java.lang.ClassCastException: org.apache.spark.sql.catalyst.expressions.GenericInternalRow cannot be cast to org.apache.spark.sql.catalyst.expressions.UnsafeRow
[info]  at org.apache.spark.sql.execution.UnsafeExternalRowSorter.sort(UnsafeExternalRowSorter.java:226)
[info]  at org.apache.spark.sql.execution.SortExec.$anonfun$doExecute$1(SortExec.scala:119)
```

### Why are the changes needed?

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

New test.

Closes #32606 from sarutak/fix-project-removal-issue.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Kousuke Saruta 2021-05-25 00:26:10 +08:00 committed by Wenchen Fan
parent c709efc1e7
commit d4fb98354a
2 changed files with 21 additions and 0 deletions

View file

@ -97,6 +97,7 @@ object RemoveRedundantProjects extends Rule[SparkPlan] {
// If a DataSourceV2ScanExec node does not support columnar, a ProjectExec node is required
// to convert the rows to UnsafeRow. See DataSourceV2Strategy for more details.
case d: DataSourceV2ScanExecBase if !d.supportsColumnar => false
case FilterExec(_, d: DataSourceV2ScanExecBase) if !d.supportsColumnar => false
case _ =>
if (requireOrdering) {
project.output.map(_.exprId.id) == child.output.map(_.exprId.id) &&

View file

@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.connector.SimpleWritableDataSource
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanHelper, DisableAdaptiveExecutionSuite, EnableAdaptiveExecutionSuite}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@ -28,6 +29,7 @@ abstract class RemoveRedundantProjectsSuiteBase
extends QueryTest
with SharedSparkSession
with AdaptiveSparkPlanHelper {
import testImplicits._
private def assertProjectExecCount(df: DataFrame, expected: Int): Unit = {
withClue(df.queryExecution) {
@ -215,6 +217,24 @@ abstract class RemoveRedundantProjectsSuiteBase
|LIMIT 10
|""".stripMargin
assertProjectExec(query, 0, 3)
}
Seq("true", "false").foreach { codegenEnabled =>
test("SPARK-35287: project generating unsafe row for DataSourceV2ScanRelation " +
s"should not be removed (codegen=$codegenEnabled)") {
withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled) {
withTempPath { path =>
val format = classOf[SimpleWritableDataSource].getName
spark.range(3).select($"id" as "i", $"id" as "j")
.write.format(format).mode("overwrite").save(path.getCanonicalPath)
val df =
spark.read.format(format).load(path.getCanonicalPath).filter($"i" > 0).orderBy($"i")
assert(df.collect === Array(Row(1, 1), Row(2, 2)))
}
}
}
}
}