[SPARK-8230][SQL] Add array/map size method

Pull Request for: https://issues.apache.org/jira/browse/SPARK-8230

Primary issue resolved is to implement array/map size for Spark SQL. Code is ready for review by a committer. Chen Hao is on the JIRA ticket, but I don't know his username on github, rxin is also on JIRA ticket.

Things to review:
1. Where to put added functions namespace wise, they seem to be part of a few operations on collections which includes `sort_array` and `array_contains`. Hence the name given `collectionOperations.scala` and `_collection_functions` in python.
2. In Python code, should it be in a `1.5.0` function array or in a collections array?
3. Are there any missing methods on the `Size` case class? Looks like many of these functions have generated Java code, is that also needed in this case?
4. Something else?

Author: Pedro Rodriguez <ski.rodriguez@gmail.com>
Author: Pedro Rodriguez <prodriguez@trulia.com>

Closes #7462 from EntilZha/SPARK-8230 and squashes the following commits:

9a442ae [Pedro Rodriguez] fixed functions and sorted __all__
9aea3bb [Pedro Rodriguez] removed imports from python docs
15d4bf1 [Pedro Rodriguez] Added null test case and changed to nullSafeCodeGen
d88247c [Pedro Rodriguez] removed python code
bd5f0e4 [Pedro Rodriguez] removed duplicate function from rebase/merge
59931b4 [Pedro Rodriguez] fixed compile bug instroduced when merging
c187175 [Pedro Rodriguez] updated code to add size to __all__ directly and removed redundent pretty print
130839f [Pedro Rodriguez] fixed failing test
aa9bade [Pedro Rodriguez] fix style
e093473 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests
0449377 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations
9a1a2ff [Pedro Rodriguez] added unit tests for map size
2bfbcb6 [Pedro Rodriguez] added unit test for size
20df2b4 [Pedro Rodriguez] Finished working version of size function and added it to python
b503e75 [Pedro Rodriguez] First attempt at implementing size for maps and arrays
99a6a5c [Pedro Rodriguez] fixed failing test
cac75ac [Pedro Rodriguez] fix style
933d843 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests
42bb7d4 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations
f9c3b8a [Pedro Rodriguez] added unit tests for map size
2515d9f [Pedro Rodriguez] added documentation
0e60541 [Pedro Rodriguez] added unit test for size
acf9853 [Pedro Rodriguez] Finished working version of size function and added it to python
84a5d38 [Pedro Rodriguez] First attempt at implementing size for maps and arrays
This commit is contained in:
Pedro Rodriguez 2015-07-21 00:53:20 -07:00 committed by Reynold Xin
parent 8c8f0ef59e
commit 560c658a74
6 changed files with 152 additions and 1 deletions

View file

@ -50,6 +50,7 @@ __all__ = [
'regexp_replace',
'sha1',
'sha2',
'size',
'sparkPartitionId',
'struct',
'udf',
@ -825,6 +826,20 @@ def weekofyear(col):
return Column(sc._jvm.functions.weekofyear(col))
@since(1.5)
def size(col):
"""
Collection function: returns the length of the array or map stored in the column.
:param col: name of column or expression
>>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data'])
>>> df.select(size(df.data)).collect()
[Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.size(_to_java_column(col)))
class UserDefinedFunction(object):
"""
User defined function in Python

View file

@ -195,8 +195,10 @@ object FunctionRegistry {
expression[Quarter]("quarter"),
expression[Second]("second"),
expression[WeekOfYear]("weekofyear"),
expression[Year]("year")
expression[Year]("year"),
// collection functions
expression[Size]("size")
)
val builtin: FunctionRegistry = {

View file

@ -0,0 +1,37 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types._
/**
* Given an array or map, returns its size.
*/
case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
override def nullSafeEval(value: Any): Int = child.dataType match {
case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size
case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();")
}
}

View file

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.types._
class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
test("Array and Map Size") {
val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType))
checkEvaluation(Size(a0), 3)
checkEvaluation(Size(a1), 0)
checkEvaluation(Size(a2), 2)
val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType))
val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType))
checkEvaluation(Size(m0), 2)
checkEvaluation(Size(m1), 0)
checkEvaluation(Size(m2), 1)
checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null)
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
}
}

View file

@ -42,6 +42,7 @@ import org.apache.spark.util.Utils
* @groupname misc_funcs Misc functions
* @groupname window_funcs Window functions
* @groupname string_funcs String functions
* @groupname collection_funcs Collection functions
* @groupname Ungrouped Support functions for DataFrames.
* @since 1.3.0
*/
@ -2053,6 +2054,25 @@ object functions {
*/
def weekofyear(columnName: String): Column = weekofyear(Column(columnName))
//////////////////////////////////////////////////////////////////////////////////////////////
// Collection functions
//////////////////////////////////////////////////////////////////////////////////////////////
/**
* Returns length of array or map
* @group collection_funcs
* @since 1.5.0
*/
def size(columnName: String): Column = size(Column(columnName))
/**
* Returns length of array or map
* @group collection_funcs
* @since 1.5.0
*/
def size(column: Column): Column = Size(column.expr)
//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////

View file

@ -267,4 +267,35 @@ class DataFrameFunctionsSuite extends QueryTest {
)
}
test("array size function") {
val df = Seq(
(Array[Int](1, 2), "x"),
(Array[Int](), "y"),
(Array[Int](1, 2, 3), "z")
).toDF("a", "b")
checkAnswer(
df.select(size("a")),
Seq(Row(2), Row(0), Row(3))
)
checkAnswer(
df.selectExpr("size(a)"),
Seq(Row(2), Row(0), Row(3))
)
}
test("map size function") {
val df = Seq(
(Map[Int, Int](1 -> 1, 2 -> 2), "x"),
(Map[Int, Int](), "y"),
(Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z")
).toDF("a", "b")
checkAnswer(
df.select(size("a")),
Seq(Row(2), Row(0), Row(3))
)
checkAnswer(
df.selectExpr("size(a)"),
Seq(Row(2), Row(0), Row(3))
)
}
}