[SPARK-8255] [SPARK-8256] [SQL] Add regex_extract/regex_replace

Add expressions `regex_extract` & `regex_replace`

Author: Cheng Hao <hao.cheng@intel.com>

Closes #7468 from chenghao-intel/regexp and squashes the following commits:

e5ea476 [Cheng Hao] minor update for documentation
ef96fd6 [Cheng Hao] update the code gen
72cf28f [Cheng Hao] Add more log for compilation error
4e11381 [Cheng Hao] Add regexp_replace / regexp_extract support
This commit is contained in:
Cheng Hao 2015-07-21 00:48:07 -07:00 committed by Davies Liu
parent d38c5029a2
commit 8c8f0ef59e
8 changed files with 323 additions and 4 deletions

View file

@ -46,6 +46,8 @@ __all__ = [
'monotonicallyIncreasingId',
'rand',
'randn',
'regexp_extract',
'regexp_replace',
'sha1',
'sha2',
'sparkPartitionId',
@ -343,6 +345,34 @@ def levenshtein(left, right):
return Column(jc)
@ignore_unicode_prefix
@since(1.5)
def regexp_extract(str, pattern, idx):
"""Extract a specific(idx) group identified by a java regex, from the specified string column.
>>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
>>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
[Row(d=u'100')]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)
return Column(jc)
@ignore_unicode_prefix
@since(1.5)
def regexp_replace(str, pattern, replacement):
"""Replace all substrings of the specified string value that match regexp with rep.
>>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
>>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect()
[Row(d=u'##-##')]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement)
return Column(jc)
@ignore_unicode_prefix
@since(1.5)
def md5(col):

View file

@ -161,6 +161,8 @@ object FunctionRegistry {
expression[Lower]("lower"),
expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
expression[RegExpExtract]("regexp_extract"),
expression[RegExpReplace]("regexp_replace"),
expression[StringInstr]("instr"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),

View file

@ -297,8 +297,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
evaluator.cook(code)
} catch {
case e: Exception =>
logError(s"failed to compile:\n $code", e)
throw e
val msg = s"failed to compile:\n $code"
logError(msg, e)
throw new Exception(msg, e)
}
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]
}

View file

@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.text.DecimalFormat
import java.util.Locale
import java.util.regex.Pattern
import java.util.regex.{MatchResult, Pattern}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
@ -876,6 +876,221 @@ case class Encode(value: Expression, charset: Expression)
}
}
/**
* Replace all substrings of str that match regexp with rep.
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
extends Expression with ImplicitCastInputTypes {
// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
// last regex pattern, we cache it for performance concern
@transient private var pattern: Pattern = _
// last replacement string, we don't want to convert a UTF8String => java.langString every time.
@transient private var lastReplacement: String = _
@transient private var lastReplacementInUTF8: UTF8String = _
// result buffer write by Matcher
@transient private val result: StringBuffer = new StringBuffer
override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable
override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable
override def eval(input: InternalRow): Any = {
val s = subject.eval(input)
if (null != s) {
val p = regexp.eval(input)
if (null != p) {
val r = rep.eval(input)
if (null != r) {
if (!p.equals(lastRegex)) {
// regex value changed
lastRegex = p.asInstanceOf[UTF8String]
pattern = Pattern.compile(lastRegex.toString)
}
if (!r.equals(lastReplacementInUTF8)) {
// replacement string changed
lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
lastReplacement = lastReplacementInUTF8.toString
}
val m = pattern.matcher(s.toString())
result.delete(0, result.length())
while (m.find) {
m.appendReplacement(result, lastReplacement)
}
m.appendTail(result)
return UTF8String.fromString(result.toString)
}
}
}
null
}
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
override def children: Seq[Expression] = subject :: regexp :: rep :: Nil
override def prettyName: String = "regexp_replace"
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val termLastRegex = ctx.freshName("lastRegex")
val termPattern = ctx.freshName("pattern")
val termLastReplacement = ctx.freshName("lastReplacement")
val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8")
val termResult = ctx.freshName("result")
val classNameUTF8String = classOf[UTF8String].getCanonicalName
val classNamePattern = classOf[Pattern].getCanonicalName
val classNameString = classOf[java.lang.String].getCanonicalName
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
ctx.addMutableState(classNameUTF8String,
termLastRegex, s"${termLastRegex} = null;")
ctx.addMutableState(classNamePattern,
termPattern, s"${termPattern} = null;")
ctx.addMutableState(classNameString,
termLastReplacement, s"${termLastReplacement} = null;")
ctx.addMutableState(classNameUTF8String,
termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
ctx.addMutableState(classNameStringBuffer,
termResult, s"${termResult} = new $classNameStringBuffer();")
val evalSubject = subject.gen(ctx)
val evalRegexp = regexp.gen(ctx)
val evalRep = rep.gen(ctx)
s"""
${evalSubject.code}
boolean ${ev.isNull} = ${evalSubject.isNull};
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
if (!${evalSubject.isNull}) {
${evalRegexp.code}
if (!${evalRegexp.isNull}) {
${evalRep.code}
if (!${evalRep.isNull}) {
if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
// regex value changed
${termLastRegex} = ${evalRegexp.primitive};
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) {
// replacement string changed
${termLastReplacementInUTF8} = ${evalRep.primitive};
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
}
${termResult}.delete(0, ${termResult}.length());
${classOf[java.util.regex.Matcher].getCanonicalName} m =
${termPattern}.matcher(${evalSubject.primitive}.toString());
while (m.find()) {
m.appendReplacement(${termResult}, ${termLastReplacement});
}
m.appendTail(${termResult});
${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString());
${ev.isNull} = false;
}
}
}
"""
}
}
/**
* Extract a specific(idx) group identified by a Java regex.
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
extends Expression with ImplicitCastInputTypes {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))
// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
// last regex pattern, we cache it for performance concern
@transient private var pattern: Pattern = _
override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable
override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable
override def eval(input: InternalRow): Any = {
val s = subject.eval(input)
if (null != s) {
val p = regexp.eval(input)
if (null != p) {
val r = idx.eval(input)
if (null != r) {
if (!p.equals(lastRegex)) {
// regex value changed
lastRegex = p.asInstanceOf[UTF8String]
pattern = Pattern.compile(lastRegex.toString)
}
val m = pattern.matcher(s.toString())
if (m.find) {
val mr: MatchResult = m.toMatchResult
return UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
}
return UTF8String.EMPTY_UTF8
}
}
}
null
}
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def children: Seq[Expression] = subject :: regexp :: idx :: Nil
override def prettyName: String = "regexp_extract"
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val termLastRegex = ctx.freshName("lastRegex")
val termPattern = ctx.freshName("pattern")
val classNameUTF8String = classOf[UTF8String].getCanonicalName
val classNamePattern = classOf[Pattern].getCanonicalName
ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;")
ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
val evalSubject = subject.gen(ctx)
val evalRegexp = regexp.gen(ctx)
val evalIdx = idx.gen(ctx)
s"""
${ctx.javaType(dataType)} ${ev.primitive} = null;
boolean ${ev.isNull} = true;
${evalSubject.code}
if (!${evalSubject.isNull}) {
${evalRegexp.code}
if (!${evalRegexp.isNull}) {
${evalIdx.code}
if (!${evalIdx.isNull}) {
if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
// regex value changed
${termLastRegex} = ${evalRegexp.primitive};
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
${classOf[java.util.regex.Matcher].getCanonicalName} m =
${termPattern}.matcher(${evalSubject.primitive}.toString());
if (m.find()) {
${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult();
${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive}));
${ev.isNull} = false;
} else {
${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8;
${ev.isNull} = false;
}
}
}
}
"""
}
}
/**
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
* and returns the result as a string. If D is 0, the result has no decimal point or

View file

@ -79,7 +79,6 @@ trait ExpressionEvalHelper {
fail(
s"""
|Code generation of $expression failed:
|${evaluated.code}
|$e
""".stripMargin)
}

View file

@ -464,6 +464,41 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringSpace(s1), null, row2)
}
test("RegexReplace") {
val row1 = create_row("100-200", "(\\d+)", "num")
val row2 = create_row("100-200", "(\\d+)", "###")
val row3 = create_row("100-200", "(-)", "###")
val s = 's.string.at(0)
val p = 'p.string.at(1)
val r = 'r.string.at(2)
val expr = RegExpReplace(s, p, r)
checkEvaluation(expr, "num-num", row1)
checkEvaluation(expr, "###-###", row2)
checkEvaluation(expr, "100###200", row3)
}
test("RegexExtract") {
val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1)
val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)
val row3 = create_row("100-200", "(\\d+).*", 1)
val row4 = create_row("100-200", "([a-z])", 1)
val s = 's.string.at(0)
val p = 'p.string.at(1)
val r = 'r.int.at(2)
val expr = RegExpExtract(s, p, r)
checkEvaluation(expr, "100", row1)
checkEvaluation(expr, "200", row2)
checkEvaluation(expr, "100", row3)
checkEvaluation(expr, "", row4) // will not match anything, empty string get
val expr1 = new RegExpExtract(s, p)
checkEvaluation(expr1, "100", row1)
}
test("SPLIT") {
val s1 = 'a.string.at(0)
val s2 = 'b.string.at(1)

View file

@ -1781,6 +1781,27 @@ object functions {
StringLocate(lit(substr).expr, str.expr, lit(pos).expr)
}
/**
* Extract a specific(idx) group identified by a java regex, from the specified string column.
*
* @group string_funcs
* @since 1.5.0
*/
def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = {
RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr)
}
/**
* Replace all substrings of the specified string value that match regexp with rep.
*
* @group string_funcs
* @since 1.5.0
*/
def regexp_replace(e: Column, pattern: String, replacement: String): Column = {
RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr)
}
/**
* Computes the BASE64 encoding of a binary column and returns it as a string column.
* This is the reverse of unbase64.

View file

@ -56,6 +56,22 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
}
test("string regex_replace / regex_extract") {
val df = Seq(("100-200", "")).toDF("a", "b")
checkAnswer(
df.select(
regexp_replace($"a", "(\\d+)", "num"),
regexp_extract($"a", "(\\d+)-(\\d+)", 1)),
Row("num-num", "100"))
checkAnswer(
df.selectExpr(
"regexp_replace(a, '(\\d+)', 'num')",
"regexp_extract(a, '(\\d+)-(\\d+)', 2)"),
Row("num-num", "200"))
}
test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(