diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 4870093e92..f2b252259b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -113,7 +113,7 @@ case class Abs(child: Expression) protected override def nullSafeEval(input: Any): Any = numeric.abs(input) } -abstract class BinaryArithmetic extends BinaryOperator { +abstract class BinaryArithmetic extends BinaryOperator with NullIntolerant { override def dataType: DataType = left.dataType @@ -146,7 +146,7 @@ object BinaryArithmetic { > SELECT 1 _FUNC_ 2; 3 """) -case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { +case class Add(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -182,8 +182,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit > SELECT 2 _FUNC_ 1; 1 """) -case class Subtract(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -219,8 +218,7 @@ case class Subtract(left: Expression, right: Expression) > SELECT 2 _FUNC_ 3; 6 """) -case class Multiply(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = NumericType @@ -243,8 +241,7 @@ case class Multiply(left: Expression, right: Expression) 1.0 """) // scalastyle:on line.size.limit -case class Divide(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) @@ -324,8 +321,7 @@ case class Divide(left: Expression, right: Expression) > SELECT 2 _FUNC_ 1.8; 0.2 """) -case class Remainder(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { +case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = NumericType @@ -412,7 +408,7 @@ case class Remainder(left: Expression, right: Expression) > SELECT _FUNC_(-10, 3); 2 """) -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 0c256c3d89..de1594d119 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -104,7 +104,7 @@ trait ExtractValue extends Expression * For example, when get field `yEAr` from ``, we should pass in `yEAr`. */ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] = None) - extends UnaryExpression with ExtractValue { + extends UnaryExpression with ExtractValue with NullIntolerant { lazy val childSchema = child.dataType.asInstanceOf[StructType] @@ -152,7 +152,7 @@ case class GetArrayStructFields( field: StructField, ordinal: Int, numFields: Int, - containsNull: Boolean) extends UnaryExpression with ExtractValue { + containsNull: Boolean) extends UnaryExpression with ExtractValue with NullIntolerant { override def dataType: DataType = ArrayType(field.dataType, containsNull) override def toString: String = s"$child.${field.name}" @@ -213,7 +213,7 @@ case class GetArrayStructFields( * We need to do type checking here as `ordinal` expression maybe unresolved. */ case class GetArrayItem(child: Expression, ordinal: Expression) - extends BinaryExpression with ExpectsInputTypes with ExtractValue { + extends BinaryExpression with ExpectsInputTypes with ExtractValue with NullIntolerant { // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) @@ -260,7 +260,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) * We need to do type checking here as `key` expression maybe unresolved. */ case class GetMapValue(child: Expression, key: Expression) - extends BinaryExpression with ImplicitCastInputTypes with ExtractValue { + extends BinaryExpression with ImplicitCastInputTypes with ExtractValue with NullIntolerant { private def keyType = child.dataType.asInstanceOf[MapType].keyType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 1b00c9e79d..4c8b177237 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -138,5 +138,5 @@ package object expressions { * input will result in null output). We will use this information during constructing IsNotNull * constraints. */ - trait NullIntolerant + trait NullIntolerant extends Expression } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 4896a6225a..b23da537be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends ImplicitCastInputTypes { - self: BinaryExpression => +abstract class StringRegexExpression extends BinaryExpression + with ImplicitCastInputTypes with NullIntolerant { def escape(v: String): String def matches(regex: Pattern, str: String): Boolean @@ -69,8 +69,7 @@ trait StringRegexExpression extends ImplicitCastInputTypes { */ @ExpressionDescription( usage = "str _FUNC_ pattern - Returns true if `str` matches `pattern`, or false otherwise.") -case class Like(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { +case class Like(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = StringUtils.escapeLikeRegex(v) @@ -122,8 +121,7 @@ case class Like(left: Expression, right: Expression) @ExpressionDescription( usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.") -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { +case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 908aa44f81..5598a14699 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -297,8 +297,8 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx } /** A base trait for functions that compare two strings, returning a boolean. */ -trait StringPredicate extends Predicate with ImplicitCastInputTypes { - self: BinaryExpression => +abstract class StringPredicate extends BinaryExpression + with Predicate with ImplicitCastInputTypes with NullIntolerant { def compare(l: UTF8String, r: UTF8String): Boolean @@ -313,8 +313,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { /** * A function that returns true if the string `left` contains the string `right`. */ -case class Contains(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class Contains(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") @@ -324,8 +323,7 @@ case class Contains(left: Expression, right: Expression) /** * A function that returns true if the string `left` starts with the string `right`. */ -case class StartsWith(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class StartsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") @@ -335,8 +333,7 @@ case class StartsWith(left: Expression, right: Expression) /** * A function that returns true if the string `left` ends with the string `right`. */ -case class EndsWith(left: Expression, right: Expression) - extends BinaryExpression with StringPredicate { +case class EndsWith(left: Expression, right: Expression) extends StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") @@ -1122,7 +1119,7 @@ case class StringSpace(child: Expression) """) // scalastyle:on line.size.limit case class Substring(str: Expression, pos: Expression, len: Expression) - extends TernaryExpression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes with NullIntolerant { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 21d1cd5932..33039127f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -347,35 +347,30 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { - private def nonNullLiteral(e: Expression): Boolean = e match { - case Literal(null, _) => false - case _ => true + private def isNullLiteral(e: Expression): Boolean = e match { + case Literal(null, _) => true + case _ => false } def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case e @ WindowExpression(Cast(Literal(0L, _), _, _), _) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) - case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _) if exprs.forall(isNullLiteral) => Cast(Literal(0L), e.dataType, Option(conf.sessionLocalTimeZone)) - case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) - case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) - case e @ GetArrayItem(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetArrayItem(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ GetMapValue(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ GetMapValue(_, Literal(null, _)) => Literal.create(null, e.dataType) - case e @ GetStructField(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ GetArrayStructFields(Literal(null, _), _, _, _, _) => - Literal.create(null, e.dataType) - case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) - case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. ae.copy(aggregateFunction = Count(Literal(1))) + case IsNull(c) if !c.nullable => Literal.create(false, BooleanType) + case IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) + + case EqualNullSafe(Literal(null, _), r) => IsNull(r) + case EqualNullSafe(l, Literal(null, _)) => IsNull(l) + // For Coalesce, remove null literals. case e @ Coalesce(children) => - val newChildren = children.filter(nonNullLiteral) + val newChildren = children.filterNot(isNullLiteral) if (newChildren.isEmpty) { Literal.create(null, e.dataType) } else if (newChildren.length == 1) { @@ -384,33 +379,13 @@ case class NullPropagation(conf: CatalystConf) extends Rule[LogicalPlan] { Coalesce(newChildren) } - case e @ Substring(Literal(null, _), _, _) => Literal.create(null, e.dataType) - case e @ Substring(_, Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ Substring(_, _, Literal(null, _)) => Literal.create(null, e.dataType) - - // Put exceptional cases above if any - case e @ BinaryArithmetic(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryArithmetic(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e @ BinaryComparison(Literal(null, _), _) => Literal.create(null, e.dataType) - case e @ BinaryComparison(_, Literal(null, _)) => Literal.create(null, e.dataType) - - case e: StringRegexExpression => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - - case e: StringPredicate => e.children match { - case Literal(null, _) :: right :: Nil => Literal.create(null, e.dataType) - case left :: Literal(null, _) :: Nil => Literal.create(null, e.dataType) - case _ => e - } - - // If the value expression is NULL then transform the In expression to - // Literal(null) - case In(Literal(null, _), list) => Literal.create(null, BooleanType) + // If the value expression is NULL then transform the In expression to null literal. + case In(Literal(null, _), _) => Literal.create(null, BooleanType) + // Non-leaf NullIntolerant expressions will return null, if at least one of its children is + // a null literal. + case e: NullIntolerant if e.children.exists(isNullLiteral) => + Literal.create(null, e.dataType) } } }