[SPARK-8926][SQL] Good errors for ExpectsInputType expressions

For example: `cannot resolve 'testfunction(null)' due to data type mismatch: argument 1 is expected to be of type int, however, null is of type datetype.`

Author: Michael Armbrust <michael@databricks.com>

Closes #7303 from marmbrus/expectsTypeErrors and squashes the following commits:

c654a0e [Michael Armbrust] fix udts and make errors pretty
137160d [Michael Armbrust] style
5428fda [Michael Armbrust] style
10fac82 [Michael Armbrust] [SPARK-8926][SQL] Good errors for ExpectsInputType expressions
This commit is contained in:
Michael Armbrust 2015-07-08 22:05:58 -07:00 committed by Reynold Xin
parent aba5784dab
commit 768907eb7b
13 changed files with 256 additions and 143 deletions

View file

@ -702,11 +702,19 @@ object HiveTypeCoercion {
@Nullable val ret: Expression = (inType, expectedType) match {
// If the expected type is already a parent of the input type, no need to cast.
case _ if expectedType.isParentOf(inType) => e
case _ if expectedType.isSameType(inType) => e
// Cast null type (usually from null literals) into target types
case (NullType, target) => Cast(e, target.defaultConcreteType)
// If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is
// already a number, leave it as is.
case (_: NumericType, NumericType) => e
// If the function accepts any numeric type and the input is a string, we follow the hive
// convention and cast that input into a double
case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)
// Implicit cast among numeric types
// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to unlimited precision decimal.
@ -732,7 +740,7 @@ object HiveTypeCoercion {
// First see if we can find our input type in the type collection. If we can, then just
// use the current expression; otherwise, find the first one we can implicitly cast.
case (_, TypeCollection(types)) =>
if (types.exists(_.isParentOf(inType))) {
if (types.exists(_.isSameType(inType))) {
e
} else {
types.flatMap(implicitCast(e, _)).headOption.orNull

View file

@ -37,7 +37,16 @@ trait ExpectsInputTypes { self: Expression =>
def inputTypes: Seq[AbstractDataType]
override def checkInputDataTypes(): TypeCheckResult = {
// TODO: implement proper type checking.
TypeCheckResult.TypeCheckSuccess
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
s"Argument ${idx + 1} is expected to be of type ${expected.simpleString}, " +
s"however, ${child.prettyString} is of type ${child.dataType.simpleString}."
}
if (mismatches.isEmpty) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(mismatches.mkString(" "))
}
}
}

View file

@ -34,9 +34,16 @@ private[sql] abstract class AbstractDataType {
private[sql] def defaultConcreteType: DataType
/**
* Returns true if this data type is a parent of the `childCandidate`.
* Returns true if this data type is the same type as `other`. This is different that equality
* as equality will also consider data type parametrization, such as decimal precision.
*/
private[sql] def isParentOf(childCandidate: DataType): Boolean
private[sql] def isSameType(other: DataType): Boolean
/**
* Returns true if `other` is an acceptable input type for a function that expectes this,
* possibly abstract, DataType.
*/
private[sql] def acceptsType(other: DataType): Boolean = isSameType(other)
/** Readable string representation for the type. */
private[sql] def simpleString: String
@ -58,11 +65,14 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
require(types.nonEmpty, s"TypeCollection ($types) cannot be empty")
private[sql] override def defaultConcreteType: DataType = types.head.defaultConcreteType
override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType
private[sql] override def isParentOf(childCandidate: DataType): Boolean = false
override private[sql] def isSameType(other: DataType): Boolean = false
private[sql] override def simpleString: String = {
override private[sql] def acceptsType(other: DataType): Boolean =
types.exists(_.isSameType(other))
override private[sql] def simpleString: String = {
types.map(_.simpleString).mkString("(", " or ", ")")
}
}
@ -108,7 +118,7 @@ abstract class NumericType extends AtomicType {
}
private[sql] object NumericType {
private[sql] object NumericType extends AbstractDataType {
/**
* Enables matching against NumericType for expressions:
* {{{
@ -117,6 +127,14 @@ private[sql] object NumericType {
* }}}
*/
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[NumericType]
override private[sql] def defaultConcreteType: DataType = DoubleType
override private[sql] def simpleString: String = "numeric"
override private[sql] def isSameType(other: DataType): Boolean = false
override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType]
}

View file

@ -26,13 +26,13 @@ object ArrayType extends AbstractDataType {
/** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */
def apply(elementType: DataType): ArrayType = ArrayType(elementType, containsNull = true)
private[sql] override def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)
private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
childCandidate.isInstanceOf[ArrayType]
override private[sql] def isSameType(other: DataType): Boolean = {
other.isInstanceOf[ArrayType]
}
private[sql] override def simpleString: String = "array"
override private[sql] def simpleString: String = "array"
}

View file

@ -76,9 +76,9 @@ abstract class DataType extends AbstractDataType {
*/
private[spark] def asNullable: DataType
private[sql] override def defaultConcreteType: DataType = this
override private[sql] def defaultConcreteType: DataType = this
private[sql] override def isParentOf(childCandidate: DataType): Boolean = this == childCandidate
override private[sql] def isSameType(other: DataType): Boolean = this == other
}

View file

@ -84,13 +84,13 @@ case class DecimalType(precisionInfo: Option[PrecisionInfo]) extends FractionalT
/** Extra factory methods and pattern matchers for Decimals */
object DecimalType extends AbstractDataType {
private[sql] override def defaultConcreteType: DataType = Unlimited
override private[sql] def defaultConcreteType: DataType = Unlimited
private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
childCandidate.isInstanceOf[DecimalType]
override private[sql] def isSameType(other: DataType): Boolean = {
other.isInstanceOf[DecimalType]
}
private[sql] override def simpleString: String = "decimal"
override private[sql] def simpleString: String = "decimal"
val Unlimited: DecimalType = DecimalType(None)

View file

@ -69,13 +69,13 @@ case class MapType(
object MapType extends AbstractDataType {
private[sql] override def defaultConcreteType: DataType = apply(NullType, NullType)
override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType)
private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
childCandidate.isInstanceOf[MapType]
override private[sql] def isSameType(other: DataType): Boolean = {
other.isInstanceOf[MapType]
}
private[sql] override def simpleString: String = "map"
override private[sql] def simpleString: String = "map"
/**
* Construct a [[MapType]] object with the given key type and value type.

View file

@ -303,13 +303,13 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
object StructType extends AbstractDataType {
private[sql] override def defaultConcreteType: DataType = new StructType
override private[sql] def defaultConcreteType: DataType = new StructType
private[sql] override def isParentOf(childCandidate: DataType): Boolean = {
childCandidate.isInstanceOf[StructType]
override private[sql] def isSameType(other: DataType): Boolean = {
other.isInstanceOf[StructType]
}
private[sql] override def simpleString: String = "struct"
override private[sql] def simpleString: String = "struct"
private[sql] def fromString(raw: String): StructType = DataType.fromString(raw) match {
case t: StructType => t

View file

@ -77,5 +77,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable {
* For UDT, asNullable will not change the nullability of its internal sqlType and just returns
* itself.
*/
private[spark] override def asNullable: UserDefinedType[UserType] = this
override private[spark] def asNullable: UserDefinedType[UserType] = this
override private[sql] def acceptsType(dataType: DataType) =
this.getClass == dataType.getClass
}

View file

@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.analysis
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.{InternalRow, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
case class TestFunction(
children: Seq[Expression],
inputTypes: Seq[AbstractDataType]) extends Expression with ExpectsInputTypes {
override def nullable: Boolean = true
override def eval(input: InternalRow): Any = throw new UnsupportedOperationException
override def dataType: DataType = StringType
}
case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output: Seq[Attribute] = Nil
}
class AnalysisErrorSuite extends SparkFunSuite with BeforeAndAfter {
import AnalysisSuite._
def errorTest(
name: String,
plan: LogicalPlan,
errorMessages: Seq[String],
caseSensitive: Boolean = true): Unit = {
test(name) {
val error = intercept[AnalysisException] {
if (caseSensitive) {
caseSensitiveAnalyze(plan)
} else {
caseInsensitiveAnalyze(plan)
}
}
errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase))
}
}
val dateLit = Literal.create(null, DateType)
errorTest(
"single invalid type, single arg",
testRelation.select(TestFunction(dateLit :: Nil, IntegerType :: Nil).as('a)),
"cannot resolve" :: "testfunction" :: "argument 1" :: "expected to be of type int" ::
"null is of type date" ::Nil)
errorTest(
"single invalid type, second arg",
testRelation.select(
TestFunction(dateLit :: dateLit :: Nil, DateType :: IntegerType :: Nil).as('a)),
"cannot resolve" :: "testfunction" :: "argument 2" :: "expected to be of type int" ::
"null is of type date" ::Nil)
errorTest(
"multiple invalid type",
testRelation.select(
TestFunction(dateLit :: dateLit :: Nil, IntegerType :: IntegerType :: Nil).as('a)),
"cannot resolve" :: "testfunction" :: "argument 1" :: "argument 2" ::
"expected to be of type int" :: "null is of type date" ::Nil)
errorTest(
"unresolved window function",
testRelation2.select(
WindowExpression(
UnresolvedWindowFunction(
"lead",
UnresolvedAttribute("c") :: Nil),
WindowSpecDefinition(
UnresolvedAttribute("a") :: Nil,
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
UnspecifiedFrame)).as('window)),
"lead" :: "window functions currently requires a HiveContext" :: Nil)
errorTest(
"too many generators",
listRelation.select(Explode('list).as('a), Explode('list).as('b)),
"only one generator" :: "explode" :: Nil)
errorTest(
"unresolved attributes",
testRelation.select('abcd),
"cannot resolve" :: "abcd" :: Nil)
errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
"cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
errorTest(
"non-boolean filters",
testRelation.where(Literal(1)),
"filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil)
errorTest(
"missing group by",
testRelation2.groupBy('a)('b),
"'b'" :: "group by" :: Nil
)
errorTest(
"ambiguous field",
nestedRelation.select($"top.duplicateField"),
"Ambiguous reference to fields" :: "duplicateField" :: Nil,
caseSensitive = false)
errorTest(
"ambiguous field due to case insensitivity",
nestedRelation.select($"top.differentCase"),
"Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil,
caseSensitive = false)
errorTest(
"missing field",
nestedRelation2.select($"top.c"),
"No such struct field" :: "aField" :: "bField" :: "cField" :: Nil,
caseSensitive = false)
errorTest(
"catch all unresolved plan",
UnresolvedTestPlan(),
"unresolved" :: Nil)
test("SPARK-6452 regression test") {
// CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
val plan =
Aggregate(
Nil,
Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
LocalRelation(
AttributeReference("a", IntegerType)(exprId = ExprId(2))))
assert(plan.resolved)
val message = intercept[AnalysisException] {
caseSensitiveAnalyze(plan)
}.getMessage
assert(message.contains("resolved attribute(s) a#1 missing from a#2"))
}
}

View file

@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.SimpleCatalystConf
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
object AnalysisSuite {
val caseSensitiveConf = new SimpleCatalystConf(true)
val caseInsensitiveConf = new SimpleCatalystConf(false)
@ -61,25 +61,28 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
StructField("duplicateField", StringType) ::
StructField("differentCase", StringType) ::
StructField("differentcase", StringType) :: Nil
StructField("duplicateField", StringType) ::
StructField("differentCase", StringType) ::
StructField("differentcase", StringType) :: Nil
))())
val nestedRelation2 = LocalRelation(
AttributeReference("top", StructType(
StructField("aField", StringType) ::
StructField("bField", StringType) ::
StructField("cField", StringType) :: Nil
StructField("bField", StringType) ::
StructField("cField", StringType) :: Nil
))())
val listRelation = LocalRelation(
AttributeReference("list", ArrayType(IntegerType))())
before {
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
}
caseSensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
caseInsensitiveCatalog.registerTable(Seq("TaBlE"), testRelation)
}
class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
import AnalysisSuite._
test("union project *") {
val plan = (1 to 100)
@ -149,91 +152,6 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
caseInsensitiveAnalyzer.execute(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
}
def errorTest(
name: String,
plan: LogicalPlan,
errorMessages: Seq[String],
caseSensitive: Boolean = true): Unit = {
test(name) {
val error = intercept[AnalysisException] {
if (caseSensitive) {
caseSensitiveAnalyze(plan)
} else {
caseInsensitiveAnalyze(plan)
}
}
errorMessages.foreach(m => assert(error.getMessage.toLowerCase contains m.toLowerCase))
}
}
errorTest(
"unresolved window function",
testRelation2.select(
WindowExpression(
UnresolvedWindowFunction(
"lead",
UnresolvedAttribute("c") :: Nil),
WindowSpecDefinition(
UnresolvedAttribute("a") :: Nil,
SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil,
UnspecifiedFrame)).as('window)),
"lead" :: "window functions currently requires a HiveContext" :: Nil)
errorTest(
"too many generators",
listRelation.select(Explode('list).as('a), Explode('list).as('b)),
"only one generator" :: "explode" :: Nil)
errorTest(
"unresolved attributes",
testRelation.select('abcd),
"cannot resolve" :: "abcd" :: Nil)
errorTest(
"bad casts",
testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
"cannot cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
errorTest(
"non-boolean filters",
testRelation.where(Literal(1)),
"filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil)
errorTest(
"missing group by",
testRelation2.groupBy('a)('b),
"'b'" :: "group by" :: Nil
)
errorTest(
"ambiguous field",
nestedRelation.select($"top.duplicateField"),
"Ambiguous reference to fields" :: "duplicateField" :: Nil,
caseSensitive = false)
errorTest(
"ambiguous field due to case insensitivity",
nestedRelation.select($"top.differentCase"),
"Ambiguous reference to fields" :: "differentCase" :: "differentcase" :: Nil,
caseSensitive = false)
errorTest(
"missing field",
nestedRelation2.select($"top.c"),
"No such struct field" :: "aField" :: "bField" :: "cField" :: Nil,
caseSensitive = false)
case class UnresolvedTestPlan() extends LeafNode {
override lazy val resolved = false
override def output: Seq[Attribute] = Nil
}
errorTest(
"catch all unresolved plan",
UnresolvedTestPlan(),
"unresolved" :: Nil)
test("divide should be casted into fractional types") {
val testRelation2 = LocalRelation(
@ -258,22 +176,4 @@ class AnalysisSuite extends SparkFunSuite with BeforeAndAfter {
assert(pl(3).dataType == DecimalType.Unlimited)
assert(pl(4).dataType == DoubleType)
}
test("SPARK-6452 regression test") {
// CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s)
val plan =
Aggregate(
Nil,
Alias(Sum(AttributeReference("a", IntegerType)(exprId = ExprId(1))), "b")() :: Nil,
LocalRelation(
AttributeReference("a", IntegerType)(exprId = ExprId(2))))
assert(plan.resolved)
val message = intercept[AnalysisException] {
caseSensitiveAnalyze(plan)
}.getMessage
assert(message.contains("resolved attribute(s) a#1 missing from a#2"))
}
}

View file

@ -77,6 +77,14 @@ class HiveTypeCoercionSuite extends PlanTest {
shouldCast(DecimalType(10, 2), TypeCollection(IntegerType, DecimalType), DecimalType(10, 2))
shouldCast(DecimalType(10, 2), TypeCollection(DecimalType, IntegerType), DecimalType(10, 2))
shouldCast(IntegerType, TypeCollection(DecimalType(10, 2), StringType), DecimalType(10, 2))
shouldCast(StringType, NumericType, DoubleType)
// NumericType should not be changed when function accepts any of them.
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType,
DecimalType.Unlimited, DecimalType(10, 2)).foreach { tpe =>
shouldCast(tpe, NumericType, tpe)
}
}
test("ineligible implicit type cast") {

View file

@ -359,7 +359,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
hiveconf.set(key, value)
}
private[sql] override def setConf[T](entry: SQLConfEntry[T], value: T): Unit = {
override private[sql] def setConf[T](entry: SQLConfEntry[T], value: T): Unit = {
setConf(entry.key, entry.stringConverter(value))
}