[SPARK-26837][SQL] Pruning nested fields from object serializers

## What changes were proposed in this pull request?

In SPARK-26619, we make change to prune unnecessary individual serializers when serializing objects. This is extension to SPARK-26619. We can further prune nested fields from object serializers if they are not used.

For example, in following query, we only use one field in a struct column:

```scala
val data = Seq((("a", 1), 1), (("b", 2), 2), (("c", 3), 3))
val df = data.toDS().map(t => (t._1, t._2 + 1)).select("_1._1")
```

So, instead of having a serializer to create a two fields struct, we can prune unnecessary field from it. This is what this PR proposes to do.

In order to make this change conservative and safer, a SQL config is added to control it. It is disabled by default.

TODO: Support to prune nested fields inside MapType's key and value.

## How was this patch tested?

Added tests.

Closes #23740 from viirya/nested-pruning-serializer-2.

Authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Liang-Chi Hsieh 2019-02-27 12:45:24 +08:00 committed by Wenchen Fan
parent 9c283662c6
commit 0f2c0b53e8
14 changed files with 545 additions and 160 deletions

View file

@ -15,9 +15,8 @@
* limitations under the License.
*/
package org.apache.spark.sql.execution
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.{Expression, GetStructField}
import org.apache.spark.sql.types.StructField
/**
@ -25,7 +24,7 @@ import org.apache.spark.sql.types.StructField
* This is in contrast to the [[GetStructField]] case class extractor which returns the field
* ordinal instead of the field itself.
*/
private[execution] object GetStructFieldObject {
object GetStructFieldObject {
def unapply(getStructField: GetStructField): Option[(Expression, StructField)] =
Some((
getStructField.child,

View file

@ -15,9 +15,8 @@
* limitations under the License.
*/
package org.apache.spark.sql.execution
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
/**
@ -26,7 +25,7 @@ import org.apache.spark.sql.types._
* are adjusted to fit the schema. All other expressions are left as-is. This
* class is motivated by columnar nested schema pruning.
*/
private[execution] case class ProjectionOverSchema(schema: StructType) {
case class ProjectionOverSchema(schema: StructType) {
private val fieldNames = schema.fieldNames.toSet
def unapply(expr: Expression): Option[Expression] = getProjection(expr)

View file

@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.types._
object SchemaPruning {
/**
* Filters the schema by the requested fields. For example, if the schema is struct<a:int, b:int>,
* and given requested field are "a", the field "b" is pruned in the returned schema.
* Note that schema field ordering at original schema is still preserved in pruned schema.
*/
def pruneDataSchema(
dataSchema: StructType,
requestedRootFields: Seq[RootField]): StructType = {
// Merge the requested root fields into a single schema. Note the ordering of the fields
// in the resulting schema may differ from their ordering in the logical relation's
// original schema
val mergedSchema = requestedRootFields
.map { case root: RootField => StructType(Array(root.field)) }
.reduceLeft(_ merge _)
val dataSchemaFieldNames = dataSchema.fieldNames.toSet
val mergedDataSchema =
StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name)))
// Sort the fields of mergedDataSchema according to their order in dataSchema,
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
}
/**
* Sorts the fields and descendant fields of structs in left according to their order in
* right. This function assumes that the fields of left are a subset of the fields of
* right, recursively. That is, left is a "subschema" of right, ignoring order of
* fields.
*/
private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType =
(left, right) match {
case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) =>
ArrayType(
sortLeftFieldsByRight(leftElementType, rightElementType),
containsNull)
case (MapType(leftKeyType, leftValueType, containsNull),
MapType(rightKeyType, rightValueType, _)) =>
MapType(
sortLeftFieldsByRight(leftKeyType, rightKeyType),
sortLeftFieldsByRight(leftValueType, rightValueType),
containsNull)
case (leftStruct: StructType, rightStruct: StructType) =>
val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains)
val sortedLeftFields = filteredRightFieldNames.map { fieldName =>
val leftFieldType = leftStruct(fieldName).dataType
val rightFieldType = rightStruct(fieldName).dataType
val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType)
StructField(fieldName, sortedLeftFieldType, nullable = leftStruct(fieldName).nullable)
}
StructType(sortedLeftFields)
case _ => left
}
/**
* Returns the set of fields from projection and filtering predicates that the query plan needs.
*/
def identifyRootFields(
projects: Seq[NamedExpression],
filters: Seq[Expression]): Seq[RootField] = {
val projectionRootFields = projects.flatMap(getRootFields)
val filterRootFields = filters.flatMap(getRootFields)
// Kind of expressions don't need to access any fields of a root fields, e.g., `IsNotNull`.
// For them, if there are any nested fields accessed in the query, we don't need to add root
// field access of above expressions.
// For example, for a query `SELECT name.first FROM contacts WHERE name IS NOT NULL`,
// we don't need to read nested fields of `name` struct other than `first` field.
val (rootFields, optRootFields) = (projectionRootFields ++ filterRootFields)
.distinct.partition(!_.prunedIfAnyChildAccessed)
optRootFields.filter { opt =>
!rootFields.exists { root =>
root.field.name == opt.field.name && {
// Checking if current optional root field can be pruned.
// For each required root field, we merge it with the optional root field:
// 1. If this optional root field has nested fields and any nested field of it is used
// in the query, the merged field type must equal to the optional root field type.
// We can prune this optional root field. For example, for optional root field
// `struct<name:struct<middle:string,last:string>>`, if its field
// `struct<name:struct<last:string>>` is used, we don't need to add this optional
// root field.
// 2. If this optional root field has no nested fields, the merged field type equals
// to the optional root field only if they are the same. If they are, we can prune
// this optional root field too.
val rootFieldType = StructType(Array(root.field))
val optFieldType = StructType(Array(opt.field))
val merged = optFieldType.merge(rootFieldType)
merged.sameType(optFieldType)
}
}
} ++ rootFields
}
/**
* Gets the root (aka top-level, no-parent) [[StructField]]s for the given [[Expression]].
* When expr is an [[Attribute]], construct a field around it and indicate that that
* field was derived from an attribute.
*/
private def getRootFields(expr: Expression): Seq[RootField] = {
expr match {
case att: Attribute =>
RootField(StructField(att.name, att.dataType, att.nullable), derivedFromAtt = true) :: Nil
case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil
// Root field accesses by `IsNotNull` and `IsNull` are special cases as the expressions
// don't actually use any nested fields. These root field accesses might be excluded later
// if there are any nested fields accesses in the query plan.
case IsNotNull(SelectedField(field)) =>
RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = true) :: Nil
case IsNull(SelectedField(field)) =>
RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = true) :: Nil
case IsNotNull(_: Attribute) | IsNull(_: Attribute) =>
expr.children.flatMap(getRootFields).map(_.copy(prunedIfAnyChildAccessed = true))
case _ =>
expr.children.flatMap(getRootFields)
}
}
/**
* This represents a "root" schema field (aka top-level, no-parent). `field` is the
* `StructField` for field name and datatype. `derivedFromAtt` indicates whether it
* was derived from an attribute or had a proper child. `prunedIfAnyChildAccessed` means
* whether this root field can be pruned if any of child field is used in the query.
*/
case class RootField(field: StructField, derivedFromAtt: Boolean,
prunedIfAnyChildAccessed: Boolean = false)
}

View file

@ -15,10 +15,9 @@
* limitations under the License.
*/
package org.apache.spark.sql.execution
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
/**
@ -53,7 +52,7 @@ import org.apache.spark.sql.types._
* is "name" and its data type is a [[org.apache.spark.sql.types.StructType]] with a single string
* field named "first".
*/
private[execution] object SelectedField {
object SelectedField {
def unapply(expr: Expression): Option[StructField] = {
// If this expression is an alias, work on its child instead
val unaliased = expr match {

View file

@ -197,7 +197,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog)
DecimalAggregates) :+
Batch("Object Expressions Optimization", fixedPoint,
EliminateMapObjects,
CombineTypedFilters) :+
CombineTypedFilters,
ObjectSerializerPruning) :+
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation,
PropagateEmptyRelation) :+
@ -594,11 +595,6 @@ object ColumnPruning extends Rule[LogicalPlan] {
case d @ DeserializeToObject(_, _, child) if !child.outputSet.subsetOf(d.references) =>
d.copy(child = prunedChild(child, d.references))
case p @ Project(_, s: SerializeFromObject) if p.references != s.outputSet =>
val usedRefs = p.references
val prunedSerializer = s.serializer.filter(usedRefs.contains)
p.copy(child = SerializeFromObject(prunedSerializer, s.child))
// Prunes the unused columns from child of Aggregate/Expand/Generate/ScriptTransformation
case a @ Aggregate(_, _, child) if !child.outputSet.subsetOf(a.references) =>
a.copy(child = prunedChild(child, a.references))

View file

@ -17,11 +17,15 @@
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, StructType}
/*
* This file defines optimization rules related to object manipulation (for the Dataset API).
@ -109,3 +113,101 @@ object EliminateMapObjects extends Rule[LogicalPlan] {
case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData
}
}
/**
* Prunes unnecessary object serializers from query plan. This rule prunes both individual
* serializer and nested fields in serializers.
*/
object ObjectSerializerPruning extends Rule[LogicalPlan] {
/**
* Collects all struct types from given data type object, recursively. Supports struct and array
* types for now.
* TODO(SPARK-26847): support map type.
*/
def collectStructType(dt: DataType, structs: ArrayBuffer[StructType]): ArrayBuffer[StructType] = {
dt match {
case s @ StructType(fields) =>
structs += s
fields.map(f => collectStructType(f.dataType, structs))
case ArrayType(elementType, _) =>
collectStructType(elementType, structs)
case _ =>
}
structs
}
/**
* This method prunes given serializer expression by given pruned data type. For example,
* given a serializer creating struct(a int, b int) and pruned data type struct(a int),
* this method returns pruned serializer creating struct(a int). For now it supports to
* prune nested fields in struct and array of struct.
* TODO(SPARK-26847): support to prune nested fields in key and value of map type.
*/
def pruneSerializer(
serializer: NamedExpression,
prunedDataType: DataType): NamedExpression = {
val prunedStructTypes = collectStructType(prunedDataType, ArrayBuffer.empty[StructType])
var structTypeIndex = 0
val prunedSerializer = serializer.transformDown {
case s: CreateNamedStruct if structTypeIndex < prunedStructTypes.size =>
val prunedType = prunedStructTypes(structTypeIndex)
// Filters out the pruned fields.
val prunedFields = s.nameExprs.zip(s.valExprs).filter { case (nameExpr, _) =>
val name = nameExpr.eval(EmptyRow).toString
prunedType.fieldNames.exists { fieldName =>
if (SQLConf.get.caseSensitiveAnalysis) {
fieldName.equals(name)
} else {
fieldName.equalsIgnoreCase(name)
}
}
}.flatMap(pair => Seq(pair._1, pair._2))
structTypeIndex += 1
CreateNamedStruct(prunedFields)
}.transformUp {
// When we change nested serializer data type, `If` expression will be unresolved because
// literal null's data type doesn't match now. We need to align it with new data type.
// Note: we should do `transformUp` explicitly to change data types.
case i @ If(_: IsNull, Literal(null, dt), ser) if !dt.sameType(ser.dataType) =>
i.copy(trueValue = Literal(null, ser.dataType))
}.asInstanceOf[NamedExpression]
if (prunedSerializer.dataType.sameType(prunedDataType)) {
prunedSerializer
} else {
serializer
}
}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p @ Project(_, s: SerializeFromObject) =>
// Prunes individual serializer if it is not used at all by above projection.
val usedRefs = p.references
val prunedSerializer = s.serializer.filter(usedRefs.contains)
val rootFields = SchemaPruning.identifyRootFields(p.projectList, Seq.empty)
if (SQLConf.get.serializerNestedSchemaPruningEnabled && rootFields.nonEmpty) {
// Prunes nested fields in serializers.
val prunedSchema = SchemaPruning.pruneDataSchema(
StructType.fromAttributes(prunedSerializer.map(_.toAttribute)), rootFields)
val nestedPrunedSerializer = prunedSerializer.zipWithIndex.map { case (serializer, idx) =>
pruneSerializer(serializer, prunedSchema(idx).dataType)
}
// Builds new projection.
val projectionOverSchema = ProjectionOverSchema(prunedSchema)
val newProjects = p.projectList.map(_.transformDown {
case projectionOverSchema(expr) => expr
}).map { case expr: NamedExpression => expr }
p.copy(projectList = newProjects,
child = SerializeFromObject(nestedPrunedSerializer, s.child))
} else {
p.copy(child = SerializeFromObject(prunedSerializer, s.child))
}
}
}

View file

@ -1528,6 +1528,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
val SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED =
buildConf("spark.sql.optimizer.serializer.nestedSchemaPruning.enabled")
.internal()
.doc("Prune nested fields from object serialization operator which are unnecessary in " +
"satisfying a query. This optimization allows object serializers to avoid " +
"executing unnecessary nested expressions.")
.booleanConf
.createWithDefault(false)
val TOP_K_SORT_FALLBACK_THRESHOLD =
buildConf("spark.sql.execution.topKSortFallbackThreshold")
.internal()
@ -2077,6 +2086,9 @@ class SQLConf extends Serializable with Logging {
def nestedSchemaPruningEnabled: Boolean = getConf(NESTED_SCHEMA_PRUNING_ENABLED)
def serializerNestedSchemaPruningEnabled: Boolean =
getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED)
def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING)
def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL)

View file

@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
class SchemaPruningSuite extends SparkFunSuite {
test("prune schema by the requested fields") {
def testPrunedSchema(
schema: StructType,
requestedFields: StructField*): Unit = {
val requestedRootFields = requestedFields.map { f =>
// `derivedFromAtt` doesn't affect the result of pruned schema.
SchemaPruning.RootField(field = f, derivedFromAtt = true)
}
val expectedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
assert(expectedSchema == StructType(requestedFields))
}
testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("a", IntegerType))
testPrunedSchema(StructType.fromDDL("a int, b int"), StructField("b", IntegerType))
val structOfStruct = StructType.fromDDL("a struct<a:int, b:int>, b int")
testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("a int, b int")))
testPrunedSchema(structOfStruct, StructField("b", IntegerType))
testPrunedSchema(structOfStruct, StructField("a", StructType.fromDDL("b int")))
val arrayOfStruct = StructField("a", ArrayType(StructType.fromDDL("a int, b int, c string")))
val mapOfStruct = StructField("d", MapType(StructType.fromDDL("a int, b int, c string"),
StructType.fromDDL("d int, e int, f string")))
val complexStruct = StructType(
arrayOfStruct :: StructField("b", structOfStruct) :: StructField("c", IntegerType) ::
mapOfStruct :: Nil)
testPrunedSchema(complexStruct, StructField("a", ArrayType(StructType.fromDDL("b int"))),
StructField("b", StructType.fromDDL("a int")))
testPrunedSchema(complexStruct,
StructField("a", ArrayType(StructType.fromDDL("b int, c string"))),
StructField("b", StructType.fromDDL("b int")))
val selectFieldInMap = StructField("d", MapType(StructType.fromDDL("a int, b int"),
StructType.fromDDL("e int, f string")))
testPrunedSchema(complexStruct, StructField("c", IntegerType), selectFieldInMap)
}
}

View file

@ -15,14 +15,13 @@
* limitations under the License.
*/
package org.apache.spark.sql.execution
package org.apache.spark.sql.catalyst.expressions
import org.scalatest.BeforeAndAfterAll
import org.scalatest.exceptions.TestFailedException
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions.NamedExpression
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._

View file

@ -399,15 +399,5 @@ class ColumnPruningSuite extends PlanTest {
val expected = input.where(rand(0L) > 0.5).where('key < 10).select('key).analyze
comparePlans(optimized, expected)
}
test("SPARK-26619: Prune the unused serializers from SerializeFromObject") {
val testRelation = LocalRelation('_1.int, '_2.int)
val serializerObject = CatalystSerde.serialize[(Int, Int)](
CatalystSerde.deserialize[(Int, Int)](testRelation))
val query = serializerObject.select('_1)
val optimized = Optimize.execute(query.analyze)
val expected = serializerObject.copy(serializer = Seq(serializerObject.serializer.head)).analyze
comparePlans(optimized, expected)
}
// todo: add more tests for column pruning
}

View file

@ -0,0 +1,103 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.mutable.ArrayBuffer
import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
class ObjectSerializerPruningSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Object serializer pruning", FixedPoint(100),
ObjectSerializerPruning,
RemoveNoopOperators) :: Nil
}
implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
test("collect struct types") {
val dataTypes = Seq(
IntegerType,
ArrayType(IntegerType),
StructType.fromDDL("a int, b int"),
ArrayType(StructType.fromDDL("a int, b int, c string")),
StructType.fromDDL("a struct<a:int, b:int>, b int")
)
val expectedTypes = Seq(
Seq.empty[StructType],
Seq.empty[StructType],
Seq(StructType.fromDDL("a int, b int")),
Seq(StructType.fromDDL("a int, b int, c string")),
Seq(StructType.fromDDL("a struct<a:int, b:int>, b int"),
StructType.fromDDL("a int, b int"))
)
dataTypes.zipWithIndex.foreach { case (dt, idx) =>
val structs = ObjectSerializerPruning.collectStructType(dt, ArrayBuffer.empty[StructType])
assert(structs === expectedTypes(idx))
}
}
test("SPARK-26619: Prune the unused serializers from SerializeFromObject") {
val testRelation = LocalRelation('_1.int, '_2.int)
val serializerObject = CatalystSerde.serialize[(Int, Int)](
CatalystSerde.deserialize[(Int, Int)](testRelation))
val query = serializerObject.select('_1)
val optimized = Optimize.execute(query.analyze)
val expected = serializerObject.copy(serializer = Seq(serializerObject.serializer.head)).analyze
comparePlans(optimized, expected)
}
test("Prune nested serializers") {
withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") {
val testRelation = LocalRelation('_1.struct(StructType.fromDDL("_1 int, _2 string")), '_2.int)
val serializerObject = CatalystSerde.serialize[((Int, String), Int)](
CatalystSerde.deserialize[((Int, String), Int)](testRelation))
val query = serializerObject.select($"_1._1")
val optimized = Optimize.execute(query.analyze)
val prunedSerializer = serializerObject.serializer.head.transformDown {
case CreateNamedStruct(children) =>
CreateNamedStruct(children.take(2))
}.transformUp {
// Aligns null literal in `If` expression to make it resolvable.
case i @ If(_: IsNull, Literal(null, dt), ser) if !dt.sameType(ser.dataType) =>
i.copy(trueValue = Literal(null, ser.dataType))
}.asInstanceOf[NamedExpression]
// `name` in `GetStructField` affects `comparePlans`. Maybe we can ignore
// `name` in `GetStructField.equals`?
val expected = serializerObject.copy(serializer = Seq(prunedSerializer))
.select($"_1._1").analyze.transformAllExpressions {
case g: GetStructField => g.copy(name = None)
}
comparePlans(optimized, expected)
}
}
}

View file

@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.{ProjectionOverSchema, SelectedField}
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, StructType}
@ -32,7 +31,9 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructField, St
* Parquet format. In Spark SQL, a root-level Parquet column corresponds to a
* SQL column, and a nested Parquet column corresponds to a [[StructField]].
*/
private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] {
object ParquetSchemaPruning extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.expressions.SchemaPruning._
override def apply(plan: LogicalPlan): LogicalPlan =
if (SQLConf.get.nestedSchemaPruningEnabled) {
apply0(plan)
@ -103,44 +104,6 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] {
(normalizedProjects, normalizedFilters)
}
/**
* Returns the set of fields from the Parquet file that the query plan needs.
*/
private def identifyRootFields(projects: Seq[NamedExpression], filters: Seq[Expression]) = {
val projectionRootFields = projects.flatMap(getRootFields)
val filterRootFields = filters.flatMap(getRootFields)
// Kind of expressions don't need to access any fields of a root fields, e.g., `IsNotNull`.
// For them, if there are any nested fields accessed in the query, we don't need to add root
// field access of above expressions.
// For example, for a query `SELECT name.first FROM contacts WHERE name IS NOT NULL`,
// we don't need to read nested fields of `name` struct other than `first` field.
val (rootFields, optRootFields) = (projectionRootFields ++ filterRootFields)
.distinct.partition(!_.prunedIfAnyChildAccessed)
optRootFields.filter { opt =>
!rootFields.exists { root =>
root.field.name == opt.field.name && {
// Checking if current optional root field can be pruned.
// For each required root field, we merge it with the optional root field:
// 1. If this optional root field has nested fields and any nested field of it is used
// in the query, the merged field type must equal to the optional root field type.
// We can prune this optional root field. For example, for optional root field
// `struct<name:struct<middle:string,last:string>>`, if its field
// `struct<name:struct<last:string>>` is used, we don't need to add this optional
// root field.
// 2. If this optional root field has no nested fields, the merged field type equals
// to the optional root field only if they are the same. If they are, we can prune
// this optional root field too.
val rootFieldType = StructType(Array(root.field))
val optFieldType = StructType(Array(opt.field))
val merged = optFieldType.merge(rootFieldType)
merged.sameType(optFieldType)
}
}
} ++ rootFields
}
/**
* Builds the new output [[Project]] Spark SQL operator that has the pruned output relation.
*/
@ -173,27 +136,6 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] {
Project(newProjects, projectionChild)
}
/**
* Filters the schema from the given file by the requested fields.
* Schema field ordering from the file is preserved.
*/
private def pruneDataSchema(
fileDataSchema: StructType,
requestedRootFields: Seq[RootField]) = {
// Merge the requested root fields into a single schema. Note the ordering of the fields
// in the resulting schema may differ from their ordering in the logical relation's
// original schema
val mergedSchema = requestedRootFields
.map { case root: RootField => StructType(Array(root.field)) }
.reduceLeft(_ merge _)
val dataSchemaFieldNames = fileDataSchema.fieldNames.toSet
val mergedDataSchema =
StructType(mergedSchema.filter(f => dataSchemaFieldNames.contains(f.name)))
// Sort the fields of mergedDataSchema according to their order in dataSchema,
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
sortLeftFieldsByRight(mergedDataSchema, fileDataSchema).asInstanceOf[StructType]
}
/**
* Builds a pruned logical relation from the output of the output relation and the schema of the
* pruned base relation.
@ -217,30 +159,6 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] {
outputRelation.copy(relation = prunedBaseRelation, output = prunedRelationOutput)
}
/**
* Gets the root (aka top-level, no-parent) [[StructField]]s for the given [[Expression]].
* When expr is an [[Attribute]], construct a field around it and indicate that that
* field was derived from an attribute.
*/
private def getRootFields(expr: Expression): Seq[RootField] = {
expr match {
case att: Attribute =>
RootField(StructField(att.name, att.dataType, att.nullable), derivedFromAtt = true) :: Nil
case SelectedField(field) => RootField(field, derivedFromAtt = false) :: Nil
// Root field accesses by `IsNotNull` and `IsNull` are special cases as the expressions
// don't actually use any nested fields. These root field accesses might be excluded later
// if there are any nested fields accesses in the query plan.
case IsNotNull(SelectedField(field)) =>
RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = true) :: Nil
case IsNull(SelectedField(field)) =>
RootField(field, derivedFromAtt = false, prunedIfAnyChildAccessed = true) :: Nil
case IsNotNull(_: Attribute) | IsNull(_: Attribute) =>
expr.children.flatMap(getRootFields).map(_.copy(prunedIfAnyChildAccessed = true))
case _ =>
expr.children.flatMap(getRootFields)
}
}
/**
* Counts the "leaf" fields of the given dataType. Informally, this is the
* number of fields of non-complex data type in the tree representation of
@ -256,42 +174,5 @@ private[sql] object ParquetSchemaPruning extends Rule[LogicalPlan] {
}
}
/**
* Sorts the fields and descendant fields of structs in left according to their order in
* right. This function assumes that the fields of left are a subset of the fields of
* right, recursively. That is, left is a "subschema" of right, ignoring order of
* fields.
*/
private def sortLeftFieldsByRight(left: DataType, right: DataType): DataType =
(left, right) match {
case (ArrayType(leftElementType, containsNull), ArrayType(rightElementType, _)) =>
ArrayType(
sortLeftFieldsByRight(leftElementType, rightElementType),
containsNull)
case (MapType(leftKeyType, leftValueType, containsNull),
MapType(rightKeyType, rightValueType, _)) =>
MapType(
sortLeftFieldsByRight(leftKeyType, rightKeyType),
sortLeftFieldsByRight(leftValueType, rightValueType),
containsNull)
case (leftStruct: StructType, rightStruct: StructType) =>
val filteredRightFieldNames = rightStruct.fieldNames.filter(leftStruct.fieldNames.contains)
val sortedLeftFields = filteredRightFieldNames.map { fieldName =>
val leftFieldType = leftStruct(fieldName).dataType
val rightFieldType = rightStruct(fieldName).dataType
val sortedLeftFieldType = sortLeftFieldsByRight(leftFieldType, rightFieldType)
StructField(fieldName, sortedLeftFieldType)
}
StructType(sortedLeftFields)
case _ => left
}
/**
* This represents a "root" schema field (aka top-level, no-parent). `field` is the
* `StructField` for field name and datatype. `derivedFromAtt` indicates whether it
* was derived from an attribute or had a proper child. `prunedIfAnyChildAccessed` means
* whether this root field can be pruned if any of child field is used in the query.
*/
private case class RootField(field: StructField, derivedFromAtt: Boolean,
prunedIfAnyChildAccessed: Boolean = false)
}

View file

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
import org.apache.spark.sql.catalyst.plans.logical.SerializeFromObject
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
class DatasetOptimizationSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("SPARK-26619: Prune the unused serializers from SerializeFromObject") {
val data = Seq(("a", 1), ("b", 2), ("c", 3))
val ds = data.toDS().map(t => (t._1, t._2 + 1)).select("_1")
val serializer = ds.queryExecution.optimizedPlan.collect {
case s: SerializeFromObject => s
}.head
assert(serializer.serializer.size == 1)
checkAnswer(ds, Seq(Row("a"), Row("b"), Row("c")))
}
// This methods checks if the given DataFrame has specified struct fields in object
// serializer. The varargs parameter `structFields` is the struct fields for object
// serializers. The first `structFields` is aligned with first serializer and ditto
// for other `structFields`.
private def testSerializer(df: DataFrame, structFields: Seq[Seq[String]]*): Unit = {
val serializer = df.queryExecution.optimizedPlan.collect {
case s: SerializeFromObject => s
}.head
serializer.serializer.zip(structFields).foreach { case (serializer, fields) =>
val structs = serializer.collect {
case c: CreateNamedStruct => c
}
assert(structs.size == fields.size)
structs.zip(fields).foreach { case (struct, fieldNames) =>
assert(struct.names.map(_.toString) == fieldNames)
}
}
}
test("Prune nested serializers: struct") {
withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") {
val data = Seq((("a", 1, ("aa", 1.0)), 1), (("b", 2, ("bb", 2.0)), 2),
(("c", 3, ("cc", 3.0)), 3))
val ds = data.toDS().map(t => (t._1, t._2 + 1))
val df1 = ds.select("_1._1")
testSerializer(df1, Seq(Seq("_1")))
checkAnswer(df1, Seq(Row("a"), Row("b"), Row("c")))
val df2 = ds.select("_1._2")
testSerializer(df2, Seq(Seq("_2")))
checkAnswer(df2, Seq(Row(1), Row(2), Row(3)))
val df3 = ds.select("_1._3._1")
testSerializer(df3, Seq(Seq("_3"), Seq("_1")))
checkAnswer(df3, Seq(Row("aa"), Row("bb"), Row("cc")))
val df4 = ds.select("_1._3._1", "_1._2")
testSerializer(df4, Seq(Seq("_2", "_3"), Seq("_1")))
checkAnswer(df4, Seq(Row("aa", 1), Row("bb", 2), Row("cc", 3)))
}
}
test("Prune nested serializers: array of struct") {
withSQLConf(SQLConf.SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED.key -> "true") {
val arrayData = Seq((Seq(("a", 1, ("a_1", 11)), ("b", 2, ("b_1", 22))), 1, ("aa", 1.0)),
(Seq(("c", 3, ("c_1", 33)), ("d", 4, ("d_1", 44))), 2, ("bb", 2.0)))
val arrayDs = arrayData.toDS().map(t => (t._1, t._2 + 1, t._3))
val df1 = arrayDs.select("_1._1")
// The serializer creates array of struct of one field "_1".
testSerializer(df1, Seq(Seq("_1")))
checkAnswer(df1, Seq(Row(Seq("a", "b")), Row(Seq("c", "d"))))
val df2 = arrayDs.select("_3._2")
testSerializer(df2, Seq(Seq("_2")))
checkAnswer(df2, Seq(Row(1.0), Row(2.0)))
// This is a more complex case. We select two root fields "_1" and "_3".
// The first serializer creates array of struct of two fields ("_1", "_3") and
// the field "_3" is a struct of one field "_2".
// The second serializer creates a struct of just one field "_1".
val df3 = arrayDs.select("_1._1", "_1._3._2", "_3._1")
testSerializer(df3, Seq(Seq("_1", "_3"), Seq("_2")), Seq(Seq("_1")))
checkAnswer(df3, Seq(Row(Seq("a", "b"), Seq(11, 22), "aa"),
Row(Seq("c", "d"), Seq(33, 44), "bb")))
}
}
}

View file

@ -26,7 +26,6 @@ import org.apache.spark.{SparkException, TaskContext}
import org.apache.spark.sql.catalyst.ScroogeLikeExample
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
import org.apache.spark.sql.catalyst.plans.logical.SerializeFromObject
import org.apache.spark.sql.catalyst.util.sideBySide
import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec, SQLExecution}
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec}
@ -1707,16 +1706,6 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkAnswer(inputDF.except(exceptDF), Seq(Row("1", null)))
}
test("SPARK-26619: Prune the unused serializers from SerializeFromObjec") {
val data = Seq(("a", 1), ("b", 2), ("c", 3))
val ds = data.toDS().map(t => (t._1, t._2 + 1)).select("_1")
val serializer = ds.queryExecution.optimizedPlan.collect {
case s: SerializeFromObject => s
}.head
assert(serializer.serializer.size == 1)
checkAnswer(ds, Seq(Row("a"), Row("b"), Row("c")))
}
test("SPARK-26706: Fix Cast.mayTruncate for bytes") {
val thrownException = intercept[AnalysisException] {
spark.range(Long.MaxValue - 10, Long.MaxValue).as[Byte]