[SPARK-8255] [SPARK-8256] [SQL] Add regex_extract/regex_replace
Add expressions `regex_extract` & `regex_replace` Author: Cheng Hao <hao.cheng@intel.com> Closes #7468 from chenghao-intel/regexp and squashes the following commits: e5ea476 [Cheng Hao] minor update for documentation ef96fd6 [Cheng Hao] update the code gen 72cf28f [Cheng Hao] Add more log for compilation error 4e11381 [Cheng Hao] Add regexp_replace / regexp_extract support
This commit is contained in:
parent
d38c5029a2
commit
8c8f0ef59e
|
@ -46,6 +46,8 @@ __all__ = [
|
|||
'monotonicallyIncreasingId',
|
||||
'rand',
|
||||
'randn',
|
||||
'regexp_extract',
|
||||
'regexp_replace',
|
||||
'sha1',
|
||||
'sha2',
|
||||
'sparkPartitionId',
|
||||
|
@ -343,6 +345,34 @@ def levenshtein(left, right):
|
|||
return Column(jc)
|
||||
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(1.5)
|
||||
def regexp_extract(str, pattern, idx):
|
||||
"""Extract a specific(idx) group identified by a java regex, from the specified string column.
|
||||
|
||||
>>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
|
||||
>>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
|
||||
[Row(d=u'100')]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)
|
||||
return Column(jc)
|
||||
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(1.5)
|
||||
def regexp_replace(str, pattern, replacement):
|
||||
"""Replace all substrings of the specified string value that match regexp with rep.
|
||||
|
||||
>>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
|
||||
>>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect()
|
||||
[Row(d=u'##-##')]
|
||||
"""
|
||||
sc = SparkContext._active_spark_context
|
||||
jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement)
|
||||
return Column(jc)
|
||||
|
||||
|
||||
@ignore_unicode_prefix
|
||||
@since(1.5)
|
||||
def md5(col):
|
||||
|
|
|
@ -161,6 +161,8 @@ object FunctionRegistry {
|
|||
expression[Lower]("lower"),
|
||||
expression[Length]("length"),
|
||||
expression[Levenshtein]("levenshtein"),
|
||||
expression[RegExpExtract]("regexp_extract"),
|
||||
expression[RegExpReplace]("regexp_replace"),
|
||||
expression[StringInstr]("instr"),
|
||||
expression[StringLocate]("locate"),
|
||||
expression[StringLPad]("lpad"),
|
||||
|
|
|
@ -297,8 +297,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
|
|||
evaluator.cook(code)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
logError(s"failed to compile:\n $code", e)
|
||||
throw e
|
||||
val msg = s"failed to compile:\n $code"
|
||||
logError(msg, e)
|
||||
throw new Exception(msg, e)
|
||||
}
|
||||
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import java.text.DecimalFormat
|
||||
import java.util.Locale
|
||||
import java.util.regex.Pattern
|
||||
import java.util.regex.{MatchResult, Pattern}
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
|
||||
|
@ -876,6 +876,221 @@ case class Encode(value: Expression, charset: Expression)
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace all substrings of str that match regexp with rep.
|
||||
*
|
||||
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
|
||||
*/
|
||||
case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
|
||||
extends Expression with ImplicitCastInputTypes {
|
||||
|
||||
// last regex in string, we will update the pattern iff regexp value changed.
|
||||
@transient private var lastRegex: UTF8String = _
|
||||
// last regex pattern, we cache it for performance concern
|
||||
@transient private var pattern: Pattern = _
|
||||
// last replacement string, we don't want to convert a UTF8String => java.langString every time.
|
||||
@transient private var lastReplacement: String = _
|
||||
@transient private var lastReplacementInUTF8: UTF8String = _
|
||||
// result buffer write by Matcher
|
||||
@transient private val result: StringBuffer = new StringBuffer
|
||||
|
||||
override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable
|
||||
override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val s = subject.eval(input)
|
||||
if (null != s) {
|
||||
val p = regexp.eval(input)
|
||||
if (null != p) {
|
||||
val r = rep.eval(input)
|
||||
if (null != r) {
|
||||
if (!p.equals(lastRegex)) {
|
||||
// regex value changed
|
||||
lastRegex = p.asInstanceOf[UTF8String]
|
||||
pattern = Pattern.compile(lastRegex.toString)
|
||||
}
|
||||
if (!r.equals(lastReplacementInUTF8)) {
|
||||
// replacement string changed
|
||||
lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
|
||||
lastReplacement = lastReplacementInUTF8.toString
|
||||
}
|
||||
val m = pattern.matcher(s.toString())
|
||||
result.delete(0, result.length())
|
||||
|
||||
while (m.find) {
|
||||
m.appendReplacement(result, lastReplacement)
|
||||
}
|
||||
m.appendTail(result)
|
||||
|
||||
return UTF8String.fromString(result.toString)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
null
|
||||
}
|
||||
|
||||
override def dataType: DataType = StringType
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
|
||||
override def children: Seq[Expression] = subject :: regexp :: rep :: Nil
|
||||
override def prettyName: String = "regexp_replace"
|
||||
|
||||
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
val termLastRegex = ctx.freshName("lastRegex")
|
||||
val termPattern = ctx.freshName("pattern")
|
||||
|
||||
val termLastReplacement = ctx.freshName("lastReplacement")
|
||||
val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8")
|
||||
|
||||
val termResult = ctx.freshName("result")
|
||||
|
||||
val classNameUTF8String = classOf[UTF8String].getCanonicalName
|
||||
val classNamePattern = classOf[Pattern].getCanonicalName
|
||||
val classNameString = classOf[java.lang.String].getCanonicalName
|
||||
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
|
||||
|
||||
ctx.addMutableState(classNameUTF8String,
|
||||
termLastRegex, s"${termLastRegex} = null;")
|
||||
ctx.addMutableState(classNamePattern,
|
||||
termPattern, s"${termPattern} = null;")
|
||||
ctx.addMutableState(classNameString,
|
||||
termLastReplacement, s"${termLastReplacement} = null;")
|
||||
ctx.addMutableState(classNameUTF8String,
|
||||
termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
|
||||
ctx.addMutableState(classNameStringBuffer,
|
||||
termResult, s"${termResult} = new $classNameStringBuffer();")
|
||||
|
||||
val evalSubject = subject.gen(ctx)
|
||||
val evalRegexp = regexp.gen(ctx)
|
||||
val evalRep = rep.gen(ctx)
|
||||
|
||||
s"""
|
||||
${evalSubject.code}
|
||||
boolean ${ev.isNull} = ${evalSubject.isNull};
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
|
||||
if (!${evalSubject.isNull}) {
|
||||
${evalRegexp.code}
|
||||
if (!${evalRegexp.isNull}) {
|
||||
${evalRep.code}
|
||||
if (!${evalRep.isNull}) {
|
||||
if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
|
||||
// regex value changed
|
||||
${termLastRegex} = ${evalRegexp.primitive};
|
||||
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
|
||||
}
|
||||
if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) {
|
||||
// replacement string changed
|
||||
${termLastReplacementInUTF8} = ${evalRep.primitive};
|
||||
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
|
||||
}
|
||||
${termResult}.delete(0, ${termResult}.length());
|
||||
${classOf[java.util.regex.Matcher].getCanonicalName} m =
|
||||
${termPattern}.matcher(${evalSubject.primitive}.toString());
|
||||
|
||||
while (m.find()) {
|
||||
m.appendReplacement(${termResult}, ${termLastReplacement});
|
||||
}
|
||||
m.appendTail(${termResult});
|
||||
${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString());
|
||||
${ev.isNull} = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract a specific(idx) group identified by a Java regex.
|
||||
*
|
||||
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
|
||||
*/
|
||||
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
|
||||
extends Expression with ImplicitCastInputTypes {
|
||||
def this(s: Expression, r: Expression) = this(s, r, Literal(1))
|
||||
|
||||
// last regex in string, we will update the pattern iff regexp value changed.
|
||||
@transient private var lastRegex: UTF8String = _
|
||||
// last regex pattern, we cache it for performance concern
|
||||
@transient private var pattern: Pattern = _
|
||||
|
||||
override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable
|
||||
override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val s = subject.eval(input)
|
||||
if (null != s) {
|
||||
val p = regexp.eval(input)
|
||||
if (null != p) {
|
||||
val r = idx.eval(input)
|
||||
if (null != r) {
|
||||
if (!p.equals(lastRegex)) {
|
||||
// regex value changed
|
||||
lastRegex = p.asInstanceOf[UTF8String]
|
||||
pattern = Pattern.compile(lastRegex.toString)
|
||||
}
|
||||
val m = pattern.matcher(s.toString())
|
||||
if (m.find) {
|
||||
val mr: MatchResult = m.toMatchResult
|
||||
return UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
|
||||
}
|
||||
return UTF8String.EMPTY_UTF8
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
null
|
||||
}
|
||||
|
||||
override def dataType: DataType = StringType
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
|
||||
override def children: Seq[Expression] = subject :: regexp :: idx :: Nil
|
||||
override def prettyName: String = "regexp_extract"
|
||||
|
||||
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
|
||||
val termLastRegex = ctx.freshName("lastRegex")
|
||||
val termPattern = ctx.freshName("pattern")
|
||||
val classNameUTF8String = classOf[UTF8String].getCanonicalName
|
||||
val classNamePattern = classOf[Pattern].getCanonicalName
|
||||
|
||||
ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;")
|
||||
ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
|
||||
|
||||
val evalSubject = subject.gen(ctx)
|
||||
val evalRegexp = regexp.gen(ctx)
|
||||
val evalIdx = idx.gen(ctx)
|
||||
|
||||
s"""
|
||||
${ctx.javaType(dataType)} ${ev.primitive} = null;
|
||||
boolean ${ev.isNull} = true;
|
||||
${evalSubject.code}
|
||||
if (!${evalSubject.isNull}) {
|
||||
${evalRegexp.code}
|
||||
if (!${evalRegexp.isNull}) {
|
||||
${evalIdx.code}
|
||||
if (!${evalIdx.isNull}) {
|
||||
if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
|
||||
// regex value changed
|
||||
${termLastRegex} = ${evalRegexp.primitive};
|
||||
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
|
||||
}
|
||||
${classOf[java.util.regex.Matcher].getCanonicalName} m =
|
||||
${termPattern}.matcher(${evalSubject.primitive}.toString());
|
||||
if (m.find()) {
|
||||
${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult();
|
||||
${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive}));
|
||||
${ev.isNull} = false;
|
||||
} else {
|
||||
${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8;
|
||||
${ev.isNull} = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
|
||||
* and returns the result as a string. If D is 0, the result has no decimal point or
|
||||
|
|
|
@ -79,7 +79,6 @@ trait ExpressionEvalHelper {
|
|||
fail(
|
||||
s"""
|
||||
|Code generation of $expression failed:
|
||||
|${evaluated.code}
|
||||
|$e
|
||||
""".stripMargin)
|
||||
}
|
||||
|
|
|
@ -464,6 +464,41 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(StringSpace(s1), null, row2)
|
||||
}
|
||||
|
||||
test("RegexReplace") {
|
||||
val row1 = create_row("100-200", "(\\d+)", "num")
|
||||
val row2 = create_row("100-200", "(\\d+)", "###")
|
||||
val row3 = create_row("100-200", "(-)", "###")
|
||||
|
||||
val s = 's.string.at(0)
|
||||
val p = 'p.string.at(1)
|
||||
val r = 'r.string.at(2)
|
||||
|
||||
val expr = RegExpReplace(s, p, r)
|
||||
checkEvaluation(expr, "num-num", row1)
|
||||
checkEvaluation(expr, "###-###", row2)
|
||||
checkEvaluation(expr, "100###200", row3)
|
||||
}
|
||||
|
||||
test("RegexExtract") {
|
||||
val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1)
|
||||
val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)
|
||||
val row3 = create_row("100-200", "(\\d+).*", 1)
|
||||
val row4 = create_row("100-200", "([a-z])", 1)
|
||||
|
||||
val s = 's.string.at(0)
|
||||
val p = 'p.string.at(1)
|
||||
val r = 'r.int.at(2)
|
||||
|
||||
val expr = RegExpExtract(s, p, r)
|
||||
checkEvaluation(expr, "100", row1)
|
||||
checkEvaluation(expr, "200", row2)
|
||||
checkEvaluation(expr, "100", row3)
|
||||
checkEvaluation(expr, "", row4) // will not match anything, empty string get
|
||||
|
||||
val expr1 = new RegExpExtract(s, p)
|
||||
checkEvaluation(expr1, "100", row1)
|
||||
}
|
||||
|
||||
test("SPLIT") {
|
||||
val s1 = 'a.string.at(0)
|
||||
val s2 = 'b.string.at(1)
|
||||
|
|
|
@ -1781,6 +1781,27 @@ object functions {
|
|||
StringLocate(lit(substr).expr, str.expr, lit(pos).expr)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Extract a specific(idx) group identified by a java regex, from the specified string column.
|
||||
*
|
||||
* @group string_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = {
|
||||
RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace all substrings of the specified string value that match regexp with rep.
|
||||
*
|
||||
* @group string_funcs
|
||||
* @since 1.5.0
|
||||
*/
|
||||
def regexp_replace(e: Column, pattern: String, replacement: String): Column = {
|
||||
RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the BASE64 encoding of a binary column and returns it as a string column.
|
||||
* This is the reverse of unbase64.
|
||||
|
|
|
@ -56,6 +56,22 @@ class StringFunctionsSuite extends QueryTest {
|
|||
checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
|
||||
}
|
||||
|
||||
test("string regex_replace / regex_extract") {
|
||||
val df = Seq(("100-200", "")).toDF("a", "b")
|
||||
|
||||
checkAnswer(
|
||||
df.select(
|
||||
regexp_replace($"a", "(\\d+)", "num"),
|
||||
regexp_extract($"a", "(\\d+)-(\\d+)", 1)),
|
||||
Row("num-num", "100"))
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr(
|
||||
"regexp_replace(a, '(\\d+)', 'num')",
|
||||
"regexp_extract(a, '(\\d+)-(\\d+)', 2)"),
|
||||
Row("num-num", "200"))
|
||||
}
|
||||
|
||||
test("string ascii function") {
|
||||
val df = Seq(("abc", "")).toDF("a", "b")
|
||||
checkAnswer(
|
||||
|
|
Loading…
Reference in a new issue