[SPARK-32968][SQL] Prune unnecessary columns from CsvToStructs
### What changes were proposed in this pull request? This patch proposes to do column pruning for CsvToStructs expression if we only require some fields from it. ### Why are the changes needed? `CsvToStructs` takes a schema parameter used to tell CSV Parser what fields are needed to parse. If `CsvToStructs` is followed by GetStructField. We can prune the schema to only parse certain field. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit test Closes #30912 from viirya/SPARK-32968. Lead-authored-by: Liang-Chi Hsieh <viirya@gmail.com> Co-authored-by: Hyukjin Kwon <gurwls223@gmail.com> Signed-off-by: HyukjinKwon <gurwls223@apache.org>
This commit is contained in:
parent
2627825647
commit
f9fe742442
|
@ -51,7 +51,8 @@ case class CsvToStructs(
|
||||||
schema: StructType,
|
schema: StructType,
|
||||||
options: Map[String, String],
|
options: Map[String, String],
|
||||||
child: Expression,
|
child: Expression,
|
||||||
timeZoneId: Option[String] = None)
|
timeZoneId: Option[String] = None,
|
||||||
|
requiredSchema: Option[StructType] = None)
|
||||||
extends UnaryExpression
|
extends UnaryExpression
|
||||||
with TimeZoneAwareExpression
|
with TimeZoneAwareExpression
|
||||||
with CodegenFallback
|
with CodegenFallback
|
||||||
|
@ -113,7 +114,12 @@ case class CsvToStructs(
|
||||||
|
|
||||||
val actualSchema =
|
val actualSchema =
|
||||||
StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
|
StructType(nullableSchema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
|
||||||
val rawParser = new UnivocityParser(actualSchema, actualSchema, parsedOptions)
|
val actualRequiredSchema =
|
||||||
|
StructType(requiredSchema.map(_.asNullable).getOrElse(nullableSchema)
|
||||||
|
.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
|
||||||
|
val rawParser = new UnivocityParser(actualSchema,
|
||||||
|
actualRequiredSchema,
|
||||||
|
parsedOptions)
|
||||||
new FailureSafeParser[String](
|
new FailureSafeParser[String](
|
||||||
input => rawParser.parse(input),
|
input => rawParser.parse(input),
|
||||||
mode,
|
mode,
|
||||||
|
@ -121,7 +127,7 @@ case class CsvToStructs(
|
||||||
parsedOptions.columnNameOfCorruptRecord)
|
parsedOptions.columnNameOfCorruptRecord)
|
||||||
}
|
}
|
||||||
|
|
||||||
override def dataType: DataType = nullableSchema
|
override def dataType: DataType = requiredSchema.getOrElse(schema).asNullable
|
||||||
|
|
||||||
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
|
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = {
|
||||||
copy(timeZoneId = Option(timeZoneId))
|
copy(timeZoneId = Option(timeZoneId))
|
||||||
|
|
|
@ -0,0 +1,120 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.catalyst.optimizer
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
|
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
||||||
|
import org.apache.spark.sql.catalyst.rules.Rule
|
||||||
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
|
import org.apache.spark.sql.types.{ArrayType, StructType}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simplify redundant csv/json related expressions.
|
||||||
|
*
|
||||||
|
* The optimization includes:
|
||||||
|
* 1. JsonToStructs(StructsToJson(child)) => child.
|
||||||
|
* 2. Prune unnecessary columns from GetStructField/GetArrayStructFields + JsonToStructs.
|
||||||
|
* 3. CreateNamedStruct(JsonToStructs(json).col1, JsonToStructs(json).col2, ...) =>
|
||||||
|
* If(IsNull(json), nullStruct, KnownNotNull(JsonToStructs(prunedSchema, ..., json)))
|
||||||
|
* if JsonToStructs(json) is shared among all fields of CreateNamedStruct. `prunedSchema`
|
||||||
|
* contains all accessed fields in original CreateNamedStruct.
|
||||||
|
* 4. Prune unnecessary columns from GetStructField + CsvToStructs.
|
||||||
|
*/
|
||||||
|
object OptimizeCsvJsonExprs extends Rule[LogicalPlan] {
|
||||||
|
private def nameOfCorruptRecord = SQLConf.get.getConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD)
|
||||||
|
|
||||||
|
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||||
|
case p =>
|
||||||
|
val optimized = if (SQLConf.get.jsonExpressionOptimization) {
|
||||||
|
p.transformExpressions(jsonOptimization)
|
||||||
|
} else {
|
||||||
|
p
|
||||||
|
}
|
||||||
|
|
||||||
|
if (SQLConf.get.csvExpressionOptimization) {
|
||||||
|
optimized.transformExpressions(csvOptimization)
|
||||||
|
} else {
|
||||||
|
optimized
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private val jsonOptimization: PartialFunction[Expression, Expression] = {
|
||||||
|
case c: CreateNamedStruct
|
||||||
|
// If we create struct from various fields of the same `JsonToStructs`.
|
||||||
|
if c.valExprs.forall { v =>
|
||||||
|
v.isInstanceOf[GetStructField] &&
|
||||||
|
v.asInstanceOf[GetStructField].child.isInstanceOf[JsonToStructs] &&
|
||||||
|
v.children.head.semanticEquals(c.valExprs.head.children.head)
|
||||||
|
} =>
|
||||||
|
val jsonToStructs = c.valExprs.map(_.children.head)
|
||||||
|
val sameFieldName = c.names.zip(c.valExprs).forall {
|
||||||
|
case (name, valExpr: GetStructField) =>
|
||||||
|
name.toString == valExpr.childSchema(valExpr.ordinal).name
|
||||||
|
case _ => false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Although `CreateNamedStruct` allows duplicated field names, e.g. "a int, a int",
|
||||||
|
// `JsonToStructs` does not support parsing json with duplicated field names.
|
||||||
|
val duplicateFields = c.names.map(_.toString).distinct.length != c.names.length
|
||||||
|
|
||||||
|
// If we create struct from various fields of the same `JsonToStructs` and we don't
|
||||||
|
// alias field names and there is no duplicated field in the struct.
|
||||||
|
if (sameFieldName && !duplicateFields) {
|
||||||
|
val fromJson = jsonToStructs.head.asInstanceOf[JsonToStructs].copy(schema = c.dataType)
|
||||||
|
val nullFields = c.children.grouped(2).flatMap {
|
||||||
|
case Seq(name, value) => Seq(name, Literal(null, value.dataType))
|
||||||
|
}.toSeq
|
||||||
|
|
||||||
|
If(IsNull(fromJson.child), c.copy(children = nullFields), KnownNotNull(fromJson))
|
||||||
|
} else {
|
||||||
|
c
|
||||||
|
}
|
||||||
|
|
||||||
|
case jsonToStructs @ JsonToStructs(_, options1,
|
||||||
|
StructsToJson(options2, child, timeZoneId2), timeZoneId1)
|
||||||
|
if options1.isEmpty && options2.isEmpty && timeZoneId1 == timeZoneId2 &&
|
||||||
|
jsonToStructs.dataType == child.dataType =>
|
||||||
|
// `StructsToJson` only fails when `JacksonGenerator` encounters data types it
|
||||||
|
// cannot convert to JSON. But `StructsToJson.checkInputDataTypes` already
|
||||||
|
// verifies its child's data types is convertible to JSON. But in
|
||||||
|
// `StructsToJson(JsonToStructs(...))` case, we cannot verify input json string
|
||||||
|
// so `JsonToStructs` might throw error in runtime. Thus we cannot optimize
|
||||||
|
// this case similarly.
|
||||||
|
child
|
||||||
|
|
||||||
|
case g @ GetStructField(j @ JsonToStructs(schema: StructType, _, _, _), ordinal, _)
|
||||||
|
if schema.length > 1 =>
|
||||||
|
val prunedSchema = StructType(Seq(schema(ordinal)))
|
||||||
|
g.copy(child = j.copy(schema = prunedSchema), ordinal = 0)
|
||||||
|
|
||||||
|
case g @ GetArrayStructFields(j @ JsonToStructs(schema: ArrayType, _, _, _), _, _, _, _)
|
||||||
|
if schema.elementType.asInstanceOf[StructType].length > 1 =>
|
||||||
|
val prunedSchema = ArrayType(StructType(Seq(g.field)), g.containsNull)
|
||||||
|
g.copy(child = j.copy(schema = prunedSchema), ordinal = 0, numFields = 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
private val csvOptimization: PartialFunction[Expression, Expression] = {
|
||||||
|
case g @ GetStructField(c @ CsvToStructs(schema: StructType, _, _, _, None), ordinal, _)
|
||||||
|
if schema.length > 1 && c.options.isEmpty && schema(ordinal).name != nameOfCorruptRecord =>
|
||||||
|
// When the parse mode is permissive, and corrupt column is not selected, we can prune here
|
||||||
|
// from `GetStructField`. To be more conservative, it does not optimize when any option
|
||||||
|
// is set.
|
||||||
|
val prunedSchema = StructType(Seq(schema(ordinal)))
|
||||||
|
g.copy(child = c.copy(requiredSchema = Some(prunedSchema)), ordinal = 0)
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,96 +0,0 @@
|
||||||
/*
|
|
||||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
|
||||||
* contributor license agreements. See the NOTICE file distributed with
|
|
||||||
* this work for additional information regarding copyright ownership.
|
|
||||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
|
||||||
* (the "License"); you may not use this file except in compliance with
|
|
||||||
* the License. You may obtain a copy of the License at
|
|
||||||
*
|
|
||||||
* http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
* See the License for the specific language governing permissions and
|
|
||||||
* limitations under the License.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.apache.spark.sql.catalyst.optimizer
|
|
||||||
|
|
||||||
import org.apache.spark.sql.catalyst.expressions._
|
|
||||||
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
|
|
||||||
import org.apache.spark.sql.catalyst.rules.Rule
|
|
||||||
import org.apache.spark.sql.internal.SQLConf
|
|
||||||
import org.apache.spark.sql.types.{ArrayType, StructType}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Simplify redundant json related expressions.
|
|
||||||
*
|
|
||||||
* The optimization includes:
|
|
||||||
* 1. JsonToStructs(StructsToJson(child)) => child.
|
|
||||||
* 2. Prune unnecessary columns from GetStructField/GetArrayStructFields + JsonToStructs.
|
|
||||||
* 3. CreateNamedStruct(JsonToStructs(json).col1, JsonToStructs(json).col2, ...) =>
|
|
||||||
* If(IsNull(json), nullStruct, KnownNotNull(JsonToStructs(prunedSchema, ..., json)))
|
|
||||||
* if JsonToStructs(json) is shared among all fields of CreateNamedStruct. `prunedSchema`
|
|
||||||
* contains all accessed fields in original CreateNamedStruct.
|
|
||||||
*/
|
|
||||||
object OptimizeJsonExprs extends Rule[LogicalPlan] {
|
|
||||||
override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
|
||||||
case p if SQLConf.get.jsonExpressionOptimization => p.transformExpressions {
|
|
||||||
|
|
||||||
case c: CreateNamedStruct
|
|
||||||
// If we create struct from various fields of the same `JsonToStructs`.
|
|
||||||
if c.valExprs.forall { v =>
|
|
||||||
v.isInstanceOf[GetStructField] &&
|
|
||||||
v.asInstanceOf[GetStructField].child.isInstanceOf[JsonToStructs] &&
|
|
||||||
v.children.head.semanticEquals(c.valExprs.head.children.head)
|
|
||||||
} =>
|
|
||||||
val jsonToStructs = c.valExprs.map(_.children.head)
|
|
||||||
val sameFieldName = c.names.zip(c.valExprs).forall {
|
|
||||||
case (name, valExpr: GetStructField) =>
|
|
||||||
name.toString == valExpr.childSchema(valExpr.ordinal).name
|
|
||||||
case _ => false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Although `CreateNamedStruct` allows duplicated field names, e.g. "a int, a int",
|
|
||||||
// `JsonToStructs` does not support parsing json with duplicated field names.
|
|
||||||
val duplicateFields = c.names.map(_.toString).distinct.length != c.names.length
|
|
||||||
|
|
||||||
// If we create struct from various fields of the same `JsonToStructs` and we don't
|
|
||||||
// alias field names and there is no duplicated field in the struct.
|
|
||||||
if (sameFieldName && !duplicateFields) {
|
|
||||||
val fromJson = jsonToStructs.head.asInstanceOf[JsonToStructs].copy(schema = c.dataType)
|
|
||||||
val nullFields = c.children.grouped(2).flatMap {
|
|
||||||
case Seq(name, value) => Seq(name, Literal(null, value.dataType))
|
|
||||||
}.toSeq
|
|
||||||
|
|
||||||
If(IsNull(fromJson.child), c.copy(children = nullFields), KnownNotNull(fromJson))
|
|
||||||
} else {
|
|
||||||
c
|
|
||||||
}
|
|
||||||
|
|
||||||
case jsonToStructs @ JsonToStructs(_, options1,
|
|
||||||
StructsToJson(options2, child, timeZoneId2), timeZoneId1)
|
|
||||||
if options1.isEmpty && options2.isEmpty && timeZoneId1 == timeZoneId2 &&
|
|
||||||
jsonToStructs.dataType == child.dataType =>
|
|
||||||
// `StructsToJson` only fails when `JacksonGenerator` encounters data types it
|
|
||||||
// cannot convert to JSON. But `StructsToJson.checkInputDataTypes` already
|
|
||||||
// verifies its child's data types is convertible to JSON. But in
|
|
||||||
// `StructsToJson(JsonToStructs(...))` case, we cannot verify input json string
|
|
||||||
// so `JsonToStructs` might throw error in runtime. Thus we cannot optimize
|
|
||||||
// this case similarly.
|
|
||||||
child
|
|
||||||
|
|
||||||
case g @ GetStructField(j @ JsonToStructs(schema: StructType, _, _, _), ordinal, _)
|
|
||||||
if schema.length > 1 =>
|
|
||||||
val prunedSchema = StructType(Seq(schema(ordinal)))
|
|
||||||
g.copy(child = j.copy(schema = prunedSchema), ordinal = 0)
|
|
||||||
|
|
||||||
case g @ GetArrayStructFields(j @ JsonToStructs(schema: ArrayType, _, _, _), _, _, _, _)
|
|
||||||
if schema.elementType.asInstanceOf[StructType].length > 1 =>
|
|
||||||
val prunedSchema = ArrayType(StructType(Seq(g.field)), g.containsNull)
|
|
||||||
g.copy(child = j.copy(schema = prunedSchema), ordinal = 0, numFields = 1)
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -114,7 +114,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
|
||||||
RemoveNoopOperators,
|
RemoveNoopOperators,
|
||||||
OptimizeUpdateFields,
|
OptimizeUpdateFields,
|
||||||
SimplifyExtractValueOps,
|
SimplifyExtractValueOps,
|
||||||
OptimizeJsonExprs,
|
OptimizeCsvJsonExprs,
|
||||||
CombineConcats) ++
|
CombineConcats) ++
|
||||||
extendedOperatorOptimizationRules
|
extendedOperatorOptimizationRules
|
||||||
|
|
||||||
|
|
|
@ -1631,6 +1631,14 @@ object SQLConf {
|
||||||
.booleanConf
|
.booleanConf
|
||||||
.createWithDefault(true)
|
.createWithDefault(true)
|
||||||
|
|
||||||
|
val CSV_EXPRESSION_OPTIMIZATION =
|
||||||
|
buildConf("spark.sql.optimizer.enableCsvExpressionOptimization")
|
||||||
|
.doc("Whether to optimize CSV expressions in SQL optimizer. It includes pruning " +
|
||||||
|
"unnecessary columns from from_csv.")
|
||||||
|
.version("3.2.0")
|
||||||
|
.booleanConf
|
||||||
|
.createWithDefault(true)
|
||||||
|
|
||||||
val FILE_SINK_LOG_DELETION = buildConf("spark.sql.streaming.fileSink.log.deletion")
|
val FILE_SINK_LOG_DELETION = buildConf("spark.sql.streaming.fileSink.log.deletion")
|
||||||
.internal()
|
.internal()
|
||||||
.doc("Whether to delete the expired log files in file stream sink.")
|
.doc("Whether to delete the expired log files in file stream sink.")
|
||||||
|
@ -3489,6 +3497,8 @@ class SQLConf extends Serializable with Logging {
|
||||||
|
|
||||||
def jsonExpressionOptimization: Boolean = getConf(SQLConf.JSON_EXPRESSION_OPTIMIZATION)
|
def jsonExpressionOptimization: Boolean = getConf(SQLConf.JSON_EXPRESSION_OPTIMIZATION)
|
||||||
|
|
||||||
|
def csvExpressionOptimization: Boolean = getConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION)
|
||||||
|
|
||||||
def parallelFileListingInStatsComputation: Boolean =
|
def parallelFileListingInStatsComputation: Boolean =
|
||||||
getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION)
|
getConf(SQLConf.PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,83 @@
|
||||||
|
/*
|
||||||
|
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||||
|
* contributor license agreements. See the NOTICE file distributed with
|
||||||
|
* this work for additional information regarding copyright ownership.
|
||||||
|
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||||
|
* (the "License"); you may not use this file except in compliance with
|
||||||
|
* the License. You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package org.apache.spark.sql.catalyst.optimizer
|
||||||
|
|
||||||
|
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||||
|
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||||
|
import org.apache.spark.sql.catalyst.expressions._
|
||||||
|
import org.apache.spark.sql.catalyst.plans.PlanTest
|
||||||
|
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
|
||||||
|
import org.apache.spark.sql.catalyst.rules.RuleExecutor
|
||||||
|
import org.apache.spark.sql.internal.SQLConf
|
||||||
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
|
class OptimizeCsvExprsSuite extends PlanTest with ExpressionEvalHelper {
|
||||||
|
|
||||||
|
private var csvExpressionOptimizeEnabled: Boolean = _
|
||||||
|
protected override def beforeAll(): Unit = {
|
||||||
|
csvExpressionOptimizeEnabled = SQLConf.get.csvExpressionOptimization
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override def afterAll(): Unit = {
|
||||||
|
SQLConf.get.setConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION, csvExpressionOptimizeEnabled)
|
||||||
|
}
|
||||||
|
|
||||||
|
object Optimizer extends RuleExecutor[LogicalPlan] {
|
||||||
|
val batches = Batch("Csv optimization", FixedPoint(10), OptimizeCsvJsonExprs) :: Nil
|
||||||
|
}
|
||||||
|
|
||||||
|
val schema = StructType.fromDDL("a int, b int")
|
||||||
|
|
||||||
|
private val csvAttr = 'csv.string
|
||||||
|
private val testRelation = LocalRelation(csvAttr)
|
||||||
|
|
||||||
|
test("SPARK-32968: prune unnecessary columns from GetStructField + from_csv") {
|
||||||
|
val options = Map.empty[String, String]
|
||||||
|
|
||||||
|
val query1 = testRelation
|
||||||
|
.select(GetStructField(CsvToStructs(schema, options, 'csv), 0))
|
||||||
|
val optimized1 = Optimizer.execute(query1.analyze)
|
||||||
|
|
||||||
|
val prunedSchema1 = StructType.fromDDL("a int")
|
||||||
|
val expected1 = testRelation
|
||||||
|
.select(GetStructField(CsvToStructs(schema, options, 'csv, None, Some(prunedSchema1)), 0))
|
||||||
|
.analyze
|
||||||
|
comparePlans(optimized1, expected1)
|
||||||
|
|
||||||
|
val query2 = testRelation
|
||||||
|
.select(GetStructField(CsvToStructs(schema, options, 'csv), 1))
|
||||||
|
val optimized2 = Optimizer.execute(query2.analyze)
|
||||||
|
|
||||||
|
val prunedSchema2 = StructType.fromDDL("b int")
|
||||||
|
val expected2 = testRelation
|
||||||
|
.select(GetStructField(CsvToStructs(schema, options, 'csv, None, Some(prunedSchema2)), 0))
|
||||||
|
.analyze
|
||||||
|
comparePlans(optimized2, expected2)
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-32968: don't prune columns if options is not empty") {
|
||||||
|
val options = Map("mode" -> "failfast")
|
||||||
|
|
||||||
|
val query = testRelation
|
||||||
|
.select(GetStructField(CsvToStructs(schema, options, 'csv), 0))
|
||||||
|
val optimized = Optimizer.execute(query.analyze)
|
||||||
|
|
||||||
|
val expected = query.analyze
|
||||||
|
comparePlans(optimized, expected)
|
||||||
|
}
|
||||||
|
}
|
|
@ -39,7 +39,7 @@ class OptimizeJsonExprsSuite extends PlanTest with ExpressionEvalHelper {
|
||||||
}
|
}
|
||||||
|
|
||||||
object Optimizer extends RuleExecutor[LogicalPlan] {
|
object Optimizer extends RuleExecutor[LogicalPlan] {
|
||||||
val batches = Batch("Json optimization", FixedPoint(10), OptimizeJsonExprs) :: Nil
|
val batches = Batch("Json optimization", FixedPoint(10), OptimizeCsvJsonExprs) :: Nil
|
||||||
}
|
}
|
||||||
|
|
||||||
val schema = StructType.fromDDL("a int, b int")
|
val schema = StructType.fromDDL("a int, b int")
|
||||||
|
|
|
@ -250,4 +250,52 @@ class CsvFunctionsSuite extends QueryTest with SharedSparkSession {
|
||||||
| """.stripMargin)
|
| """.stripMargin)
|
||||||
checkAnswer(toDF("yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), toDF("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]"))
|
checkAnswer(toDF("yyyy-MM-dd'T'HH:mm:ss.SSSXXX"), toDF("yyyy-MM-dd'T'HH:mm:ss[.SSS][XXX]"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-32968: Pruning csv field should not change result") {
|
||||||
|
Seq("true", "false").foreach { enabled =>
|
||||||
|
withSQLConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION.key -> enabled) {
|
||||||
|
val df1 = sparkContext.parallelize(Seq("a,b")).toDF("csv")
|
||||||
|
.selectExpr("from_csv(csv, 'a string, b string', map('mode', 'failfast')) as parsed")
|
||||||
|
checkAnswer(df1.selectExpr("parsed.a"), Seq(Row("a")))
|
||||||
|
checkAnswer(df1.selectExpr("parsed.b"), Seq(Row("b")))
|
||||||
|
|
||||||
|
val df2 = sparkContext.parallelize(Seq("a,b")).toDF("csv")
|
||||||
|
.selectExpr("from_csv(csv, 'a string, b string') as parsed")
|
||||||
|
checkAnswer(df2.selectExpr("parsed.a"), Seq(Row("a")))
|
||||||
|
checkAnswer(df2.selectExpr("parsed.b"), Seq(Row("b")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-32968: bad csv input with csv pruning optimization") {
|
||||||
|
Seq("true", "false").foreach { enabled =>
|
||||||
|
withSQLConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION.key -> enabled) {
|
||||||
|
val df = sparkContext.parallelize(Seq("1,\u0001\u0000\u0001234")).toDF("csv")
|
||||||
|
.selectExpr("from_csv(csv, 'a int, b int', map('mode', 'failfast')) as parsed")
|
||||||
|
|
||||||
|
val err1 = intercept[SparkException] {
|
||||||
|
df.selectExpr("parsed.a").collect
|
||||||
|
}
|
||||||
|
|
||||||
|
val err2 = intercept[SparkException] {
|
||||||
|
df.selectExpr("parsed.b").collect
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(err1.getMessage.contains("Malformed records are detected in record parsing"))
|
||||||
|
assert(err2.getMessage.contains("Malformed records are detected in record parsing"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
test("SPARK-32968: csv pruning optimization with corrupt record field") {
|
||||||
|
Seq("true", "false").foreach { enabled =>
|
||||||
|
withSQLConf(SQLConf.CSV_EXPRESSION_OPTIMIZATION.key -> enabled) {
|
||||||
|
val df = sparkContext.parallelize(Seq("a,b,c,d")).toDF("csv")
|
||||||
|
.selectExpr("from_csv(csv, 'a string, b string, _corrupt_record string') as parsed")
|
||||||
|
.selectExpr("parsed._corrupt_record")
|
||||||
|
|
||||||
|
checkAnswer(df, Seq(Row("a,b,c,d")))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue