[SPARK-19695][SQL] Throw an exception if a columnNameOfCorruptRecord
field violates requirements in json formats
## What changes were proposed in this pull request? This pr comes from #16928 and fixed a json behaviour along with the CSV one. ## How was this patch tested? Added tests in `JsonSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #17023 from maropu/SPARK-19695.
This commit is contained in:
parent
66c4b79afd
commit
769aa0f1d2
|
@ -58,7 +58,10 @@ class JacksonParser(
|
||||||
private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length))
|
private val emptyRow: Seq[InternalRow] = Seq(new GenericInternalRow(schema.length))
|
||||||
|
|
||||||
private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord)
|
private val corruptFieldIndex = schema.getFieldIndex(options.columnNameOfCorruptRecord)
|
||||||
corruptFieldIndex.foreach(idx => require(schema(idx).dataType == StringType))
|
corruptFieldIndex.foreach { corrFieldIndex =>
|
||||||
|
require(schema(corrFieldIndex).dataType == StringType)
|
||||||
|
require(schema(corrFieldIndex).nullable)
|
||||||
|
}
|
||||||
|
|
||||||
@transient
|
@transient
|
||||||
private[this] var isWarningPrinted: Boolean = false
|
private[this] var isWarningPrinted: Boolean = false
|
||||||
|
|
|
@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.DDLUtils
|
||||||
import org.apache.spark.sql.execution.datasources.DataSource
|
import org.apache.spark.sql.execution.datasources.DataSource
|
||||||
import org.apache.spark.sql.execution.datasources.jdbc._
|
import org.apache.spark.sql.execution.datasources.jdbc._
|
||||||
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
|
import org.apache.spark.sql.execution.datasources.json.JsonInferSchema
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.{StringType, StructType}
|
||||||
import org.apache.spark.unsafe.types.UTF8String
|
import org.apache.spark.unsafe.types.UTF8String
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -365,6 +365,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
|
||||||
createParser)
|
createParser)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check a field requirement for corrupt records here to throw an exception in a driver side
|
||||||
|
schema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
|
||||||
|
val f = schema(corruptFieldIndex)
|
||||||
|
if (f.dataType != StringType || !f.nullable) {
|
||||||
|
throw new AnalysisException(
|
||||||
|
"The field for corrupt records must be string type and nullable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
val parsed = jsonDataset.rdd.mapPartitions { iter =>
|
val parsed = jsonDataset.rdd.mapPartitions { iter =>
|
||||||
val parser = new JacksonParser(schema, parsedOptions)
|
val parser = new JacksonParser(schema, parsedOptions)
|
||||||
iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))
|
iter.flatMap(parser.parse(_, createParser, UTF8String.fromString))
|
||||||
|
|
|
@ -22,13 +22,13 @@ import org.apache.hadoop.fs.{FileStatus, Path}
|
||||||
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
|
import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext}
|
||||||
|
|
||||||
import org.apache.spark.internal.Logging
|
import org.apache.spark.internal.Logging
|
||||||
import org.apache.spark.sql.{Row, SparkSession}
|
import org.apache.spark.sql.{AnalysisException, SparkSession}
|
||||||
import org.apache.spark.sql.catalyst.InternalRow
|
import org.apache.spark.sql.catalyst.InternalRow
|
||||||
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions}
|
import org.apache.spark.sql.catalyst.json.{JacksonGenerator, JacksonParser, JSONOptions}
|
||||||
import org.apache.spark.sql.catalyst.util.CompressionCodecs
|
import org.apache.spark.sql.catalyst.util.CompressionCodecs
|
||||||
import org.apache.spark.sql.execution.datasources._
|
import org.apache.spark.sql.execution.datasources._
|
||||||
import org.apache.spark.sql.sources._
|
import org.apache.spark.sql.sources._
|
||||||
import org.apache.spark.sql.types.StructType
|
import org.apache.spark.sql.types.{StringType, StructType}
|
||||||
import org.apache.spark.util.SerializableConfiguration
|
import org.apache.spark.util.SerializableConfiguration
|
||||||
|
|
||||||
class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
|
class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
|
||||||
|
@ -102,6 +102,15 @@ class JsonFileFormat extends TextBasedFileFormat with DataSourceRegister {
|
||||||
sparkSession.sessionState.conf.sessionLocalTimeZone,
|
sparkSession.sessionState.conf.sessionLocalTimeZone,
|
||||||
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
|
sparkSession.sessionState.conf.columnNameOfCorruptRecord)
|
||||||
|
|
||||||
|
// Check a field requirement for corrupt records here to throw an exception in a driver side
|
||||||
|
dataSchema.getFieldIndex(parsedOptions.columnNameOfCorruptRecord).foreach { corruptFieldIndex =>
|
||||||
|
val f = dataSchema(corruptFieldIndex)
|
||||||
|
if (f.dataType != StringType || !f.nullable) {
|
||||||
|
throw new AnalysisException(
|
||||||
|
"The field for corrupt records must be string type and nullable")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
(file: PartitionedFile) => {
|
(file: PartitionedFile) => {
|
||||||
val parser = new JacksonParser(requiredSchema, parsedOptions)
|
val parser = new JacksonParser(requiredSchema, parsedOptions)
|
||||||
JsonDataSource(parsedOptions).readFile(
|
JsonDataSource(parsedOptions).readFile(
|
||||||
|
|
|
@ -1944,4 +1944,35 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
|
||||||
assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode"))
|
assert(exceptionTwo.getMessage.contains("Malformed line in FAILFAST mode"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("Throw an exception if a `columnNameOfCorruptRecord` field violates requirements") {
|
||||||
|
val columnNameOfCorruptRecord = "_unparsed"
|
||||||
|
val schema = StructType(
|
||||||
|
StructField(columnNameOfCorruptRecord, IntegerType, true) ::
|
||||||
|
StructField("a", StringType, true) ::
|
||||||
|
StructField("b", StringType, true) ::
|
||||||
|
StructField("c", StringType, true) :: Nil)
|
||||||
|
val errMsg = intercept[AnalysisException] {
|
||||||
|
spark.read
|
||||||
|
.option("mode", "PERMISSIVE")
|
||||||
|
.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
|
||||||
|
.schema(schema)
|
||||||
|
.json(corruptRecords)
|
||||||
|
}.getMessage
|
||||||
|
assert(errMsg.startsWith("The field for corrupt records must be string type and nullable"))
|
||||||
|
|
||||||
|
withTempPath { dir =>
|
||||||
|
val path = dir.getCanonicalPath
|
||||||
|
corruptRecords.toDF("value").write.text(path)
|
||||||
|
val errMsg = intercept[AnalysisException] {
|
||||||
|
spark.read
|
||||||
|
.option("mode", "PERMISSIVE")
|
||||||
|
.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
|
||||||
|
.schema(schema)
|
||||||
|
.json(path)
|
||||||
|
.collect
|
||||||
|
}.getMessage
|
||||||
|
assert(errMsg.startsWith("The field for corrupt records must be string type and nullable"))
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue