[SPARK-25669][SQL] Check CSV header only when it exists

## What changes were proposed in this pull request?

Currently the first row of dataset of CSV strings is compared to field names of user specified or inferred schema independently of presence of CSV header. It causes false-positive error messages. For example, parsing `"1,2"` outputs the error:

```java
java.lang.IllegalArgumentException: CSV header does not conform to the schema.
 Header: 1, 2
 Schema: _c0, _c1
Expected: _c0 but found: 1
```

In the PR, I propose:
- Checking CSV header only when it exists
- Filter header from the input dataset only if it exists

## How was this patch tested?

Added a test to `CSVSuite` which reproduces the issue.

Closes #22656 from MaxGekk/inferred-header-check.

Authored-by: Maxim Gekk <maxim.gekk@databricks.com>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
This commit is contained in:
Maxim Gekk 2018-10-09 14:35:00 +08:00 committed by hyukjinkwon
parent a4b14a9cf8
commit 46fe40838a
2 changed files with 11 additions and 2 deletions

View file

@ -505,7 +505,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
val actualSchema =
StructType(schema.filterNot(_.name == parsedOptions.columnNameOfCorruptRecord))
val linesWithoutHeader: RDD[String] = maybeFirstLine.map { firstLine =>
val linesWithoutHeader = if (parsedOptions.headerFlag && maybeFirstLine.isDefined) {
val firstLine = maybeFirstLine.get
val parser = new CsvParser(parsedOptions.asParserSettings)
val columnNames = parser.parseLine(firstLine)
CSVDataSource.checkHeaderColumnNames(
@ -515,7 +516,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
parsedOptions.enforceSchema,
sparkSession.sessionState.conf.caseSensitiveAnalysis)
filteredLines.rdd.mapPartitions(CSVUtils.filterHeaderLine(_, firstLine, parsedOptions))
}.getOrElse(filteredLines.rdd)
} else {
filteredLines.rdd
}
val parsed = linesWithoutHeader.mapPartitions { iter =>
val rawParser = new UnivocityParser(actualSchema, parsedOptions)

View file

@ -1820,4 +1820,10 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils with Te
checkAnswer(spark.read.option("multiLine", true).schema(schema).csv(input), Row(null))
assert(spark.read.csv(input).collect().toSet == Set(Row()))
}
test("field names of inferred schema shouldn't compare to the first row") {
val input = Seq("1,2").toDS()
val df = spark.read.option("enforceSchema", false).csv(input)
checkAnswer(df, Row("1", "2"))
}
}