[SPARK-16409][SQL] regexp_extract with optional groups causes NPE

## What changes were proposed in this pull request?

regexp_extract actually returns null when it shouldn't when a regex matches but the requested optional group did not. This makes it return an empty string, as apparently designed.

## How was this patch tested?

Additional unit test

Author: Sean Owen <sowen@cloudera.com>

Closes #14504 from srowen/SPARK-16409.
This commit is contained in:
Sean Owen 2016-08-07 12:20:07 +01:00
parent bdfab9f942
commit 8d87252087
3 changed files with 22 additions and 2 deletions

View file

@ -1445,6 +1445,9 @@ def regexp_extract(str, pattern, idx):
>>> df = spark.createDataFrame([('100-200',)], ['str'])
>>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
[Row(d=u'100')]
>>> df = spark.createDataFrame([('aaaac',)], ['str'])
>>> df.select(regexp_extract('str', '(a+)(b)?(c)', 2).alias('d')).collect()
[Row(d=u'')]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)

View file

@ -329,7 +329,12 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
val m = pattern.matcher(s.toString)
if (m.find) {
val mr: MatchResult = m.toMatchResult
UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
val group = mr.group(r.asInstanceOf[Int])
if (group == null) { // Pattern matched, but not optional group
UTF8String.EMPTY_UTF8
} else {
UTF8String.fromString(group)
}
} else {
UTF8String.EMPTY_UTF8
}
@ -367,7 +372,11 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
${termPattern}.matcher($subject.toString());
if (${matcher}.find()) {
java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult();
${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
if (${matchResult}.group($idx) == null) {
${ev.value} = UTF8String.EMPTY_UTF8;
} else {
${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
}
$setEvNotNull
} else {
${ev.value} = UTF8String.EMPTY_UTF8;

View file

@ -94,6 +94,14 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row("300", "100") :: Row("400", "100") :: Row("400-400", "100") :: Nil)
}
test("non-matching optional group") {
val df = Seq("aaaac").toDF("s")
checkAnswer(
df.select(regexp_extract($"s", "(a+)(b)?(c)", 2)),
Row("")
)
}
test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(