[SPARK-16289][SQL] Implement posexplode table generating function

## What changes were proposed in this pull request?

This PR implements `posexplode` table generating function. Currently, master branch raises the following exception for `map` argument. It's different from Hive.

**Before**
```scala
scala> sql("select posexplode(map('a', 1, 'b', 2))").show
org.apache.spark.sql.AnalysisException: No handler for Hive UDF ... posexplode() takes an array as a parameter; line 1 pos 7
```

**After**
```scala
scala> sql("select posexplode(map('a', 1, 'b', 2))").show
+---+---+-----+
|pos|key|value|
+---+---+-----+
|  0|  a|    1|
|  1|  b|    2|
+---+---+-----+
```

For `array` argument, `after` is the same with `before`.
```
scala> sql("select posexplode(array(1, 2, 3))").show
+---+---+
|pos|col|
+---+---+
|  0|  1|
|  1|  2|
|  2|  3|
+---+---+
```

## How was this patch tested?

Pass the Jenkins tests with newly added testcases.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #13971 from dongjoon-hyun/SPARK-16289.
This commit is contained in:
Dongjoon Hyun 2016-06-30 12:03:54 -07:00 committed by Reynold Xin
parent fdf9f94f8c
commit 46395db80e
14 changed files with 276 additions and 72 deletions

View file

@ -234,6 +234,7 @@ exportMethods("%in%",
"over",
"percent_rank",
"pmod",
"posexplode",
"quarter",
"rand",
"randn",

View file

@ -2934,3 +2934,20 @@ setMethod("sort_array",
jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc)
column(jc)
})
#' posexplode
#'
#' Creates a new row for each element with position in the given array or map column.
#'
#' @rdname posexplode
#' @name posexplode
#' @family collection_funcs
#' @export
#' @examples \dontrun{posexplode(df$c)}
#' @note posexplode since 2.1.0
setMethod("posexplode",
signature(x = "Column"),
function(x) {
jc <- callJStatic("org.apache.spark.sql.functions", "posexplode", x@jc)
column(jc)
})

View file

@ -1050,6 +1050,10 @@ setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") })
#' @export
setGeneric("pmod", function(y, x) { standardGeneric("pmod") })
#' @rdname posexplode
#' @export
setGeneric("posexplode", function(x) { standardGeneric("posexplode") })
#' @rdname quarter
#' @export
setGeneric("quarter", function(x) { standardGeneric("quarter") })

View file

@ -1065,7 +1065,7 @@ test_that("column functions", {
c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c)
c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c)
c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c)
c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c)
c7 <- mean(c) + min(c) + month(c) + negate(c) + posexplode(c) + quarter(c)
c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + monotonically_increasing_id()
c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c)
c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c)

View file

@ -1637,6 +1637,27 @@ def explode(col):
return Column(jc)
@since(2.1)
def posexplode(col):
"""Returns a new row for each element with position in the given array or map.
>>> from pyspark.sql import Row
>>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
>>> eDF.select(posexplode(eDF.intlist)).collect()
[Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)]
>>> eDF.select(posexplode(eDF.mapfield)).show()
+---+---+-----+
|pos|key|value|
+---+---+-----+
| 0| a| b|
+---+---+-----+
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.posexplode(_to_java_column(col))
return Column(jc)
@ignore_unicode_prefix
@since(1.6)
def get_json_object(col, path):

View file

@ -176,6 +176,7 @@ object FunctionRegistry {
expression[NullIf]("nullif"),
expression[Nvl]("nvl"),
expression[Nvl2]("nvl2"),
expression[PosExplode]("posexplode"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),

View file

@ -94,13 +94,10 @@ case class UserDefinedGenerator(
}
/**
* Given an input array produces a sequence of rows for each value in the array.
* A base class for Explode and PosExplode
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.")
// scalastyle:on line.size.limit
case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
abstract class ExplodeBase(child: Expression, position: Boolean)
extends UnaryExpression with Generator with CodegenFallback with Serializable {
override def children: Seq[Expression] = child :: Nil
@ -115,9 +112,26 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
override def elementSchema: StructType = child.dataType match {
case ArrayType(et, containsNull) => new StructType().add("col", et, containsNull)
case ArrayType(et, containsNull) =>
if (position) {
new StructType()
.add("pos", IntegerType, false)
.add("col", et, containsNull)
} else {
new StructType()
.add("col", et, containsNull)
}
case MapType(kt, vt, valueContainsNull) =>
new StructType().add("key", kt, false).add("value", vt, valueContainsNull)
if (position) {
new StructType()
.add("pos", IntegerType, false)
.add("key", kt, false)
.add("value", vt, valueContainsNull)
} else {
new StructType()
.add("key", kt, false)
.add("value", vt, valueContainsNull)
}
}
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@ -129,7 +143,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
} else {
val rows = new Array[InternalRow](inputArray.numElements())
inputArray.foreach(et, (i, e) => {
rows(i) = InternalRow(e)
rows(i) = if (position) InternalRow(i, e) else InternalRow(e)
})
rows
}
@ -141,7 +155,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
val rows = new Array[InternalRow](inputMap.numElements())
var i = 0
inputMap.foreach(kt, vt, (k, v) => {
rows(i) = InternalRow(k, v)
rows(i) = if (position) InternalRow(i, k, v) else InternalRow(k, v)
i += 1
})
rows
@ -149,3 +163,35 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
}
}
}
/**
* Given an input array produces a sequence of rows for each value in the array.
*
* {{{
* SELECT explode(array(10,20)) ->
* 10
* 20
* }}}
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of map a into multiple rows and columns.",
extended = "> SELECT _FUNC_(array(10,20));\n 10\n 20")
// scalastyle:on line.size.limit
case class Explode(child: Expression) extends ExplodeBase(child, position = false)
/**
* Given an input array produces a sequence of rows for each position and value in the array.
*
* {{{
* SELECT posexplode(array(10,20)) ->
* 0 10
* 1 20
* }}}
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows with positions, or the elements of a map into multiple rows and columns with positions.",
extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
// scalastyle:on line.size.limit
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)

View file

@ -166,6 +166,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
assertError(Explode('intField),
"input to function explode should be array or map type")
assertError(PosExplode('intField),
"input to function explode should be array or map type")
}
test("check types for CreateNamedStruct") {

View file

@ -0,0 +1,71 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.unsafe.types.UTF8String
class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = {
assert(actual.eval(null).toSeq === expected)
}
private final val int_array = Seq(1, 2, 3)
private final val str_array = Seq("a", "b", "c")
test("explode") {
val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3))
val str_correct_answer = Seq(
Seq(UTF8String.fromString("a")),
Seq(UTF8String.fromString("b")),
Seq(UTF8String.fromString("c")))
checkTuple(
Explode(CreateArray(Seq.empty)),
Seq.empty)
checkTuple(
Explode(CreateArray(int_array.map(Literal(_)))),
int_correct_answer.map(InternalRow.fromSeq(_)))
checkTuple(
Explode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
}
test("posexplode") {
val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3))
val str_correct_answer = Seq(
Seq(0, UTF8String.fromString("a")),
Seq(1, UTF8String.fromString("b")),
Seq(2, UTF8String.fromString("c")))
checkTuple(
PosExplode(CreateArray(Seq.empty)),
Seq.empty)
checkTuple(
PosExplode(CreateArray(int_array.map(Literal(_)))),
int_correct_answer.map(InternalRow.fromSeq(_)))
checkTuple(
PosExplode(CreateArray(str_array.map(Literal(_)))),
str_correct_answer.map(InternalRow.fromSeq(_)))
}
}

View file

@ -159,6 +159,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
// Leave an unaliased generator with an empty list of names since the analyzer will generate
// the correct defaults after the nested expression's type has been resolved.
case explode: Explode => MultiAlias(explode, Nil)
case explode: PosExplode => MultiAlias(explode, Nil)
case jt: JsonTuple => MultiAlias(jt, Nil)

View file

@ -2721,6 +2721,14 @@ object functions {
*/
def explode(e: Column): Column = withExpr { Explode(e.expr) }
/**
* Creates a new row for each element with position in the given array or map column.
*
* @group collection_funcs
* @since 2.1.0
*/
def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) }
/**
* Extracts json object from a json string based on json path specified, and returns json string
* of the extracted json object. It will return null if the input json string is invalid.

View file

@ -122,66 +122,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
}
test("single explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select(explode('intList)),
Row(1) :: Row(2) :: Row(3) :: Nil)
}
test("explode and other columns") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select($"a", explode('intList)),
Row(1, 1) ::
Row(1, 2) ::
Row(1, 3) :: Nil)
checkAnswer(
df.select($"*", explode('intList)),
Row(1, Seq(1, 2, 3), 1) ::
Row(1, Seq(1, 2, 3), 2) ::
Row(1, Seq(1, 2, 3), 3) :: Nil)
}
test("aliased explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select(explode('intList).as('int)).select('int),
Row(1) :: Row(2) :: Row(3) :: Nil)
checkAnswer(
df.select(explode('intList).as('int)).select(sum('int)),
Row(6) :: Nil)
}
test("explode on map") {
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
checkAnswer(
df.select(explode('map)),
Row("a", "b"))
}
test("explode on map with aliases") {
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
checkAnswer(
df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
Row("a", "b"))
}
test("self join explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
val exploded = df.select(explode('intList).as('i))
checkAnswer(
exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
Row(3) :: Nil)
}
test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))

View file

@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
import testImplicits._
test("single explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select(explode('intList)),
Row(1) :: Row(2) :: Row(3) :: Nil)
}
test("single posexplode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select(posexplode('intList)),
Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil)
}
test("explode and other columns") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select($"a", explode('intList)),
Row(1, 1) ::
Row(1, 2) ::
Row(1, 3) :: Nil)
checkAnswer(
df.select($"*", explode('intList)),
Row(1, Seq(1, 2, 3), 1) ::
Row(1, Seq(1, 2, 3), 2) ::
Row(1, Seq(1, 2, 3), 3) :: Nil)
}
test("aliased explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
checkAnswer(
df.select(explode('intList).as('int)).select('int),
Row(1) :: Row(2) :: Row(3) :: Nil)
checkAnswer(
df.select(explode('intList).as('int)).select(sum('int)),
Row(6) :: Nil)
}
test("explode on map") {
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
checkAnswer(
df.select(explode('map)),
Row("a", "b"))
}
test("explode on map with aliases") {
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
checkAnswer(
df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
Row("a", "b"))
}
test("self join explode") {
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
val exploded = df.select(explode('intList).as('i))
checkAnswer(
exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
Row(3) :: Nil)
}
}

View file

@ -245,6 +245,6 @@ private[sql] class HiveSessionCatalog(
"xpath_number", "xpath_short", "xpath_string",
// table generating function
"inline", "posexplode"
"inline"
)
}