[SPARK-27217][SQL] Nested column aliasing for more operators which can prune nested column

### What changes were proposed in this pull request?

Currently we only push nested column pruning from a Project through a few operators such as LIMIT, SAMPLE, etc. There are a few operators like Aggregate, Expand which can prune nested columns by themselves, without a Project on top.

This patch extends the feature to those operators.

### Why are the changes needed?

Currently nested column pruning only applied on a few cases. It limits the benefit of nested column pruning. Extending nested column pruning coverage to make this feature more generally applied through different queries.

### Does this PR introduce _any_ user-facing change?

Yes. More SQL operators are covered by nested column pruning.

### How was this patch tested?

Added unit test, end-to-end tests.

Closes #28560 from viirya/SPARK-27217-2.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Liang-Chi Hsieh 2020-06-10 18:08:47 +09:00 committed by HyukjinKwon
parent 82ff29be7a
commit 43063e2db2
3 changed files with 190 additions and 10 deletions

View file

@ -35,6 +35,11 @@ object NestedColumnAliasing {
case Project(projectList, child)
if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) =>
getAliasSubMap(projectList)
case plan if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(plan) =>
val exprCandidatesToPrune = plan.expressions
getAliasSubMap(exprCandidatesToPrune, plan.producedAttributes.toSeq)
case _ => None
}
@ -48,7 +53,11 @@ object NestedColumnAliasing {
case Project(projectList, child) =>
Project(
getNewProjectList(projectList, nestedFieldToAlias),
replaceChildrenWithAliases(child, attrToAliases))
replaceChildrenWithAliases(child, nestedFieldToAlias, attrToAliases))
// The operators reaching here was already guarded by `canPruneOn`.
case other =>
replaceChildrenWithAliases(other, nestedFieldToAlias, attrToAliases)
}
/**
@ -68,10 +77,23 @@ object NestedColumnAliasing {
*/
def replaceChildrenWithAliases(
plan: LogicalPlan,
nestedFieldToAlias: Map[ExtractValue, Alias],
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
plan.withNewChildren(plan.children.map { plan =>
Project(plan.output.flatMap(a => attrToAliases.getOrElse(a.exprId, Seq(a))), plan)
})
}).transformExpressions {
case f: ExtractValue if nestedFieldToAlias.contains(f) =>
nestedFieldToAlias(f).toAttribute
}
}
/**
* Returns true for those operators that we can prune nested column on it.
*/
private def canPruneOn(plan: LogicalPlan) = plan match {
case _: Aggregate => true
case _: Expand => true
case _ => false
}
/**
@ -204,15 +226,8 @@ object GeneratorNestedColumnAliasing {
g: Generate,
nestedFieldToAlias: Map[ExtractValue, Alias],
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
val newGenerator = g.generator.transform {
case f: ExtractValue if nestedFieldToAlias.contains(f) =>
nestedFieldToAlias(f).toAttribute
}.asInstanceOf[Generator]
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
val newGenerate = g.copy(generator = newGenerator)
NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases)
NestedColumnAliasing.replaceChildrenWithAliases(g, nestedFieldToAlias, attrToAliases)
}
/**

View file

@ -341,6 +341,100 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
.analyze
comparePlans(optimized, expected)
}
test("Nested field pruning for Aggregate") {
def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = {
val query1 = basePlan(contact).groupBy($"id")(first($"name.first").as("first")).analyze
val optimized1 = Optimize.execute(query1)
val aliases1 = collectGeneratedAliases(optimized1)
val expected1 = basePlan(
contact
.select($"id", 'name.getField("first").as(aliases1(0)))
).groupBy($"id")(first($"${aliases1(0)}").as("first")).analyze
comparePlans(optimized1, expected1)
val query2 = basePlan(contact).groupBy($"name.last")(first($"name.first").as("first")).analyze
val optimized2 = Optimize.execute(query2)
val aliases2 = collectGeneratedAliases(optimized2)
val expected2 = basePlan(
contact
.select('name.getField("last").as(aliases2(0)), 'name.getField("first").as(aliases2(1)))
).groupBy($"${aliases2(0)}")(first($"${aliases2(1)}").as("first")).analyze
comparePlans(optimized2, expected2)
}
Seq(
(plan: LogicalPlan) => plan,
(plan: LogicalPlan) => plan.limit(100),
(plan: LogicalPlan) => plan.repartition(100),
(plan: LogicalPlan) => Sample(0.0, 0.6, false, 11L, plan)).foreach { base =>
runTest(base)
}
val query3 = contact.groupBy($"id")(first($"name"), first($"name.first").as("first")).analyze
val optimized3 = Optimize.execute(query3)
val expected3 = contact.select($"id", $"name")
.groupBy($"id")(first($"name"), first($"name.first").as("first")).analyze
comparePlans(optimized3, expected3)
}
test("Nested field pruning for Expand") {
def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = {
val query1 = Expand(
Seq(
Seq($"name.first", $"name.middle"),
Seq(ConcatWs(Seq($"name.first", $"name.middle")),
ConcatWs(Seq($"name.middle", $"name.first")))
),
Seq('a.string, 'b.string),
basePlan(contact)
).analyze
val optimized1 = Optimize.execute(query1)
val aliases1 = collectGeneratedAliases(optimized1)
val expected1 = Expand(
Seq(
Seq($"${aliases1(0)}", $"${aliases1(1)}"),
Seq(ConcatWs(Seq($"${aliases1(0)}", $"${aliases1(1)}")),
ConcatWs(Seq($"${aliases1(1)}", $"${aliases1(0)}")))
),
Seq('a.string, 'b.string),
basePlan(contact.select(
'name.getField("first").as(aliases1(0)),
'name.getField("middle").as(aliases1(1))))
).analyze
comparePlans(optimized1, expected1)
}
Seq(
(plan: LogicalPlan) => plan,
(plan: LogicalPlan) => plan.limit(100),
(plan: LogicalPlan) => plan.repartition(100),
(plan: LogicalPlan) => Sample(0.0, 0.6, false, 11L, plan)).foreach { base =>
runTest(base)
}
val query2 = Expand(
Seq(
Seq($"name", $"name.middle"),
Seq($"name", ConcatWs(Seq($"name.middle", $"name.first")))
),
Seq('a.string, 'b.string),
contact
).analyze
val optimized2 = Optimize.execute(query2)
val expected2 = Expand(
Seq(
Seq($"name", $"name.middle"),
Seq($"name", ConcatWs(Seq($"name.middle", $"name.first")))
),
Seq('a.string, 'b.string),
contact.select($"name")
).analyze
comparePlans(optimized2, expected2)
}
}
object NestedColumnAliasingSuite {

View file

@ -23,7 +23,9 @@ import org.scalactic.Equality
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
import org.apache.spark.sql.catalyst.SchemaPruningTest
import org.apache.spark.sql.catalyst.expressions.Concat
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.Expand
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions._
@ -338,6 +340,75 @@ abstract class SchemaPruningSuite
}
}
testSchemaPruning("select one deep nested complex field after repartition") {
val query = sql("select * from contacts")
.repartition(100)
.where("employer.company.address is not null")
.selectExpr("employer.id as employer_id")
checkScan(query,
"struct<employer:struct<id:int,company:struct<address:string>>>")
checkAnswer(query, Row(0) :: Nil)
}
testSchemaPruning("select nested field in aggregation function of Aggregate") {
val query1 = sql("select count(name.first) from contacts group by name.last")
checkScan(query1, "struct<name:struct<first:string,last:string>>")
checkAnswer(query1, Row(2) :: Row(2) :: Nil)
val query2 = sql("select count(name.first), sum(pets) from contacts group by id")
checkScan(query2, "struct<id:int,name:struct<first:string>,pets:int>")
checkAnswer(query2, Row(1, 1) :: Row(1, null) :: Row(1, 3) :: Row(1, null) :: Nil)
val query3 = sql("select count(name.first), first(name) from contacts group by id")
checkScan(query3, "struct<id:int,name:struct<first:string,middle:string,last:string>>")
checkAnswer(query3,
Row(1, Row("Jane", "X.", "Doe")) ::
Row(1, Row("Jim", null, "Jones")) ::
Row(1, Row("John", "Y.", "Doe")) ::
Row(1, Row("Janet", null, "Jones")) :: Nil)
val query4 = sql("select count(name.first), sum(pets) from contacts group by name.last")
checkScan(query4, "struct<name:struct<first:string,last:string>,pets:int>")
checkAnswer(query4, Row(2, null) :: Row(2, 4) :: Nil)
}
testSchemaPruning("select nested field in Expand") {
import org.apache.spark.sql.catalyst.dsl.expressions._
val query1 = Expand(
Seq(
Seq($"name.first", $"name.last"),
Seq(Concat(Seq($"name.first", $"name.last")),
Concat(Seq($"name.last", $"name.first")))
),
Seq('a.string, 'b.string),
sql("select * from contacts").logicalPlan
).toDF()
checkScan(query1, "struct<name:struct<first:string,last:string>>")
checkAnswer(query1,
Row("Jane", "Doe") ::
Row("JaneDoe", "DoeJane") ::
Row("John", "Doe") ::
Row("JohnDoe", "DoeJohn") ::
Row("Jim", "Jones") ::
Row("JimJones", "JonesJim") ::
Row("Janet", "Jones") ::
Row("JanetJones", "JonesJanet") :: Nil)
val name = StructType.fromDDL("first string, middle string, last string")
val query2 = Expand(
Seq(Seq($"name", $"name.last")),
Seq('a.struct(name), 'b.string),
sql("select * from contacts").logicalPlan
).toDF()
checkScan(query2, "struct<name:struct<first:string,middle:string,last:string>>")
checkAnswer(query2,
Row(Row("Jane", "X.", "Doe"), "Doe") ::
Row(Row("John", "Y.", "Doe"), "Doe") ::
Row(Row("Jim", null, "Jones"), "Jones") ::
Row(Row("Janet", null, "Jones"), "Jones") ::Nil)
}
protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = {
test(s"Spark vectorized reader - without partition data column - $testName") {
withSQLConf(vectorizedReaderEnabledKey -> "true") {