[SPARK-32968][SQL] Prune unnecessary columns from CsvToStructs

### What changes were proposed in this pull request?

This patch proposes to do column pruning for CsvToStructs expression if we only require some fields from it.

### Why are the changes needed?

`CsvToStructs` takes a schema parameter used to tell CSV Parser what fields are needed to parse. If `CsvToStructs` is followed by GetStructField. We can prune the schema to only parse certain field.

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Unit test

Closes #30912 from viirya/SPARK-32968.

Lead-authored-by: Liang-Chi Hsieh <viirya@gmail.com>
Co-authored-by: Hyukjin Kwon <gurwls223@gmail.com>
Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
Liang-Chi Hsieh 2020-12-29 21:37:17 +09:00 committed by HyukjinKwon
parent 2627825647
commit f9fe742442
8 changed files with 272 additions and 101 deletions

View file

@ -51,7 +51,8 @@ case class CsvToStructs(
schema: StructType,
options: Map[String, String],
child: Expression,
timeZoneId: Option[String] = None)
timeZoneId: Option[String] = None,
requiredSchema: Option[StructType] = None)
extends UnaryExpression
with TimeZoneAwareExpression
with CodegenFallback
@ -113,7 +114,12 @@ case class CsvToStructs(
val actualSchema =
StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions)
val actualRequiredSchema =
StructType(requiredSchema.map(_.asNullable).getOrElse(nullableSchema)
.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val rawParser = new UnivocityParser(actualSchema,
actualRequiredSchema,
parsedOptions)
new FailureSafeParser[String](
input => rawParser.parse(input),
mode,
@ -121,7 +127,7 @@ case class CsvToStructs(
parsedOptions.columnNameOfCorruptRecord)
}
override def dataType: DataType = nullableSchema
override def dataType: DataType = requiredSchema.getOrElse(schema).asNullable
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
copy(timeZoneId = Option(timeZoneId))

View file

@ -0,0 +1,120 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, StructType}
/**
* Simplify redundant csv/json related expressions.
*
* The optimization includes:
* 1. JsonToStructs(StructsToJson(child)) => child.
* 2. Prune unnecessary columns from GetStructField/GetArrayStructFields + JsonToStructs.
* 3. CreateNamedStruct(JsonToStructs(json).col1, JsonToStructs(json).col2, ...) =>
* If(IsNull(json), nullStruct, KnownNotNull(JsonToStructs(prunedSchema, ..., json)))
* if JsonToStructs(json) is shared among all fields of CreateNamedStruct. `prunedSchema`
* contains all accessed fields in original CreateNamedStruct.
* 4. Prune unnecessary columns from GetStructField + CsvToStructs.
*/
object OptimizeCsvJsonExprs extends Rule[LogicalPlan] {
private def nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p =>
val optimized = if (SQLConf.get.jsonExpressionOptimization) {
p.transformExpressions(jsonOptimization)
} else {
p
}
if (SQLConf.get.csvExpressionOptimization) {
optimized.transformExpressions(csvOptimization)
} else {
optimized
}
}
private val jsonOptimization: PartialFunction[Expression, Expression] = {
case c: CreateNamedStruct
// If we create struct from various fields of the same `JsonToStructs`.
if c.valExprs.forall { v =>
v.isInstanceOf[GetStructField] &&
v.asInstanceOf[GetStructField].child.isInstanceOf[JsonToStructs] &&
v.children.head.semanticEquals(c.valExprs.head.children.head)
} =>
val jsonToStructs = c.valExprs.map(_.children.head)
val sameFieldName = c.names.zip(c.valExprs).forall {
case (name, valExpr: GetStructField) =>
name.toString == valExpr.childSchema(valExpr.ordinal).name
case _ => false
}
// Although `CreateNamedStruct` allows duplicated field names, e.g. "a int, a int",
// `JsonToStructs` does not support parsing json with duplicated field names.
val duplicateFields = c.names.map(_.toString).distinct.length != c.names.length
// If we create struct from various fields of the same `JsonToStructs` and we don't
// alias field names and there is no duplicated field in the struct.
if (sameFieldName && !duplicateFields) {
val fromJson = jsonToStructs.head.asInstanceOf[JsonToStructs].copy(schema = c.dataType)
val nullFields = c.children.grouped(2).flatMap {
case Seq(name, value) => Seq(name, Literal(null, value.dataType))
}.toSeq
If(IsNull(fromJson.child), c.copy(children = nullFields), KnownNotNull(fromJson))
} else {
c
}
case jsonToStructs @ JsonToStructs(_, options1,
StructsToJson(options2, child, timeZoneId2), timeZoneId1)
if options1.isEmpty && options2.isEmpty && timeZoneId1 == timeZoneId2 &&
jsonToStructs.dataType == child.dataType =>
// `StructsToJson` only fails when `JacksonGenerator` encounters data types it
// cannot convert to JSON. But `StructsToJson.checkInputDataTypes` already
// verifies its child's data types is convertible to JSON. But in
// `StructsToJson(JsonToStructs(...))` case, we cannot verify input json string
// so `JsonToStructs` might throw error in runtime. Thus we cannot optimize
// this case similarly.
child
case g @ GetStructField(j @ JsonToStructs(schema: StructType, _, _, _), ordinal, _)
if schema.length > 1 =>
val prunedSchema = StructType(Seq(schema(ordinal)))
g.copy(child = j.copy(schema = prunedSchema), ordinal = 0)
case g @ GetArrayStructFields(j @ JsonToStructs(schema: ArrayType, _, _, _), _, _, _, _)
if schema.elementType.asInstanceOf[StructType].length > 1 =>
val prunedSchema = ArrayType(StructType(Seq(g.field)), g.containsNull)
g.copy(child = j.copy(schema = prunedSchema), ordinal = 0, numFields = 1)
}
private val csvOptimization: PartialFunction[Expression, Expression] = {
case g @ GetStructField(c @ CsvToStructs(schema: StructType, _, _, _, None), ordinal, _)
if schema.length > 1 && c.options.isEmpty && schema(ordinal).name != nameOfCorruptRecord =>
// When the parse mode is permissive, and corrupt column is not selected, we can prune here
// from `GetStructField`. To be more conservative, it does not optimize when any option
// is set.
val prunedSchema = StructType(Seq(schema(ordinal)))
g.copy(child = c.copy(requiredSchema = Some(prunedSchema)), ordinal = 0)
}
}

View file

@ -1,96 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, StructType}
/**
* Simplify redundant json related expressions.
*
* The optimization includes:
* 1. JsonToStructs(StructsToJson(child)) => child.
* 2. Prune unnecessary columns from GetStructField/GetArrayStructFields + JsonToStructs.
* 3. CreateNamedStruct(JsonToStructs(json).col1, JsonToStructs(json).col2, ...) =>
* If(IsNull(json), nullStruct, KnownNotNull(JsonToStructs(prunedSchema, ..., json)))
* if JsonToStructs(json) is shared among all fields of CreateNamedStruct. `prunedSchema`
* contains all accessed fields in original CreateNamedStruct.
*/
object OptimizeJsonExprs extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p if SQLConf.get.jsonExpressionOptimization => p.transformExpressions {
case c: CreateNamedStruct
// If we create struct from various fields of the same `JsonToStructs`.
if c.valExprs.forall { v =>
v.isInstanceOf[GetStructField] &&
v.asInstanceOf[GetStructField].child.isInstanceOf[JsonToStructs] &&
v.children.head.semanticEquals(c.valExprs.head.children.head)
} =>
val jsonToStructs = c.valExprs.map(_.children.head)
val sameFieldName = c.names.zip(c.valExprs).forall {
case (name, valExpr: GetStructField) =>
name.toString == valExpr.childSchema(valExpr.ordinal).name
case _ => false
}
// Although `CreateNamedStruct` allows duplicated field names, e.g. "a int, a int",
// `JsonToStructs` does not support parsing json with duplicated field names.
val duplicateFields = c.names.map(_.toString).distinct.length != c.names.length
// If we create struct from various fields of the same `JsonToStructs` and we don't
// alias field names and there is no duplicated field in the struct.
if (sameFieldName && !duplicateFields) {
val fromJson = jsonToStructs.head.asInstanceOf[JsonToStructs].copy(schema = c.dataType)
val nullFields = c.children.grouped(2).flatMap {
case Seq(name, value) => Seq(name, Literal(null, value.dataType))
}.toSeq
If(IsNull(fromJson.child), c.copy(children = nullFields), KnownNotNull(fromJson))
} else {
c
}
case jsonToStructs @ JsonToStructs(_, options1,
StructsToJson(options2, child, timeZoneId2), timeZoneId1)
if options1.isEmpty && options2.isEmpty && timeZoneId1 == timeZoneId2 &&
jsonToStructs.dataType == child.dataType =>
// `StructsToJson` only fails when `JacksonGenerator` encounters data types it
// cannot convert to JSON. But `StructsToJson.checkInputDataTypes` already
// verifies its child's data types is convertible to JSON. But in
// `StructsToJson(JsonToStructs(...))` case, we cannot verify input json string
// so `JsonToStructs` might throw error in runtime. Thus we cannot optimize
// this case similarly.
child
case g @ GetStructField(j @ JsonToStructs(schema: StructType, _, _, _), ordinal, _)
if schema.length > 1 =>
val prunedSchema = StructType(Seq(schema(ordinal)))
g.copy(child = j.copy(schema = prunedSchema), ordinal = 0)
case g @ GetArrayStructFields(j @ JsonToStructs(schema: ArrayType, _, _, _), _, _, _, _)
if schema.elementType.asInstanceOf[StructType].length > 1 =>
val prunedSchema = ArrayType(StructType(Seq(g.field)), g.containsNull)
g.copy(child = j.copy(schema = prunedSchema), ordinal = 0, numFields = 1)
}
}
}

View file

@ -114,7 +114,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
RemoveNoopOperators,
OptimizeUpdateFields,
SimplifyExtractValueOps,
OptimizeJsonExprs,
OptimizeCsvJsonExprs,
CombineConcats) ++
extendedOperatorOptimizationRules

View file

@ -1631,6 +1631,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
val CSV_EXPRESSION_OPTIMIZATION =
buildConf("spark.sql.optimizer.enableCsvExpressionOptimization")
.doc("Whether to optimize CSV expressions in SQL optimizer. It includes pruning " +
"unnecessary columns from from_csv.")
.version("3.2.0")
.booleanConf
.createWithDefault(true)
val FILE_SINK_LOG_DELETION = buildConf("spark.sql.streaming.fileSink.log.deletion")
.internal()
.doc("Whether to delete the expired log files in file stream sink.")
@ -3489,6 +3497,8 @@ class SQLConf extends Serializable with Logging {
def jsonExpressionOptimization: Boolean = getConf(SQLConf.JSON_EXPRESSION_OPTIMIZATION)
def csvExpressionOptimization: Boolean = getConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION)
def parallelFileListingInStatsComputation: Boolean =
getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION)

View file

@ -0,0 +1,83 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
class OptimizeCsvExprsSuite extends PlanTest with ExpressionEvalHelper {
private var csvExpressionOptimizeEnabled: Boolean = _
protected override def beforeAll(): Unit = {
csvExpressionOptimizeEnabled = SQLConf.get.csvExpressionOptimization
}
protected override def afterAll(): Unit = {
SQLConf.get.setConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION, csvExpressionOptimizeEnabled)
}
object Optimizer extends RuleExecutor[LogicalPlan] {
val batches = Batch("Csv optimization", FixedPoint(10), OptimizeCsvJsonExprs) :: Nil
}
val schema = StructType.fromDDL("a int, b int")
private val csvAttr = 'csv.string
private val testRelation = LocalRelation(csvAttr)
test("SPARK-32968: prune unnecessary columns from GetStructField + from_csv") {
val options = Map.empty[String, String]
val query1 = testRelation
.select(GetStructField(CsvToStructs(schema, options, 'csv), 0))
val optimized1 = Optimizer.execute(query1.analyze)
val prunedSchema1 = StructType.fromDDL("a int")
val expected1 = testRelation
.select(GetStructField(CsvToStructs(schema, options, 'csv, None, Some(prunedSchema1)), 0))
.analyze
comparePlans(optimized1, expected1)
val query2 = testRelation
.select(GetStructField(CsvToStructs(schema, options, 'csv), 1))
val optimized2 = Optimizer.execute(query2.analyze)
val prunedSchema2 = StructType.fromDDL("b int")
val expected2 = testRelation
.select(GetStructField(CsvToStructs(schema, options, 'csv, None, Some(prunedSchema2)), 0))
.analyze
comparePlans(optimized2, expected2)
}
test("SPARK-32968: don't prune columns if options is not empty") {
val options = Map("mode" -> "failfast")
val query = testRelation
.select(GetStructField(CsvToStructs(schema, options, 'csv), 0))
val optimized = Optimizer.execute(query.analyze)
val expected = query.analyze
comparePlans(optimized, expected)
}
}

View file

@ -39,7 +39,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper {
}
object Optimizer extends RuleExecutor[LogicalPlan] {
val batches = Batch("Json optimization", FixedPoint(10), OptimizeJsonExprs) :: Nil
val batches = Batch("Json optimization", FixedPoint(10), OptimizeCsvJsonExprs) :: Nil
}
val schema = StructType.fromDDL("a int, b int")

View file

@ -250,4 +250,52 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession {
| """.stripMargin)
checkAnswer(toDF("yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), toDF("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]"))
}
test("SPARK-32968: Pruning csv field should not change result") {
Seq("true", "false").foreach { enabled =>
withSQLConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION.key -> enabled) {
val df1 = sparkContext.parallelize(Seq("a,b")).toDF("csv")
.selectExpr("from_csv(csv, 'a string, b string', map('mode', 'failfast')) as parsed")
checkAnswer(df1.selectExpr("parsed.a"), Seq(Row("a")))
checkAnswer(df1.selectExpr("parsed.b"), Seq(Row("b")))
val df2 = sparkContext.parallelize(Seq("a,b")).toDF("csv")
.selectExpr("from_csv(csv, 'a string, b string') as parsed")
checkAnswer(df2.selectExpr("parsed.a"), Seq(Row("a")))
checkAnswer(df2.selectExpr("parsed.b"), Seq(Row("b")))
}
}
}
test("SPARK-32968: bad csv input with csv pruning optimization") {
Seq("true", "false").foreach { enabled =>
withSQLConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION.key -> enabled) {
val df = sparkContext.parallelize(Seq("1,\u0001\u0000\u0001234")).toDF("csv")
.selectExpr("from_csv(csv, 'a int, b int', map('mode', 'failfast')) as parsed")
val err1 = intercept[SparkException] {
df.selectExpr("parsed.a").collect
}
val err2 = intercept[SparkException] {
df.selectExpr("parsed.b").collect
}
assert(err1.getMessage.contains("Malformed records are detected in record parsing"))
assert(err2.getMessage.contains("Malformed records are detected in record parsing"))
}
}
}
test("SPARK-32968: csv pruning optimization with corrupt record field") {
Seq("true", "false").foreach { enabled =>
withSQLConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION.key -> enabled) {
val df = sparkContext.parallelize(Seq("a,b,c,d")).toDF("csv")
.selectExpr("from_csv(csv, 'a string, b string, _corrupt_record string') as parsed")
.selectExpr("parsed._corrupt_record")
checkAnswer(df, Seq(Row("a,b,c,d")))
}
}
}
}