[SPARK-8230][SQL] Add array/map size method
Pull Request for: https://issues.apache.org/jira/browse/SPARK-8230 Primary issue resolved is to implement array/map size for Spark SQL. Code is ready for review by a committer. Chen Hao is on the JIRA ticket, but I don't know his username on github, rxin is also on JIRA ticket. Things to review: 1. Where to put added functions namespace wise, they seem to be part of a few operations on collections which includes `sort_array` and `array_contains`. Hence the name given `collectionOperations.scala` and `_collection_functions` in python. 2. In Python code, should it be in a `1.5.0` function array or in a collections array? 3. Are there any missing methods on the `Size` case class? Looks like many of these functions have generated Java code, is that also needed in this case? 4. Something else? Author: Pedro Rodriguez <ski.rodriguez@gmail.com> Author: Pedro Rodriguez <prodriguez@trulia.com> Closes #7462 from EntilZha/SPARK-8230 and squashes the following commits: 9a442ae [Pedro Rodriguez] fixed functions and sorted __all__ 9aea3bb [Pedro Rodriguez] removed imports from python docs 15d4bf1 [Pedro Rodriguez] Added null test case and changed to nullSafeCodeGen d88247c [Pedro Rodriguez] removed python code bd5f0e4 [Pedro Rodriguez] removed duplicate function from rebase/merge 59931b4 [Pedro Rodriguez] fixed compile bug instroduced when merging c187175 [Pedro Rodriguez] updated code to add size to __all__ directly and removed redundent pretty print 130839f [Pedro Rodriguez] fixed failing test aa9bade [Pedro Rodriguez] fix style e093473 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests 0449377 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations 9a1a2ff [Pedro Rodriguez] added unit tests for map size 2bfbcb6 [Pedro Rodriguez] added unit test for size 20df2b4 [Pedro Rodriguez] Finished working version of size function and added it to python b503e75 [Pedro Rodriguez] First attempt at implementing size for maps and arrays 99a6a5c [Pedro Rodriguez] fixed failing test cac75ac [Pedro Rodriguez] fix style 933d843 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests 42bb7d4 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations f9c3b8a [Pedro Rodriguez] added unit tests for map size 2515d9f [Pedro Rodriguez] added documentation 0e60541 [Pedro Rodriguez] added unit test for size acf9853 [Pedro Rodriguez] Finished working version of size function and added it to python 84a5d38 [Pedro Rodriguez] First attempt at implementing size for maps and arrays
This commit is contained in:
parent
8c8f0ef59e
commit
560c658a74
|
@ -50,6 +50,7 @@ __all__ = [
|
|||
'regexp_replace',
|
||||
'sha1',
|
||||
'sha2',
|
||||
'size',
|
||||
'sparkPartitionId',
|
||||
'struct',
|
||||
'udf',
|
||||
|
@ -825,6 +826,20 @@ def weekofyear(col):
|
|||
return Column(sc._jvm.functions.weekofyear(col))
|
||||
|
||||
|
||||
@since(1.5)
|
||||
def size(col):
|
||||
"""
|
||||
Collection function: returns the length of the array or map stored in the column.
|
||||
:param col: name of column or expression
|
||||
|
||||
>>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])
|
||||
>>> df.select(size(df.data)).collect()
|
||||
[Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.size(_to_java_column(col)))
|
||||
|
||||
|
||||
class UserDefinedFunction(object):
|
||||
"""
|
||||
User defined function in Python
|
||||
|
|
|
@ -195,8 +195,10 @@ object FunctionRegistry {
|
|||
expression[Quarter]("quarter"),
|
||||
expression[Second]("second"),
|
||||
expression[WeekOfYear]("weekofyear"),
|
||||
expression[Year]("year")
|
||||
expression[Year]("year"),
|
||||
|
||||
// collection functions
|
||||
expression[Size]("size")
|
||||
)
|
||||
|
||||
val builtin: FunctionRegistry = {
|
||||
|
|
|
@ -0,0 +1,37 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
/**
|
||||
* Given an array or map, returns its size.
|
||||
*/
|
||||
case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {
|
||||
override def dataType: DataType = IntegerType
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
|
||||
|
||||
override def nullSafeEval(value: Any): Int = child.dataType match {
|
||||
case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size
|
||||
case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size
|
||||
}
|
||||
|
||||
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();")
|
||||
}
|
||||
}
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
||||
class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
||||
|
||||
test("Array and Map Size") {
|
||||
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
|
||||
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
|
||||
val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
|
||||
|
||||
checkEvaluation(Size(a0), 3)
|
||||
checkEvaluation(Size(a1), 0)
|
||||
checkEvaluation(Size(a2), 2)
|
||||
|
||||
val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType))
|
||||
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
|
||||
val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType))
|
||||
|
||||
checkEvaluation(Size(m0), 2)
|
||||
checkEvaluation(Size(m1), 0)
|
||||
checkEvaluation(Size(m2), 1)
|
||||
|
||||
checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
|
||||
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
|
||||
}
|
||||
}
|
|
@ -42,6 +42,7 @@ import org.apache.spark.util.Utils
|
|||
* @groupname misc_funcs Misc functions
|
||||
* @groupname window_funcs Window functions
|
||||
* @groupname string_funcs String functions
|
||||
* @groupname collection_funcs Collection functions
|
||||
* @groupname Ungrouped Support functions for DataFrames.
|
||||
* @since 1.3.0
|
||||
*/
|
||||
|
@ -2053,6 +2054,25 @@ object functions {
|
|||
*/
|
||||
def weekofyear(columnName: String): Column = weekofyear(Column(columnName))
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Collection functions
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/**
|
||||
* Returns length of array or map
|
||||
* @group collection_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def size(columnName: String): Column = size(Column(columnName))
|
||||
|
||||
/**
|
||||
* Returns length of array or map
|
||||
* @group collection_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def size(column: Column): Column = Size(column.expr)
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
|
|
|
@ -267,4 +267,35 @@ class DataFrameFunctionsSuite extends QueryTest {
|
|||
)
|
||||
}
|
||||
|
||||
test("array size function") {
|
||||
val df = Seq(
|
||||
(Array[Int](1, 2), "x"),
|
||||
(Array[Int](), "y"),
|
||||
(Array[Int](1, 2, 3), "z")
|
||||
).toDF("a", "b")
|
||||
checkAnswer(
|
||||
df.select(size("a")),
|
||||
Seq(Row(2), Row(0), Row(3))
|
||||
)
|
||||
checkAnswer(
|
||||
df.selectExpr("size(a)"),
|
||||
Seq(Row(2), Row(0), Row(3))
|
||||
)
|
||||
}
|
||||
|
||||
test("map size function") {
|
||||
val df = Seq(
|
||||
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),
|
||||
(Map[Int, Int](), "y"),
|
||||
(Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z")
|
||||
).toDF("a", "b")
|
||||
checkAnswer(
|
||||
df.select(size("a")),
|
||||
Seq(Row(2), Row(0), Row(3))
|
||||
)
|
||||
checkAnswer(
|
||||
df.selectExpr("size(a)"),
|
||||
Seq(Row(2), Row(0), Row(3))
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue