[SPARK-8075] [SQL] apply type check interface to more expressions
a follow up of https://github.com/apache/spark/pull/6405. Note: It's not a big change, a lot of changing is due to I swap some code in `aggregates.scala` to make aggregate functions right below its corresponding aggregate expressions. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6723 from cloud-fan/type-check and squashes the following commits: 2124301 [Wenchen Fan] fix tests 5a658bb [Wenchen Fan] add tests 287d3bb [Wenchen Fan] apply type check interface to more expressions
This commit is contained in:
parent
7daa70292e
commit
b71d3254e5
|
@ -587,8 +587,8 @@ class Analyzer(
|
|||
failAnalysis(
|
||||
s"""Expect multiple names given for ${g.getClass.getName},
|
||||
|but only single name '${name}' specified""".stripMargin)
|
||||
case Alias(g: Generator, name) => Some((g, name :: Nil))
|
||||
case MultiAlias(g: Generator, names) => Some(g, names)
|
||||
case Alias(g: Generator, name) if g.resolved => Some((g, name :: Nil))
|
||||
case MultiAlias(g: Generator, names) if g.resolved => Some(g, names)
|
||||
case _ => None
|
||||
}
|
||||
}
|
||||
|
|
|
@ -317,6 +317,7 @@ trait HiveTypeCoercion {
|
|||
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
|
||||
|
||||
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
|
||||
case SumDistinct(e @ StringType()) => Sum(Cast(e, DoubleType))
|
||||
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
|
||||
}
|
||||
}
|
||||
|
@ -590,11 +591,12 @@ trait HiveTypeCoercion {
|
|||
// Skip nodes who's children have not been resolved yet.
|
||||
case e if !e.childrenResolved => e
|
||||
|
||||
case a @ CreateArray(children) if !a.resolved =>
|
||||
val commonType = a.childTypes.reduce(
|
||||
(a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
|
||||
CreateArray(
|
||||
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))
|
||||
case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 =>
|
||||
val types = children.map(_.dataType)
|
||||
findTightestCommonTypeAndPromoteToString(types) match {
|
||||
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
|
||||
case None => a
|
||||
}
|
||||
|
||||
// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
|
||||
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
|
||||
|
@ -620,12 +622,11 @@ trait HiveTypeCoercion {
|
|||
// Coalesce should return the first non-null value, which could be any column
|
||||
// from the list. So we need to make sure the return type is deterministic and
|
||||
// compatible with every child column.
|
||||
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
|
||||
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
|
||||
val types = es.map(_.dataType)
|
||||
findTightestCommonTypeAndPromoteToString(types) match {
|
||||
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
|
||||
case None =>
|
||||
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
|
||||
case None => c
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -22,7 +22,7 @@ import java.sql.{Date, Timestamp}
|
|||
import java.text.{DateFormat, SimpleDateFormat}
|
||||
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.sql.catalyst
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -31,7 +31,14 @@ import org.apache.spark.unsafe.types.UTF8String
|
|||
/** Cast the child expression to the target data type. */
|
||||
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {
|
||||
|
||||
override lazy val resolved = childrenResolved && resolve(child.dataType, dataType)
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (resolve(child.dataType, dataType)) {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"cannot cast ${child.dataType} to $dataType")
|
||||
}
|
||||
}
|
||||
|
||||
override def foldable: Boolean = child.foldable
|
||||
|
||||
|
|
|
@ -162,9 +162,7 @@ abstract class Expression extends TreeNode[Expression] {
|
|||
/**
|
||||
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
|
||||
* or returns a `TypeCheckResult` with an error message if invalid.
|
||||
* Note: it's not valid to call this method until `childrenResolved == true`
|
||||
* TODO: we should remove the default implementation and implement it for all
|
||||
* expressions with proper error message.
|
||||
* Note: it's not valid to call this method until `childrenResolved == true`.
|
||||
*/
|
||||
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
|
||||
}
|
||||
|
|
|
@ -96,6 +96,11 @@ object ExtractValue {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A common interface of all kinds of extract value expressions.
|
||||
* Note: concrete extract value expressions are created only by `ExtractValue.apply`,
|
||||
* we don't need to do type check for them.
|
||||
*/
|
||||
trait ExtractValue extends UnaryExpression {
|
||||
self: Product =>
|
||||
}
|
||||
|
@ -179,9 +184,6 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
|
|||
|
||||
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
|
||||
|
||||
override lazy val resolved = childrenResolved &&
|
||||
child.dataType.isInstanceOf[ArrayType] && ordinal.dataType.isInstanceOf[IntegralType]
|
||||
|
||||
protected def evalNotNull(value: Any, ordinal: Any) = {
|
||||
// TODO: consider using Array[_] for ArrayType child to avoid
|
||||
// boxing of primitives
|
||||
|
@ -203,8 +205,6 @@ case class GetMapValue(child: Expression, ordinal: Expression)
|
|||
|
||||
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
|
||||
|
||||
override lazy val resolved = childrenResolved && child.dataType.isInstanceOf[MapType]
|
||||
|
||||
protected def evalNotNull(value: Any, ordinal: Any) = {
|
||||
val baseValue = value.asInstanceOf[Map[Any, _]]
|
||||
baseValue.get(ordinal).orNull
|
||||
|
|
|
@ -19,9 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import com.clearspring.analytics.stream.cardinality.HyperLogLog
|
||||
|
||||
import org.apache.spark.sql.catalyst
|
||||
import org.apache.spark.sql.catalyst.errors.TreeNodeException
|
||||
import org.apache.spark.sql.catalyst.trees
|
||||
import org.apache.spark.sql.catalyst.errors.TreeNodeException
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.util.collection.OpenHashSet
|
||||
|
||||
|
@ -101,6 +102,9 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[
|
|||
}
|
||||
|
||||
override def newInstance(): MinFunction = new MinFunction(child, this)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult =
|
||||
TypeUtils.checkForOrderingExpr(child.dataType, "function min")
|
||||
}
|
||||
|
||||
case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
|
||||
|
@ -132,6 +136,9 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[
|
|||
}
|
||||
|
||||
override def newInstance(): MaxFunction = new MaxFunction(child, this)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult =
|
||||
TypeUtils.checkForOrderingExpr(child.dataType, "function max")
|
||||
}
|
||||
|
||||
case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
|
||||
|
@ -165,6 +172,21 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
|
|||
override def newInstance(): CountFunction = new CountFunction(child, this)
|
||||
}
|
||||
|
||||
case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
var count: Long = _
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
if (evaluatedExpr != null) {
|
||||
count += 1L
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = count
|
||||
}
|
||||
|
||||
case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate {
|
||||
def this() = this(null)
|
||||
|
||||
|
@ -183,6 +205,28 @@ case class CountDistinct(expressions: Seq[Expression]) extends PartialAggregate
|
|||
}
|
||||
}
|
||||
|
||||
case class CountDistinctFunction(
|
||||
@transient expr: Seq[Expression],
|
||||
@transient base: AggregateExpression)
|
||||
extends AggregateFunction {
|
||||
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
val seen = new OpenHashSet[Any]()
|
||||
|
||||
@transient
|
||||
val distinctValue = new InterpretedProjection(expr)
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = distinctValue(input)
|
||||
if (!evaluatedExpr.anyNull) {
|
||||
seen.add(evaluatedExpr)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = seen.size.toLong
|
||||
}
|
||||
|
||||
case class CollectHashSet(expressions: Seq[Expression]) extends AggregateExpression {
|
||||
def this() = this(null)
|
||||
|
||||
|
@ -278,6 +322,25 @@ case class ApproxCountDistinctPartition(child: Expression, relativeSD: Double)
|
|||
}
|
||||
}
|
||||
|
||||
case class ApproxCountDistinctPartitionFunction(
|
||||
expr: Expression,
|
||||
base: AggregateExpression,
|
||||
relativeSD: Double)
|
||||
extends AggregateFunction {
|
||||
def this() = this(null, null, 0) // Required for serialization.
|
||||
|
||||
private val hyperLogLog = new HyperLogLog(relativeSD)
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
if (evaluatedExpr != null) {
|
||||
hyperLogLog.offer(evaluatedExpr)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = hyperLogLog
|
||||
}
|
||||
|
||||
case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
|
||||
extends AggregateExpression with trees.UnaryNode[Expression] {
|
||||
|
||||
|
@ -289,6 +352,23 @@ case class ApproxCountDistinctMerge(child: Expression, relativeSD: Double)
|
|||
}
|
||||
}
|
||||
|
||||
case class ApproxCountDistinctMergeFunction(
|
||||
expr: Expression,
|
||||
base: AggregateExpression,
|
||||
relativeSD: Double)
|
||||
extends AggregateFunction {
|
||||
def this() = this(null, null, 0) // Required for serialization.
|
||||
|
||||
private val hyperLogLog = new HyperLogLog(relativeSD)
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
|
||||
}
|
||||
|
||||
case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05)
|
||||
extends PartialAggregate with trees.UnaryNode[Expression] {
|
||||
|
||||
|
@ -349,6 +429,56 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
|
|||
}
|
||||
|
||||
override def newInstance(): AverageFunction = new AverageFunction(child, this)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult =
|
||||
TypeUtils.checkForNumericExpr(child.dataType, "function average")
|
||||
}
|
||||
|
||||
case class AverageFunction(expr: Expression, base: AggregateExpression)
|
||||
extends AggregateFunction {
|
||||
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
private val zero = Cast(Literal(0), calcType)
|
||||
|
||||
private var count: Long = _
|
||||
private val sum = MutableLiteral(zero.eval(null), calcType)
|
||||
|
||||
private def addFunction(value: Any) = Add(sum,
|
||||
Cast(Literal.create(value, expr.dataType), calcType))
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
if (count == 0L) {
|
||||
null
|
||||
} else {
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(Divide(
|
||||
Cast(sum, DecimalType.Unlimited),
|
||||
Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
|
||||
case _ =>
|
||||
Divide(
|
||||
Cast(sum, dataType),
|
||||
Cast(Literal(count), dataType)).eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
if (evaluatedExpr != null) {
|
||||
count += 1
|
||||
sum.update(addFunction(evaluatedExpr), input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
||||
|
@ -383,6 +513,40 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
|
|||
}
|
||||
|
||||
override def newInstance(): SumFunction = new SumFunction(child, this)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult =
|
||||
TypeUtils.checkForNumericExpr(child.dataType, "function sum")
|
||||
}
|
||||
|
||||
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
private val zero = Cast(Literal(0), calcType)
|
||||
|
||||
private val sum = MutableLiteral(null, calcType)
|
||||
|
||||
private val addFunction =
|
||||
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
sum.update(addFunction, input)
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(sum, dataType).eval(null)
|
||||
case _ => sum.eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -409,6 +573,43 @@ case class CombineSum(child: Expression) extends AggregateExpression {
|
|||
override def newInstance(): CombineSumFunction = new CombineSumFunction(child, this)
|
||||
}
|
||||
|
||||
case class CombineSumFunction(expr: Expression, base: AggregateExpression)
|
||||
extends AggregateFunction {
|
||||
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
private val zero = Cast(Literal(0), calcType)
|
||||
|
||||
private val sum = MutableLiteral(null, calcType)
|
||||
|
||||
private val addFunction =
|
||||
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val result = expr.eval(input)
|
||||
// partial sum result can be null only when no input rows present
|
||||
if(result != null) {
|
||||
sum.update(addFunction, input)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(sum, dataType).eval(null)
|
||||
case _ => sum.eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class SumDistinct(child: Expression)
|
||||
extends PartialAggregate with trees.UnaryNode[Expression] {
|
||||
|
||||
|
@ -431,6 +632,35 @@ case class SumDistinct(child: Expression)
|
|||
CombineSetsAndSum(partialSet.toAttribute, this),
|
||||
partialSet :: Nil)
|
||||
}
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult =
|
||||
TypeUtils.checkForNumericExpr(child.dataType, "function sumDistinct")
|
||||
}
|
||||
|
||||
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
|
||||
extends AggregateFunction {
|
||||
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val seen = new scala.collection.mutable.HashSet[Any]()
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
if (evaluatedExpr != null) {
|
||||
seen += evaluatedExpr
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
if (seen.size == 0) {
|
||||
null
|
||||
} else {
|
||||
Cast(Literal(
|
||||
seen.reduceLeft(
|
||||
dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
|
||||
dataType).eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class CombineSetsAndSum(inputSet: Expression, base: Expression) extends AggregateExpression {
|
||||
|
@ -489,6 +719,20 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
|
|||
override def newInstance(): FirstFunction = new FirstFunction(child, this)
|
||||
}
|
||||
|
||||
case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
var result: Any = null
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
if (result == null) {
|
||||
result = expr.eval(input)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = result
|
||||
}
|
||||
|
||||
case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
|
||||
override def references: AttributeSet = child.references
|
||||
override def nullable: Boolean = true
|
||||
|
@ -504,234 +748,6 @@ case class Last(child: Expression) extends PartialAggregate with trees.UnaryNode
|
|||
override def newInstance(): LastFunction = new LastFunction(child, this)
|
||||
}
|
||||
|
||||
case class AverageFunction(expr: Expression, base: AggregateExpression)
|
||||
extends AggregateFunction {
|
||||
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
private val zero = Cast(Literal(0), calcType)
|
||||
|
||||
private var count: Long = _
|
||||
private val sum = MutableLiteral(zero.eval(null), calcType)
|
||||
|
||||
private def addFunction(value: Any) = Add(sum,
|
||||
Cast(Literal.create(value, expr.dataType), calcType))
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
if (count == 0L) {
|
||||
null
|
||||
} else {
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(Divide(
|
||||
Cast(sum, DecimalType.Unlimited),
|
||||
Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null)
|
||||
case _ =>
|
||||
Divide(
|
||||
Cast(sum, dataType),
|
||||
Cast(Literal(count), dataType)).eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
if (evaluatedExpr != null) {
|
||||
count += 1
|
||||
sum.update(addFunction(evaluatedExpr), input)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
var count: Long = _
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
if (evaluatedExpr != null) {
|
||||
count += 1L
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = count
|
||||
}
|
||||
|
||||
case class ApproxCountDistinctPartitionFunction(
|
||||
expr: Expression,
|
||||
base: AggregateExpression,
|
||||
relativeSD: Double)
|
||||
extends AggregateFunction {
|
||||
def this() = this(null, null, 0) // Required for serialization.
|
||||
|
||||
private val hyperLogLog = new HyperLogLog(relativeSD)
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
if (evaluatedExpr != null) {
|
||||
hyperLogLog.offer(evaluatedExpr)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = hyperLogLog
|
||||
}
|
||||
|
||||
case class ApproxCountDistinctMergeFunction(
|
||||
expr: Expression,
|
||||
base: AggregateExpression,
|
||||
relativeSD: Double)
|
||||
extends AggregateFunction {
|
||||
def this() = this(null, null, 0) // Required for serialization.
|
||||
|
||||
private val hyperLogLog = new HyperLogLog(relativeSD)
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
hyperLogLog.addAll(evaluatedExpr.asInstanceOf[HyperLogLog])
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = hyperLogLog.cardinality()
|
||||
}
|
||||
|
||||
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
private val zero = Cast(Literal(0), calcType)
|
||||
|
||||
private val sum = MutableLiteral(null, calcType)
|
||||
|
||||
private val addFunction =
|
||||
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
sum.update(addFunction, input)
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(sum, dataType).eval(null)
|
||||
case _ => sum.eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class CombineSumFunction(expr: Expression, base: AggregateExpression)
|
||||
extends AggregateFunction {
|
||||
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val calcType =
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
DecimalType.Unlimited
|
||||
case _ =>
|
||||
expr.dataType
|
||||
}
|
||||
|
||||
private val zero = Cast(Literal(0), calcType)
|
||||
|
||||
private val sum = MutableLiteral(null, calcType)
|
||||
|
||||
private val addFunction =
|
||||
Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum, zero))
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val result = expr.eval(input)
|
||||
// partial sum result can be null only when no input rows present
|
||||
if(result != null) {
|
||||
sum.update(addFunction, input)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
expr.dataType match {
|
||||
case DecimalType.Fixed(_, _) =>
|
||||
Cast(sum, dataType).eval(null)
|
||||
case _ => sum.eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
|
||||
extends AggregateFunction {
|
||||
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
private val seen = new scala.collection.mutable.HashSet[Any]()
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = expr.eval(input)
|
||||
if (evaluatedExpr != null) {
|
||||
seen += evaluatedExpr
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
if (seen.size == 0) {
|
||||
null
|
||||
} else {
|
||||
Cast(Literal(
|
||||
seen.reduceLeft(
|
||||
dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)),
|
||||
dataType).eval(null)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
case class CountDistinctFunction(
|
||||
@transient expr: Seq[Expression],
|
||||
@transient base: AggregateExpression)
|
||||
extends AggregateFunction {
|
||||
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
val seen = new OpenHashSet[Any]()
|
||||
|
||||
@transient
|
||||
val distinctValue = new InterpretedProjection(expr)
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
val evaluatedExpr = distinctValue(input)
|
||||
if (!evaluatedExpr.anyNull) {
|
||||
seen.add(evaluatedExpr)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = seen.size.toLong
|
||||
}
|
||||
|
||||
case class FirstFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
var result: Any = null
|
||||
|
||||
override def update(input: InternalRow): Unit = {
|
||||
if (result == null) {
|
||||
result = expr.eval(input)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): Any = result
|
||||
}
|
||||
|
||||
case class LastFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
|
||||
def this() = this(null, null) // Required for serialization.
|
||||
|
||||
|
|
|
@ -25,8 +25,6 @@ import org.apache.spark.sql.types._
|
|||
abstract class UnaryArithmetic extends UnaryExpression {
|
||||
self: Product =>
|
||||
|
||||
override def foldable: Boolean = child.foldable
|
||||
override def nullable: Boolean = child.nullable
|
||||
override def dataType: DataType = child.dataType
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
|
|
|
@ -17,9 +17,10 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
||||
/**
|
||||
* Returns an Array containing the evaluation of all children expressions.
|
||||
*/
|
||||
|
@ -27,15 +28,12 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
|
|||
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
||||
lazy val childTypes = children.map(_.dataType).distinct
|
||||
|
||||
override lazy val resolved =
|
||||
childrenResolved && childTypes.size <= 1
|
||||
override def checkInputDataTypes(): TypeCheckResult =
|
||||
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
|
||||
|
||||
override def dataType: DataType = {
|
||||
assert(resolved, s"Invalid dataType of mixed ArrayType ${childTypes.mkString(",")}")
|
||||
ArrayType(
|
||||
childTypes.headOption.getOrElse(NullType),
|
||||
children.headOption.map(_.dataType).getOrElse(NullType),
|
||||
containsNull = children.exists(_.nullable))
|
||||
}
|
||||
|
||||
|
@ -56,19 +54,15 @@ case class CreateStruct(children: Seq[Expression]) extends Expression {
|
|||
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
||||
override lazy val resolved: Boolean = childrenResolved
|
||||
|
||||
override lazy val dataType: StructType = {
|
||||
assert(resolved,
|
||||
s"CreateStruct contains unresolvable children: ${children.filterNot(_.resolved)}.")
|
||||
val fields = children.zipWithIndex.map { case (child, idx) =>
|
||||
child match {
|
||||
case ne: NamedExpression =>
|
||||
StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
|
||||
case _ =>
|
||||
StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
|
||||
}
|
||||
val fields = children.zipWithIndex.map { case (child, idx) =>
|
||||
child match {
|
||||
case ne: NamedExpression =>
|
||||
StructField(ne.name, ne.dataType, ne.nullable, ne.metadata)
|
||||
case _ =>
|
||||
StructField(s"col${idx + 1}", child.dataType, child.nullable, Metadata.empty)
|
||||
}
|
||||
}
|
||||
StructType(fields)
|
||||
}
|
||||
|
||||
|
|
|
@ -17,16 +17,17 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/** Return the unscaled Long value of a Decimal, assuming it fits in a Long */
|
||||
/**
|
||||
* Return the unscaled Long value of a Decimal, assuming it fits in a Long.
|
||||
* Note: this expression is internal and created only by the optimizer,
|
||||
* we don't need to do type check for it.
|
||||
*/
|
||||
case class UnscaledValue(child: Expression) extends UnaryExpression {
|
||||
|
||||
override def dataType: DataType = LongType
|
||||
override def foldable: Boolean = child.foldable
|
||||
override def nullable: Boolean = child.nullable
|
||||
override def toString: String = s"UnscaledValue($child)"
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
|
@ -43,12 +44,14 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
|
|||
}
|
||||
}
|
||||
|
||||
/** Create a Decimal from an unscaled Long value */
|
||||
/**
|
||||
* Create a Decimal from an unscaled Long value.
|
||||
* Note: this expression is internal and created only by the optimizer,
|
||||
* we don't need to do type check for it.
|
||||
*/
|
||||
case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression {
|
||||
|
||||
override def dataType: DataType = DecimalType(precision, scale)
|
||||
override def foldable: Boolean = child.foldable
|
||||
override def nullable: Boolean = child.nullable
|
||||
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
|
||||
|
||||
override def eval(input: InternalRow): Decimal = {
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import scala.collection.Map
|
||||
|
||||
import org.apache.spark.sql.catalyst
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, trees}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -100,9 +100,14 @@ case class UserDefinedGenerator(
|
|||
case class Explode(child: Expression)
|
||||
extends Generator with trees.UnaryNode[Expression] {
|
||||
|
||||
override lazy val resolved =
|
||||
child.resolved &&
|
||||
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"input to function explode should be array or map type, not ${child.dataType}")
|
||||
}
|
||||
}
|
||||
|
||||
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
|
||||
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
|
||||
|
|
|
@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import java.lang.{Long => JLong}
|
||||
|
||||
import org.apache.spark.sql.catalyst
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.types.{DataType, DoubleType, LongType, StringType}
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
@ -60,7 +59,6 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
|
|||
|
||||
override def expectedChildTypes: Seq[DataType] = Seq(DoubleType)
|
||||
override def dataType: DataType = DoubleType
|
||||
override def foldable: Boolean = child.foldable
|
||||
override def nullable: Boolean = true
|
||||
override def toString: String = s"$name($child)"
|
||||
|
||||
|
@ -224,7 +222,7 @@ case class Bin(child: Expression)
|
|||
|
||||
def funcName: String = name.toLowerCase
|
||||
|
||||
override def eval(input: catalyst.InternalRow): Any = {
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val evalE = child.eval(input)
|
||||
if (evalE == null) {
|
||||
null
|
||||
|
|
|
@ -17,7 +17,6 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
|
||||
import org.apache.spark.sql.catalyst.errors.TreeNodeException
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
|
||||
|
@ -113,7 +112,8 @@ case class Alias(child: Expression, name: String)(
|
|||
extends NamedExpression with trees.UnaryNode[Expression] {
|
||||
|
||||
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
|
||||
override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]
|
||||
override lazy val resolved =
|
||||
childrenResolved && checkInputDataTypes().isSuccess && !child.isInstanceOf[Generator]
|
||||
|
||||
override def eval(input: InternalRow): Any = child.eval(input)
|
||||
|
||||
|
|
|
@ -17,33 +17,32 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
|
||||
import org.apache.spark.sql.catalyst.util.TypeUtils
|
||||
import org.apache.spark.sql.types.DataType
|
||||
|
||||
case class Coalesce(children: Seq[Expression]) extends Expression {
|
||||
|
||||
/** Coalesce is nullable if all of its children are nullable, or if it has no children. */
|
||||
override def nullable: Boolean = !children.exists(!_.nullable)
|
||||
override def nullable: Boolean = children.forall(_.nullable)
|
||||
|
||||
// Coalesce is foldable if all children are foldable.
|
||||
override def foldable: Boolean = !children.exists(!_.foldable)
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
||||
// Only resolved if all the children are of the same type.
|
||||
override lazy val resolved = childrenResolved && (children.map(_.dataType).distinct.size == 1)
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children == Nil) {
|
||||
TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty")
|
||||
} else {
|
||||
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce")
|
||||
}
|
||||
}
|
||||
|
||||
override def toString: String = s"Coalesce(${children.mkString(",")})"
|
||||
|
||||
override def dataType: DataType = if (resolved) {
|
||||
children.head.dataType
|
||||
} else {
|
||||
val childTypes = children.map(c => s"$c: ${c.dataType}").mkString(", ")
|
||||
throw new UnresolvedException(
|
||||
this, s"Coalesce cannot have children of different types. $childTypes")
|
||||
}
|
||||
override def dataType: DataType = children.head.dataType
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
var i = 0
|
||||
var result: Any = null
|
||||
val childIterator = children.iterator
|
||||
while (childIterator.hasNext && result == null) {
|
||||
|
@ -75,7 +74,6 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|
|||
}
|
||||
|
||||
case class IsNull(child: Expression) extends UnaryExpression with Predicate {
|
||||
override def foldable: Boolean = child.foldable
|
||||
override def nullable: Boolean = false
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
|
@ -93,7 +91,6 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate {
|
|||
}
|
||||
|
||||
case class IsNotNull(child: Expression) extends UnaryExpression with Predicate {
|
||||
override def foldable: Boolean = child.foldable
|
||||
override def nullable: Boolean = false
|
||||
override def toString: String = s"IS NOT NULL $child"
|
||||
|
||||
|
|
|
@ -78,6 +78,8 @@ case class NewSet(elementType: DataType) extends LeafExpression {
|
|||
/**
|
||||
* Adds an item to a set.
|
||||
* For performance, this expression mutates its input during evaluation.
|
||||
* Note: this expression is internal and created only by the GeneratedAggregate,
|
||||
* we don't need to do type check for it.
|
||||
*/
|
||||
case class AddItemToSet(item: Expression, set: Expression) extends Expression {
|
||||
|
||||
|
@ -85,7 +87,7 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
|
|||
|
||||
override def nullable: Boolean = set.nullable
|
||||
|
||||
override def dataType: OpenHashSetUDT = set.dataType.asInstanceOf[OpenHashSetUDT]
|
||||
override def dataType: DataType = set.dataType
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val itemEval = item.eval(input)
|
||||
|
@ -128,12 +130,14 @@ case class AddItemToSet(item: Expression, set: Expression) extends Expression {
|
|||
/**
|
||||
* Combines the elements of two sets.
|
||||
* For performance, this expression mutates its left input set during evaluation.
|
||||
* Note: this expression is internal and created only by the GeneratedAggregate,
|
||||
* we don't need to do type check for it.
|
||||
*/
|
||||
case class CombineSets(left: Expression, right: Expression) extends BinaryExpression {
|
||||
|
||||
override def nullable: Boolean = left.nullable || right.nullable
|
||||
|
||||
override def dataType: OpenHashSetUDT = left.dataType.asInstanceOf[OpenHashSetUDT]
|
||||
override def dataType: DataType = left.dataType
|
||||
|
||||
override def symbol: String = "++="
|
||||
|
||||
|
@ -176,6 +180,8 @@ case class CombineSets(left: Expression, right: Expression) extends BinaryExpres
|
|||
|
||||
/**
|
||||
* Returns the number of elements in the input set.
|
||||
* Note: this expression is internal and created only by the GeneratedAggregate,
|
||||
* we don't need to do type check for it.
|
||||
*/
|
||||
case class CountSet(child: Expression) extends UnaryExpression {
|
||||
|
||||
|
|
|
@ -117,8 +117,6 @@ trait CaseConversionExpression extends ExpectsInputTypes {
|
|||
|
||||
def convert(v: UTF8String): UTF8String
|
||||
|
||||
override def foldable: Boolean = child.foldable
|
||||
override def nullable: Boolean = child.nullable
|
||||
override def dataType: DataType = StringType
|
||||
override def expectedChildTypes: Seq[DataType] = Seq(StringType)
|
||||
|
||||
|
|
|
@ -68,7 +68,8 @@ case class WindowSpecDefinition(
|
|||
override def children: Seq[Expression] = partitionSpec ++ orderSpec
|
||||
|
||||
override lazy val resolved: Boolean =
|
||||
childrenResolved && frameSpecification.isInstanceOf[SpecifiedWindowFrame]
|
||||
childrenResolved && checkInputDataTypes().isSuccess &&
|
||||
frameSpecification.isInstanceOf[SpecifiedWindowFrame]
|
||||
|
||||
|
||||
override def toString: String = simpleString
|
||||
|
|
|
@ -48,6 +48,15 @@ object TypeUtils {
|
|||
}
|
||||
}
|
||||
|
||||
def checkForSameTypeInputExpr(types: Seq[DataType], caller: String): TypeCheckResult = {
|
||||
if (types.distinct.size > 1) {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"input to $caller should all be the same type, but it's ${types.mkString("[", ", ", "]")}")
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
}
|
||||
}
|
||||
|
||||
def getNumeric(t: DataType): Numeric[Any] =
|
||||
t.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
|
||||
|
||||
|
|
|
@ -193,7 +193,7 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
errorTest(
|
||||
"bad casts",
|
||||
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
|
||||
"invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
|
||||
"cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
|
||||
|
||||
errorTest(
|
||||
"non-boolean filters",
|
||||
|
@ -264,9 +264,9 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
|
|||
val plan =
|
||||
Aggregate(
|
||||
Nil,
|
||||
Alias(Sum(AttributeReference("a", StringType)(exprId = ExprId(1))), "b")() :: Nil,
|
||||
Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
|
||||
LocalRelation(
|
||||
AttributeReference("a", StringType)(exprId = ExprId(2))))
|
||||
AttributeReference("a", IntegerType)(exprId = ExprId(2))))
|
||||
|
||||
assert(plan.resolved)
|
||||
|
||||
|
|
|
@ -15,13 +15,13 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
package org.apache.spark.sql.catalyst.analysis
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer
|
||||
import org.apache.spark.sql.catalyst.dsl.expressions._
|
||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
|
||||
import org.apache.spark.sql.types.StringType
|
||||
|
||||
|
@ -136,6 +136,28 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
assertError(
|
||||
CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
|
||||
"WHEN expressions in CaseWhen should all be boolean type")
|
||||
}
|
||||
|
||||
test("check types for aggregates") {
|
||||
// We will cast String to Double for sum and average
|
||||
assertSuccess(Sum('stringField))
|
||||
assertSuccess(SumDistinct('stringField))
|
||||
assertSuccess(Average('stringField))
|
||||
|
||||
assertError(Min('complexField), "function min accepts non-complex type")
|
||||
assertError(Max('complexField), "function max accepts non-complex type")
|
||||
assertError(Sum('booleanField), "function sum accepts numeric type")
|
||||
assertError(SumDistinct('booleanField), "function sumDistinct accepts numeric type")
|
||||
assertError(Average('booleanField), "function average accepts numeric type")
|
||||
}
|
||||
|
||||
test("check types for others") {
|
||||
assertError(CreateArray(Seq('intField, 'booleanField)),
|
||||
"input to function array should all be the same type")
|
||||
assertError(Coalesce(Seq('intField, 'booleanField)),
|
||||
"input to function coalesce should all be the same type")
|
||||
assertError(Coalesce(Nil), "input to function coalesce cannot be empty")
|
||||
assertError(Explode('intField),
|
||||
"input to function explode should be array or map type")
|
||||
}
|
||||
}
|
|
@ -55,7 +55,7 @@ private[spark] case class PythonUDF(
|
|||
|
||||
override def toString: String = s"PythonUDF#$name(${children.mkString(",")})"
|
||||
|
||||
def nullable: Boolean = true
|
||||
override def nullable: Boolean = true
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.")
|
||||
|
|
|
@ -59,10 +59,4 @@ class HiveTypeCoercionSuite extends HiveComparisonTest {
|
|||
}
|
||||
assert(numEquals === 1)
|
||||
}
|
||||
|
||||
test("COALESCE with different types") {
|
||||
intercept[RuntimeException] {
|
||||
TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue