[SPARK-12867][SQL] Nullability of Intersect can be stricter
JIRA: https://issues.apache.org/jira/browse/SPARK-12867 When intersecting one nullable column with one non-nullable column, the result will not contain any null. Thus, we can make nullability of `intersect` stricter. liancheng Could you please check if the code changes are appropriate? Also added test cases to verify the results. Thanks! Author: gatorsmile <gatorsmile@gmail.com> Closes #10812 from gatorsmile/nullabilityIntersect.
This commit is contained in:
parent
2388de5191
commit
b72e01e821
|
@ -91,11 +91,6 @@ case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
|
|||
}
|
||||
|
||||
abstract class SetOperation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode {
|
||||
override def output: Seq[Attribute] =
|
||||
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
|
||||
leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
|
||||
}
|
||||
|
||||
final override lazy val resolved: Boolean =
|
||||
childrenResolved &&
|
||||
left.output.length == right.output.length &&
|
||||
|
@ -108,13 +103,24 @@ private[sql] object SetOperation {
|
|||
|
||||
case class Union(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
|
||||
|
||||
override def output: Seq[Attribute] =
|
||||
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
|
||||
leftAttr.withNullability(leftAttr.nullable || rightAttr.nullable)
|
||||
}
|
||||
|
||||
override def statistics: Statistics = {
|
||||
val sizeInBytes = left.statistics.sizeInBytes + right.statistics.sizeInBytes
|
||||
Statistics(sizeInBytes = sizeInBytes)
|
||||
}
|
||||
}
|
||||
|
||||
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right)
|
||||
case class Intersect(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
|
||||
|
||||
override def output: Seq[Attribute] =
|
||||
left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
|
||||
leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
|
||||
}
|
||||
}
|
||||
|
||||
case class Except(left: LogicalPlan, right: LogicalPlan) extends SetOperation(left, right) {
|
||||
/** We don't use right.output because those rows get excluded from the set. */
|
||||
|
|
|
@ -337,6 +337,27 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
|
|||
checkAnswer(lowerCaseData.intersect(upperCaseData), Nil)
|
||||
}
|
||||
|
||||
test("intersect - nullability") {
|
||||
val nonNullableInts = Seq(Tuple1(1), Tuple1(3)).toDF()
|
||||
assert(nonNullableInts.schema.forall(_.nullable == false))
|
||||
|
||||
val df1 = nonNullableInts.intersect(nullInts)
|
||||
checkAnswer(df1, Row(1) :: Row(3) :: Nil)
|
||||
assert(df1.schema.forall(_.nullable == false))
|
||||
|
||||
val df2 = nullInts.intersect(nonNullableInts)
|
||||
checkAnswer(df2, Row(1) :: Row(3) :: Nil)
|
||||
assert(df2.schema.forall(_.nullable == false))
|
||||
|
||||
val df3 = nullInts.intersect(nullInts)
|
||||
checkAnswer(df3, Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
|
||||
assert(df3.schema.forall(_.nullable == true))
|
||||
|
||||
val df4 = nonNullableInts.intersect(nonNullableInts)
|
||||
checkAnswer(df4, Row(1) :: Row(3) :: Nil)
|
||||
assert(df4.schema.forall(_.nullable == false))
|
||||
}
|
||||
|
||||
test("udf") {
|
||||
val foo = udf((a: Int, b: String) => a.toString + b)
|
||||
|
||||
|
|
Loading…
Reference in a new issue