[SPARK-23908][SQL][FOLLOW-UP] Rename inputs to arguments, and add argument type check.

## What changes were proposed in this pull request?

This is a follow-up pr of #21954 to address comments.

- Rename ambiguous name `inputs` to `arguments`.
- Add argument type check and remove hacky workaround.
- Address other small comments.

## How was this patch tested?

Existing tests and some additional tests.

Closes #22075 from ueshin/issues/SPARK-23908/fup1.

Authored-by: Takuya UESHIN <ueshin@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
This commit is contained in:
Takuya UESHIN 2018-08-13 20:58:29 +08:00 committed by Wenchen Fan
parent 2e3abdff23
commit b804ca5771
6 changed files with 152 additions and 98 deletions

View file

@ -90,6 +90,20 @@ trait CheckAnalysis extends PredicateHelper {
u.failAnalysis(s"Table or view not found: ${u.tableIdentifier}")
case operator: LogicalPlan =>
// Check argument data types of higher-order functions downwards first.
// If the arguments of the higher-order functions are resolved but the type check fails,
// the argument functions will not get resolved, but we should report the argument type
// check failure instead of claiming the argument functions are unresolved.
operator transformExpressionsDown {
case hof: HigherOrderFunction
if hof.argumentsResolved && hof.checkArgumentDataTypes().isFailure =>
hof.checkArgumentDataTypes() match {
case TypeCheckResult.TypeCheckFailure(message) =>
hof.failAnalysis(
s"cannot resolve '${hof.sql}' due to argument data type mismatch: $message")
}
}
operator transformExpressionsUp {
case a: Attribute if !a.resolved =>
val from = operator.inputSet.map(_.qualifiedName).mkString(", ")

View file

@ -95,15 +95,15 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
*/
private def createLambda(
e: Expression,
partialArguments: Seq[(DataType, Boolean)]): LambdaFunction = e match {
argInfo: Seq[(DataType, Boolean)]): LambdaFunction = e match {
case f: LambdaFunction if f.bound => f
case LambdaFunction(function, names, _) =>
if (names.size != partialArguments.size) {
if (names.size != argInfo.size) {
e.failAnalysis(
s"The number of lambda function arguments '${names.size}' does not " +
"match the number of arguments expected by the higher order function " +
s"'${partialArguments.size}'.")
s"'${argInfo.size}'.")
}
if (names.map(a => canonicalizer(a.name)).distinct.size < names.size) {
@ -111,7 +111,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
"Lambda function arguments should not have names that are semantically the same.")
}
val arguments = partialArguments.zip(names).map {
val arguments = argInfo.zip(names).map {
case ((dataType, nullable), ne) =>
NamedLambdaVariable(ne.name, dataType, nullable)
}
@ -122,7 +122,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
// create a lambda function with default parameters because this is expected by the higher
// order function. Note that we hide the lambda variables produced by this function in order
// to prevent accidental naming collisions.
val arguments = partialArguments.zipWithIndex.map {
val arguments = argInfo.zipWithIndex.map {
case ((dataType, nullable), i) =>
NamedLambdaVariable(s"col$i", dataType, nullable)
}
@ -135,7 +135,7 @@ case class ResolveLambdaVariables(conf: SQLConf) extends Rule[LogicalPlan] {
private def resolve(e: Expression, parentLambdaMap: LambdaVariableMap): Expression = e match {
case _ if e.resolved => e
case h: HigherOrderFunction if h.inputResolved =>
case h: HigherOrderFunction if h.argumentsResolved && h.checkArgumentDataTypes().isSuccess =>
h.bind(createLambda).mapChildren(resolve(_, parentLambdaMap))
case l: LambdaFunction if !l.bound =>

View file

@ -41,10 +41,19 @@ trait ExpectsInputTypes extends Expression {
def inputTypes: Seq[AbstractDataType]
override def checkInputDataTypes(): TypeCheckResult = {
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
ExpectsInputTypes.checkInputDataTypes(children, inputTypes)
}
}
object ExpectsInputTypes {
def checkInputDataTypes(
inputs: Seq[Expression],
inputTypes: Seq[AbstractDataType]): TypeCheckResult = {
val mismatches = inputs.zip(inputTypes).zipWithIndex.collect {
case ((input, expected), idx) if !expected.acceptsType(input.dataType) =>
s"argument ${idx + 1} requires ${expected.simpleString} type, " +
s"however, '${child.sql}' is of ${child.dataType.catalogString} type."
s"however, '${input.sql}' is of ${input.dataType.catalogString} type."
}
if (mismatches.isEmpty) {
@ -55,7 +64,6 @@ trait ExpectsInputTypes extends Expression {
}
}
/**
* A mixin for the analyzer to perform implicit type casting using
* [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]].

View file

@ -35,8 +35,8 @@ case class NamedLambdaVariable(
name: String,
dataType: DataType,
nullable: Boolean,
value: AtomicReference[Any] = new AtomicReference(),
exprId: ExprId = NamedExpression.newExprId)
exprId: ExprId = NamedExpression.newExprId,
value: AtomicReference[Any] = new AtomicReference())
extends LeafExpression
with NamedExpression
with CodegenFallback {
@ -44,7 +44,7 @@ case class NamedLambdaVariable(
override def qualifier: Seq[String] = Seq.empty
override def newInstance(): NamedExpression =
copy(value = new AtomicReference(), exprId = NamedExpression.newExprId)
copy(exprId = NamedExpression.newExprId, value = new AtomicReference())
override def toAttribute: Attribute = {
AttributeReference(name, dataType, nullable, Metadata.empty)(exprId, Seq.empty)
@ -88,30 +88,45 @@ object LambdaFunction {
* A higher order function takes one or more (lambda) functions and applies these to some objects.
* The function produces a number of variables which can be consumed by some lambda function.
*/
trait HigherOrderFunction extends Expression {
trait HigherOrderFunction extends Expression with ExpectsInputTypes {
override def children: Seq[Expression] = inputs ++ functions
override def children: Seq[Expression] = arguments ++ functions
/**
* Inputs to the higher ordered function.
* Arguments of the higher ordered function.
*/
def inputs: Seq[Expression]
def arguments: Seq[Expression]
def argumentTypes: Seq[AbstractDataType]
/**
* All inputs have been resolved. This means that the types and nullabilty of (most of) the
* All arguments have been resolved. This means that the types and nullabilty of (most of) the
* lambda function arguments is known, and that we can start binding the lambda functions.
*/
lazy val inputResolved: Boolean = inputs.forall(_.resolved)
lazy val argumentsResolved: Boolean = arguments.forall(_.resolved)
/**
* Checks the argument data types, returns `TypeCheckResult.success` if it's valid,
* or returns a `TypeCheckResult` with an error message if invalid.
* Note: it's not valid to call this method until `argumentsResolved == true`.
*/
def checkArgumentDataTypes(): TypeCheckResult = {
ExpectsInputTypes.checkInputDataTypes(arguments, argumentTypes)
}
/**
* Functions applied by the higher order function.
*/
def functions: Seq[Expression]
def functionTypes: Seq[AbstractDataType]
override def inputTypes: Seq[AbstractDataType] = argumentTypes ++ functionTypes
/**
* All inputs must be resolved and all functions must be resolved lambda functions.
*/
override lazy val resolved: Boolean = inputResolved && functions.forall {
override lazy val resolved: Boolean = argumentsResolved && functions.forall {
case l: LambdaFunction => l.resolved
case _ => false
}
@ -123,6 +138,8 @@ trait HigherOrderFunction extends Expression {
*/
def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): HigherOrderFunction
// Make sure the lambda variables refer the same instances as of arguments for case that the
// variables in instantiated separately during serialization or for some reason.
@transient lazy val functionsForEval: Seq[Expression] = functions.map {
case LambdaFunction(function, arguments, hidden) =>
val argumentMap = arguments.map { arg => arg.exprId -> arg }.toMap
@ -133,51 +150,38 @@ trait HigherOrderFunction extends Expression {
}
}
object HigherOrderFunction {
def arrayArgumentType(dt: DataType): (DataType, Boolean) = {
dt match {
case ArrayType(elementType, containsNull) => (elementType, containsNull)
case _ =>
val ArrayType(elementType, containsNull) = ArrayType.defaultConcreteType
(elementType, containsNull)
}
}
def mapKeyValueArgumentType(dt: DataType): (DataType, DataType, Boolean) = dt match {
case MapType(kType, vType, vContainsNull) => (kType, vType, vContainsNull)
case _ =>
val MapType(kType, vType, vContainsNull) = MapType.defaultConcreteType
(kType, vType, vContainsNull)
}
}
/**
* Trait for functions having as input one argument and one function.
*/
trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTypes {
trait SimpleHigherOrderFunction extends HigherOrderFunction {
def input: Expression
def argument: Expression
override def inputs: Seq[Expression] = input :: Nil
override def arguments: Seq[Expression] = argument :: Nil
def argumentType: AbstractDataType
override def argumentTypes(): Seq[AbstractDataType] = argumentType :: Nil
def function: Expression
override def functions: Seq[Expression] = function :: Nil
def expectingFunctionType: AbstractDataType = AnyDataType
def functionType: AbstractDataType = AnyDataType
@transient lazy val functionForEval: Expression = functionsForEval.head
override def functionTypes: Seq[AbstractDataType] = functionType :: Nil
def functionForEval: Expression = functionsForEval.head
/**
* Called by [[eval]]. If a subclass keeps the default nullability, it can override this method
* in order to save null-check code.
*/
protected def nullSafeEval(inputRow: InternalRow, input: Any): Any =
protected def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any =
sys.error(s"UnaryHigherOrderFunction must override either eval or nullSafeEval")
override def eval(inputRow: InternalRow): Any = {
val value = input.eval(inputRow)
val value = argument.eval(inputRow)
if (value == null) {
null
} else {
@ -187,11 +191,11 @@ trait SimpleHigherOrderFunction extends HigherOrderFunction with ExpectsInputTyp
}
trait ArrayBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, expectingFunctionType)
override def argumentType: AbstractDataType = ArrayType
}
trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType, expectingFunctionType)
override def argumentType: AbstractDataType = MapType
}
/**
@ -209,21 +213,21 @@ trait MapBasedSimpleHigherOrderFunction extends SimpleHigherOrderFunction {
""",
since = "2.4.0")
case class ArrayTransform(
input: Expression,
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
override def nullable: Boolean = input.nullable
override def nullable: Boolean = argument.nullable
override def dataType: ArrayType = ArrayType(function.dataType, function.nullable)
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayTransform = {
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
val ArrayType(elementType, containsNull) = argument.dataType
function match {
case LambdaFunction(_, arguments, _) if arguments.size == 2 =>
copy(function = f(function, elem :: (IntegerType, false) :: Nil))
copy(function = f(function, (elementType, containsNull) :: (IntegerType, false) :: Nil))
case _ =>
copy(function = f(function, elem :: Nil))
copy(function = f(function, (elementType, containsNull) :: Nil))
}
}
@ -237,8 +241,8 @@ case class ArrayTransform(
(elementVar, indexVar)
}
override def nullSafeEval(inputRow: InternalRow, inputValue: Any): Any = {
val arr = inputValue.asInstanceOf[ArrayData]
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
val result = new GenericArrayData(new Array[Any](arr.numElements))
var i = 0
@ -268,7 +272,7 @@ examples = """
""",
since = "2.4.0")
case class MapFilter(
input: Expression,
argument: Expression,
function: Expression)
extends MapBasedSimpleHigherOrderFunction with CodegenFallback {
@ -277,17 +281,16 @@ case class MapFilter(
(args.head.asInstanceOf[NamedLambdaVariable], args.tail.head.asInstanceOf[NamedLambdaVariable])
}
@transient val (keyType, valueType, valueContainsNull) =
HigherOrderFunction.mapKeyValueArgumentType(input.dataType)
@transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): MapFilter = {
copy(function = f(function, (keyType, false) :: (valueType, valueContainsNull) :: Nil))
}
override def nullable: Boolean = input.nullable
override def nullable: Boolean = argument.nullable
override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val m = value.asInstanceOf[MapData]
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val m = argumentValue.asInstanceOf[MapData]
val f = functionForEval
val retKeys = new mutable.ListBuffer[Any]
val retValues = new mutable.ListBuffer[Any]
@ -302,9 +305,9 @@ case class MapFilter(
ArrayBasedMapData(retKeys.toArray, retValues.toArray)
}
override def dataType: DataType = input.dataType
override def dataType: DataType = argument.dataType
override def expectingFunctionType: AbstractDataType = BooleanType
override def functionType: AbstractDataType = BooleanType
override def prettyName: String = "map_filter"
}
@ -321,25 +324,25 @@ case class MapFilter(
""",
since = "2.4.0")
case class ArrayFilter(
input: Expression,
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
override def nullable: Boolean = input.nullable
override def nullable: Boolean = argument.nullable
override def dataType: DataType = input.dataType
override def dataType: DataType = argument.dataType
override def expectingFunctionType: AbstractDataType = BooleanType
override def functionType: AbstractDataType = BooleanType
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayFilter = {
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
copy(function = f(function, elem :: Nil))
val ArrayType(elementType, containsNull) = argument.dataType
copy(function = f(function, (elementType, containsNull) :: Nil))
}
@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function
override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val arr = value.asInstanceOf[ArrayData]
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
val buffer = new mutable.ArrayBuffer[Any](arr.numElements)
var i = 0
@ -368,25 +371,25 @@ case class ArrayFilter(
""",
since = "2.4.0")
case class ArrayExists(
input: Expression,
argument: Expression,
function: Expression)
extends ArrayBasedSimpleHigherOrderFunction with CodegenFallback {
override def nullable: Boolean = input.nullable
override def nullable: Boolean = argument.nullable
override def dataType: DataType = BooleanType
override def expectingFunctionType: AbstractDataType = BooleanType
override def functionType: AbstractDataType = BooleanType
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayExists = {
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
copy(function = f(function, elem :: Nil))
val ArrayType(elementType, containsNull) = argument.dataType
copy(function = f(function, (elementType, containsNull) :: Nil))
}
@transient lazy val LambdaFunction(_, Seq(elementVar: NamedLambdaVariable), _) = function
override def nullSafeEval(inputRow: InternalRow, value: Any): Any = {
val arr = value.asInstanceOf[ArrayData]
override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = {
val arr = argumentValue.asInstanceOf[ArrayData]
val f = functionForEval
var exists = false
var i = 0
@ -422,45 +425,49 @@ case class ArrayExists(
""",
since = "2.4.0")
case class ArrayAggregate(
input: Expression,
argument: Expression,
zero: Expression,
merge: Expression,
finish: Expression)
extends HigherOrderFunction with CodegenFallback {
def this(input: Expression, zero: Expression, merge: Expression) = {
this(input, zero, merge, LambdaFunction.identity)
def this(argument: Expression, zero: Expression, merge: Expression) = {
this(argument, zero, merge, LambdaFunction.identity)
}
override def inputs: Seq[Expression] = input :: zero :: Nil
override def arguments: Seq[Expression] = argument :: zero :: Nil
override def argumentTypes: Seq[AbstractDataType] = ArrayType :: AnyDataType :: Nil
override def functions: Seq[Expression] = merge :: finish :: Nil
override def nullable: Boolean = input.nullable || finish.nullable
override def functionTypes: Seq[AbstractDataType] = zero.dataType :: AnyDataType :: Nil
override def nullable: Boolean = argument.nullable || finish.nullable
override def dataType: DataType = finish.dataType
override def checkInputDataTypes(): TypeCheckResult = {
if (!ArrayType.acceptsType(input.dataType)) {
TypeCheckResult.TypeCheckFailure(
s"argument 1 requires ${ArrayType.simpleString} type, " +
s"however, '${input.sql}' is of ${input.dataType.catalogString} type.")
} else if (!DataType.equalsStructurally(
zero.dataType, merge.dataType, ignoreNullability = true)) {
TypeCheckResult.TypeCheckFailure(
s"argument 3 requires ${zero.dataType.simpleString} type, " +
s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.")
} else {
TypeCheckResult.TypeCheckSuccess
checkArgumentDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (!DataType.equalsStructurally(
zero.dataType, merge.dataType, ignoreNullability = true)) {
TypeCheckResult.TypeCheckFailure(
s"argument 3 requires ${zero.dataType.simpleString} type, " +
s"however, '${merge.sql}' is of ${merge.dataType.catalogString} type.")
} else {
TypeCheckResult.TypeCheckSuccess
}
case failure => failure
}
}
override def bind(f: (Expression, Seq[(DataType, Boolean)]) => LambdaFunction): ArrayAggregate = {
// Be very conservative with nullable. We cannot be sure that the accumulator does not
// evaluate to null. So we always set nullable to true here.
val elem = HigherOrderFunction.arrayArgumentType(input.dataType)
val ArrayType(elementType, containsNull) = argument.dataType
val acc = zero.dataType -> true
val newMerge = f(merge, acc :: elem :: Nil)
val newMerge = f(merge, acc :: (elementType, containsNull) :: Nil)
val newFinish = f(finish, acc :: Nil)
copy(merge = newMerge, finish = newFinish)
}
@ -470,7 +477,7 @@ case class ArrayAggregate(
@transient lazy val LambdaFunction(_, Seq(accForFinishVar: NamedLambdaVariable), _) = finish
override def eval(input: InternalRow): Any = {
val arr = this.input.eval(input).asInstanceOf[ArrayData]
val arr = argument.eval(input).asInstanceOf[ArrayData]
if (arr == null) {
null
} else {

View file

@ -81,7 +81,7 @@ trait PlanTestBase extends PredicateHelper { self: Suite =>
case ae: AggregateExpression =>
ae.copy(resultId = ExprId(0))
case lv: NamedLambdaVariable =>
lv.copy(value = null, exprId = ExprId(0))
lv.copy(exprId = ExprId(0), value = null)
}
}

View file

@ -1852,6 +1852,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
df.selectExpr("transform(i, x -> x)")
}
assert(ex2.getMessage.contains("data type mismatch: argument 1 requires array type"))
val ex3 = intercept[AnalysisException] {
df.selectExpr("transform(a, x -> x)")
}
assert(ex3.getMessage.contains("cannot resolve '`a`'"))
}
test("map_filter") {
@ -1898,6 +1903,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
df.selectExpr("map_filter(i, (k, v) -> k > v)")
}
assert(ex3.getMessage.contains("data type mismatch: argument 1 requires map type"))
val ex4 = intercept[AnalysisException] {
df.selectExpr("map_filter(a, (k, v) -> k > v)")
}
assert(ex4.getMessage.contains("cannot resolve '`a`'"))
}
test("filter function - array for primitive type not containing null") {
@ -1994,6 +2004,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
df.selectExpr("filter(s, x -> x)")
}
assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type"))
val ex4 = intercept[AnalysisException] {
df.selectExpr("filter(a, x -> x)")
}
assert(ex4.getMessage.contains("cannot resolve '`a`'"))
}
test("exists function - array for primitive type not containing null") {
@ -2090,6 +2105,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
df.selectExpr("exists(s, x -> x)")
}
assert(ex3.getMessage.contains("data type mismatch: argument 2 requires boolean type"))
val ex4 = intercept[AnalysisException] {
df.selectExpr("exists(a, x -> x)")
}
assert(ex4.getMessage.contains("cannot resolve '`a`'"))
}
test("aggregate function - array for primitive type not containing null") {
@ -2211,6 +2231,11 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
df.selectExpr("aggregate(s, 0, (acc, x) -> x)")
}
assert(ex4.getMessage.contains("data type mismatch: argument 3 requires int type"))
val ex5 = intercept[AnalysisException] {
df.selectExpr("aggregate(a, 0, (acc, x) -> x)")
}
assert(ex5.getMessage.contains("cannot resolve '`a`'"))
}
private def assertValuesDoNotChangeAfterCoalesceOrUnion(v: Column): Unit = {