[SPARK-27217][SQL] Nested column aliasing for more operators which can prune nested column
### What changes were proposed in this pull request? Currently we only push nested column pruning from a Project through a few operators such as LIMIT, SAMPLE, etc. There are a few operators like Aggregate, Expand which can prune nested columns by themselves, without a Project on top. This patch extends the feature to those operators. ### Why are the changes needed? Currently nested column pruning only applied on a few cases. It limits the benefit of nested column pruning. Extending nested column pruning coverage to make this feature more generally applied through different queries. ### Does this PR introduce _any_ user-facing change? Yes. More SQL operators are covered by nested column pruning. ### How was this patch tested? Added unit test, end-to-end tests. Closes #28560 from viirya/SPARK-27217-2. Authored-by: Liang-Chi Hsieh <viirya@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
82ff29be7a
commit
43063e2db2
|
@ -35,6 +35,11 @@ object NestedColumnAliasing {
|
|||
case Project(projectList, child)
|
||||
if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) =>
|
||||
getAliasSubMap(projectList)
|
||||
|
||||
case plan if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(plan) =>
|
||||
val exprCandidatesToPrune = plan.expressions
|
||||
getAliasSubMap(exprCandidatesToPrune, plan.producedAttributes.toSeq)
|
||||
|
||||
case _ => None
|
||||
}
|
||||
|
||||
|
@ -48,7 +53,11 @@ object NestedColumnAliasing {
|
|||
case Project(projectList, child) =>
|
||||
Project(
|
||||
getNewProjectList(projectList, nestedFieldToAlias),
|
||||
replaceChildrenWithAliases(child, attrToAliases))
|
||||
replaceChildrenWithAliases(child, nestedFieldToAlias, attrToAliases))
|
||||
|
||||
// The operators reaching here was already guarded by `canPruneOn`.
|
||||
case other =>
|
||||
replaceChildrenWithAliases(other, nestedFieldToAlias, attrToAliases)
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -68,10 +77,23 @@ object NestedColumnAliasing {
|
|||
*/
|
||||
def replaceChildrenWithAliases(
|
||||
plan: LogicalPlan,
|
||||
nestedFieldToAlias: Map[ExtractValue, Alias],
|
||||
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
|
||||
plan.withNewChildren(plan.children.map { plan =>
|
||||
Project(plan.output.flatMap(a => attrToAliases.getOrElse(a.exprId, Seq(a))), plan)
|
||||
})
|
||||
}).transformExpressions {
|
||||
case f: ExtractValue if nestedFieldToAlias.contains(f) =>
|
||||
nestedFieldToAlias(f).toAttribute
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns true for those operators that we can prune nested column on it.
|
||||
*/
|
||||
private def canPruneOn(plan: LogicalPlan) = plan match {
|
||||
case _: Aggregate => true
|
||||
case _: Expand => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -204,15 +226,8 @@ object GeneratorNestedColumnAliasing {
|
|||
g: Generate,
|
||||
nestedFieldToAlias: Map[ExtractValue, Alias],
|
||||
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
|
||||
val newGenerator = g.generator.transform {
|
||||
case f: ExtractValue if nestedFieldToAlias.contains(f) =>
|
||||
nestedFieldToAlias(f).toAttribute
|
||||
}.asInstanceOf[Generator]
|
||||
|
||||
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
|
||||
val newGenerate = g.copy(generator = newGenerator)
|
||||
|
||||
NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases)
|
||||
NestedColumnAliasing.replaceChildrenWithAliases(g, nestedFieldToAlias, attrToAliases)
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -341,6 +341,100 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
|
|||
.analyze
|
||||
comparePlans(optimized, expected)
|
||||
}
|
||||
|
||||
test("Nested field pruning for Aggregate") {
|
||||
def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = {
|
||||
val query1 = basePlan(contact).groupBy($"id")(first($"name.first").as("first")).analyze
|
||||
val optimized1 = Optimize.execute(query1)
|
||||
val aliases1 = collectGeneratedAliases(optimized1)
|
||||
|
||||
val expected1 = basePlan(
|
||||
contact
|
||||
.select($"id", 'name.getField("first").as(aliases1(0)))
|
||||
).groupBy($"id")(first($"${aliases1(0)}").as("first")).analyze
|
||||
comparePlans(optimized1, expected1)
|
||||
|
||||
val query2 = basePlan(contact).groupBy($"name.last")(first($"name.first").as("first")).analyze
|
||||
val optimized2 = Optimize.execute(query2)
|
||||
val aliases2 = collectGeneratedAliases(optimized2)
|
||||
|
||||
val expected2 = basePlan(
|
||||
contact
|
||||
.select('name.getField("last").as(aliases2(0)), 'name.getField("first").as(aliases2(1)))
|
||||
).groupBy($"${aliases2(0)}")(first($"${aliases2(1)}").as("first")).analyze
|
||||
comparePlans(optimized2, expected2)
|
||||
}
|
||||
|
||||
Seq(
|
||||
(plan: LogicalPlan) => plan,
|
||||
(plan: LogicalPlan) => plan.limit(100),
|
||||
(plan: LogicalPlan) => plan.repartition(100),
|
||||
(plan: LogicalPlan) => Sample(0.0, 0.6, false, 11L, plan)).foreach { base =>
|
||||
runTest(base)
|
||||
}
|
||||
|
||||
val query3 = contact.groupBy($"id")(first($"name"), first($"name.first").as("first")).analyze
|
||||
val optimized3 = Optimize.execute(query3)
|
||||
val expected3 = contact.select($"id", $"name")
|
||||
.groupBy($"id")(first($"name"), first($"name.first").as("first")).analyze
|
||||
comparePlans(optimized3, expected3)
|
||||
}
|
||||
|
||||
test("Nested field pruning for Expand") {
|
||||
def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = {
|
||||
val query1 = Expand(
|
||||
Seq(
|
||||
Seq($"name.first", $"name.middle"),
|
||||
Seq(ConcatWs(Seq($"name.first", $"name.middle")),
|
||||
ConcatWs(Seq($"name.middle", $"name.first")))
|
||||
),
|
||||
Seq('a.string, 'b.string),
|
||||
basePlan(contact)
|
||||
).analyze
|
||||
val optimized1 = Optimize.execute(query1)
|
||||
val aliases1 = collectGeneratedAliases(optimized1)
|
||||
|
||||
val expected1 = Expand(
|
||||
Seq(
|
||||
Seq($"${aliases1(0)}", $"${aliases1(1)}"),
|
||||
Seq(ConcatWs(Seq($"${aliases1(0)}", $"${aliases1(1)}")),
|
||||
ConcatWs(Seq($"${aliases1(1)}", $"${aliases1(0)}")))
|
||||
),
|
||||
Seq('a.string, 'b.string),
|
||||
basePlan(contact.select(
|
||||
'name.getField("first").as(aliases1(0)),
|
||||
'name.getField("middle").as(aliases1(1))))
|
||||
).analyze
|
||||
comparePlans(optimized1, expected1)
|
||||
}
|
||||
|
||||
Seq(
|
||||
(plan: LogicalPlan) => plan,
|
||||
(plan: LogicalPlan) => plan.limit(100),
|
||||
(plan: LogicalPlan) => plan.repartition(100),
|
||||
(plan: LogicalPlan) => Sample(0.0, 0.6, false, 11L, plan)).foreach { base =>
|
||||
runTest(base)
|
||||
}
|
||||
|
||||
val query2 = Expand(
|
||||
Seq(
|
||||
Seq($"name", $"name.middle"),
|
||||
Seq($"name", ConcatWs(Seq($"name.middle", $"name.first")))
|
||||
),
|
||||
Seq('a.string, 'b.string),
|
||||
contact
|
||||
).analyze
|
||||
val optimized2 = Optimize.execute(query2)
|
||||
val expected2 = Expand(
|
||||
Seq(
|
||||
Seq($"name", $"name.middle"),
|
||||
Seq($"name", ConcatWs(Seq($"name.middle", $"name.first")))
|
||||
),
|
||||
Seq('a.string, 'b.string),
|
||||
contact.select($"name")
|
||||
).analyze
|
||||
comparePlans(optimized2, expected2)
|
||||
}
|
||||
}
|
||||
|
||||
object NestedColumnAliasingSuite {
|
||||
|
|
|
@ -23,7 +23,9 @@ import org.scalactic.Equality
|
|||
|
||||
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
|
||||
import org.apache.spark.sql.catalyst.SchemaPruningTest
|
||||
import org.apache.spark.sql.catalyst.expressions.Concat
|
||||
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
|
||||
import org.apache.spark.sql.catalyst.plans.logical.Expand
|
||||
import org.apache.spark.sql.execution.FileSourceScanExec
|
||||
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
|
||||
import org.apache.spark.sql.functions._
|
||||
|
@ -338,6 +340,75 @@ abstract class SchemaPruningSuite
|
|||
}
|
||||
}
|
||||
|
||||
testSchemaPruning("select one deep nested complex field after repartition") {
|
||||
val query = sql("select * from contacts")
|
||||
.repartition(100)
|
||||
.where("employer.company.address is not null")
|
||||
.selectExpr("employer.id as employer_id")
|
||||
checkScan(query,
|
||||
"struct<employer:struct<id:int,company:struct<address:string>>>")
|
||||
checkAnswer(query, Row(0) :: Nil)
|
||||
}
|
||||
|
||||
testSchemaPruning("select nested field in aggregation function of Aggregate") {
|
||||
val query1 = sql("select count(name.first) from contacts group by name.last")
|
||||
checkScan(query1, "struct<name:struct<first:string,last:string>>")
|
||||
checkAnswer(query1, Row(2) :: Row(2) :: Nil)
|
||||
|
||||
val query2 = sql("select count(name.first), sum(pets) from contacts group by id")
|
||||
checkScan(query2, "struct<id:int,name:struct<first:string>,pets:int>")
|
||||
checkAnswer(query2, Row(1, 1) :: Row(1, null) :: Row(1, 3) :: Row(1, null) :: Nil)
|
||||
|
||||
val query3 = sql("select count(name.first), first(name) from contacts group by id")
|
||||
checkScan(query3, "struct<id:int,name:struct<first:string,middle:string,last:string>>")
|
||||
checkAnswer(query3,
|
||||
Row(1, Row("Jane", "X.", "Doe")) ::
|
||||
Row(1, Row("Jim", null, "Jones")) ::
|
||||
Row(1, Row("John", "Y.", "Doe")) ::
|
||||
Row(1, Row("Janet", null, "Jones")) :: Nil)
|
||||
|
||||
val query4 = sql("select count(name.first), sum(pets) from contacts group by name.last")
|
||||
checkScan(query4, "struct<name:struct<first:string,last:string>,pets:int>")
|
||||
checkAnswer(query4, Row(2, null) :: Row(2, 4) :: Nil)
|
||||
}
|
||||
|
||||
testSchemaPruning("select nested field in Expand") {
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
|
||||
val query1 = Expand(
|
||||
Seq(
|
||||
Seq($"name.first", $"name.last"),
|
||||
Seq(Concat(Seq($"name.first", $"name.last")),
|
||||
Concat(Seq($"name.last", $"name.first")))
|
||||
),
|
||||
Seq('a.string, 'b.string),
|
||||
sql("select * from contacts").logicalPlan
|
||||
).toDF()
|
||||
checkScan(query1, "struct<name:struct<first:string,last:string>>")
|
||||
checkAnswer(query1,
|
||||
Row("Jane", "Doe") ::
|
||||
Row("JaneDoe", "DoeJane") ::
|
||||
Row("John", "Doe") ::
|
||||
Row("JohnDoe", "DoeJohn") ::
|
||||
Row("Jim", "Jones") ::
|
||||
Row("JimJones", "JonesJim") ::
|
||||
Row("Janet", "Jones") ::
|
||||
Row("JanetJones", "JonesJanet") :: Nil)
|
||||
|
||||
val name = StructType.fromDDL("first string, middle string, last string")
|
||||
val query2 = Expand(
|
||||
Seq(Seq($"name", $"name.last")),
|
||||
Seq('a.struct(name), 'b.string),
|
||||
sql("select * from contacts").logicalPlan
|
||||
).toDF()
|
||||
checkScan(query2, "struct<name:struct<first:string,middle:string,last:string>>")
|
||||
checkAnswer(query2,
|
||||
Row(Row("Jane", "X.", "Doe"), "Doe") ::
|
||||
Row(Row("John", "Y.", "Doe"), "Doe") ::
|
||||
Row(Row("Jim", null, "Jones"), "Jones") ::
|
||||
Row(Row("Janet", null, "Jones"), "Jones") ::Nil)
|
||||
}
|
||||
|
||||
protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = {
|
||||
test(s"Spark vectorized reader - without partition data column - $testName") {
|
||||
withSQLConf(vectorizedReaderEnabledKey -> "true") {
|
||||
|
|
Loading…
Reference in a new issue