From 17781d75308c328b11cab3658ca4f358539414f2 Mon Sep 17 00:00:00 2001 From: Parker Hegstrom Date: Sat, 6 Oct 2018 14:30:43 +0800 Subject: [PATCH] [SPARK-25202][SQL] Implements split with limit sql function ## What changes were proposed in this pull request? Adds support for the setting limit in the sql split function ## How was this patch tested? 1. Updated unit tests 2. Tested using Scala spark shell Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22227 from phegstrom/master. Authored-by: Parker Hegstrom Signed-off-by: hyukjinkwon --- R/pkg/R/functions.R | 15 +++++-- R/pkg/R/generics.R | 2 +- R/pkg/tests/fulltests/test_sparkSQL.R | 8 ++++ .../apache/spark/unsafe/types/UTF8String.java | 6 +++ .../spark/unsafe/types/UTF8StringSuite.java | 14 +++--- python/pyspark/sql/functions.py | 28 +++++++++--- .../expressions/regexpExpressions.scala | 44 ++++++++++++++----- .../expressions/RegexpExpressionsSuite.scala | 15 +++++-- .../org/apache/spark/sql/functions.scala | 32 ++++++++++++-- .../sql-tests/inputs/string-functions.sql | 6 ++- .../results/string-functions.sql.out | 18 +++++++- .../spark/sql/StringFunctionsSuite.scala | 44 +++++++++++++++++-- 12 files changed, 189 insertions(+), 43 deletions(-) diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 2cb4cb8d53..6a8fef5aa7 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3473,13 +3473,21 @@ setMethod("collect_set", #' @details #' \code{split_string}: Splits string on regular expression. -#' Equivalent to \code{split} SQL function. +#' Equivalent to \code{split} SQL function. Optionally a +#' \code{limit} can be specified #' #' @rdname column_string_functions +#' @param limit determines the length of the returned array. +#' \itemize{ +#' \item \code{limit > 0}: length of the array will be at most \code{limit} +#' \item \code{limit <= 0}: the returned array can have any length +#' } +#' #' @aliases split_string split_string,Column-method #' @examples #' #' \dontrun{ +#' head(select(df, split_string(df$Class, "\\d", 2))) #' head(select(df, split_string(df$Sex, "a"))) #' head(select(df, split_string(df$Class, "\\d"))) #' # This is equivalent to the following SQL expression @@ -3487,8 +3495,9 @@ setMethod("collect_set", #' @note split_string 2.3.0 setMethod("split_string", signature(x = "Column", pattern = "character"), - function(x, pattern) { - jc <- callJStatic("org.apache.spark.sql.functions", "split", x@jc, pattern) + function(x, pattern, limit = -1) { + jc <- callJStatic("org.apache.spark.sql.functions", + "split", x@jc, pattern, as.integer(limit)) column(jc) }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 27c1b312d6..697d124095 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -1258,7 +1258,7 @@ setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array") #' @rdname column_string_functions #' @name NULL -setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") }) +setGeneric("split_string", function(x, pattern, ...) { standardGeneric("split_string") }) #' @rdname column_string_functions #' @name NULL diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 50eff3755e..5cc75aa3f3 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -1819,6 +1819,14 @@ test_that("string operators", { collect(select(df4, split_string(df4$a, "\\\\")))[1, 1], list(list("a.b@c.d 1", "b")) ) + expect_equal( + collect(select(df4, split_string(df4$a, "\\.", 2)))[1, 1], + list(list("a", "b@c.d 1\\b")) + ) + expect_equal( + collect(select(df4, split_string(df4$a, "b", 0)))[1, 1], + list(list("a.", "@c.d 1\\", "")) + ) l5 <- list(list(a = "abc")) df5 <- createDataFrame(l5) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index dff4a73f3e..3a3bfc4a94 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -958,6 +958,12 @@ public final class UTF8String implements Comparable, Externalizable, } public UTF8String[] split(UTF8String pattern, int limit) { + // Java String's split method supports "ignore empty string" behavior when the limit is 0 + // whereas other languages do not. To avoid this java specific behavior, we fall back to + // -1 when the limit is 0. + if (limit == 0) { + limit = -1; + } String[] splits = toString().split(pattern.toString(), limit); UTF8String[] res = new UTF8String[splits.length]; for (int i = 0; i < res.length; i++) { diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index dae13f03b0..cf9cc6b180 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -393,12 +393,14 @@ public class UTF8StringSuite { @Test public void split() { - assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1), - new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")})); - assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), - new UTF8String[]{fromString("ab"), fromString("def,ghi")})); - assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2), - new UTF8String[]{fromString("ab"), fromString("def,ghi")})); + UTF8String[] negativeAndZeroLimitCase = + new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi"), fromString("")}; + assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), 0), + negativeAndZeroLimitCase)); + assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), -1), + negativeAndZeroLimitCase)); + assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), 2), + new UTF8String[]{fromString("ab"), fromString("def,ghi,")})); } @Test diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3128d5792e..7685264b2d 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1691,18 +1691,32 @@ def repeat(col, n): @since(1.5) @ignore_unicode_prefix -def split(str, pattern): +def split(str, pattern, limit=-1): """ - Splits str around pattern (pattern is a regular expression). + Splits str around matches of the given pattern. - .. note:: pattern is a string represent the regular expression. + :param str: a string expression to split + :param pattern: a string representing a regular expression. The regex string should be + a Java regular expression. + :param limit: an integer which controls the number of times `pattern` is applied. - >>> df = spark.createDataFrame([('ab12cd',)], ['s',]) - >>> df.select(split(df.s, '[0-9]+').alias('s')).collect() - [Row(s=[u'ab', u'cd'])] + * ``limit > 0``: The resulting array's length will not be more than `limit`, and the + resulting array's last entry will contain all input beyond the last + matched pattern. + * ``limit <= 0``: `pattern` will be applied as many times as possible, and the resulting + array can be of any size. + + .. versionchanged:: 3.0 + `split` now takes an optional `limit` field. If not provided, default limit value is -1. + + >>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',]) + >>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect() + [Row(s=[u'one', u'twoBthreeC'])] + >>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect() + [Row(s=[u'one', u'two', u'three', u''])] """ sc = SparkContext._active_spark_context - return Column(sc._jvm.functions.split(_to_java_column(str), pattern)) + return Column(sc._jvm.functions.split(_to_java_column(str), pattern, limit)) @ignore_unicode_prefix diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index bf0c35fe61..4f5ea1e95f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -157,7 +157,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi arguments = """ Arguments: * str - a string expression - * regexp - a string expression. The pattern string should be a Java regular expression. + * regexp - a string expression. The regex string should be a Java regular expression. Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL parser. For example, to match "\abc", a regular expression for `regexp` can be @@ -229,33 +229,53 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress /** - * Splits str around pat (pattern is a regular expression). + * Splits str around matches of the given regex. */ @ExpressionDescription( - usage = "_FUNC_(str, regex) - Splits `str` around occurrences that match `regex`.", + usage = "_FUNC_(str, regex, limit) - Splits `str` around occurrences that match `regex`" + + " and returns an array with a length of at most `limit`", + arguments = """ + Arguments: + * str - a string expression to split. + * regex - a string representing a regular expression. The regex string should be a + Java regular expression. + * limit - an integer expression which controls the number of times the regex is applied. + * limit > 0: The resulting array's length will not be more than `limit`, + and the resulting array's last entry will contain all input + beyond the last matched regex. + * limit <= 0: `regex` will be applied as many times as possible, and + the resulting array can be of any size. + """, examples = """ Examples: > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]'); ["one","two","three",""] + > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', -1); + ["one","two","three",""] + > SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', 2); + ["one","twoBthreeC"] """) -case class StringSplit(str: Expression, pattern: Expression) - extends BinaryExpression with ImplicitCastInputTypes { +case class StringSplit(str: Expression, regex: Expression, limit: Expression) + extends TernaryExpression with ImplicitCastInputTypes { - override def left: Expression = str - override def right: Expression = pattern override def dataType: DataType = ArrayType(StringType) - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = str :: regex :: limit :: Nil - override def nullSafeEval(string: Any, regex: Any): Any = { - val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1) + def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1)); + + override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = { + val strings = string.asInstanceOf[UTF8String].split( + regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int]) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName - nullSafeCodeGen(ctx, ev, (str, pattern) => + nullSafeCodeGen(ctx, ev, (str, regex, limit) => { // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. - s"""${ev.value} = new $arrayClass($str.split($pattern, -1));""") + s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin + }) } override def prettyName: String = "split" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala index d532dc4f77..06fb73ad83 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/RegexpExpressionsSuite.scala @@ -225,11 +225,18 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row3 = create_row("aa2bb3cc", null) checkEvaluation( - StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1) + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), -1), Seq("aa", "bb", "cc"), row1) checkEvaluation( - StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1) - checkEvaluation(StringSplit(s1, s2), null, row2) - checkEvaluation(StringSplit(s1, s2), null, row3) + StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), 2), Seq("aa", "bb3cc"), row1) + // limit = 0 should behave just like limit = -1 + checkEvaluation( + StringSplit(Literal("aacbbcddc"), Literal("c"), 0), Seq("aa", "bb", "dd", ""), row1) + checkEvaluation( + StringSplit(Literal("aacbbcddc"), Literal("c"), -1), Seq("aa", "bb", "dd", ""), row1) + checkEvaluation( + StringSplit(s1, s2, -1), Seq("aa", "bb", "cc"), row1) + checkEvaluation(StringSplit(s1, s2, -1), null, row2) + checkEvaluation(StringSplit(s1, s2, -1), null, row3) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 367ac66dd7..4247d3110f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2546,15 +2546,39 @@ object functions { def soundex(e: Column): Column = withExpr { SoundEx(e.expr) } /** - * Splits str around pattern (pattern is a regular expression). + * Splits str around matches of the given regex. * - * @note Pattern is a string representation of the regular expression. + * @param str a string expression to split + * @param regex a string representing a regular expression. The regex string should be + * a Java regular expression. * * @group string_funcs * @since 1.5.0 */ - def split(str: Column, pattern: String): Column = withExpr { - StringSplit(str.expr, lit(pattern).expr) + def split(str: Column, regex: String): Column = withExpr { + StringSplit(str.expr, Literal(regex), Literal(-1)) + } + + /** + * Splits str around matches of the given regex. + * + * @param str a string expression to split + * @param regex a string representing a regular expression. The regex string should be + * a Java regular expression. + * @param limit an integer expression which controls the number of times the regex is applied. + *
    + *
  • limit greater than 0: The resulting array's length will not be more than limit, + * and the resulting array's last entry will contain all input beyond the last + * matched regex.
  • + *
  • limit less than or equal to 0: `regex` will be applied as many times as + * possible, and the resulting array can be of any size.
  • + *
+ * + * @group string_funcs + * @since 3.0.0 + */ + def split(str: Column, regex: String, limit: Int): Column = withExpr { + StringSplit(str.expr, Literal(regex), Literal(limit)) } /** diff --git a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql index 4113734e17..2effb43183 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/string-functions.sql @@ -46,4 +46,8 @@ FROM ( encode(string(id + 2), 'utf-8') col3, encode(string(id + 3), 'utf-8') col4 FROM range(10) -) +); + +-- split function +SELECT split('aa1cc2ee3', '[1-9]+'); +SELECT split('aa1cc2ee3', '[1-9]+', 2); diff --git a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out index 7b3dc84388..e8f2e0a814 100644 --- a/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/string-functions.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 17 -- !query 0 @@ -161,3 +161,19 @@ struct == Physical Plan == *Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x] +- *Range (0, 10, step=1, splits=2) + + +-- !query 15 +SELECT split('aa1cc2ee3', '[1-9]+') +-- !query 15 schema +struct> +-- !query 15 output +["aa","cc","ee",""] + + +-- !query 16 +SELECT split('aa1cc2ee3', '[1-9]+', 2) +-- !query 16 schema +struct> +-- !query 16 output +["aa","cc2ee3"] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 3d76b9ac33..bb19fde2b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -329,16 +329,52 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { Row(" ")) } - test("string split function") { - val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b") + test("string split function with no limit") { + val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b") checkAnswer( df.select(split($"a", "[1-9]+")), - Row(Seq("aa", "bb", "cc"))) + Row(Seq("aa", "bb", "cc", ""))) checkAnswer( df.selectExpr("split(a, '[1-9]+')"), - Row(Seq("aa", "bb", "cc"))) + Row(Seq("aa", "bb", "cc", ""))) + } + + test("string split function with limit explicitly set to 0") { + val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select(split($"a", "[1-9]+", 0)), + Row(Seq("aa", "bb", "cc", ""))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+', 0)"), + Row(Seq("aa", "bb", "cc", ""))) + } + + test("string split function with positive limit") { + val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select(split($"a", "[1-9]+", 2)), + Row(Seq("aa", "bb3cc4"))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+', 2)"), + Row(Seq("aa", "bb3cc4"))) + } + + test("string split function with negative limit") { + val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b") + + checkAnswer( + df.select(split($"a", "[1-9]+", -2)), + Row(Seq("aa", "bb", "cc", ""))) + + checkAnswer( + df.selectExpr("split(a, '[1-9]+', -2)"), + Row(Seq("aa", "bb", "cc", ""))) } test("string / binary length function") {