[SPARK-30625][SQL] Support escape
as third parameter of the like
function
### What changes were proposed in this pull request?
In the PR, I propose to transform the `Like` expression to `TernaryExpression`, and add third parameter `escape`. So, the `like` function will have feature parity with `LIKE ... ESCAPE` syntax supported by 187f3c1773
.
### Why are the changes needed?
The `like` functions can be called with 2 or 3 parameters, and functionally equivalent to `LIKE` and `LIKE ... ESCAPE` SQL expressions.
### Does this PR introduce any user-facing change?
Yes, before `like` fails with the exception:
```sql
spark-sql> SELECT like('_Apache Spark_', '__%Spark__', '_');
Error in query: Invalid number of arguments for function like. Expected: 2; Found: 3; line 1 pos 7
```
After:
```sql
spark-sql> SELECT like('_Apache Spark_', '__%Spark__', '_');
true
```
### How was this patch tested?
- Add new example for the `like` function which is checked by `SQLQuerySuite`
- Run `RegexpExpressionsSuite` and `ExpressionParserSuite`.
Closes #27355 from MaxGekk/like-3-args.
Authored-by: Maxim Gekk <max.gekk@gmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
c5c580ba0d
commit
8aebc80e0e
|
@ -99,7 +99,7 @@ package object dsl {
|
|||
}
|
||||
|
||||
def like(other: Expression, escapeChar: Char = '\\'): Expression =
|
||||
Like(expr, other, escapeChar)
|
||||
Like(expr, other, Literal(escapeChar.toString))
|
||||
def rlike(other: Expression): Expression = RLike(expr, other)
|
||||
def contains(other: Expression): Expression = Contains(expr, other)
|
||||
def startsWith(other: Expression): Expression = StartsWith(expr, other)
|
||||
|
|
|
@ -22,6 +22,7 @@ import java.util.regex.{MatchResult, Pattern}
|
|||
|
||||
import org.apache.commons.text.StringEscapeUtils
|
||||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
|
||||
|
@ -29,18 +30,20 @@ import org.apache.spark.sql.types._
|
|||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
||||
abstract class StringRegexExpression extends BinaryExpression
|
||||
trait StringRegexExpression extends Expression
|
||||
with ImplicitCastInputTypes with NullIntolerant {
|
||||
|
||||
def str: Expression
|
||||
def pattern: Expression
|
||||
|
||||
def escape(v: String): String
|
||||
def matches(regex: Pattern, str: String): Boolean
|
||||
|
||||
override def dataType: DataType = BooleanType
|
||||
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
|
||||
|
||||
// try cache the pattern for Literal
|
||||
private lazy val cache: Pattern = right match {
|
||||
case x @ Literal(value: String, StringType) => compile(value)
|
||||
private lazy val cache: Pattern = pattern match {
|
||||
case Literal(value: String, StringType) => compile(value)
|
||||
case _ => null
|
||||
}
|
||||
|
||||
|
@ -51,10 +54,9 @@ abstract class StringRegexExpression extends BinaryExpression
|
|||
Pattern.compile(escape(str))
|
||||
}
|
||||
|
||||
protected def pattern(str: String) = if (cache == null) compile(str) else cache
|
||||
|
||||
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
|
||||
val regex = pattern(input2.asInstanceOf[UTF8String].toString)
|
||||
def nullSafeMatch(input1: Any, input2: Any): Any = {
|
||||
val s = input2.asInstanceOf[UTF8String].toString
|
||||
val regex = if (cache == null) compile(s) else cache
|
||||
if(regex == null) {
|
||||
null
|
||||
} else {
|
||||
|
@ -62,7 +64,7 @@ abstract class StringRegexExpression extends BinaryExpression
|
|||
}
|
||||
}
|
||||
|
||||
override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}"
|
||||
override def sql: String = s"${str.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${pattern.sql}"
|
||||
}
|
||||
|
||||
// scalastyle:off line.contains.tab
|
||||
|
@ -107,46 +109,65 @@ abstract class StringRegexExpression extends BinaryExpression
|
|||
true
|
||||
> SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/';
|
||||
true
|
||||
> SELECT _FUNC_('_Apache Spark_', '__%Spark__', '_');
|
||||
true
|
||||
""",
|
||||
note = """
|
||||
Use RLIKE to match with standard regular expressions.
|
||||
""",
|
||||
since = "1.0.0")
|
||||
// scalastyle:on line.contains.tab
|
||||
case class Like(left: Expression, right: Expression, escapeChar: Char)
|
||||
extends StringRegexExpression {
|
||||
case class Like(str: Expression, pattern: Expression, escape: Expression)
|
||||
extends TernaryExpression with StringRegexExpression {
|
||||
|
||||
def this(left: Expression, right: Expression) = this(left, right, '\\')
|
||||
def this(str: Expression, pattern: Expression) = this(str, pattern, Literal("\\"))
|
||||
|
||||
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
|
||||
override def children: Seq[Expression] = Seq(str, pattern, escape)
|
||||
|
||||
private lazy val escapeChar: Char = if (escape.foldable) {
|
||||
escape.eval() match {
|
||||
case s: UTF8String if s != null && s.numChars() == 1 => s.toString.charAt(0)
|
||||
case s => throw new AnalysisException(
|
||||
s"The 'escape' parameter must be a string literal of one char but it is $s.")
|
||||
}
|
||||
} else {
|
||||
throw new AnalysisException("The 'escape' parameter must be a string literal.")
|
||||
}
|
||||
|
||||
override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar)
|
||||
|
||||
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
|
||||
|
||||
override def toString: String = escapeChar match {
|
||||
case '\\' => s"$left LIKE $right"
|
||||
case c => s"$left LIKE $right ESCAPE '$c'"
|
||||
case '\\' => s"$str LIKE $pattern"
|
||||
case c => s"$str LIKE $pattern ESCAPE '$c'"
|
||||
}
|
||||
|
||||
protected override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
|
||||
nullSafeMatch(input1, input2)
|
||||
}
|
||||
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val patternClass = classOf[Pattern].getName
|
||||
val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex"
|
||||
|
||||
if (right.foldable) {
|
||||
val rVal = right.eval()
|
||||
if (rVal != null) {
|
||||
if (pattern.foldable) {
|
||||
val patternVal = pattern.eval()
|
||||
if (patternVal != null) {
|
||||
val regexStr =
|
||||
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
|
||||
val pattern = ctx.addMutableState(patternClass, "patternLike",
|
||||
StringEscapeUtils.escapeJava(escape(patternVal.asInstanceOf[UTF8String].toString()))
|
||||
val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern",
|
||||
v => s"""$v = $patternClass.compile("$regexStr");""")
|
||||
|
||||
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
|
||||
val eval = left.genCode(ctx)
|
||||
val eval = str.genCode(ctx)
|
||||
ev.copy(code = code"""
|
||||
${eval.code}
|
||||
boolean ${ev.isNull} = ${eval.isNull};
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
if (!${ev.isNull}) {
|
||||
${ev.value} = $pattern.matcher(${eval.value}.toString()).matches();
|
||||
${ev.value} = $compiledPattern.matcher(${eval.value}.toString()).matches();
|
||||
}
|
||||
""")
|
||||
} else {
|
||||
|
@ -164,18 +185,18 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
|
|||
} else {
|
||||
escapeChar
|
||||
}
|
||||
val rightStr = ctx.freshName("rightStr")
|
||||
val pattern = ctx.addMutableState(patternClass, "pattern")
|
||||
val lastRightStr = ctx.addMutableState(classOf[String].getName, "lastRightStr")
|
||||
val patternStr = ctx.freshName("patternStr")
|
||||
val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern")
|
||||
val lastPatternStr = ctx.addMutableState(classOf[String].getName, "lastPatternStr")
|
||||
|
||||
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
|
||||
nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => {
|
||||
s"""
|
||||
String $rightStr = $eval2.toString();
|
||||
if (!$rightStr.equals($lastRightStr)) {
|
||||
$pattern = $patternClass.compile($escapeFunc($rightStr, '$newEscapeChar'));
|
||||
$lastRightStr = $rightStr;
|
||||
String $patternStr = $eval2.toString();
|
||||
if (!$patternStr.equals($lastPatternStr)) {
|
||||
$compiledPattern = $patternClass.compile($escapeFunc($patternStr, '$newEscapeChar'));
|
||||
$lastPatternStr = $patternStr;
|
||||
}
|
||||
${ev.value} = $pattern.matcher($eval1.toString()).matches();
|
||||
${ev.value} = $compiledPattern.matcher($eval1.toString()).matches();
|
||||
"""
|
||||
})
|
||||
}
|
||||
|
@ -214,12 +235,20 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
|
|||
""",
|
||||
since = "1.0.0")
|
||||
// scalastyle:on line.contains.tab
|
||||
case class RLike(left: Expression, right: Expression) extends StringRegexExpression {
|
||||
case class RLike(left: Expression, right: Expression)
|
||||
extends BinaryExpression with StringRegexExpression {
|
||||
|
||||
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
|
||||
|
||||
override def str: Expression = left
|
||||
override def pattern: Expression = right
|
||||
|
||||
override def escape(v: String): String = v
|
||||
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
|
||||
override def toString: String = s"$left RLIKE $right"
|
||||
|
||||
protected override def nullSafeEval(input1: Any, input2: Any): Any = nullSafeMatch(input1, input2)
|
||||
|
||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
val patternClass = classOf[Pattern].getName
|
||||
|
||||
|
|
|
@ -1392,9 +1392,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
|
|||
throw new ParseException("Invalid escape string." +
|
||||
"Escape string must contains only one character.", ctx)
|
||||
}
|
||||
str.charAt(0)
|
||||
str
|
||||
}.getOrElse('\\')
|
||||
invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar))
|
||||
invertIfNotDefined(Like(e, expression(ctx.pattern), Literal(escapeChar)))
|
||||
case SqlBaseParser.RLIKE =>
|
||||
invertIfNotDefined(RLike(e, expression(ctx.pattern)))
|
||||
case SqlBaseParser.NULL if ctx.NOT != null =>
|
||||
|
|
|
@ -3562,6 +3562,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
|
|||
checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
|
||||
Seq(Row(1)))
|
||||
}
|
||||
|
||||
test("the like function with the escape parameter") {
|
||||
val df = Seq(("abc", "a_c", "!")).toDF("str", "pattern", "escape")
|
||||
checkAnswer(df.selectExpr("like(str, pattern, '@')"), Row(true))
|
||||
|
||||
val longEscapeError = intercept[AnalysisException] {
|
||||
df.selectExpr("like(str, pattern, '@%')").collect()
|
||||
}.getMessage
|
||||
assert(longEscapeError.contains("The 'escape' parameter must be a string literal of one char"))
|
||||
|
||||
val nonFoldableError = intercept[AnalysisException] {
|
||||
df.selectExpr("like(str, pattern, escape)").collect()
|
||||
}.getMessage
|
||||
assert(nonFoldableError.contains("The 'escape' parameter must be a string literal"))
|
||||
}
|
||||
}
|
||||
|
||||
object DataFrameFunctionsSuite {
|
||||
|
|
Loading…
Reference in a new issue