[SPARK-23835][SQL] Add not-null check to Tuples' arguments deserialization
## What changes were proposed in this pull request? There was no check on nullability for arguments of `Tuple`s. This could lead to have weird behavior when a null value had to be deserialized into a non-nullable Scala object: in those cases, the `null` got silently transformed in a valid value (like `-1` for `Int`), corresponding to the default value we are using in the SQL codebase. This situation was very likely to happen when deserializing to a Tuple of primitive Scala types (like Double, Int, ...). The PR adds the `AssertNotNull` to arguments of tuples which have been asked to be converted to non-nullable types. ## How was this patch tested? added UT Author: Marco Gaido <marcogaido91@gmail.com> Closes #20976 from mgaido91/SPARK-23835.
This commit is contained in:
parent
30ffb53cad
commit
0a9172a05e
|
@ -79,7 +79,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
|
|||
val reader = createKafkaReader(topic)
|
||||
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
|
||||
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
|
||||
.as[(Int, Int)]
|
||||
.as[(Option[Int], Int)]
|
||||
.map(_._2)
|
||||
|
||||
try {
|
||||
|
@ -119,7 +119,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
|
|||
val reader = createKafkaReader(topic)
|
||||
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
|
||||
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
|
||||
.as[(Int, Int)]
|
||||
.as[(Option[Int], Int)]
|
||||
.map(_._2)
|
||||
|
||||
try {
|
||||
|
@ -167,7 +167,7 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest {
|
|||
val reader = createKafkaReader(topic)
|
||||
.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
|
||||
.selectExpr("CAST(key AS INT)", "CAST(value AS INT)")
|
||||
.as[(Int, Int)]
|
||||
.as[(Option[Int], Int)]
|
||||
.map(_._2)
|
||||
|
||||
try {
|
||||
|
|
|
@ -138,7 +138,7 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext {
|
|||
val reader = createKafkaReader(topic)
|
||||
.selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value")
|
||||
.selectExpr("CAST(key as INT) key", "CAST(value as INT) value")
|
||||
.as[(Int, Int)]
|
||||
.as[(Option[Int], Int)]
|
||||
.map(_._2)
|
||||
|
||||
try {
|
||||
|
|
|
@ -382,22 +382,22 @@ object ScalaReflection extends ScalaReflection {
|
|||
val clsName = getClassNameFromType(fieldType)
|
||||
val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
|
||||
// For tuples, we based grab the inner fields by ordinal instead of name.
|
||||
if (cls.getName startsWith "scala.Tuple") {
|
||||
val constructor = if (cls.getName startsWith "scala.Tuple") {
|
||||
deserializerFor(
|
||||
fieldType,
|
||||
Some(addToPathOrdinal(i, dataType, newTypePath)),
|
||||
newTypePath)
|
||||
} else {
|
||||
val constructor = deserializerFor(
|
||||
deserializerFor(
|
||||
fieldType,
|
||||
Some(addToPath(fieldName, dataType, newTypePath)),
|
||||
newTypePath)
|
||||
}
|
||||
|
||||
if (!nullable) {
|
||||
AssertNotNull(constructor, newTypePath)
|
||||
} else {
|
||||
constructor
|
||||
}
|
||||
if (!nullable) {
|
||||
AssertNotNull(constructor, newTypePath)
|
||||
} else {
|
||||
constructor
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp}
|
|||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Literal, SpecificInternalRow, UpCast}
|
||||
import org.apache.spark.sql.catalyst.expressions.{BoundReference, Expression, Literal, SpecificInternalRow, UpCast}
|
||||
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
@ -365,4 +365,14 @@ class ScalaReflectionSuite extends SparkFunSuite {
|
|||
StructField("_2", NullType, nullable = true))),
|
||||
nullable = true))
|
||||
}
|
||||
|
||||
test("SPARK-23835: add null check to non-nullable types in Tuples") {
|
||||
def numberOfCheckedArguments(deserializer: Expression): Int = {
|
||||
assert(deserializer.isInstanceOf[NewInstance])
|
||||
deserializer.asInstanceOf[NewInstance].arguments.count(_.isInstanceOf[AssertNotNull])
|
||||
}
|
||||
assert(numberOfCheckedArguments(deserializerFor[(Double, Double)]) == 2)
|
||||
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Double, Int)]) == 1)
|
||||
assert(numberOfCheckedArguments(deserializerFor[(java.lang.Integer, java.lang.Integer)]) == 0)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1453,6 +1453,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
val group2 = cached.groupBy("x").agg(min(col("z")) as "value")
|
||||
checkAnswer(group1.union(group2), Row(4, 5) :: Row(1, 2) :: Row(4, 6) :: Row(1, 3) :: Nil)
|
||||
}
|
||||
|
||||
test("SPARK-23835: null primitive data type should throw NullPointerException") {
|
||||
val ds = Seq[(Option[Int], Option[Int])]((Some(1), None)).toDS()
|
||||
intercept[NullPointerException](ds.as[(Int, Int)].collect())
|
||||
}
|
||||
}
|
||||
|
||||
case class TestDataUnion(x: Int, y: Int, z: Int)
|
||||
|
|
Loading…
Reference in a new issue