[SPARK-16289][SQL] Implement posexplode table generating function
## What changes were proposed in this pull request? This PR implements `posexplode` table generating function. Currently, master branch raises the following exception for `map` argument. It's different from Hive. **Before** ```scala scala> sql("select posexplode(map('a', 1, 'b', 2))").show org.apache.spark.sql.AnalysisException: No handler for Hive UDF ... posexplode() takes an array as a parameter; line 1 pos 7 ``` **After** ```scala scala> sql("select posexplode(map('a', 1, 'b', 2))").show +---+---+-----+ |pos|key|value| +---+---+-----+ | 0| a| 1| | 1| b| 2| +---+---+-----+ ``` For `array` argument, `after` is the same with `before`. ``` scala> sql("select posexplode(array(1, 2, 3))").show +---+---+ |pos|col| +---+---+ | 0| 1| | 1| 2| | 2| 3| +---+---+ ``` ## How was this patch tested? Pass the Jenkins tests with newly added testcases. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13971 from dongjoon-hyun/SPARK-16289.
This commit is contained in:
parent
fdf9f94f8c
commit
46395db80e
|
@ -234,6 +234,7 @@ exportMethods("%in%",
|
|||
"over",
|
||||
"percent_rank",
|
||||
"pmod",
|
||||
"posexplode",
|
||||
"quarter",
|
||||
"rand",
|
||||
"randn",
|
||||
|
|
|
@ -2934,3 +2934,20 @@ setMethod("sort_array",
|
|||
jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc)
|
||||
column(jc)
|
||||
})
|
||||
|
||||
#' posexplode
|
||||
#'
|
||||
#' Creates a new row for each element with position in the given array or map column.
|
||||
#'
|
||||
#' @rdname posexplode
|
||||
#' @name posexplode
|
||||
#' @family collection_funcs
|
||||
#' @export
|
||||
#' @examples \dontrun{posexplode(df$c)}
|
||||
#' @note posexplode since 2.1.0
|
||||
setMethod("posexplode",
|
||||
signature(x = "Column"),
|
||||
function(x) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions", "posexplode", x@jc)
|
||||
column(jc)
|
||||
})
|
||||
|
|
|
@ -1050,6 +1050,10 @@ setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") })
|
|||
#' @export
|
||||
setGeneric("pmod", function(y, x) { standardGeneric("pmod") })
|
||||
|
||||
#' @rdname posexplode
|
||||
#' @export
|
||||
setGeneric("posexplode", function(x) { standardGeneric("posexplode") })
|
||||
|
||||
#' @rdname quarter
|
||||
#' @export
|
||||
setGeneric("quarter", function(x) { standardGeneric("quarter") })
|
||||
|
|
|
@ -1065,7 +1065,7 @@ test_that("column functions", {
|
|||
c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c)
|
||||
c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c)
|
||||
c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c)
|
||||
c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c)
|
||||
c7 <- mean(c) + min(c) + month(c) + negate(c) + posexplode(c) + quarter(c)
|
||||
c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + monotonically_increasing_id()
|
||||
c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c)
|
||||
c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c)
|
||||
|
|
|
@ -1637,6 +1637,27 @@ def explode(col):
|
|||
return Column(jc)
|
||||
|
||||
|
||||
@since(2.1)
|
||||
def posexplode(col):
|
||||
"""Returns a new row for each element with position in the given array or map.
|
||||
|
||||
>>> from pyspark.sql import Row
|
||||
>>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
|
||||
>>> eDF.select(posexplode(eDF.intlist)).collect()
|
||||
[Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)]
|
||||
|
||||
>>> eDF.select(posexplode(eDF.mapfield)).show()
|
||||
+---+---+-----+
|
||||
|pos|key|value|
|
||||
+---+---+-----+
|
||||
| 0| a| b|
|
||||
+---+---+-----+
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = sc._jvm.functions.posexplode(_to_java_column(col))
|
||||
return Column(jc)
|
||||
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(1.6)
|
||||
def get_json_object(col, path):
|
||||
|
|
|
@ -176,6 +176,7 @@ object FunctionRegistry {
|
|||
expression[NullIf]("nullif"),
|
||||
expression[Nvl]("nvl"),
|
||||
expression[Nvl2]("nvl2"),
|
||||
expression[PosExplode]("posexplode"),
|
||||
expression[Rand]("rand"),
|
||||
expression[Randn]("randn"),
|
||||
expression[CreateStruct]("struct"),
|
||||
|
|
|
@ -94,13 +94,10 @@ case class UserDefinedGenerator(
|
|||
}
|
||||
|
||||
/**
|
||||
* Given an input array produces a sequence of rows for each value in the array.
|
||||
* A base class for Explode and PosExplode
|
||||
*/
|
||||
// scalastyle:off line.size.limit
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.")
|
||||
// scalastyle:on line.size.limit
|
||||
case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
|
||||
abstract class ExplodeBase(child: Expression, position: Boolean)
|
||||
extends UnaryExpression with Generator with CodegenFallback with Serializable {
|
||||
|
||||
override def children: Seq[Expression] = child :: Nil
|
||||
|
||||
|
@ -115,9 +112,26 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
|
|||
|
||||
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
|
||||
override def elementSchema: StructType = child.dataType match {
|
||||
case ArrayType(et, containsNull) => new StructType().add("col", et, containsNull)
|
||||
case ArrayType(et, containsNull) =>
|
||||
if (position) {
|
||||
new StructType()
|
||||
.add("pos", IntegerType, false)
|
||||
.add("col", et, containsNull)
|
||||
} else {
|
||||
new StructType()
|
||||
.add("col", et, containsNull)
|
||||
}
|
||||
case MapType(kt, vt, valueContainsNull) =>
|
||||
new StructType().add("key", kt, false).add("value", vt, valueContainsNull)
|
||||
if (position) {
|
||||
new StructType()
|
||||
.add("pos", IntegerType, false)
|
||||
.add("key", kt, false)
|
||||
.add("value", vt, valueContainsNull)
|
||||
} else {
|
||||
new StructType()
|
||||
.add("key", kt, false)
|
||||
.add("value", vt, valueContainsNull)
|
||||
}
|
||||
}
|
||||
|
||||
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
|
||||
|
@ -129,7 +143,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
|
|||
} else {
|
||||
val rows = new Array[InternalRow](inputArray.numElements())
|
||||
inputArray.foreach(et, (i, e) => {
|
||||
rows(i) = InternalRow(e)
|
||||
rows(i) = if (position) InternalRow(i, e) else InternalRow(e)
|
||||
})
|
||||
rows
|
||||
}
|
||||
|
@ -141,7 +155,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
|
|||
val rows = new Array[InternalRow](inputMap.numElements())
|
||||
var i = 0
|
||||
inputMap.foreach(kt, vt, (k, v) => {
|
||||
rows(i) = InternalRow(k, v)
|
||||
rows(i) = if (position) InternalRow(i, k, v) else InternalRow(k, v)
|
||||
i += 1
|
||||
})
|
||||
rows
|
||||
|
@ -149,3 +163,35 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Given an input array produces a sequence of rows for each value in the array.
|
||||
*
|
||||
* {{{
|
||||
* SELECT explode(array(10,20)) ->
|
||||
* 10
|
||||
* 20
|
||||
* }}}
|
||||
*/
|
||||
// scalastyle:off line.size.limit
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of map a into multiple rows and columns.",
|
||||
extended = "> SELECT _FUNC_(array(10,20));\n 10\n 20")
|
||||
// scalastyle:on line.size.limit
|
||||
case class Explode(child: Expression) extends ExplodeBase(child, position = false)
|
||||
|
||||
/**
|
||||
* Given an input array produces a sequence of rows for each position and value in the array.
|
||||
*
|
||||
* {{{
|
||||
* SELECT posexplode(array(10,20)) ->
|
||||
* 0 10
|
||||
* 1 20
|
||||
* }}}
|
||||
*/
|
||||
// scalastyle:off line.size.limit
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(a) - Separates the elements of array a into multiple rows with positions, or the elements of a map into multiple rows and columns with positions.",
|
||||
extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
|
||||
// scalastyle:on line.size.limit
|
||||
case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
|
||||
|
|
|
@ -166,6 +166,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
|
|||
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
|
||||
assertError(Explode('intField),
|
||||
"input to function explode should be array or map type")
|
||||
assertError(PosExplode('intField),
|
||||
"input to function explode should be array or map type")
|
||||
}
|
||||
|
||||
test("check types for CreateNamedStruct") {
|
||||
|
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||
private def checkTuple(actual: ExplodeBase, expected: Seq[InternalRow]): Unit = {
|
||||
assert(actual.eval(null).toSeq === expected)
|
||||
}
|
||||
|
||||
private final val int_array = Seq(1, 2, 3)
|
||||
private final val str_array = Seq("a", "b", "c")
|
||||
|
||||
test("explode") {
|
||||
val int_correct_answer = Seq(Seq(1), Seq(2), Seq(3))
|
||||
val str_correct_answer = Seq(
|
||||
Seq(UTF8String.fromString("a")),
|
||||
Seq(UTF8String.fromString("b")),
|
||||
Seq(UTF8String.fromString("c")))
|
||||
|
||||
checkTuple(
|
||||
Explode(CreateArray(Seq.empty)),
|
||||
Seq.empty)
|
||||
|
||||
checkTuple(
|
||||
Explode(CreateArray(int_array.map(Literal(_)))),
|
||||
int_correct_answer.map(InternalRow.fromSeq(_)))
|
||||
|
||||
checkTuple(
|
||||
Explode(CreateArray(str_array.map(Literal(_)))),
|
||||
str_correct_answer.map(InternalRow.fromSeq(_)))
|
||||
}
|
||||
|
||||
test("posexplode") {
|
||||
val int_correct_answer = Seq(Seq(0, 1), Seq(1, 2), Seq(2, 3))
|
||||
val str_correct_answer = Seq(
|
||||
Seq(0, UTF8String.fromString("a")),
|
||||
Seq(1, UTF8String.fromString("b")),
|
||||
Seq(2, UTF8String.fromString("c")))
|
||||
|
||||
checkTuple(
|
||||
PosExplode(CreateArray(Seq.empty)),
|
||||
Seq.empty)
|
||||
|
||||
checkTuple(
|
||||
PosExplode(CreateArray(int_array.map(Literal(_)))),
|
||||
int_correct_answer.map(InternalRow.fromSeq(_)))
|
||||
|
||||
checkTuple(
|
||||
PosExplode(CreateArray(str_array.map(Literal(_)))),
|
||||
str_correct_answer.map(InternalRow.fromSeq(_)))
|
||||
}
|
||||
}
|
|
@ -159,6 +159,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
|
|||
// Leave an unaliased generator with an empty list of names since the analyzer will generate
|
||||
// the correct defaults after the nested expression's type has been resolved.
|
||||
case explode: Explode => MultiAlias(explode, Nil)
|
||||
case explode: PosExplode => MultiAlias(explode, Nil)
|
||||
|
||||
case jt: JsonTuple => MultiAlias(jt, Nil)
|
||||
|
||||
|
|
|
@ -2721,6 +2721,14 @@ object functions {
|
|||
*/
|
||||
def explode(e: Column): Column = withExpr { Explode(e.expr) }
|
||||
|
||||
/**
|
||||
* Creates a new row for each element with position in the given array or map column.
|
||||
*
|
||||
* @group collection_funcs
|
||||
* @since 2.1.0
|
||||
*/
|
||||
def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) }
|
||||
|
||||
/**
|
||||
* Extracts json object from a json string based on json path specified, and returns json string
|
||||
* of the extracted json object. It will return null if the input json string is invalid.
|
||||
|
|
|
@ -122,66 +122,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
|
|||
assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
|
||||
}
|
||||
|
||||
test("single explode") {
|
||||
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
|
||||
checkAnswer(
|
||||
df.select(explode('intList)),
|
||||
Row(1) :: Row(2) :: Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("explode and other columns") {
|
||||
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
|
||||
|
||||
checkAnswer(
|
||||
df.select($"a", explode('intList)),
|
||||
Row(1, 1) ::
|
||||
Row(1, 2) ::
|
||||
Row(1, 3) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
df.select($"*", explode('intList)),
|
||||
Row(1, Seq(1, 2, 3), 1) ::
|
||||
Row(1, Seq(1, 2, 3), 2) ::
|
||||
Row(1, Seq(1, 2, 3), 3) :: Nil)
|
||||
}
|
||||
|
||||
test("aliased explode") {
|
||||
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
|
||||
|
||||
checkAnswer(
|
||||
df.select(explode('intList).as('int)).select('int),
|
||||
Row(1) :: Row(2) :: Row(3) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
df.select(explode('intList).as('int)).select(sum('int)),
|
||||
Row(6) :: Nil)
|
||||
}
|
||||
|
||||
test("explode on map") {
|
||||
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
|
||||
|
||||
checkAnswer(
|
||||
df.select(explode('map)),
|
||||
Row("a", "b"))
|
||||
}
|
||||
|
||||
test("explode on map with aliases") {
|
||||
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
|
||||
|
||||
checkAnswer(
|
||||
df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
|
||||
Row("a", "b"))
|
||||
}
|
||||
|
||||
test("self join explode") {
|
||||
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
|
||||
val exploded = df.select(explode('intList).as('i))
|
||||
|
||||
checkAnswer(
|
||||
exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
|
||||
Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("collect on column produced by a binary operator") {
|
||||
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
|
||||
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
|
||||
class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
test("single explode") {
|
||||
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
|
||||
checkAnswer(
|
||||
df.select(explode('intList)),
|
||||
Row(1) :: Row(2) :: Row(3) :: Nil)
|
||||
}
|
||||
|
||||
test("single posexplode") {
|
||||
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
|
||||
checkAnswer(
|
||||
df.select(posexplode('intList)),
|
||||
Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil)
|
||||
}
|
||||
|
||||
test("explode and other columns") {
|
||||
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
|
||||
|
||||
checkAnswer(
|
||||
df.select($"a", explode('intList)),
|
||||
Row(1, 1) ::
|
||||
Row(1, 2) ::
|
||||
Row(1, 3) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
df.select($"*", explode('intList)),
|
||||
Row(1, Seq(1, 2, 3), 1) ::
|
||||
Row(1, Seq(1, 2, 3), 2) ::
|
||||
Row(1, Seq(1, 2, 3), 3) :: Nil)
|
||||
}
|
||||
|
||||
test("aliased explode") {
|
||||
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
|
||||
|
||||
checkAnswer(
|
||||
df.select(explode('intList).as('int)).select('int),
|
||||
Row(1) :: Row(2) :: Row(3) :: Nil)
|
||||
|
||||
checkAnswer(
|
||||
df.select(explode('intList).as('int)).select(sum('int)),
|
||||
Row(6) :: Nil)
|
||||
}
|
||||
|
||||
test("explode on map") {
|
||||
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
|
||||
|
||||
checkAnswer(
|
||||
df.select(explode('map)),
|
||||
Row("a", "b"))
|
||||
}
|
||||
|
||||
test("explode on map with aliases") {
|
||||
val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
|
||||
|
||||
checkAnswer(
|
||||
df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
|
||||
Row("a", "b"))
|
||||
}
|
||||
|
||||
test("self join explode") {
|
||||
val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
|
||||
val exploded = df.select(explode('intList).as('i))
|
||||
|
||||
checkAnswer(
|
||||
exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
|
||||
Row(3) :: Nil)
|
||||
}
|
||||
}
|
|
@ -245,6 +245,6 @@ private[sql] class HiveSessionCatalog(
|
|||
"xpath_number", "xpath_short", "xpath_string",
|
||||
|
||||
// table generating function
|
||||
"inline", "posexplode"
|
||||
"inline"
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue