[SPARK-8075] [SQL] apply type check interface to more expressions

a follow up of https://github.com/apache/spark/pull/6405.
Note: It's not a big change, a lot of changing is due to I swap some code in `aggregates.scala` to make aggregate functions right below its corresponding aggregate expressions.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #6723 from cloud-fan/type-check and squashes the following commits:

2124301 [Wenchen Fan] fix tests
5a658bb [Wenchen Fan] add tests
287d3bb [Wenchen Fan] apply type check interface to more expressions
This commit is contained in:
Wenchen Fan 2015-06-24 16:26:00 -07:00 committed by Michael Armbrust
parent 7daa70292e
commit b71d3254e5
21 changed files with 365 additions and 318 deletions

View file

@ -587,8 +587,8 @@ class Analyzer(
failAnalysis(
s"""Expect multiple names given for ${g.getClass.getName},
|but only single name '${name}' specified""".stripMargin)
case Alias(g: Generator, name) => Some((g, name :: Nil))
case MultiAlias(g: Generator, names) => Some(g, names)
case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
case _ => None
}
}

View file

@ -317,6 +317,7 @@ trait HiveTypeCoercion {
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
}
}
@ -590,11 +591,12 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
case a @ CreateArray(children) if !a.resolved =>
val commonType = a.childTypes.reduce(
(a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
CreateArray(
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))
case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 =>
val types = children.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
case None => a
}
// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
@ -620,12 +622,11 @@ trait HiveTypeCoercion {
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
val types = es.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None =>
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
case None => c
}
}
}

View file

@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
import java.text.{DateFormat, SimpleDateFormat}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
@ -31,7 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
override def checkInputDataTypes(): TypeCheckResult = {
if (resolve(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"cannot cast ${child.dataType} to $dataType")
}
}
override def foldable: Boolean = child.foldable

View file

@ -162,9 +162,7 @@ abstract class Expression extends TreeNode[Expression] {
/**
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `childrenResolved == true`
* TODO: we should remove the default implementation and implement it for all
* expressions with proper error message.
* Note: it's not valid to call this method until `childrenResolved == true`.
*/
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
}

View file

@ -96,6 +96,11 @@ object ExtractValue {
}
}
/**
* A common interface of all kinds of extract value expressions.
* Note: concrete extract value expressions are created only by `ExtractValue.apply`,
* we don't need to do type check for them.
*/
trait ExtractValue extends UnaryExpression {
self: Product =>
}
@ -179,9 +184,6 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
override lazy val resolved = childrenResolved &&
child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]
protected def evalNotNull(value: Any, ordinal: Any) = {
// TODO: consider using Array[_] for ArrayType child to avoid
// boxing of primitives
@ -203,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression)
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]
protected def evalNotNull(value: Any, ordinal: Any) = {
val baseValue = value.asInstanceOf[Map[Any, _]]
baseValue.get(ordinal).orNull

View file

@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashSet
@ -101,6 +102,9 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[
}
override def newInstance(): MinFunction = new MinFunction(child, this)
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(child.dataType, "function min")
}
case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@ -132,6 +136,9 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[
}
override def newInstance(): MaxFunction = new MaxFunction(child, this)
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForOrderingExpr(child.dataType, "function max")
}
case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
@ -165,6 +172,21 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
override def newInstance(): CountFunction = new CountFunction(child, this)
}
case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
var count: Long = _
override def update(input: InternalRow): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
count += 1L
}
}
override def eval(input: InternalRow): Any = count
}
case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
def this() = this(null)
@ -183,6 +205,28 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
}
}
case class CountDistinctFunction(
@transient expr: Seq[Expression],
@transient base: AggregateExpression)
extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
val seen = new OpenHashSet[Any]()
@transient
val distinctValue = new InterpretedProjection(expr)
override def update(input: InternalRow): Unit = {
val evaluatedExpr = distinctValue(input)
if (!evaluatedExpr.anyNull) {
seen.add(evaluatedExpr)
}
}
override def eval(input: InternalRow): Any = seen.size.toLong
}
case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
def this() = this(null)
@ -278,6 +322,25 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
}
}
case class ApproxCountDistinctPartitionFunction(
expr: Expression,
base: AggregateExpression,
relativeSD: Double)
extends AggregateFunction {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
override def update(input: InternalRow): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
hyperLogLog.offer(evaluatedExpr)
}
}
override def eval(input: InternalRow): Any = hyperLogLog
}
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
extends AggregateExpression with trees.UnaryNode[Expression] {
@ -289,6 +352,23 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
}
}
case class ApproxCountDistinctMergeFunction(
expr: Expression,
base: AggregateExpression,
relativeSD: Double)
extends AggregateFunction {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
override def update(input: InternalRow): Unit = {
val evaluatedExpr = expr.eval(input)
hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
}
override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
}
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
extends PartialAggregate with trees.UnaryNode[Expression] {
@ -349,6 +429,56 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
}
override def newInstance(): AverageFunction = new AverageFunction(child, this)
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function average")
}
case class AverageFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}
private val zero = Cast(Literal(0), calcType)
private var count: Long = _
private val sum = MutableLiteral(zero.eval(null), calcType)
private def addFunction(value: Any) = Add(sum,
Cast(Literal.create(value, expr.dataType), calcType))
override def eval(input: InternalRow): Any = {
if (count == 0L) {
null
} else {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(Divide(
Cast(sum, DecimalType.Unlimited),
Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
case _ =>
Divide(
Cast(sum, dataType),
Cast(Literal(count), dataType)).eval(null)
}
}
}
override def update(input: InternalRow): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
count += 1
sum.update(addFunction(evaluatedExpr), input)
}
}
}
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
@ -383,6 +513,40 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
}
override def newInstance(): SumFunction = new SumFunction(child, this)
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function sum")
}
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}
private val zero = Cast(Literal(0), calcType)
private val sum = MutableLiteral(null, calcType)
private val addFunction =
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
override def update(input: InternalRow): Unit = {
sum.update(addFunction, input)
}
override def eval(input: InternalRow): Any = {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(sum, dataType).eval(null)
case _ => sum.eval(null)
}
}
}
/**
@ -409,6 +573,43 @@ case class CombineSum(child: Expression) extends AggregateExpression {
override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
}
case class CombineSumFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}
private val zero = Cast(Literal(0), calcType)
private val sum = MutableLiteral(null, calcType)
private val addFunction =
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
override def update(input: InternalRow): Unit = {
val result = expr.eval(input)
// partial sum result can be null only when no input rows present
if(result != null) {
sum.update(addFunction, input)
}
}
override def eval(input: InternalRow): Any = {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(sum, dataType).eval(null)
case _ => sum.eval(null)
}
}
}
case class SumDistinct(child: Expression)
extends PartialAggregate with trees.UnaryNode[Expression] {
@ -431,6 +632,35 @@ case class SumDistinct(child: Expression)
CombineSetsAndSum(partialSet.toAttribute, this),
partialSet :: Nil)
}
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
}
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
private val seen = new scala.collection.mutable.HashSet[Any]()
override def update(input: InternalRow): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
seen += evaluatedExpr
}
}
override def eval(input: InternalRow): Any = {
if (seen.size == 0) {
null
} else {
Cast(Literal(
seen.reduceLeft(
dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
dataType).eval(null)
}
}
}
case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
@ -489,6 +719,20 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
override def newInstance(): FirstFunction = new FirstFunction(child, this)
}
case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
var result: Any = null
override def update(input: InternalRow): Unit = {
if (result == null) {
result = expr.eval(input)
}
}
override def eval(input: InternalRow): Any = result
}
case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references: AttributeSet = child.references
override def nullable: Boolean = true
@ -504,234 +748,6 @@ case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode
override def newInstance(): LastFunction = new LastFunction(child, this)
}
case class AverageFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}
private val zero = Cast(Literal(0), calcType)
private var count: Long = _
private val sum = MutableLiteral(zero.eval(null), calcType)
private def addFunction(value: Any) = Add(sum,
Cast(Literal.create(value, expr.dataType), calcType))
override def eval(input: InternalRow): Any = {
if (count == 0L) {
null
} else {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(Divide(
Cast(sum, DecimalType.Unlimited),
Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
case _ =>
Divide(
Cast(sum, dataType),
Cast(Literal(count), dataType)).eval(null)
}
}
}
override def update(input: InternalRow): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
count += 1
sum.update(addFunction(evaluatedExpr), input)
}
}
}
case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
var count: Long = _
override def update(input: InternalRow): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
count += 1L
}
}
override def eval(input: InternalRow): Any = count
}
case class ApproxCountDistinctPartitionFunction(
expr: Expression,
base: AggregateExpression,
relativeSD: Double)
extends AggregateFunction {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
override def update(input: InternalRow): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
hyperLogLog.offer(evaluatedExpr)
}
}
override def eval(input: InternalRow): Any = hyperLogLog
}
case class ApproxCountDistinctMergeFunction(
expr: Expression,
base: AggregateExpression,
relativeSD: Double)
extends AggregateFunction {
def this() = this(null, null, 0) // Required for serialization.
private val hyperLogLog = new HyperLogLog(relativeSD)
override def update(input: InternalRow): Unit = {
val evaluatedExpr = expr.eval(input)
hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
}
override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
}
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}
private val zero = Cast(Literal(0), calcType)
private val sum = MutableLiteral(null, calcType)
private val addFunction =
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
override def update(input: InternalRow): Unit = {
sum.update(addFunction, input)
}
override def eval(input: InternalRow): Any = {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(sum, dataType).eval(null)
case _ => sum.eval(null)
}
}
}
case class CombineSumFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
private val calcType =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
DecimalType.Unlimited
case _ =>
expr.dataType
}
private val zero = Cast(Literal(0), calcType)
private val sum = MutableLiteral(null, calcType)
private val addFunction =
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
override def update(input: InternalRow): Unit = {
val result = expr.eval(input)
// partial sum result can be null only when no input rows present
if(result != null) {
sum.update(addFunction, input)
}
}
override def eval(input: InternalRow): Any = {
expr.dataType match {
case DecimalType.Fixed(_, _) =>
Cast(sum, dataType).eval(null)
case _ => sum.eval(null)
}
}
}
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
private val seen = new scala.collection.mutable.HashSet[Any]()
override def update(input: InternalRow): Unit = {
val evaluatedExpr = expr.eval(input)
if (evaluatedExpr != null) {
seen += evaluatedExpr
}
}
override def eval(input: InternalRow): Any = {
if (seen.size == 0) {
null
} else {
Cast(Literal(
seen.reduceLeft(
dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
dataType).eval(null)
}
}
}
case class CountDistinctFunction(
@transient expr: Seq[Expression],
@transient base: AggregateExpression)
extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
val seen = new OpenHashSet[Any]()
@transient
val distinctValue = new InterpretedProjection(expr)
override def update(input: InternalRow): Unit = {
val evaluatedExpr = distinctValue(input)
if (!evaluatedExpr.anyNull) {
seen.add(evaluatedExpr)
}
}
override def eval(input: InternalRow): Any = seen.size.toLong
}
case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
var result: Any = null
override def update(input: InternalRow): Unit = {
if (result == null) {
result = expr.eval(input)
}
}
override def eval(input: InternalRow): Any = result
}
case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.

View file

@ -25,8 +25,6 @@ import org.apache.spark.sql.types._
abstract class UnaryArithmetic extends UnaryExpression {
self: Product =>
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
override def dataType: DataType = child.dataType
override def eval(input: InternalRow): Any = {

View file

@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types._
/**
* Returns an Array containing the evaluation of all children expressions.
*/
@ -27,15 +28,12 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
lazy val childTypes = children.map(_.dataType).distinct
override lazy val resolved =
childrenResolved && childTypes.size <= 1
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
override def dataType: DataType = {
assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}")
ArrayType(
childTypes.headOption.getOrElse(NullType),
children.headOption.map(_.dataType).getOrElse(NullType),
containsNull = children.exists(_.nullable))
}
@ -56,19 +54,15 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
override lazy val resolved: Boolean = childrenResolved
override lazy val dataType: StructType = {
assert(resolved,
s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.")
val fields = children.zipWithIndex.map { case (child, idx) =>
child match {
case ne: NamedExpression =>
StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
case _ =>
StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
}
val fields = children.zipWithIndex.map { case (child, idx) =>
child match {
case ne: NamedExpression =>
StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
case _ =>
StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
}
}
StructType(fields)
}

View file

@ -17,16 +17,17 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types._
/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
/**
* Return the unscaled Long value of a Decimal, assuming it fits in a Long.
* Note: this expression is internal and created only by the optimizer,
* we don't need to do type check for it.
*/
case class UnscaledValue(child: Expression) extends UnaryExpression {
override def dataType: DataType = LongType
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
override def toString: String = s"UnscaledValue($child)"
override def eval(input: InternalRow): Any = {
@ -43,12 +44,14 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
}
}
/** Create a Decimal from an unscaled Long value */
/**
* Create a Decimal from an unscaled Long value.
* Note: this expression is internal and created only by the optimizer,
* we don't need to do type check for it.
*/
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
override def dataType: DataType = DecimalType(precision, scale)
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
override def eval(input: InternalRow): Decimal = {

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import scala.collection.Map
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
import org.apache.spark.sql.types._
@ -100,9 +100,14 @@ case class UserDefinedGenerator(
case class Explode(child: Expression)
extends Generator with trees.UnaryNode[Expression] {
override lazy val resolved =
child.resolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
override def checkInputDataTypes(): TypeCheckResult = {
if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"input to function explode should be array or map type, not ${child.dataType}")
}
}
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
case ArrayType(et, containsNull) => (et, containsNull) :: Nil

View file

@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
import java.lang.{Long => JLong}
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType}
import org.apache.spark.unsafe.types.UTF8String
@ -60,7 +59,6 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
override def dataType: DataType = DoubleType
override def foldable: Boolean = child.foldable
override def nullable: Boolean = true
override def toString: String = s"$name($child)"
@ -224,7 +222,7 @@ case class Bin(child: Expression)
def funcName: String = name.toLowerCase
override def eval(input: catalyst.InternalRow): Any = {
override def eval(input: InternalRow): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null

View file

@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
@ -113,7 +112,8 @@ case class Alias(child: Expression, name: String)(
extends NamedExpression with trees.UnaryNode[Expression] {
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]
override lazy val resolved =
childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator]
override def eval(input: InternalRow): Any = child.eval(input)

View file

@ -17,33 +17,32 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.catalyst.util.TypeUtils
import org.apache.spark.sql.types.DataType
case class Coalesce(children: Seq[Expression]) extends Expression {
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
override def nullable: Boolean = !children.exists(!_.nullable)
override def nullable: Boolean = children.forall(_.nullable)
// Coalesce is foldable if all children are foldable.
override def foldable: Boolean = !children.exists(!_.foldable)
override def foldable: Boolean = children.forall(_.foldable)
// Only resolved if all the children are of the same type.
override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1)
override def checkInputDataTypes(): TypeCheckResult = {
if (children == Nil) {
TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty")
} else {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce")
}
}
override def toString: String = s"Coalesce(${children.mkString(",")})"
override def dataType: DataType = if (resolved) {
children.head.dataType
} else {
val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ")
throw new UnresolvedException(
this, s"Coalesce cannot have children of different types. $childTypes")
}
override def dataType: DataType = children.head.dataType
override def eval(input: InternalRow): Any = {
var i = 0
var result: Any = null
val childIterator = children.iterator
while (childIterator.hasNext && result == null) {
@ -75,7 +74,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
}
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
override def foldable: Boolean = child.foldable
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = {
@ -93,7 +91,6 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
}
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
override def foldable: Boolean = child.foldable
override def nullable: Boolean = false
override def toString: String = s"IS NOT NULL $child"

View file

@ -78,6 +78,8 @@ case class NewSet(elementType: DataType) extends LeafExpression {
/**
* Adds an item to a set.
* For performance, this expression mutates its input during evaluation.
* Note: this expression is internal and created only by the GeneratedAggregate,
* we don't need to do type check for it.
*/
case class AddItemToSet(item: Expression, set: Expression) extends Expression {
@ -85,7 +87,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
override def nullable: Boolean = set.nullable
override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT]
override def dataType: DataType = set.dataType
override def eval(input: InternalRow): Any = {
val itemEval = item.eval(input)
@ -128,12 +130,14 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
/**
* Combines the elements of two sets.
* For performance, this expression mutates its left input set during evaluation.
* Note: this expression is internal and created only by the GeneratedAggregate,
* we don't need to do type check for it.
*/
case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
override def nullable: Boolean = left.nullable || right.nullable
override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT]
override def dataType: DataType = left.dataType
override def symbol: String = "++="
@ -176,6 +180,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
/**
* Returns the number of elements in the input set.
* Note: this expression is internal and created only by the GeneratedAggregate,
* we don't need to do type check for it.
*/
case class CountSet(child: Expression) extends UnaryExpression {

View file

@ -117,8 +117,6 @@ trait CaseConversionExpression extends ExpectsInputTypes {
def convert(v: UTF8String): UTF8String
override def foldable: Boolean = child.foldable
override def nullable: Boolean = child.nullable
override def dataType: DataType = StringType
override def expectedChildTypes: Seq[DataType] = Seq(StringType)

View file

@ -68,7 +68,8 @@ case class WindowSpecDefinition(
override def children: Seq[Expression] = partitionSpec ++ orderSpec
override lazy val resolved: Boolean =
childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame]
childrenResolved && checkInputDataTypes().isSuccess &&
frameSpecification.isInstanceOf[SpecifiedWindowFrame]
override def toString: String = simpleString

View file

@ -48,6 +48,15 @@ object TypeUtils {
}
}
def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
if (types.distinct.size > 1) {
TypeCheckResult.TypeCheckFailure(
s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}")
} else {
TypeCheckResult.TypeCheckSuccess
}
}
def getNumeric(t: DataType): Numeric[Any] =
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]

View file

@ -193,7 +193,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
"invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
"cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
errorTest(
"non-boolean filters",
@ -264,9 +264,9 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
val plan =
Aggregate(
Nil,
Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil,
Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
LocalRelation(
AttributeReference("a", StringType)(exprId = ExprId(2))))
AttributeReference("a", IntegerType)(exprId = ExprId(2))))
assert(plan.resolved)

View file

@ -15,13 +15,13 @@
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types.StringType
@ -136,6 +136,28 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(
CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
"WHEN expressions in CaseWhen should all be boolean type")
}
test("check types for aggregates") {
// We will cast String to Double for sum and average
assertSuccess(Sum('stringField))
assertSuccess(SumDistinct('stringField))
assertSuccess(Average('stringField))
assertError(Min('complexField), "function min accepts non-complex type")
assertError(Max('complexField), "function max accepts non-complex type")
assertError(Sum('booleanField), "function sum accepts numeric type")
assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type")
assertError(Average('booleanField), "function average accepts numeric type")
}
test("check types for others") {
assertError(CreateArray(Seq('intField, 'booleanField)),
"input to function array should all be the same type")
assertError(Coalesce(Seq('intField, 'booleanField)),
"input to function coalesce should all be the same type")
assertError(Coalesce(Nil), "input to function coalesce cannot be empty")
assertError(Explode('intField),
"input to function explode should be array or map type")
}
}

View file

@ -55,7 +55,7 @@ private[spark] case class PythonUDF(
override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
def nullable: Boolean = true
override def nullable: Boolean = true
override def eval(input: InternalRow): Any = {
throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.")

View file

@ -59,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
}
assert(numEquals === 1)
}
test("COALESCE with different types") {
intercept[RuntimeException] {
TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect()
}
}
}