[SPARK-21281][SQL] Use string types by default if array and map have no argument

## What changes were proposed in this pull request?
This pr modified code to use string types by default if `array` and `map` in functions have no argument. This behaviour is the same with Hive one;
```
hive> CREATE TEMPORARY TABLE t1 AS SELECT map();
hive> DESCRIBE t1;
_c0   map<string,string>

hive> CREATE TEMPORARY TABLE t2 AS SELECT array();
hive> DESCRIBE t2;
_c0   array<string>
```

## How was this patch tested?
Added tests in `DataFrameFunctionsSuite`.

Author: Takeshi Yamamuro <yamamuro@apache.org>

Closes #18516 from maropu/SPARK-21281.
This commit is contained in:
Takeshi Yamamuro 2017-07-07 23:05:38 -07:00 committed by gatorsmile
parent e1a172c201
commit 7896e7b99d
7 changed files with 74 additions and 33 deletions

View file

@ -527,13 +527,14 @@ case class Least(children: Seq[Expression]) extends Expression {
override def checkInputDataTypes(): TypeCheckResult = {
if (children.length <= 1) {
TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments")
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least two arguments")
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
} else {
TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
}
}
@ -592,13 +593,14 @@ case class Greatest(children: Seq[Expression]) extends Expression {
override def checkInputDataTypes(): TypeCheckResult = {
if (children.length <= 1) {
TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments")
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least two arguments")
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
TypeCheckResult.TypeCheckFailure(
s"The expressions should all have the same type," +
s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
} else {
TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
}
}

View file

@ -41,12 +41,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
override def checkInputDataTypes(): TypeCheckResult =
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
override def checkInputDataTypes(): TypeCheckResult = {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
}
override def dataType: ArrayType = {
ArrayType(
children.headOption.map(_.dataType).getOrElse(NullType),
children.headOption.map(_.dataType).getOrElse(StringType),
containsNull = children.exists(_.nullable))
}
@ -93,7 +94,7 @@ private [sql] object GenArrayData {
if (!ctx.isPrimitiveType(elementType)) {
val genericArrayClass = classOf[GenericArrayData].getName
ctx.addMutableState("Object[]", arrayName,
s"$arrayName = new Object[${numElements}];")
s"$arrayName = new Object[$numElements];")
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
val isNullAssignment = if (!isMapKey) {
@ -119,7 +120,7 @@ private [sql] object GenArrayData {
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
val baseOffset = Platform.BYTE_ARRAY_OFFSET
ctx.addMutableState("UnsafeArrayData", arrayDataName, "");
ctx.addMutableState("UnsafeArrayData", arrayDataName, "")
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
@ -169,13 +170,16 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.")
TypeCheckResult.TypeCheckFailure(
s"$prettyName expects a positive even number of arguments.")
} else if (keys.map(_.dataType).distinct.length > 1) {
TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " +
"type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
TypeCheckResult.TypeCheckFailure(
"The given keys of function map should all be the same type, but they are " +
keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else if (values.map(_.dataType).distinct.length > 1) {
TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " +
"type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
TypeCheckResult.TypeCheckFailure(
"The given values of function map should all be the same type, but they are " +
values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else {
TypeCheckResult.TypeCheckSuccess
}
@ -183,8 +187,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
override def dataType: DataType = {
MapType(
keyType = keys.headOption.map(_.dataType).getOrElse(NullType),
valueType = values.headOption.map(_.dataType).getOrElse(NullType),
keyType = keys.headOption.map(_.dataType).getOrElse(StringType),
valueType = values.headOption.map(_.dataType).getOrElse(StringType),
valueContainsNull = values.exists(_.nullable))
}
@ -292,14 +296,17 @@ trait CreateNamedStructLike extends Expression {
}
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
if (children.length < 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least one argument")
} else if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
} else {
val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
if (invalidNames.nonEmpty) {
TypeCheckResult.TypeCheckFailure(
"Only foldable StringType expressions are allowed to appear at odd position, got:" +
s" ${invalidNames.mkString(",")}")
s" ${invalidNames.mkString(",")}")
} else if (!names.contains(null)) {
TypeCheckResult.TypeCheckSuccess
} else {

View file

@ -247,8 +247,9 @@ abstract class HashExpression[E] extends Expression {
override def nullable: Boolean = false
override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckFailure("function hash requires at least one argument")
if (children.length < 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least one argument")
} else {
TypeCheckResult.TypeCheckSuccess
}

View file

@ -52,10 +52,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
override def checkInputDataTypes(): TypeCheckResult = {
if (children == Nil) {
TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty")
if (children.length < 1) {
TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName requires at least one argument")
} else {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce")
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
}
}

View file

@ -155,7 +155,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
"input to function array should all be the same type")
assertError(Coalesce(Seq('intField, 'booleanField)),
"input to function coalesce should all be the same type")
assertError(Coalesce(Nil), "input to function coalesce cannot be empty")
assertError(Coalesce(Nil), "function coalesce requires at least one argument")
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
assertError(Explode('intField),
"input to function explode should be array or map type")
@ -207,7 +207,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
test("check types for Greatest/Least") {
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
assertError(operator(Seq('booleanField)), "requires at least two arguments")
assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
}

View file

@ -1565,10 +1565,7 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
def greatest(exprs: Column*): Column = withExpr {
require(exprs.length > 1, "greatest requires at least 2 arguments.")
Greatest(exprs.map(_.expr))
}
def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) }
/**
* Returns the greatest value of the list of column names, skipping null values.
@ -1672,10 +1669,7 @@ object functions {
* @since 1.5.0
*/
@scala.annotation.varargs
def least(exprs: Column*): Column = withExpr {
require(exprs.length > 1, "least requires at least 2 arguments.")
Least(exprs.map(_.expr))
}
def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) }
/**
* Returns the least value of the list of column names, skipping null values.

View file

@ -448,6 +448,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
rand(Random.nextLong()), randn(Random.nextLong())
).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
}
test("SPARK-21281 use string types by default if array and map have no argument") {
val ds = spark.range(1)
var expectedSchema = new StructType()
.add("x", ArrayType(StringType, containsNull = false), nullable = false)
assert(ds.select(array().as("x")).schema == expectedSchema)
expectedSchema = new StructType()
.add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false)
assert(ds.select(map().as("x")).schema == expectedSchema)
}
test("SPARK-21281 fails if functions have no argument") {
val df = Seq(1).toDF("a")
val funcsMustHaveAtLeastOneArg =
("coalesce", (df: DataFrame) => df.select(coalesce())) ::
("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) ::
("named_struct", (df: DataFrame) => df.select(struct())) ::
("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) ::
("hash", (df: DataFrame) => df.select(hash())) ::
("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil
funcsMustHaveAtLeastOneArg.foreach { case (name, func) =>
val errMsg = intercept[AnalysisException] { func(df) }.getMessage
assert(errMsg.contains(s"input to function $name requires at least one argument"))
}
val funcsMustHaveAtLeastTwoArgs =
("greatest", (df: DataFrame) => df.select(greatest())) ::
("greatest", (df: DataFrame) => df.selectExpr("greatest()")) ::
("least", (df: DataFrame) => df.select(least())) ::
("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil
funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) =>
val errMsg = intercept[AnalysisException] { func(df) }.getMessage
assert(errMsg.contains(s"input to function $name requires at least two arguments"))
}
}
}
object DataFrameFunctionsSuite {