[SPARK-20121][SQL] simplify NullPropagation with NullIntolerant
## What changes were proposed in this pull request? Instead of iterating all expressions that can return null for null inputs, we can just check `NullIntolerant`. ## How was this patch tested? existing tests Author: Wenchen Fan <wenchen@databricks.com> Closes #17450 from cloud-fan/null.
This commit is contained in:
parent
5e00a5de14
commit
c734fc504a
|
@ -113,7 +113,7 @@ case class Abs(child: Expression)
|
|||
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
|
||||
}
|
||||
|
||||
abstract class BinaryArithmetic extends BinaryOperator {
|
||||
abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant {
|
||||
|
||||
override def dataType: DataType = left.dataType
|
||||
|
||||
|
@ -146,7 +146,7 @@ object BinaryArithmetic {
|
|||
> SELECT 1 _FUNC_ 2;
|
||||
3
|
||||
""")
|
||||
case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
|
||||
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
|
||||
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
|
||||
|
||||
|
@ -182,8 +182,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit
|
|||
> SELECT 2 _FUNC_ 1;
|
||||
1
|
||||
""")
|
||||
case class Subtract(left: Expression, right: Expression)
|
||||
extends BinaryArithmetic with NullIntolerant {
|
||||
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
|
||||
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
|
||||
|
||||
|
@ -219,8 +218,7 @@ case class Subtract(left: Expression, right: Expression)
|
|||
> SELECT 2 _FUNC_ 3;
|
||||
6
|
||||
""")
|
||||
case class Multiply(left: Expression, right: Expression)
|
||||
extends BinaryArithmetic with NullIntolerant {
|
||||
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
|
||||
override def inputType: AbstractDataType = NumericType
|
||||
|
||||
|
@ -243,8 +241,7 @@ case class Multiply(left: Expression, right: Expression)
|
|||
1.0
|
||||
""")
|
||||
// scalastyle:on line.size.limit
|
||||
case class Divide(left: Expression, right: Expression)
|
||||
extends BinaryArithmetic with NullIntolerant {
|
||||
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
|
||||
override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
|
||||
|
||||
|
@ -324,8 +321,7 @@ case class Divide(left: Expression, right: Expression)
|
|||
> SELECT 2 _FUNC_ 1.8;
|
||||
0.2
|
||||
""")
|
||||
case class Remainder(left: Expression, right: Expression)
|
||||
extends BinaryArithmetic with NullIntolerant {
|
||||
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
|
||||
override def inputType: AbstractDataType = NumericType
|
||||
|
||||
|
@ -412,7 +408,7 @@ case class Remainder(left: Expression, right: Expression)
|
|||
> SELECT _FUNC_(-10, 3);
|
||||
2
|
||||
""")
|
||||
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
|
||||
case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
|
||||
|
||||
override def toString: String = s"pmod($left, $right)"
|
||||
|
||||
|
|
|
@ -104,7 +104,7 @@ trait ExtractValue extends Expression
|
|||
* For example, when get field `yEAr` from `<year: int, month: int>`, we should pass in `yEAr`.
|
||||
*/
|
||||
case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None)
|
||||
extends UnaryExpression with ExtractValue {
|
||||
extends UnaryExpression with ExtractValue with NullIntolerant {
|
||||
|
||||
lazy val childSchema = child.dataType.asInstanceOf[StructType]
|
||||
|
||||
|
@ -152,7 +152,7 @@ case class GetArrayStructFields(
|
|||
field: StructField,
|
||||
ordinal: Int,
|
||||
numFields: Int,
|
||||
containsNull: Boolean) extends UnaryExpression with ExtractValue {
|
||||
containsNull: Boolean) extends UnaryExpression with ExtractValue with NullIntolerant {
|
||||
|
||||
override def dataType: DataType = ArrayType(field.dataType, containsNull)
|
||||
override def toString: String = s"$child.${field.name}"
|
||||
|
@ -213,7 +213,7 @@ case class GetArrayStructFields(
|
|||
* We need to do type checking here as `ordinal` expression maybe unresolved.
|
||||
*/
|
||||
case class GetArrayItem(child: Expression, ordinal: Expression)
|
||||
extends BinaryExpression with ExpectsInputTypes with ExtractValue {
|
||||
extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant {
|
||||
|
||||
// We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`.
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType)
|
||||
|
@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
|
|||
* We need to do type checking here as `key` expression maybe unresolved.
|
||||
*/
|
||||
case class GetMapValue(child: Expression, key: Expression)
|
||||
extends BinaryExpression with ImplicitCastInputTypes with ExtractValue {
|
||||
extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant {
|
||||
|
||||
private def keyType = child.dataType.asInstanceOf[MapType].keyType
|
||||
|
||||
|
|
|
@ -138,5 +138,5 @@ package object expressions {
|
|||
* input will result in null output). We will use this information during constructing IsNotNull
|
||||
* constraints.
|
||||
*/
|
||||
trait NullIntolerant
|
||||
trait NullIntolerant extends Expression
|
||||
}
|
||||
|
|
|
@ -27,8 +27,8 @@ import org.apache.spark.sql.types._
|
|||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
||||
trait StringRegexExpression extends ImplicitCastInputTypes {
|
||||
self: BinaryExpression =>
|
||||
abstract class StringRegexExpression extends BinaryExpression
|
||||
with ImplicitCastInputTypes with NullIntolerant {
|
||||
|
||||
def escape(v: String): String
|
||||
def matches(regex: Pattern, str: String): Boolean
|
||||
|
@ -69,8 +69,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes {
|
|||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "str _FUNC_ pattern - Returns true if `str` matches `pattern`, or false otherwise.")
|
||||
case class Like(left: Expression, right: Expression)
|
||||
extends BinaryExpression with StringRegexExpression {
|
||||
case class Like(left: Expression, right: Expression) extends StringRegexExpression {
|
||||
|
||||
override def escape(v: String): String = StringUtils.escapeLikeRegex(v)
|
||||
|
||||
|
@ -122,8 +121,7 @@ case class Like(left: Expression, right: Expression)
|
|||
|
||||
@ExpressionDescription(
|
||||
usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.")
|
||||
case class RLike(left: Expression, right: Expression)
|
||||
extends BinaryExpression with StringRegexExpression {
|
||||
case class RLike(left: Expression, right: Expression) extends StringRegexExpression {
|
||||
|
||||
override def escape(v: String): String = v
|
||||
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
|
||||
|
|
|
@ -297,8 +297,8 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx
|
|||
}
|
||||
|
||||
/** A base trait for functions that compare two strings, returning a boolean. */
|
||||
trait StringPredicate extends Predicate with ImplicitCastInputTypes {
|
||||
self: BinaryExpression =>
|
||||
abstract class StringPredicate extends BinaryExpression
|
||||
with Predicate with ImplicitCastInputTypes with NullIntolerant {
|
||||
|
||||
def compare(l: UTF8String, r: UTF8String): Boolean
|
||||
|
||||
|
@ -313,8 +313,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes {
|
|||
/**
|
||||
* A function that returns true if the string `left` contains the string `right`.
|
||||
*/
|
||||
case class Contains(left: Expression, right: Expression)
|
||||
extends BinaryExpression with StringPredicate {
|
||||
case class Contains(left: Expression, right: Expression) extends StringPredicate {
|
||||
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)")
|
||||
|
@ -324,8 +323,7 @@ case class Contains(left: Expression, right: Expression)
|
|||
/**
|
||||
* A function that returns true if the string `left` starts with the string `right`.
|
||||
*/
|
||||
case class StartsWith(left: Expression, right: Expression)
|
||||
extends BinaryExpression with StringPredicate {
|
||||
case class StartsWith(left: Expression, right: Expression) extends StringPredicate {
|
||||
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)")
|
||||
|
@ -335,8 +333,7 @@ case class StartsWith(left: Expression, right: Expression)
|
|||
/**
|
||||
* A function that returns true if the string `left` ends with the string `right`.
|
||||
*/
|
||||
case class EndsWith(left: Expression, right: Expression)
|
||||
extends BinaryExpression with StringPredicate {
|
||||
case class EndsWith(left: Expression, right: Expression) extends StringPredicate {
|
||||
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)")
|
||||
|
@ -1122,7 +1119,7 @@ case class StringSpace(child: Expression)
|
|||
""")
|
||||
// scalastyle:on line.size.limit
|
||||
case class Substring(str: Expression, pos: Expression, len: Expression)
|
||||
extends TernaryExpression with ImplicitCastInputTypes {
|
||||
extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant {
|
||||
|
||||
def this(str: Expression, pos: Expression) = {
|
||||
this(str, pos, Literal(Integer.MAX_VALUE))
|
||||
|
|
|
@ -347,35 +347,30 @@ object LikeSimplification extends Rule[LogicalPlan] {
|
|||
* Null value propagation from bottom to top of the expression tree.
|
||||
*/
|
||||
case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] {
|
||||
private def nonNullLiteral(e: Expression): Boolean = e match {
|
||||
case Literal(null, _) => false
|
||||
case _ => true
|
||||
private def isNullLiteral(e: Expression): Boolean = e match {
|
||||
case Literal(null, _) => true
|
||||
case _ => false
|
||||
}
|
||||
|
||||
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
|
||||
case q: LogicalPlan => q transformExpressionsUp {
|
||||
case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) =>
|
||||
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
|
||||
case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) =>
|
||||
case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) =>
|
||||
Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone))
|
||||
case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
|
||||
case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
|
||||
case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType)
|
||||
case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) =>
|
||||
Literal.create(null, e.dataType)
|
||||
case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r)
|
||||
case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l)
|
||||
case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) =>
|
||||
// This rule should be only triggered when isDistinct field is false.
|
||||
ae.copy(aggregateFunction = Count(Literal(1)))
|
||||
|
||||
case IsNull(c) if !c.nullable => Literal.create(false, BooleanType)
|
||||
case IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType)
|
||||
|
||||
case EqualNullSafe(Literal(null, _), r) => IsNull(r)
|
||||
case EqualNullSafe(l, Literal(null, _)) => IsNull(l)
|
||||
|
||||
// For Coalesce, remove null literals.
|
||||
case e @ Coalesce(children) =>
|
||||
val newChildren = children.filter(nonNullLiteral)
|
||||
val newChildren = children.filterNot(isNullLiteral)
|
||||
if (newChildren.isEmpty) {
|
||||
Literal.create(null, e.dataType)
|
||||
} else if (newChildren.length == 1) {
|
||||
|
@ -384,33 +379,13 @@ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] {
|
|||
Coalesce(newChildren)
|
||||
}
|
||||
|
||||
case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType)
|
||||
case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
|
||||
// Put exceptional cases above if any
|
||||
case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
|
||||
case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType)
|
||||
case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType)
|
||||
|
||||
case e: StringRegexExpression => e.children match {
|
||||
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
|
||||
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
|
||||
case _ => e
|
||||
}
|
||||
|
||||
case e: StringPredicate => e.children match {
|
||||
case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType)
|
||||
case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType)
|
||||
case _ => e
|
||||
}
|
||||
|
||||
// If the value expression is NULL then transform the In expression to
|
||||
// Literal(null)
|
||||
case In(Literal(null, _), list) => Literal.create(null, BooleanType)
|
||||
// If the value expression is NULL then transform the In expression to null literal.
|
||||
case In(Literal(null, _), _) => Literal.create(null, BooleanType)
|
||||
|
||||
// Non-leaf NullIntolerant expressions will return null, if at least one of its children is
|
||||
// a null literal.
|
||||
case e: NullIntolerant if e.children.exists(isNullLiteral) =>
|
||||
Literal.create(null, e.dataType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue