[SPARK-25202][SQL] Implements split with limit sql function

## What changes were proposed in this pull request?

Adds support for the setting limit in the sql split function

## How was this patch tested?

1. Updated unit tests
2. Tested using Scala spark shell

Please review http://spark.apache.org/contributing.html before opening a pull request.

Closes #22227 from phegstrom/master.

Authored-by: Parker Hegstrom <phegstrom@palantir.com>
Signed-off-by: hyukjinkwon <gurwls223@apache.org>
This commit is contained in:
Parker Hegstrom 2018-10-06 14:30:43 +08:00 committed by hyukjinkwon
parent 44cf800c83
commit 17781d7530
12 changed files with 189 additions and 43 deletions

View file

@ -3473,13 +3473,21 @@ setMethod("collect_set",
#' @details
#' \code{split_string}: Splits string on regular expression.
#' Equivalent to \code{split} SQL function.
#' Equivalent to \code{split} SQL function. Optionally a
#' \code{limit} can be specified
#'
#' @rdname column_string_functions
#' @param limit determines the length of the returned array.
#' \itemize{
#' \item \code{limit > 0}: length of the array will be at most \code{limit}
#' \item \code{limit <= 0}: the returned array can have any length
#' }
#'
#' @aliases split_string split_string,Column-method
#' @examples
#'
#' \dontrun{
#' head(select(df, split_string(df$Class, "\\d", 2)))
#' head(select(df, split_string(df$Sex, "a")))
#' head(select(df, split_string(df$Class, "\\d")))
#' # This is equivalent to the following SQL expression
@ -3487,8 +3495,9 @@ setMethod("collect_set",
#' @note split_string 2.3.0
setMethod("split_string",
signature(x = "Column", pattern = "character"),
function(x, pattern) {
jc <- callJStatic("org.apache.spark.sql.functions", "split", x@jc, pattern)
function(x, pattern, limit = -1) {
jc <- callJStatic("org.apache.spark.sql.functions",
"split", x@jc, pattern, as.integer(limit))
column(jc)
})

View file

@ -1258,7 +1258,7 @@ setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array")
#' @rdname column_string_functions
#' @name NULL
setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") })
setGeneric("split_string", function(x, pattern, ...) { standardGeneric("split_string") })
#' @rdname column_string_functions
#' @name NULL

View file

@ -1819,6 +1819,14 @@ test_that("string operators", {
collect(select(df4, split_string(df4$a, "\\\\")))[1, 1],
list(list("a.b@c.d 1", "b"))
)
expect_equal(
collect(select(df4, split_string(df4$a, "\\.", 2)))[1, 1],
list(list("a", "b@c.d 1\\b"))
)
expect_equal(
collect(select(df4, split_string(df4$a, "b", 0)))[1, 1],
list(list("a.", "@c.d 1\\", ""))
)
l5 <- list(list(a = "abc"))
df5 <- createDataFrame(l5)

View file

@ -958,6 +958,12 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
}
public UTF8String[] split(UTF8String pattern, int limit) {
// Java String's split method supports "ignore empty string" behavior when the limit is 0
// whereas other languages do not. To avoid this java specific behavior, we fall back to
// -1 when the limit is 0.
if (limit == 0) {
limit = -1;
}
String[] splits = toString().split(pattern.toString(), limit);
UTF8String[] res = new UTF8String[splits.length];
for (int i = 0; i < res.length; i++) {

View file

@ -393,12 +393,14 @@ public class UTF8StringSuite {
@Test
public void split() {
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1),
new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")}));
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2),
new UTF8String[]{fromString("ab"), fromString("def,ghi")}));
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2),
new UTF8String[]{fromString("ab"), fromString("def,ghi")}));
UTF8String[] negativeAndZeroLimitCase =
new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi"), fromString("")};
assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), 0),
negativeAndZeroLimitCase));
assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), -1),
negativeAndZeroLimitCase));
assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), 2),
new UTF8String[]{fromString("ab"), fromString("def,ghi,")}));
}
@Test

View file

@ -1691,18 +1691,32 @@ def repeat(col, n):
@since(1.5)
@ignore_unicode_prefix
def split(str, pattern):
def split(str, pattern, limit=-1):
"""
Splits str around pattern (pattern is a regular expression).
Splits str around matches of the given pattern.
.. note:: pattern is a string represent the regular expression.
:param str: a string expression to split
:param pattern: a string representing a regular expression. The regex string should be
a Java regular expression.
:param limit: an integer which controls the number of times `pattern` is applied.
>>> df = spark.createDataFrame([('ab12cd',)], ['s',])
>>> df.select(split(df.s, '[0-9]+').alias('s')).collect()
[Row(s=[u'ab', u'cd'])]
* ``limit > 0``: The resulting array's length will not be more than `limit`, and the
resulting array's last entry will contain all input beyond the last
matched pattern.
* ``limit <= 0``: `pattern` will be applied as many times as possible, and the resulting
array can be of any size.
.. versionchanged:: 3.0
`split` now takes an optional `limit` field. If not provided, default limit value is -1.
>>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',])
>>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect()
[Row(s=[u'one', u'twoBthreeC'])]
>>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect()
[Row(s=[u'one', u'two', u'three', u''])]
"""
sc = SparkContext._active_spark_context
return Column(sc._jvm.functions.split(_to_java_column(str), pattern))
return Column(sc._jvm.functions.split(_to_java_column(str), pattern, limit))
@ignore_unicode_prefix

View file

@ -157,7 +157,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
arguments = """
Arguments:
* str - a string expression
* regexp - a string expression. The pattern string should be a Java regular expression.
* regexp - a string expression. The regex string should be a Java regular expression.
Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL
parser. For example, to match "\abc", a regular expression for `regexp` can be
@ -229,33 +229,53 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
/**
* Splits str around pat (pattern is a regular expression).
* Splits str around matches of the given regex.
*/
@ExpressionDescription(
usage = "_FUNC_(str, regex) - Splits `str` around occurrences that match `regex`.",
usage = "_FUNC_(str, regex, limit) - Splits `str` around occurrences that match `regex`" +
" and returns an array with a length of at most `limit`",
arguments = """
Arguments:
* str - a string expression to split.
* regex - a string representing a regular expression. The regex string should be a
Java regular expression.
* limit - an integer expression which controls the number of times the regex is applied.
* limit > 0: The resulting array's length will not be more than `limit`,
and the resulting array's last entry will contain all input
beyond the last matched regex.
* limit <= 0: `regex` will be applied as many times as possible, and
the resulting array can be of any size.
""",
examples = """
Examples:
> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]');
["one","two","three",""]
> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', -1);
["one","two","three",""]
> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', 2);
["one","twoBthreeC"]
""")
case class StringSplit(str: Expression, pattern: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
case class StringSplit(str: Expression, regex: Expression, limit: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
override def left: Expression = str
override def right: Expression = pattern
override def dataType: DataType = ArrayType(StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
override def children: Seq[Expression] = str :: regex :: limit :: Nil
override def nullSafeEval(string: Any, regex: Any): Any = {
val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1)
def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1));
override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = {
val strings = string.asInstanceOf[UTF8String].split(
regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int])
new GenericArrayData(strings.asInstanceOf[Array[Any]])
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, pattern) =>
nullSafeCodeGen(ctx, ev, (str, regex, limit) => {
// Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
s"""${ev.value} = new $arrayClass($str.split($pattern, -1));""")
s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin
})
}
override def prettyName: String = "split"

View file

@ -225,11 +225,18 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val row3 = create_row("aa2bb3cc", null)
checkEvaluation(
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1)
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), -1), Seq("aa", "bb", "cc"), row1)
checkEvaluation(
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
checkEvaluation(StringSplit(s1, s2), null, row2)
checkEvaluation(StringSplit(s1, s2), null, row3)
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), 2), Seq("aa", "bb3cc"), row1)
// limit = 0 should behave just like limit = -1
checkEvaluation(
StringSplit(Literal("aacbbcddc"), Literal("c"), 0), Seq("aa", "bb", "dd", ""), row1)
checkEvaluation(
StringSplit(Literal("aacbbcddc"), Literal("c"), -1), Seq("aa", "bb", "dd", ""), row1)
checkEvaluation(
StringSplit(s1, s2, -1), Seq("aa", "bb", "cc"), row1)
checkEvaluation(StringSplit(s1, s2, -1), null, row2)
checkEvaluation(StringSplit(s1, s2, -1), null, row3)
}
}

View file

@ -2546,15 +2546,39 @@ object functions {
def soundex(e: Column): Column = withExpr { SoundEx(e.expr) }
/**
* Splits str around pattern (pattern is a regular expression).
* Splits str around matches of the given regex.
*
* @note Pattern is a string representation of the regular expression.
* @param str a string expression to split
* @param regex a string representing a regular expression. The regex string should be
* a Java regular expression.
*
* @group string_funcs
* @since 1.5.0
*/
def split(str: Column, pattern: String): Column = withExpr {
StringSplit(str.expr, lit(pattern).expr)
def split(str: Column, regex: String): Column = withExpr {
StringSplit(str.expr, Literal(regex), Literal(-1))
}
/**
* Splits str around matches of the given regex.
*
* @param str a string expression to split
* @param regex a string representing a regular expression. The regex string should be
* a Java regular expression.
* @param limit an integer expression which controls the number of times the regex is applied.
* <ul>
* <li>limit greater than 0: The resulting array's length will not be more than limit,
* and the resulting array's last entry will contain all input beyond the last
* matched regex.</li>
* <li>limit less than or equal to 0: `regex` will be applied as many times as
* possible, and the resulting array can be of any size.</li>
* </ul>
*
* @group string_funcs
* @since 3.0.0
*/
def split(str: Column, regex: String, limit: Int): Column = withExpr {
StringSplit(str.expr, Literal(regex), Literal(limit))
}
/**

View file

@ -46,4 +46,8 @@ FROM (
encode(string(id + 2), 'utf-8') col3,
encode(string(id + 3), 'utf-8') col4
FROM range(10)
)
);
-- split function
SELECT split('aa1cc2ee3', '[1-9]+');
SELECT split('aa1cc2ee3', '[1-9]+', 2);

View file

@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 15
-- Number of queries: 17
-- !query 0
@ -161,3 +161,19 @@ struct<plan:string>
== Physical Plan ==
*Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
+- *Range (0, 10, step=1, splits=2)
-- !query 15
SELECT split('aa1cc2ee3', '[1-9]+')
-- !query 15 schema
struct<split(aa1cc2ee3, [1-9]+, -1):array<string>>
-- !query 15 output
["aa","cc","ee",""]
-- !query 16
SELECT split('aa1cc2ee3', '[1-9]+', 2)
-- !query 16 schema
struct<split(aa1cc2ee3, [1-9]+, 2):array<string>>
-- !query 16 output
["aa","cc2ee3"]

View file

@ -329,16 +329,52 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row(" "))
}
test("string split function") {
val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b")
test("string split function with no limit") {
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")
checkAnswer(
df.select(split($"a", "[1-9]+")),
Row(Seq("aa", "bb", "cc")))
Row(Seq("aa", "bb", "cc", "")))
checkAnswer(
df.selectExpr("split(a, '[1-9]+')"),
Row(Seq("aa", "bb", "cc")))
Row(Seq("aa", "bb", "cc", "")))
}
test("string split function with limit explicitly set to 0") {
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")
checkAnswer(
df.select(split($"a", "[1-9]+", 0)),
Row(Seq("aa", "bb", "cc", "")))
checkAnswer(
df.selectExpr("split(a, '[1-9]+', 0)"),
Row(Seq("aa", "bb", "cc", "")))
}
test("string split function with positive limit") {
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")
checkAnswer(
df.select(split($"a", "[1-9]+", 2)),
Row(Seq("aa", "bb3cc4")))
checkAnswer(
df.selectExpr("split(a, '[1-9]+', 2)"),
Row(Seq("aa", "bb3cc4")))
}
test("string split function with negative limit") {
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")
checkAnswer(
df.select(split($"a", "[1-9]+", -2)),
Row(Seq("aa", "bb", "cc", "")))
checkAnswer(
df.selectExpr("split(a, '[1-9]+', -2)"),
Row(Seq("aa", "bb", "cc", "")))
}
test("string / binary length function") {