[SPARK-14061][SQL] implement CreateMap

## What changes were proposed in this pull request?

As we have `CreateArray` and `CreateStruct`, we should also have `CreateMap`.  This PR adds the `CreateMap` expression, and the DataFrame API, and python API.

## How was this patch tested?

various new tests.

Author: Wenchen Fan <wenchen@databricks.com>

Closes #11879 from cloud-fan/create_map.
This commit is contained in:
Wenchen Fan 2016-03-25 09:50:06 -07:00 committed by Yin Huai
parent 6603d9f7e2
commit 43b15e01c4
12 changed files with 277 additions and 19 deletions

View file

@ -1498,6 +1498,26 @@ def translate(srcCol, matching, replace):
# ---------------------- Collection functions ------------------------------
@ignore_unicode_prefix
@since(2.0)
def create_map(*cols):
"""Creates a new map column.
:param cols: list of column names (string) or list of :class:`Column` expressions that grouped
as key-value pairs, e.g. (key1, value1, key2, value2, ...).
>>> df.select(create_map('name', 'age').alias("map")).collect()
[Row(map={u'Alice': 2}), Row(map={u'Bob': 5})]
>>> df.select(create_map([df.name, df.age]).alias("map")).collect()
[Row(map={u'Alice': 2}), Row(map={u'Bob': 5})]
"""
sc = SparkContext._active_spark_context
if len(cols) == 1 and isinstance(cols[0], (list, set)):
cols = cols[0]
jc = sc._jvm.functions.map(_to_seq(sc, cols, _to_java_column))
return Column(jc)
@since(1.4)
def array(*cols):
"""Creates a new array column.

View file

@ -126,6 +126,7 @@ object FunctionRegistry {
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
expression[CreateMap]("map"),
expression[CreateNamedStruct]("named_struct"),
expression[NaNvl]("nanvl"),
expression[Coalesce]("nvl"),

View file

@ -160,6 +160,9 @@ object HiveTypeCoercion {
})
}
private def haveSameType(exprs: Seq[Expression]): Boolean =
exprs.map(_.dataType).distinct.length == 1
/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
@ -443,13 +446,37 @@ object HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
case a @ CreateArray(children) if children.map(_.dataType).distinct.size > 1 =>
case a @ CreateArray(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
case None => a
}
case m @ CreateMap(children) if m.keys.length == m.values.length &&
(!haveSameType(m.keys) || !haveSameType(m.values)) =>
val newKeys = if (haveSameType(m.keys)) {
m.keys
} else {
val types = m.keys.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
case None => m.keys
}
}
val newValues = if (haveSameType(m.values)) {
m.values
} else {
val types = m.values.map(_.dataType)
findTightestCommonTypeAndPromoteToString(types) match {
case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
case None => m.values
}
}
CreateMap(newKeys.zip(newValues).flatMap { case (k, v) => Seq(k, v) })
// Promote SUM, SUM DISTINCT and AVERAGE to largest types to prevent overflows.
case s @ Sum(e @ DecimalType()) => s // Decimal is already the biggest.
case Sum(e @ IntegralType()) if e.dataType != LongType => Sum(Cast(e, LongType))
@ -468,21 +495,21 @@ object HiveTypeCoercion {
// Coalesce should return the first non-null value, which could be any column
// from the list. So we need to make sure the return type is deterministic and
// compatible with every child column.
case c @ Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
case c @ Coalesce(es) if !haveSameType(es) =>
val types = es.map(_.dataType)
findWiderCommonType(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None => c
}
case g @ Greatest(children) if children.map(_.dataType).distinct.size > 1 =>
case g @ Greatest(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findTightestCommonType(types) match {
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
case None => g
}
case l @ Least(children) if children.map(_.dataType).distinct.size > 1 =>
case l @ Least(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findTightestCommonType(types) match {
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -69,6 +69,87 @@ case class CreateArray(children: Seq[Expression]) extends Expression {
override def prettyName: String = "array"
}
/**
* Returns a catalyst Map containing the evaluation of all children expressions as keys and values.
* The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...)
*/
case class CreateMap(children: Seq[Expression]) extends Expression {
private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children)
private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children)
override def foldable: Boolean = children.forall(_.foldable)
override def checkInputDataTypes(): TypeCheckResult = {
if (children.size % 2 != 0) {
TypeCheckResult.TypeCheckFailure(s"$prettyName expects an positive even number of arguments.")
} else if (keys.map(_.dataType).distinct.length > 1) {
TypeCheckResult.TypeCheckFailure("The given keys of function map should all be the same " +
"type, but they are " + keys.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else if (values.map(_.dataType).distinct.length > 1) {
TypeCheckResult.TypeCheckFailure("The given values of function map should all be the same " +
"type, but they are " + values.map(_.dataType.simpleString).mkString("[", ", ", "]"))
} else {
TypeCheckResult.TypeCheckSuccess
}
}
override def dataType: DataType = {
MapType(
keyType = keys.headOption.map(_.dataType).getOrElse(NullType),
valueType = values.headOption.map(_.dataType).getOrElse(NullType),
valueContainsNull = values.exists(_.nullable))
}
override def nullable: Boolean = false
override def eval(input: InternalRow): Any = {
val keyArray = keys.map(_.eval(input)).toArray
if (keyArray.contains(null)) {
throw new RuntimeException("Cannot use null as map key!")
}
val valueArray = values.map(_.eval(input)).toArray
new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray))
}
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
val arrayClass = classOf[GenericArrayData].getName
val mapClass = classOf[ArrayBasedMapData].getName
val keyArray = ctx.freshName("keyArray")
val valueArray = ctx.freshName("valueArray")
val keyData = s"new $arrayClass($keyArray)"
val valueData = s"new $arrayClass($valueArray)"
s"""
final boolean ${ev.isNull} = false;
final Object[] $keyArray = new Object[${keys.size}];
final Object[] $valueArray = new Object[${values.size}];
""" + keys.zipWithIndex.map {
case (key, i) =>
val eval = key.gen(ctx)
s"""
${eval.code}
if (${eval.isNull}) {
throw new RuntimeException("Cannot use null as map key!");
} else {
$keyArray[$i] = ${eval.value};
}
"""
}.mkString("\n") + values.zipWithIndex.map {
case (value, i) =>
val eval = value.gen(ctx)
s"""
${eval.code}
if (${eval.isNull}) {
$valueArray[$i] = null;
} else {
$valueArray[$i] = ${eval.value};
}
"""
}.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);"
}
override def prettyName: String = "map"
}
/**
* Returns a Row containing the evaluation of all children expressions.
*/

View file

@ -24,7 +24,6 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy())
// We need to check equality of map type in tests.
override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[ArrayBasedMapData]) {
return false
@ -35,11 +34,11 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte
return false
}
ArrayBasedMapData.toScalaMap(this) == ArrayBasedMapData.toScalaMap(other)
this.keyArray == other.keyArray && this.valueArray == other.valueArray
}
override def hashCode: Int = {
ArrayBasedMapData.toScalaMap(this).hashCode()
keyArray.hashCode() * 37 + valueArray.hashCode()
}
override def toString: String = {

View file

@ -173,13 +173,23 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
CreateNamedStruct(Seq("a", "b", 2.0)), "even number of arguments")
assertError(
CreateNamedStruct(Seq(1, "a", "b", 2.0)),
"Only foldable StringType expressions are allowed to appear at odd position")
"Only foldable StringType expressions are allowed to appear at odd position")
assertError(
CreateNamedStruct(Seq('a.string.at(0), "a", "b", 2.0)),
"Only foldable StringType expressions are allowed to appear at odd position")
"Only foldable StringType expressions are allowed to appear at odd position")
assertError(
CreateNamedStruct(Seq(Literal.create(null, StringType), "a")),
"Field name should not be null")
"Field name should not be null")
}
test("check types for CreateMap") {
assertError(CreateMap(Seq("a", "b", 2.0)), "even number of arguments")
assertError(
CreateMap(Seq('intField, 'stringField, 'booleanField, 'stringField)),
"keys of function map should all be the same type")
assertError(
CreateMap(Seq('stringField, 'intField, 'stringField, 'booleanField)),
"values of function map should all be the same type")
}
test("check types for ROUND") {

View file

@ -250,6 +250,67 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Nil))
}
test("CreateArray casts") {
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal.create(1.0, FloatType)
:: Nil),
CreateArray(Cast(Literal(1.0), DoubleType)
:: Cast(Literal(1), DoubleType)
:: Cast(Literal.create(1.0, FloatType), DoubleType)
:: Nil))
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
CreateArray(Literal(1.0)
:: Literal(1)
:: Literal("a")
:: Nil),
CreateArray(Cast(Literal(1.0), StringType)
:: Cast(Literal(1), StringType)
:: Cast(Literal("a"), StringType)
:: Nil))
}
test("CreateMap casts") {
// type coercion for map keys
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal.create(2.0, FloatType)
:: Literal("b")
:: Nil),
CreateMap(Cast(Literal(1), FloatType)
:: Literal("a")
:: Cast(Literal.create(2.0, FloatType), FloatType)
:: Literal("b")
:: Nil))
// type coercion for map values
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal(2)
:: Literal(3.0)
:: Nil),
CreateMap(Literal(1)
:: Cast(Literal("a"), StringType)
:: Literal(2)
:: Cast(Literal(3.0), StringType)
:: Nil))
// type coercion for both map keys and values
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,
CreateMap(Literal(1)
:: Literal("a")
:: Literal(2.0)
:: Literal(3.0)
:: Nil),
CreateMap(Cast(Literal(1), DoubleType)
:: Cast(Literal("a"), StringType)
:: Cast(Literal(2.0), DoubleType)
:: Cast(Literal(3.0), StringType)
:: Nil))
}
test("greatest/least cast") {
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
ruleTest(HiveTypeCoercion.FunctionArgumentConversion,

View file

@ -134,6 +134,46 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(CreateArray(Literal.create(null, IntegerType) :: Nil), null :: Nil)
}
test("CreateMap") {
def interlace(keys: Seq[Literal], values: Seq[Literal]): Seq[Literal] = {
keys.zip(values).flatMap { case (k, v) => Seq(k, v) }
}
def createMap(keys: Seq[Any], values: Seq[Any]): Map[Any, Any] = {
// catalyst map is order-sensitive, so we create ListMap here to preserve the elements order.
scala.collection.immutable.ListMap(keys.zip(values): _*)
}
val intSeq = Seq(5, 10, 15, 20, 25)
val longSeq = intSeq.map(_.toLong)
val strSeq = intSeq.map(_.toString)
checkEvaluation(CreateMap(Nil), Map.empty)
checkEvaluation(
CreateMap(interlace(intSeq.map(Literal(_)), longSeq.map(Literal(_)))),
createMap(intSeq, longSeq))
checkEvaluation(
CreateMap(interlace(strSeq.map(Literal(_)), longSeq.map(Literal(_)))),
createMap(strSeq, longSeq))
checkEvaluation(
CreateMap(interlace(longSeq.map(Literal(_)), strSeq.map(Literal(_)))),
createMap(longSeq, strSeq))
val strWithNull = strSeq.drop(1).map(Literal(_)) :+ Literal.create(null, StringType)
checkEvaluation(
CreateMap(interlace(intSeq.map(Literal(_)), strWithNull)),
createMap(intSeq, strWithNull.map(_.value)))
intercept[RuntimeException] {
checkEvaluationWithoutCodegen(
CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
null, null)
}
intercept[RuntimeException] {
checkEvalutionWithUnsafeProjection(
CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
null, null)
}
}
test("CreateStruct") {
val row = create_row(1, 2, 3)
val c1 = 'a.int.at(0)

View file

@ -904,6 +904,17 @@ object functions {
array((colName +: colNames).map(col) : _*)
}
/**
* Creates a new map column. The input columns must be grouped as key-value pairs, e.g.
* (key1, value1, key2, value2, ...). The key columns must all have the same data type, and can't
* be null. The value columns must all have the same data type.
*
* @group normal_funcs
* @since 2.0
*/
@scala.annotation.varargs
def map(cols: Column*): Column = withExpr { CreateMap(cols.map(_.expr)) }
/**
* Marks a DataFrame as small enough for use in broadcast joins.
*

View file

@ -41,7 +41,13 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
test("UDF on array") {
val f = udf((a: String) => a)
val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")
df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect()
df.select(array($"a").as("s")).select(f($"s".getItem(0))).collect()
}
test("UDF on map") {
val f = udf((a: String) => a)
val df = Seq("a" -> 1).toDF("a", "b")
df.select(map($"a", $"b").as("s")).select(f($"s".getItem("a"))).collect()
}
test("SPARK-12477 accessing null element in array field") {

View file

@ -44,15 +44,16 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
val expectedType = ArrayType(IntegerType, containsNull = false)
assert(row.schema(0).dataType === expectedType)
assert(row.getAs[Seq[Int]](0) === Seq(0, 2))
assert(row.getSeq[Int](0) === Seq(0, 2))
}
// Turn this on once we add a rule to the analyzer to throw a friendly exception
ignore("array: throw exception if putting columns of different types into an array") {
val df = Seq((0, "str")).toDF("a", "b")
intercept[AnalysisException] {
df.select(array("a", "b"))
}
test("map with column expressions") {
val df = Seq(1 -> "a").toDF("a", "b")
val row = df.select(map($"a" + 1, $"b")).first()
val expectedType = MapType(IntegerType, StringType, valueContainsNull = true)
assert(row.schema(0).dataType === expectedType)
assert(row.getMap[Int, String](0) === Map(2 -> "a"))
}
test("struct with column name") {

View file

@ -100,6 +100,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkSqlGeneration("SELECT isnull(null), isnull('a')")
checkSqlGeneration("SELECT isnotnull(null), isnotnull('a')")
checkSqlGeneration("SELECT least(1,null,3)")
checkSqlGeneration("SELECT map(1, 'a', 2, 'b')")
checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
checkSqlGeneration("SELECT nvl(null, 1, 2)")