From 560c658a7462844c698b5bda09a4cfb4094fd65b Mon Sep 17 00:00:00 2001 From: Pedro Rodriguez Date: Tue, 21 Jul 2015 00:53:20 -0700 Subject: [PATCH] [SPARK-8230][SQL] Add array/map size method Pull Request for: https://issues.apache.org/jira/browse/SPARK-8230 Primary issue resolved is to implement array/map size for Spark SQL. Code is ready for review by a committer. Chen Hao is on the JIRA ticket, but I don't know his username on github, rxin is also on JIRA ticket. Things to review: 1. Where to put added functions namespace wise, they seem to be part of a few operations on collections which includes `sort_array` and `array_contains`. Hence the name given `collectionOperations.scala` and `_collection_functions` in python. 2. In Python code, should it be in a `1.5.0` function array or in a collections array? 3. Are there any missing methods on the `Size` case class? Looks like many of these functions have generated Java code, is that also needed in this case? 4. Something else? Author: Pedro Rodriguez Author: Pedro Rodriguez Closes #7462 from EntilZha/SPARK-8230 and squashes the following commits: 9a442ae [Pedro Rodriguez] fixed functions and sorted __all__ 9aea3bb [Pedro Rodriguez] removed imports from python docs 15d4bf1 [Pedro Rodriguez] Added null test case and changed to nullSafeCodeGen d88247c [Pedro Rodriguez] removed python code bd5f0e4 [Pedro Rodriguez] removed duplicate function from rebase/merge 59931b4 [Pedro Rodriguez] fixed compile bug instroduced when merging c187175 [Pedro Rodriguez] updated code to add size to __all__ directly and removed redundent pretty print 130839f [Pedro Rodriguez] fixed failing test aa9bade [Pedro Rodriguez] fix style e093473 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests 0449377 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations 9a1a2ff [Pedro Rodriguez] added unit tests for map size 2bfbcb6 [Pedro Rodriguez] added unit test for size 20df2b4 [Pedro Rodriguez] Finished working version of size function and added it to python b503e75 [Pedro Rodriguez] First attempt at implementing size for maps and arrays 99a6a5c [Pedro Rodriguez] fixed failing test cac75ac [Pedro Rodriguez] fix style 933d843 [Pedro Rodriguez] updated python code with docs, switched classes/traits implemented, added (failing) expression tests 42bb7d4 [Pedro Rodriguez] refactored code to use better abstract classes/traits and implementations f9c3b8a [Pedro Rodriguez] added unit tests for map size 2515d9f [Pedro Rodriguez] added documentation 0e60541 [Pedro Rodriguez] added unit test for size acf9853 [Pedro Rodriguez] Finished working version of size function and added it to python 84a5d38 [Pedro Rodriguez] First attempt at implementing size for maps and arrays --- python/pyspark/sql/functions.py | 15 ++++++ .../catalyst/analysis/FunctionRegistry.scala | 4 +- .../expressions/collectionOperations.scala | 37 +++++++++++++++ .../CollectionFunctionsSuite.scala | 46 +++++++++++++++++++ .../org/apache/spark/sql/functions.scala | 20 ++++++++ .../spark/sql/DataFrameFunctionsSuite.scala | 31 +++++++++++++ 6 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3c134faa0a..719e623a1a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -50,6 +50,7 @@ __all__ = [ 'regexp_replace', 'sha1', 'sha2', + 'size', 'sparkPartitionId', 'struct', 'udf', @@ -825,6 +826,20 @@ def weekofyear(col): return Column(sc._jvm.functions.weekofyear(col)) +@since(1.5) +def size(col): + """ + Collection function: returns the length of the array or map stored in the column. + :param col: name of column or expression + + >>> df = sqlContext.createDataFrame([([1, 2, 3],),([1],),([],)], ['data']) + >>> df.select(size(df.data)).collect() + [Row(size(data)=3), Row(size(data)=1), Row(size(data)=0)] + """ + sc = SparkContext._active_spark_context + return Column(sc._jvm.functions.size(_to_java_column(col))) + + class UserDefinedFunction(object): """ User defined function in Python diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index aec392379c..13523720da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -195,8 +195,10 @@ object FunctionRegistry { expression[Quarter]("quarter"), expression[Second]("second"), expression[WeekOfYear]("weekofyear"), - expression[Year]("year") + expression[Year]("year"), + // collection functions + expression[Size]("size") ) val builtin: FunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala new file mode 100644 index 0000000000..2d92dcf23a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.types._ + +/** + * Given an array or map, returns its size. + */ +case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def dataType: DataType = IntegerType + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) + + override def nullSafeEval(value: Any): Int = child.dataType match { + case ArrayType(_, _) => value.asInstanceOf[Seq[Any]].size + case MapType(_, _, _) => value.asInstanceOf[Map[Any, Any]].size + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).size();") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala new file mode 100644 index 0000000000..28c41b5716 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.types._ + + +class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("Array and Map Size") { + val a0 = Literal.create(Seq(1, 2, 3), ArrayType(IntegerType)) + val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType)) + val a2 = Literal.create(Seq(1, 2), ArrayType(IntegerType)) + + checkEvaluation(Size(a0), 3) + checkEvaluation(Size(a1), 0) + checkEvaluation(Size(a2), 2) + + val m0 = Literal.create(Map("a" -> "a", "b" -> "b"), MapType(StringType, StringType)) + val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType)) + val m2 = Literal.create(Map("a" -> "a"), MapType(StringType, StringType)) + + checkEvaluation(Size(m0), 2) + checkEvaluation(Size(m1), 0) + checkEvaluation(Size(m2), 1) + + checkEvaluation(Literal.create(null, MapType(StringType, StringType)), null) + checkEvaluation(Literal.create(null, ArrayType(StringType)), null) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 6d60dae624..60b089180c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -42,6 +42,7 @@ import org.apache.spark.util.Utils * @groupname misc_funcs Misc functions * @groupname window_funcs Window functions * @groupname string_funcs String functions + * @groupname collection_funcs Collection functions * @groupname Ungrouped Support functions for DataFrames. * @since 1.3.0 */ @@ -2053,6 +2054,25 @@ object functions { */ def weekofyear(columnName: String): Column = weekofyear(Column(columnName)) + ////////////////////////////////////////////////////////////////////////////////////////////// + // Collection functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Returns length of array or map + * @group collection_funcs + * @since 1.5.0 + */ + def size(columnName: String): Column = size(Column(columnName)) + + /** + * Returns length of array or map + * @group collection_funcs + * @since 1.5.0 + */ + def size(column: Column): Column = Size(column.expr) + + ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 8d2ff2f969..1baec5d376 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -267,4 +267,35 @@ class DataFrameFunctionsSuite extends QueryTest { ) } + test("array size function") { + val df = Seq( + (Array[Int](1, 2), "x"), + (Array[Int](), "y"), + (Array[Int](1, 2, 3), "z") + ).toDF("a", "b") + checkAnswer( + df.select(size("a")), + Seq(Row(2), Row(0), Row(3)) + ) + checkAnswer( + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) + ) + } + + test("map size function") { + val df = Seq( + (Map[Int, Int](1 -> 1, 2 -> 2), "x"), + (Map[Int, Int](), "y"), + (Map[Int, Int](1 -> 1, 2 -> 2, 3 -> 3), "z") + ).toDF("a", "b") + checkAnswer( + df.select(size("a")), + Seq(Row(2), Row(0), Row(3)) + ) + checkAnswer( + df.selectExpr("size(a)"), + Seq(Row(2), Row(0), Row(3)) + ) + } }