[SPARK-25202][SQL] Implements split with limit sql function
## What changes were proposed in this pull request? Adds support for the setting limit in the sql split function ## How was this patch tested? 1. Updated unit tests 2. Tested using Scala spark shell Please review http://spark.apache.org/contributing.html before opening a pull request. Closes #22227 from phegstrom/master. Authored-by: Parker Hegstrom <phegstrom@palantir.com> Signed-off-by: hyukjinkwon <gurwls223@apache.org>
This commit is contained in:
parent
44cf800c83
commit
17781d7530
|
@ -3473,13 +3473,21 @@ setMethod("collect_set",
|
|||
|
||||
#' @details
|
||||
#' \code{split_string}: Splits string on regular expression.
|
||||
#' Equivalent to \code{split} SQL function.
|
||||
#' Equivalent to \code{split} SQL function. Optionally a
|
||||
#' \code{limit} can be specified
|
||||
#'
|
||||
#' @rdname column_string_functions
|
||||
#' @param limit determines the length of the returned array.
|
||||
#' \itemize{
|
||||
#' \item \code{limit > 0}: length of the array will be at most \code{limit}
|
||||
#' \item \code{limit <= 0}: the returned array can have any length
|
||||
#' }
|
||||
#'
|
||||
#' @aliases split_string split_string,Column-method
|
||||
#' @examples
|
||||
#'
|
||||
#' \dontrun{
|
||||
#' head(select(df, split_string(df$Class, "\\d", 2)))
|
||||
#' head(select(df, split_string(df$Sex, "a")))
|
||||
#' head(select(df, split_string(df$Class, "\\d")))
|
||||
#' # This is equivalent to the following SQL expression
|
||||
|
@ -3487,8 +3495,9 @@ setMethod("collect_set",
|
|||
#' @note split_string 2.3.0
|
||||
setMethod("split_string",
|
||||
signature(x = "Column", pattern = "character"),
|
||||
function(x, pattern) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions", "split", x@jc, pattern)
|
||||
function(x, pattern, limit = -1) {
|
||||
jc <- callJStatic("org.apache.spark.sql.functions",
|
||||
"split", x@jc, pattern, as.integer(limit))
|
||||
column(jc)
|
||||
})
|
||||
|
||||
|
|
|
@ -1258,7 +1258,7 @@ setGeneric("sort_array", function(x, asc = TRUE) { standardGeneric("sort_array")
|
|||
|
||||
#' @rdname column_string_functions
|
||||
#' @name NULL
|
||||
setGeneric("split_string", function(x, pattern) { standardGeneric("split_string") })
|
||||
setGeneric("split_string", function(x, pattern, ...) { standardGeneric("split_string") })
|
||||
|
||||
#' @rdname column_string_functions
|
||||
#' @name NULL
|
||||
|
|
|
@ -1819,6 +1819,14 @@ test_that("string operators", {
|
|||
collect(select(df4, split_string(df4$a, "\\\\")))[1, 1],
|
||||
list(list("a.b@c.d 1", "b"))
|
||||
)
|
||||
expect_equal(
|
||||
collect(select(df4, split_string(df4$a, "\\.", 2)))[1, 1],
|
||||
list(list("a", "b@c.d 1\\b"))
|
||||
)
|
||||
expect_equal(
|
||||
collect(select(df4, split_string(df4$a, "b", 0)))[1, 1],
|
||||
list(list("a.", "@c.d 1\\", ""))
|
||||
)
|
||||
|
||||
l5 <- list(list(a = "abc"))
|
||||
df5 <- createDataFrame(l5)
|
||||
|
|
|
@ -958,6 +958,12 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
|
|||
}
|
||||
|
||||
public UTF8String[] split(UTF8String pattern, int limit) {
|
||||
// Java String's split method supports "ignore empty string" behavior when the limit is 0
|
||||
// whereas other languages do not. To avoid this java specific behavior, we fall back to
|
||||
// -1 when the limit is 0.
|
||||
if (limit == 0) {
|
||||
limit = -1;
|
||||
}
|
||||
String[] splits = toString().split(pattern.toString(), limit);
|
||||
UTF8String[] res = new UTF8String[splits.length];
|
||||
for (int i = 0; i < res.length; i++) {
|
||||
|
|
|
@ -393,12 +393,14 @@ public class UTF8StringSuite {
|
|||
|
||||
@Test
|
||||
public void split() {
|
||||
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), -1),
|
||||
new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi")}));
|
||||
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2),
|
||||
new UTF8String[]{fromString("ab"), fromString("def,ghi")}));
|
||||
assertTrue(Arrays.equals(fromString("ab,def,ghi").split(fromString(","), 2),
|
||||
new UTF8String[]{fromString("ab"), fromString("def,ghi")}));
|
||||
UTF8String[] negativeAndZeroLimitCase =
|
||||
new UTF8String[]{fromString("ab"), fromString("def"), fromString("ghi"), fromString("")};
|
||||
assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), 0),
|
||||
negativeAndZeroLimitCase));
|
||||
assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), -1),
|
||||
negativeAndZeroLimitCase));
|
||||
assertTrue(Arrays.equals(fromString("ab,def,ghi,").split(fromString(","), 2),
|
||||
new UTF8String[]{fromString("ab"), fromString("def,ghi,")}));
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
|
@ -1691,18 +1691,32 @@ def repeat(col, n):
|
|||
|
||||
@since(1.5)
|
||||
@ignore_unicode_prefix
|
||||
def split(str, pattern):
|
||||
def split(str, pattern, limit=-1):
|
||||
"""
|
||||
Splits str around pattern (pattern is a regular expression).
|
||||
Splits str around matches of the given pattern.
|
||||
|
||||
.. note:: pattern is a string represent the regular expression.
|
||||
:param str: a string expression to split
|
||||
:param pattern: a string representing a regular expression. The regex string should be
|
||||
a Java regular expression.
|
||||
:param limit: an integer which controls the number of times `pattern` is applied.
|
||||
|
||||
>>> df = spark.createDataFrame([('ab12cd',)], ['s',])
|
||||
>>> df.select(split(df.s, '[0-9]+').alias('s')).collect()
|
||||
[Row(s=[u'ab', u'cd'])]
|
||||
* ``limit > 0``: The resulting array's length will not be more than `limit`, and the
|
||||
resulting array's last entry will contain all input beyond the last
|
||||
matched pattern.
|
||||
* ``limit <= 0``: `pattern` will be applied as many times as possible, and the resulting
|
||||
array can be of any size.
|
||||
|
||||
.. versionchanged:: 3.0
|
||||
`split` now takes an optional `limit` field. If not provided, default limit value is -1.
|
||||
|
||||
>>> df = spark.createDataFrame([('oneAtwoBthreeC',)], ['s',])
|
||||
>>> df.select(split(df.s, '[ABC]', 2).alias('s')).collect()
|
||||
[Row(s=[u'one', u'twoBthreeC'])]
|
||||
>>> df.select(split(df.s, '[ABC]', -1).alias('s')).collect()
|
||||
[Row(s=[u'one', u'two', u'three', u''])]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
return Column(sc._jvm.functions.split(_to_java_column(str), pattern))
|
||||
return Column(sc._jvm.functions.split(_to_java_column(str), pattern, limit))
|
||||
|
||||
|
||||
@ignore_unicode_prefix
|
||||
|
|
|
@ -157,7 +157,7 @@ case class Like(left: Expression, right: Expression) extends StringRegexExpressi
|
|||
arguments = """
|
||||
Arguments:
|
||||
* str - a string expression
|
||||
* regexp - a string expression. The pattern string should be a Java regular expression.
|
||||
* regexp - a string expression. The regex string should be a Java regular expression.
|
||||
|
||||
Since Spark 2.0, string literals (including regex patterns) are unescaped in our SQL
|
||||
parser. For example, to match "\abc", a regular expression for `regexp` can be
|
||||
|
@ -229,33 +229,53 @@ case class RLike(left: Expression, right: Expression) extends StringRegexExpress
|
|||
|
||||
|
||||
/**
|
||||
* Splits str around pat (pattern is a regular expression).
|
||||
* Splits str around matches of the given regex.
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(str, regex) - Splits `str` around occurrences that match `regex`.",
|
||||
usage = "_FUNC_(str, regex, limit) - Splits `str` around occurrences that match `regex`" +
|
||||
" and returns an array with a length of at most `limit`",
|
||||
arguments = """
|
||||
Arguments:
|
||||
* str - a string expression to split.
|
||||
* regex - a string representing a regular expression. The regex string should be a
|
||||
Java regular expression.
|
||||
* limit - an integer expression which controls the number of times the regex is applied.
|
||||
* limit > 0: The resulting array's length will not be more than `limit`,
|
||||
and the resulting array's last entry will contain all input
|
||||
beyond the last matched regex.
|
||||
* limit <= 0: `regex` will be applied as many times as possible, and
|
||||
the resulting array can be of any size.
|
||||
""",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]');
|
||||
["one","two","three",""]
|
||||
> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', -1);
|
||||
["one","two","three",""]
|
||||
> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]', 2);
|
||||
["one","twoBthreeC"]
|
||||
""")
|
||||
case class StringSplit(str: Expression, pattern: Expression)
|
||||
extends BinaryExpression with ImplicitCastInputTypes {
|
||||
case class StringSplit(str: Expression, regex: Expression, limit: Expression)
|
||||
extends TernaryExpression with ImplicitCastInputTypes {
|
||||
|
||||
override def left: Expression = str
|
||||
override def right: Expression = pattern
|
||||
override def dataType: DataType = ArrayType(StringType)
|
||||
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
|
||||
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
|
||||
override def children: Seq[Expression] = str :: regex :: limit :: Nil
|
||||
|
||||
override def nullSafeEval(string: Any, regex: Any): Any = {
|
||||
val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1)
|
||||
def this(exp: Expression, regex: Expression) = this(exp, regex, Literal(-1));
|
||||
|
||||
override def nullSafeEval(string: Any, regex: Any, limit: Any): Any = {
|
||||
val strings = string.asInstanceOf[UTF8String].split(
|
||||
regex.asInstanceOf[UTF8String], limit.asInstanceOf[Int])
|
||||
new GenericArrayData(strings.asInstanceOf[Array[Any]])
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val arrayClass = classOf[GenericArrayData].getName
|
||||
nullSafeCodeGen(ctx, ev, (str, pattern) =>
|
||||
nullSafeCodeGen(ctx, ev, (str, regex, limit) => {
|
||||
// Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
|
||||
s"""${ev.value} = new $arrayClass($str.split($pattern, -1));""")
|
||||
s"""${ev.value} = new $arrayClass($str.split($regex,$limit));""".stripMargin
|
||||
})
|
||||
}
|
||||
|
||||
override def prettyName: String = "split"
|
||||
|
|
|
@ -225,11 +225,18 @@ class RegexpExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
val row3 = create_row("aa2bb3cc", null)
|
||||
|
||||
checkEvaluation(
|
||||
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+")), Seq("aa", "bb", "cc"), row1)
|
||||
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), -1), Seq("aa", "bb", "cc"), row1)
|
||||
checkEvaluation(
|
||||
StringSplit(s1, s2), Seq("aa", "bb", "cc"), row1)
|
||||
checkEvaluation(StringSplit(s1, s2), null, row2)
|
||||
checkEvaluation(StringSplit(s1, s2), null, row3)
|
||||
StringSplit(Literal("aa2bb3cc"), Literal("[1-9]+"), 2), Seq("aa", "bb3cc"), row1)
|
||||
// limit = 0 should behave just like limit = -1
|
||||
checkEvaluation(
|
||||
StringSplit(Literal("aacbbcddc"), Literal("c"), 0), Seq("aa", "bb", "dd", ""), row1)
|
||||
checkEvaluation(
|
||||
StringSplit(Literal("aacbbcddc"), Literal("c"), -1), Seq("aa", "bb", "dd", ""), row1)
|
||||
checkEvaluation(
|
||||
StringSplit(s1, s2, -1), Seq("aa", "bb", "cc"), row1)
|
||||
checkEvaluation(StringSplit(s1, s2, -1), null, row2)
|
||||
checkEvaluation(StringSplit(s1, s2, -1), null, row3)
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -2546,15 +2546,39 @@ object functions {
|
|||
def soundex(e: Column): Column = withExpr { SoundEx(e.expr) }
|
||||
|
||||
/**
|
||||
* Splits str around pattern (pattern is a regular expression).
|
||||
* Splits str around matches of the given regex.
|
||||
*
|
||||
* @note Pattern is a string representation of the regular expression.
|
||||
* @param str a string expression to split
|
||||
* @param regex a string representing a regular expression. The regex string should be
|
||||
* a Java regular expression.
|
||||
*
|
||||
* @group string_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def split(str: Column, pattern: String): Column = withExpr {
|
||||
StringSplit(str.expr, lit(pattern).expr)
|
||||
def split(str: Column, regex: String): Column = withExpr {
|
||||
StringSplit(str.expr, Literal(regex), Literal(-1))
|
||||
}
|
||||
|
||||
/**
|
||||
* Splits str around matches of the given regex.
|
||||
*
|
||||
* @param str a string expression to split
|
||||
* @param regex a string representing a regular expression. The regex string should be
|
||||
* a Java regular expression.
|
||||
* @param limit an integer expression which controls the number of times the regex is applied.
|
||||
* <ul>
|
||||
* <li>limit greater than 0: The resulting array's length will not be more than limit,
|
||||
* and the resulting array's last entry will contain all input beyond the last
|
||||
* matched regex.</li>
|
||||
* <li>limit less than or equal to 0: `regex` will be applied as many times as
|
||||
* possible, and the resulting array can be of any size.</li>
|
||||
* </ul>
|
||||
*
|
||||
* @group string_funcs
|
||||
* @since 3.0.0
|
||||
*/
|
||||
def split(str: Column, regex: String, limit: Int): Column = withExpr {
|
||||
StringSplit(str.expr, Literal(regex), Literal(limit))
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -46,4 +46,8 @@ FROM (
|
|||
encode(string(id + 2), 'utf-8') col3,
|
||||
encode(string(id + 3), 'utf-8') col4
|
||||
FROM range(10)
|
||||
)
|
||||
);
|
||||
|
||||
-- split function
|
||||
SELECT split('aa1cc2ee3', '[1-9]+');
|
||||
SELECT split('aa1cc2ee3', '[1-9]+', 2);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
-- Automatically generated by SQLQueryTestSuite
|
||||
-- Number of queries: 15
|
||||
-- Number of queries: 17
|
||||
|
||||
|
||||
-- !query 0
|
||||
|
@ -161,3 +161,19 @@ struct<plan:string>
|
|||
== Physical Plan ==
|
||||
*Project [concat(cast(id#xL as string), cast(encode(cast((id#xL + 2) as string), utf-8) as string), cast(encode(cast((id#xL + 3) as string), utf-8) as string)) AS col#x]
|
||||
+- *Range (0, 10, step=1, splits=2)
|
||||
|
||||
|
||||
-- !query 15
|
||||
SELECT split('aa1cc2ee3', '[1-9]+')
|
||||
-- !query 15 schema
|
||||
struct<split(aa1cc2ee3, [1-9]+, -1):array<string>>
|
||||
-- !query 15 output
|
||||
["aa","cc","ee",""]
|
||||
|
||||
|
||||
-- !query 16
|
||||
SELECT split('aa1cc2ee3', '[1-9]+', 2)
|
||||
-- !query 16 schema
|
||||
struct<split(aa1cc2ee3, [1-9]+, 2):array<string>>
|
||||
-- !query 16 output
|
||||
["aa","cc2ee3"]
|
||||
|
|
|
@ -329,16 +329,52 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
Row(" "))
|
||||
}
|
||||
|
||||
test("string split function") {
|
||||
val df = Seq(("aa2bb3cc", "[1-9]+")).toDF("a", "b")
|
||||
test("string split function with no limit") {
|
||||
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")
|
||||
|
||||
checkAnswer(
|
||||
df.select(split($"a", "[1-9]+")),
|
||||
Row(Seq("aa", "bb", "cc")))
|
||||
Row(Seq("aa", "bb", "cc", "")))
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr("split(a, '[1-9]+')"),
|
||||
Row(Seq("aa", "bb", "cc")))
|
||||
Row(Seq("aa", "bb", "cc", "")))
|
||||
}
|
||||
|
||||
test("string split function with limit explicitly set to 0") {
|
||||
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")
|
||||
|
||||
checkAnswer(
|
||||
df.select(split($"a", "[1-9]+", 0)),
|
||||
Row(Seq("aa", "bb", "cc", "")))
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr("split(a, '[1-9]+', 0)"),
|
||||
Row(Seq("aa", "bb", "cc", "")))
|
||||
}
|
||||
|
||||
test("string split function with positive limit") {
|
||||
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")
|
||||
|
||||
checkAnswer(
|
||||
df.select(split($"a", "[1-9]+", 2)),
|
||||
Row(Seq("aa", "bb3cc4")))
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr("split(a, '[1-9]+', 2)"),
|
||||
Row(Seq("aa", "bb3cc4")))
|
||||
}
|
||||
|
||||
test("string split function with negative limit") {
|
||||
val df = Seq(("aa2bb3cc4", "[1-9]+")).toDF("a", "b")
|
||||
|
||||
checkAnswer(
|
||||
df.select(split($"a", "[1-9]+", -2)),
|
||||
Row(Seq("aa", "bb", "cc", "")))
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr("split(a, '[1-9]+', -2)"),
|
||||
Row(Seq("aa", "bb", "cc", "")))
|
||||
}
|
||||
|
||||
test("string / binary length function") {
|
||||
|
|
Loading…
Reference in a new issue