[SPARK-21281][SQL] Use string types by default if array and map have no argument
## What changes were proposed in this pull request? This pr modified code to use string types by default if `array` and `map` in functions have no argument. This behaviour is the same with Hive one; ``` hive> CREATE TEMPORARY TABLE t1 AS SELECT map(); hive> DESCRIBE t1; _c0 map<string,string> hive> CREATE TEMPORARY TABLE t2 AS SELECT array(); hive> DESCRIBE t2; _c0 array<string> ``` ## How was this patch tested? Added tests in `DataFrameFunctionsSuite`. Author: Takeshi Yamamuro <yamamuro@apache.org> Closes #18516 from maropu/SPARK-21281.
This commit is contained in:
parent
e1a172c201
commit
7896e7b99d
|
@ -527,13 +527,14 @@ case class Least(children: Seq[Expression]) extends Expression {
|
|||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children.length <= 1) {
|
||||
TypeCheckResult.TypeCheckFailure(s"LEAST requires at least 2 arguments")
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"input to function $prettyName requires at least two arguments")
|
||||
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"The expressions should all have the same type," +
|
||||
s" got LEAST(${children.map(_.dataType.simpleString).mkString(", ")}).")
|
||||
} else {
|
||||
TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
|
||||
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -592,13 +593,14 @@ case class Greatest(children: Seq[Expression]) extends Expression {
|
|||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children.length <= 1) {
|
||||
TypeCheckResult.TypeCheckFailure(s"GREATEST requires at least 2 arguments")
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"input to function $prettyName requires at least two arguments")
|
||||
} else if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"The expressions should all have the same type," +
|
||||
s" got GREATEST(${children.map(_.dataType.simpleString).mkString(", ")}).")
|
||||
} else {
|
||||
TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
|
||||
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -41,12 +41,13 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
|
|||
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult =
|
||||
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function array")
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
|
||||
}
|
||||
|
||||
override def dataType: ArrayType = {
|
||||
ArrayType(
|
||||
children.headOption.map(_.dataType).getOrElse(NullType),
|
||||
children.headOption.map(_.dataType).getOrElse(StringType),
|
||||
containsNull = children.exists(_.nullable))
|
||||
}
|
||||
|
||||
|
@ -93,7 +94,7 @@ private [sql] object GenArrayData {
|
|||
if (!ctx.isPrimitiveType(elementType)) {
|
||||
val genericArrayClass = classOf[GenericArrayData].getName
|
||||
ctx.addMutableState("Object[]", arrayName,
|
||||
s"$arrayName = new Object[${numElements}];")
|
||||
s"$arrayName = new Object[$numElements];")
|
||||
|
||||
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
|
||||
val isNullAssignment = if (!isMapKey) {
|
||||
|
@ -119,7 +120,7 @@ private [sql] object GenArrayData {
|
|||
UnsafeArrayData.calculateHeaderPortionInBytes(numElements) +
|
||||
ByteArrayMethods.roundNumberOfBytesToNearestWord(elementType.defaultSize * numElements)
|
||||
val baseOffset = Platform.BYTE_ARRAY_OFFSET
|
||||
ctx.addMutableState("UnsafeArrayData", arrayDataName, "");
|
||||
ctx.addMutableState("UnsafeArrayData", arrayDataName, "")
|
||||
|
||||
val primitiveValueTypeName = ctx.primitiveTypeName(elementType)
|
||||
val assignments = elementsCode.zipWithIndex.map { case (eval, i) =>
|
||||
|
@ -169,13 +170,16 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
|
|||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children.size % 2 != 0) {
|
||||
TypeCheckResult.TypeCheckFailure(s"$prettyName expects a positive even number of arguments.")
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"$prettyName expects a positive even number of arguments.")
|
||||
} else if (keys.map(_.dataType).distinct.length > 1) {
|
||||
TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " +
|
||||
"type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
"The given keys of function map should all be the same type, but they are " +
|
||||
keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
|
||||
} else if (values.map(_.dataType).distinct.length > 1) {
|
||||
TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " +
|
||||
"type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
"The given values of function map should all be the same type, but they are " +
|
||||
values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
}
|
||||
|
@ -183,8 +187,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression {
|
|||
|
||||
override def dataType: DataType = {
|
||||
MapType(
|
||||
keyType = keys.headOption.map(_.dataType).getOrElse(NullType),
|
||||
valueType = values.headOption.map(_.dataType).getOrElse(NullType),
|
||||
keyType = keys.headOption.map(_.dataType).getOrElse(StringType),
|
||||
valueType = values.headOption.map(_.dataType).getOrElse(StringType),
|
||||
valueContainsNull = values.exists(_.nullable))
|
||||
}
|
||||
|
||||
|
@ -292,14 +296,17 @@ trait CreateNamedStructLike extends Expression {
|
|||
}
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children.size % 2 != 0) {
|
||||
if (children.length < 1) {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"input to function $prettyName requires at least one argument")
|
||||
} else if (children.size % 2 != 0) {
|
||||
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an even number of arguments.")
|
||||
} else {
|
||||
val invalidNames = nameExprs.filterNot(e => e.foldable && e.dataType == StringType)
|
||||
if (invalidNames.nonEmpty) {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
"Only foldable StringType expressions are allowed to appear at odd position, got:" +
|
||||
s" ${invalidNames.mkString(",")}")
|
||||
s" ${invalidNames.mkString(",")}")
|
||||
} else if (!names.contains(null)) {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
} else {
|
||||
|
|
|
@ -247,8 +247,9 @@ abstract class HashExpression[E] extends Expression {
|
|||
override def nullable: Boolean = false
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children.isEmpty) {
|
||||
TypeCheckResult.TypeCheckFailure("function hash requires at least one argument")
|
||||
if (children.length < 1) {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"input to function $prettyName requires at least one argument")
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
}
|
||||
|
|
|
@ -52,10 +52,11 @@ case class Coalesce(children: Seq[Expression]) extends Expression {
|
|||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children == Nil) {
|
||||
TypeCheckResult.TypeCheckFailure("input to function coalesce cannot be empty")
|
||||
if (children.length < 1) {
|
||||
TypeCheckResult.TypeCheckFailure(
|
||||
s"input to function $prettyName requires at least one argument")
|
||||
} else {
|
||||
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), "function coalesce")
|
||||
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), s"function $prettyName")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -155,7 +155,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
"input to function array should all be the same type")
|
||||
assertError(Coalesce(Seq('intField, 'booleanField)),
|
||||
"input to function coalesce should all be the same type")
|
||||
assertError(Coalesce(Nil), "input to function coalesce cannot be empty")
|
||||
assertError(Coalesce(Nil), "function coalesce requires at least one argument")
|
||||
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
|
||||
assertError(Explode('intField),
|
||||
"input to function explode should be array or map type")
|
||||
|
@ -207,7 +207,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
|
||||
test("check types for Greatest/Least") {
|
||||
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
|
||||
assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
|
||||
assertError(operator(Seq('booleanField)), "requires at least two arguments")
|
||||
assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
|
||||
assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
|
||||
}
|
||||
|
|
|
@ -1565,10 +1565,7 @@ object functions {
|
|||
* @since 1.5.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def greatest(exprs: Column*): Column = withExpr {
|
||||
require(exprs.length > 1, "greatest requires at least 2 arguments.")
|
||||
Greatest(exprs.map(_.expr))
|
||||
}
|
||||
def greatest(exprs: Column*): Column = withExpr { Greatest(exprs.map(_.expr)) }
|
||||
|
||||
/**
|
||||
* Returns the greatest value of the list of column names, skipping null values.
|
||||
|
@ -1672,10 +1669,7 @@ object functions {
|
|||
* @since 1.5.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def least(exprs: Column*): Column = withExpr {
|
||||
require(exprs.length > 1, "least requires at least 2 arguments.")
|
||||
Least(exprs.map(_.expr))
|
||||
}
|
||||
def least(exprs: Column*): Column = withExpr { Least(exprs.map(_.expr)) }
|
||||
|
||||
/**
|
||||
* Returns the least value of the list of column names, skipping null values.
|
||||
|
|
|
@ -448,6 +448,42 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
rand(Random.nextLong()), randn(Random.nextLong())
|
||||
).foreach(assertValuesDoNotChangeAfterCoalesceOrUnion(_))
|
||||
}
|
||||
|
||||
test("SPARK-21281 use string types by default if array and map have no argument") {
|
||||
val ds = spark.range(1)
|
||||
var expectedSchema = new StructType()
|
||||
.add("x", ArrayType(StringType, containsNull = false), nullable = false)
|
||||
assert(ds.select(array().as("x")).schema == expectedSchema)
|
||||
expectedSchema = new StructType()
|
||||
.add("x", MapType(StringType, StringType, valueContainsNull = false), nullable = false)
|
||||
assert(ds.select(map().as("x")).schema == expectedSchema)
|
||||
}
|
||||
|
||||
test("SPARK-21281 fails if functions have no argument") {
|
||||
val df = Seq(1).toDF("a")
|
||||
|
||||
val funcsMustHaveAtLeastOneArg =
|
||||
("coalesce", (df: DataFrame) => df.select(coalesce())) ::
|
||||
("coalesce", (df: DataFrame) => df.selectExpr("coalesce()")) ::
|
||||
("named_struct", (df: DataFrame) => df.select(struct())) ::
|
||||
("named_struct", (df: DataFrame) => df.selectExpr("named_struct()")) ::
|
||||
("hash", (df: DataFrame) => df.select(hash())) ::
|
||||
("hash", (df: DataFrame) => df.selectExpr("hash()")) :: Nil
|
||||
funcsMustHaveAtLeastOneArg.foreach { case (name, func) =>
|
||||
val errMsg = intercept[AnalysisException] { func(df) }.getMessage
|
||||
assert(errMsg.contains(s"input to function $name requires at least one argument"))
|
||||
}
|
||||
|
||||
val funcsMustHaveAtLeastTwoArgs =
|
||||
("greatest", (df: DataFrame) => df.select(greatest())) ::
|
||||
("greatest", (df: DataFrame) => df.selectExpr("greatest()")) ::
|
||||
("least", (df: DataFrame) => df.select(least())) ::
|
||||
("least", (df: DataFrame) => df.selectExpr("least()")) :: Nil
|
||||
funcsMustHaveAtLeastTwoArgs.foreach { case (name, func) =>
|
||||
val errMsg = intercept[AnalysisException] { func(df) }.getMessage
|
||||
assert(errMsg.contains(s"input to function $name requires at least two arguments"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object DataFrameFunctionsSuite {
|
||||
|
|
Loading…
Reference in a new issue