[SPARK-14061][SQL] implement CreateMap
## What changes were proposed in this pull request? As we have `CreateArray` and `CreateStruct`, we should also have `CreateMap`. This PR adds the `CreateMap` expression, and the DataFrame API, and python API. ## How was this patch tested? various new tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #11879 from cloud-fan/create_map.
This commit is contained in:
parent
6603d9f7e2
commit
43b15e01c4
|
@ -1498,6 +1498,26 @@ def translate(srcCol, matching, replace):
|
|||
|
||||
# ---------------------- Collection functions ------------------------------
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(2.0)
|
||||
def create_map(*cols):
|
||||
"""Creates a new map column.
|
||||
|
||||
:param cols: list of column names (string) or list of :class:`Column` expressions that grouped
|
||||
as key-value pairs, e.g. (key1, value1, key2, value2, ...).
|
||||
|
||||
>>> df.select(create_map('name', 'age').alias("map")).collect()
|
||||
[Row(map={u'Alice': 2}), Row(map={u'Bob': 5})]
|
||||
>>> df.select(create_map([df.name, df.age]).alias("map")).collect()
|
||||
[Row(map={u'Alice': 2}), Row(map={u'Bob': 5})]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
if len(cols) == 1 and isinstance(cols[0], (list, set)):
|
||||
cols = cols[0]
|
||||
jc = sc._jvm.functions.map(_to_seq(sc, cols, _to_java_column))
|
||||
return Column(jc)
|
||||
|
||||
|
||||
@since(1.4)
|
||||
def array(*cols):
|
||||
"""Creates a new array column.
|
||||
|
|
|
@ -126,6 +126,7 @@ object FunctionRegistry {
|
|||
expression[IsNull]("isnull"),
|
||||
expression[IsNotNull]("isnotnull"),
|
||||
expression[Least]("least"),
|
||||
expression[CreateMap]("map"),
|
||||
expression[CreateNamedStruct]("named_struct"),
|
||||
expression[NaNvl]("nanvl"),
|
||||
expression[Coalesce]("nvl"),
|
||||
|
|
|
@ -160,6 +160,9 @@ object HiveTypeCoercion {
|
|||
})
|
||||
}
|
||||
|
||||
private def haveSameType(exprs: Seq[Expression]): Boolean =
|
||||
exprs.map(_.dataType).distinct.length == 1
|
||||
|
||||
/**
|
||||
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
|
||||
* instances higher in the query tree.
|
||||
|
@ -443,13 +446,37 @@ object HiveTypeCoercion {
|
|||
// Skip nodes who's children have not been resolved yet.
|
||||
case e if !e.childrenResolved => e
|
||||
|
||||
case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 =>
|
||||
case a @ CreateArray(children) if !haveSameType(children) =>
|
||||
val types = children.map(_.dataType)
|
||||
findTightestCommonTypeAndPromoteToString(types) match {
|
||||
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
|
||||
case None => a
|
||||
}
|
||||
|
||||
case m @ CreateMap(children) if m.keys.length == m.values.length &&
|
||||
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
|
||||
val newKeys = if (haveSameType(m.keys)) {
|
||||
m.keys
|
||||
} else {
|
||||
val types = m.keys.map(_.dataType)
|
||||
findTightestCommonTypeAndPromoteToString(types) match {
|
||||
case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
|
||||
case None => m.keys
|
||||
}
|
||||
}
|
||||
|
||||
val newValues = if (haveSameType(m.values)) {
|
||||
m.values
|
||||
} else {
|
||||
val types = m.values.map(_.dataType)
|
||||
findTightestCommonTypeAndPromoteToString(types) match {
|
||||
case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
|
||||
case None => m.values
|
||||
}
|
||||
}
|
||||
|
||||
CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
|
||||
|
||||
// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
|
||||
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
|
||||
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
|
||||
|
@ -468,21 +495,21 @@ object HiveTypeCoercion {
|
|||
// Coalesce should return the first non-null value, which could be any column
|
||||
// from the list. So we need to make sure the return type is deterministic and
|
||||
// compatible with every child column.
|
||||
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
|
||||
case c @ Coalesce(es) if !haveSameType(es) =>
|
||||
val types = es.map(_.dataType)
|
||||
findWiderCommonType(types) match {
|
||||
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
|
||||
case None => c
|
||||
}
|
||||
|
||||
case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 =>
|
||||
case g @ Greatest(children) if !haveSameType(children) =>
|
||||
val types = children.map(_.dataType)
|
||||
findTightestCommonType(types) match {
|
||||
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
|
||||
case None => g
|
||||
}
|
||||
|
||||
case l @ Least(children) if children.map(_.dataType).distinct.size > 1 =>
|
||||
case l @ Least(children) if !haveSameType(children) =>
|
||||
val types = children.map(_.dataType)
|
||||
findTightestCommonType(types) match {
|
||||
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils}
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
@ -69,6 +69,87 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
|
|||
override def prettyName: String = "array"
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a catalyst Map containing the evaluation of all children expressions as keys and values.
|
||||
* The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...)
|
||||
*/
|
||||
case class CreateMap(children: Seq[Expression]) extends Expression {
|
||||
private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children)
|
||||
private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children)
|
||||
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (children.size % 2 != 0) {
|
||||
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an positive even number of arguments.")
|
||||
} else if (keys.map(_.dataType).distinct.length > 1) {
|
||||
TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " +
|
||||
"type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
|
||||
} else if (values.map(_.dataType).distinct.length > 1) {
|
||||
TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " +
|
||||
"type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
|
||||
} else {
|
||||
TypeCheckResult.TypeCheckSuccess
|
||||
}
|
||||
}
|
||||
|
||||
override def dataType: DataType = {
|
||||
MapType(
|
||||
keyType = keys.headOption.map(_.dataType).getOrElse(NullType),
|
||||
valueType = values.headOption.map(_.dataType).getOrElse(NullType),
|
||||
valueContainsNull = values.exists(_.nullable))
|
||||
}
|
||||
|
||||
override def nullable: Boolean = false
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val keyArray = keys.map(_.eval(input)).toArray
|
||||
if (keyArray.contains(null)) {
|
||||
throw new RuntimeException("Cannot use null as map key!")
|
||||
}
|
||||
val valueArray = values.map(_.eval(input)).toArray
|
||||
new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray))
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
|
||||
val arrayClass = classOf[GenericArrayData].getName
|
||||
val mapClass = classOf[ArrayBasedMapData].getName
|
||||
val keyArray = ctx.freshName("keyArray")
|
||||
val valueArray = ctx.freshName("valueArray")
|
||||
val keyData = s"new $arrayClass($keyArray)"
|
||||
val valueData = s"new $arrayClass($valueArray)"
|
||||
s"""
|
||||
final boolean ${ev.isNull} = false;
|
||||
final Object[] $keyArray = new Object[${keys.size}];
|
||||
final Object[] $valueArray = new Object[${values.size}];
|
||||
""" + keys.zipWithIndex.map {
|
||||
case (key, i) =>
|
||||
val eval = key.gen(ctx)
|
||||
s"""
|
||||
${eval.code}
|
||||
if (${eval.isNull}) {
|
||||
throw new RuntimeException("Cannot use null as map key!");
|
||||
} else {
|
||||
$keyArray[$i] = ${eval.value};
|
||||
}
|
||||
"""
|
||||
}.mkString("\n") + values.zipWithIndex.map {
|
||||
case (value, i) =>
|
||||
val eval = value.gen(ctx)
|
||||
s"""
|
||||
${eval.code}
|
||||
if (${eval.isNull}) {
|
||||
$valueArray[$i] = null;
|
||||
} else {
|
||||
$valueArray[$i] = ${eval.value};
|
||||
}
|
||||
"""
|
||||
}.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);"
|
||||
}
|
||||
|
||||
override def prettyName: String = "map"
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a Row containing the evaluation of all children expressions.
|
||||
*/
|
||||
|
|
|
@ -24,7 +24,6 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
|
|||
|
||||
override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy())
|
||||
|
||||
// We need to check equality of map type in tests.
|
||||
override def equals(o: Any): Boolean = {
|
||||
if (!o.isInstanceOf[ArrayBasedMapData]) {
|
||||
return false
|
||||
|
@ -35,11 +34,11 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
|
|||
return false
|
||||
}
|
||||
|
||||
ArrayBasedMapData.toScalaMap(this) == ArrayBasedMapData.toScalaMap(other)
|
||||
this.keyArray == other.keyArray && this.valueArray == other.valueArray
|
||||
}
|
||||
|
||||
override def hashCode: Int = {
|
||||
ArrayBasedMapData.toScalaMap(this).hashCode()
|
||||
keyArray.hashCode() * 37 + valueArray.hashCode()
|
||||
}
|
||||
|
||||
override def toString: String = {
|
||||
|
|
|
@ -173,13 +173,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments")
|
||||
assertError(
|
||||
CreateNamedStruct(Seq(1, "a", "b", 2.0)),
|
||||
"Only foldable StringType expressions are allowed to appear at odd position")
|
||||
"Only foldable StringType expressions are allowed to appear at odd position")
|
||||
assertError(
|
||||
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
|
||||
"Only foldable StringType expressions are allowed to appear at odd position")
|
||||
"Only foldable StringType expressions are allowed to appear at odd position")
|
||||
assertError(
|
||||
CreateNamedStruct(Seq(Literal.create(null, StringType), "a")),
|
||||
"Field name should not be null")
|
||||
"Field name should not be null")
|
||||
}
|
||||
|
||||
test("check types for CreateMap") {
|
||||
assertError(CreateMap(Seq("a", "b", 2.0)), "even number of arguments")
|
||||
assertError(
|
||||
CreateMap(Seq('intField, 'stringField, 'booleanField, 'stringField)),
|
||||
"keys of function map should all be the same type")
|
||||
assertError(
|
||||
CreateMap(Seq('stringField, 'intField, 'stringField, 'booleanField)),
|
||||
"values of function map should all be the same type")
|
||||
}
|
||||
|
||||
test("check types for ROUND") {
|
||||
|
|
|
@ -250,6 +250,67 @@ class HiveTypeCoercionSuite extends PlanTest {
|
|||
:: Nil))
|
||||
}
|
||||
|
||||
test("CreateArray casts") {
|
||||
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
|
||||
CreateArray(Literal(1.0)
|
||||
:: Literal(1)
|
||||
:: Literal.create(1.0, FloatType)
|
||||
:: Nil),
|
||||
CreateArray(Cast(Literal(1.0), DoubleType)
|
||||
:: Cast(Literal(1), DoubleType)
|
||||
:: Cast(Literal.create(1.0, FloatType), DoubleType)
|
||||
:: Nil))
|
||||
|
||||
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
|
||||
CreateArray(Literal(1.0)
|
||||
:: Literal(1)
|
||||
:: Literal("a")
|
||||
:: Nil),
|
||||
CreateArray(Cast(Literal(1.0), StringType)
|
||||
:: Cast(Literal(1), StringType)
|
||||
:: Cast(Literal("a"), StringType)
|
||||
:: Nil))
|
||||
}
|
||||
|
||||
test("CreateMap casts") {
|
||||
// type coercion for map keys
|
||||
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
|
||||
CreateMap(Literal(1)
|
||||
:: Literal("a")
|
||||
:: Literal.create(2.0, FloatType)
|
||||
:: Literal("b")
|
||||
:: Nil),
|
||||
CreateMap(Cast(Literal(1), FloatType)
|
||||
:: Literal("a")
|
||||
:: Cast(Literal.create(2.0, FloatType), FloatType)
|
||||
:: Literal("b")
|
||||
:: Nil))
|
||||
// type coercion for map values
|
||||
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
|
||||
CreateMap(Literal(1)
|
||||
:: Literal("a")
|
||||
:: Literal(2)
|
||||
:: Literal(3.0)
|
||||
:: Nil),
|
||||
CreateMap(Literal(1)
|
||||
:: Cast(Literal("a"), StringType)
|
||||
:: Literal(2)
|
||||
:: Cast(Literal(3.0), StringType)
|
||||
:: Nil))
|
||||
// type coercion for both map keys and values
|
||||
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
|
||||
CreateMap(Literal(1)
|
||||
:: Literal("a")
|
||||
:: Literal(2.0)
|
||||
:: Literal(3.0)
|
||||
:: Nil),
|
||||
CreateMap(Cast(Literal(1), DoubleType)
|
||||
:: Cast(Literal("a"), StringType)
|
||||
:: Cast(Literal(2.0), DoubleType)
|
||||
:: Cast(Literal(3.0), StringType)
|
||||
:: Nil))
|
||||
}
|
||||
|
||||
test("greatest/least cast") {
|
||||
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
|
||||
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
|
||||
|
|
|
@ -134,6 +134,46 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
|
||||
}
|
||||
|
||||
test("CreateMap") {
|
||||
def interlace(keys: Seq[Literal], values: Seq[Literal]): Seq[Literal] = {
|
||||
keys.zip(values).flatMap { case (k, v) => Seq(k, v) }
|
||||
}
|
||||
|
||||
def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
|
||||
// catalyst map is order-sensitive, so we create ListMap here to preserve the elements order.
|
||||
scala.collection.immutable.ListMap(keys.zip(values): _*)
|
||||
}
|
||||
|
||||
val intSeq = Seq(5, 10, 15, 20, 25)
|
||||
val longSeq = intSeq.map(_.toLong)
|
||||
val strSeq = intSeq.map(_.toString)
|
||||
checkEvaluation(CreateMap(Nil), Map.empty)
|
||||
checkEvaluation(
|
||||
CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))),
|
||||
createMap(intSeq, longSeq))
|
||||
checkEvaluation(
|
||||
CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))),
|
||||
createMap(strSeq, longSeq))
|
||||
checkEvaluation(
|
||||
CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))),
|
||||
createMap(longSeq, strSeq))
|
||||
|
||||
val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType)
|
||||
checkEvaluation(
|
||||
CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)),
|
||||
createMap(intSeq, strWithNull.map(_.value)))
|
||||
intercept[RuntimeException] {
|
||||
checkEvaluationWithoutCodegen(
|
||||
CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
|
||||
null, null)
|
||||
}
|
||||
intercept[RuntimeException] {
|
||||
checkEvalutionWithUnsafeProjection(
|
||||
CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
|
||||
null, null)
|
||||
}
|
||||
}
|
||||
|
||||
test("CreateStruct") {
|
||||
val row = create_row(1, 2, 3)
|
||||
val c1 = 'a.int.at(0)
|
||||
|
|
|
@ -904,6 +904,17 @@ object functions {
|
|||
array((colName +: colNames).map(col) : _*)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new map column. The input columns must be grouped as key-value pairs, e.g.
|
||||
* (key1, value1, key2, value2, ...). The key columns must all have the same data type, and can't
|
||||
* be null. The value columns must all have the same data type.
|
||||
*
|
||||
* @group normal_funcs
|
||||
* @since 2.0
|
||||
*/
|
||||
@scala.annotation.varargs
|
||||
def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) }
|
||||
|
||||
/**
|
||||
* Marks a DataFrame as small enough for use in broadcast joins.
|
||||
*
|
||||
|
|
|
@ -41,7 +41,13 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
|
|||
test("UDF on array") {
|
||||
val f = udf((a: String) => a)
|
||||
val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
|
||||
df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect()
|
||||
df.select(array($"a").as("s")).select(f($"s".getItem(0))).collect()
|
||||
}
|
||||
|
||||
test("UDF on map") {
|
||||
val f = udf((a: String) => a)
|
||||
val df = Seq("a" -> 1).toDF("a", "b")
|
||||
df.select(map($"a", $"b").as("s")).select(f($"s".getItem("a"))).collect()
|
||||
}
|
||||
|
||||
test("SPARK-12477 accessing null element in array field") {
|
||||
|
|
|
@ -44,15 +44,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
val expectedType = ArrayType(IntegerType, containsNull = false)
|
||||
assert(row.schema(0).dataType === expectedType)
|
||||
assert(row.getAs[Seq[Int]](0) === Seq(0, 2))
|
||||
assert(row.getSeq[Int](0) === Seq(0, 2))
|
||||
}
|
||||
|
||||
// Turn this on once we add a rule to the analyzer to throw a friendly exception
|
||||
ignore("array: throw exception if putting columns of different types into an array") {
|
||||
val df = Seq((0, "str")).toDF("a", "b")
|
||||
intercept[AnalysisException] {
|
||||
df.select(array("a", "b"))
|
||||
}
|
||||
test("map with column expressions") {
|
||||
val df = Seq(1 -> "a").toDF("a", "b")
|
||||
val row = df.select(map($"a" + 1, $"b")).first()
|
||||
|
||||
val expectedType = MapType(IntegerType, StringType, valueContainsNull = true)
|
||||
assert(row.schema(0).dataType === expectedType)
|
||||
assert(row.getMap[Int, String](0) === Map(2 -> "a"))
|
||||
}
|
||||
|
||||
test("struct with column name") {
|
||||
|
|
|
@ -100,6 +100,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
|
|||
checkSqlGeneration("SELECT isnull(null), isnull('a')")
|
||||
checkSqlGeneration("SELECT isnotnull(null), isnotnull('a')")
|
||||
checkSqlGeneration("SELECT least(1,null,3)")
|
||||
checkSqlGeneration("SELECT map(1, 'a', 2, 'b')")
|
||||
checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
|
||||
checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
|
||||
checkSqlGeneration("SELECT nvl(null, 1, 2)")
|
||||
|
|
Loading…
Reference in a new issue