[SPARK-30625][SQL] Support escape as third parameter of the like function

### What changes were proposed in this pull request?
In the PR, I propose to transform the `Like` expression to `TernaryExpression`, and add third parameter `escape`. So, the `like` function will have feature parity with `LIKE ... ESCAPE` syntax supported by 187f3c1773.

### Why are the changes needed?
The `like` functions can be called with 2 or 3 parameters, and functionally equivalent to `LIKE` and `LIKE ... ESCAPE` SQL expressions.

### Does this PR introduce any user-facing change?
Yes, before `like` fails with the exception:
```sql
spark-sql> SELECT like('_Apache Spark_', '__%Spark__', '_');
Error in query: Invalid number of arguments for function like. Expected: 2; Found: 3; line 1 pos 7
```
After:
```sql
spark-sql> SELECT like('_Apache Spark_', '__%Spark__', '_');
true
```

### How was this patch tested?
- Add new example for the `like` function which is checked by `SQLQuerySuite`
- Run `RegexpExpressionsSuite` and `ExpressionParserSuite`.

Closes #27355 from MaxGekk/like-3-args.

Authored-by: Maxim Gekk <max.gekk@gmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Maxim Gekk 2020-01-27 11:19:32 -08:00 committed by Dongjoon Hyun
parent c5c580ba0d
commit 8aebc80e0e
4 changed files with 78 additions and 34 deletions

View file

@ -99,7 +99,7 @@ package object dsl {
} }
def like(other: Expression, escapeChar: Char = '\\'): Expression = def like(other: Expression, escapeChar: Char = '\\'): Expression =
Like(expr, other, escapeChar) Like(expr, other, Literal(escapeChar.toString))
def rlike(other: Expression): Expression = RLike(expr, other) def rlike(other: Expression): Expression = RLike(expr, other)
def contains(other: Expression): Expression = Contains(expr, other) def contains(other: Expression): Expression = Contains(expr, other)
def startsWith(other: Expression): Expression = StartsWith(expr, other) def startsWith(other: Expression): Expression = StartsWith(expr, other)

View file

@ -22,6 +22,7 @@ import java.util.regex.{MatchResult, Pattern}
import org.apache.commons.text.StringEscapeUtils import org.apache.commons.text.StringEscapeUtils
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
@ -29,18 +30,20 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.unsafe.types.UTF8String
abstract class StringRegexExpression extends BinaryExpression trait StringRegexExpression extends Expression
with ImplicitCastInputTypes with NullIntolerant { with ImplicitCastInputTypes with NullIntolerant {
def str: Expression
def pattern: Expression
def escape(v: String): String def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean def matches(regex: Pattern, str: String): Boolean
override def dataType: DataType = BooleanType override def dataType: DataType = BooleanType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
// try cache the pattern for Literal // try cache the pattern for Literal
private lazy val cache: Pattern = right match { private lazy val cache: Pattern = pattern match {
case x @ Literal(value: String, StringType) => compile(value) case Literal(value: String, StringType) => compile(value)
case _ => null case _ => null
} }
@ -51,10 +54,9 @@ abstract class StringRegexExpression extends BinaryExpression
Pattern.compile(escape(str)) Pattern.compile(escape(str))
} }
protected def pattern(str: String) = if (cache == null) compile(str) else cache def nullSafeMatch(input1: Any, input2: Any): Any = {
val s = input2.asInstanceOf[UTF8String].toString
protected override def nullSafeEval(input1: Any, input2: Any): Any = { val regex = if (cache == null) compile(s) else cache
val regex = pattern(input2.asInstanceOf[UTF8String].toString)
if(regex == null) { if(regex == null) {
null null
} else { } else {
@ -62,7 +64,7 @@ abstract class StringRegexExpression extends BinaryExpression
} }
} }
override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}" override def sql: String = s"${str.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${pattern.sql}"
} }
// scalastyle:off line.contains.tab // scalastyle:off line.contains.tab
@ -107,46 +109,65 @@ abstract class StringRegexExpression extends BinaryExpression
true true
> SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/'; > SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/';
true true
> SELECT _FUNC_('_Apache Spark_', '__%Spark__', '_');
true
""", """,
note = """ note = """
Use RLIKE to match with standard regular expressions. Use RLIKE to match with standard regular expressions.
""", """,
since = "1.0.0") since = "1.0.0")
// scalastyle:on line.contains.tab // scalastyle:on line.contains.tab
case class Like(left: Expression, right: Expression, escapeChar: Char) case class Like(str: Expression, pattern: Expression, escape: Expression)
extends StringRegexExpression { extends TernaryExpression with StringRegexExpression {
def this(left: Expression, right: Expression) = this(left, right, '\\') def this(str: Expression, pattern: Expression) = this(str, pattern, Literal("\\"))
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
override def children: Seq[Expression] = Seq(str, pattern, escape)
private lazy val escapeChar: Char = if (escape.foldable) {
escape.eval() match {
case s: UTF8String if s != null && s.numChars() == 1 => s.toString.charAt(0)
case s => throw new AnalysisException(
s"The 'escape' parameter must be a string literal of one char but it is $s.")
}
} else {
throw new AnalysisException("The 'escape' parameter must be a string literal.")
}
override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar) override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
override def toString: String = escapeChar match { override def toString: String = escapeChar match {
case '\\' => s"$left LIKE $right" case '\\' => s"$str LIKE $pattern"
case c => s"$left LIKE $right ESCAPE '$c'" case c => s"$str LIKE $pattern ESCAPE '$c'"
}
protected override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
nullSafeMatch(input1, input2)
} }
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName val patternClass = classOf[Pattern].getName
val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex"
if (right.foldable) { if (pattern.foldable) {
val rVal = right.eval() val patternVal = pattern.eval()
if (rVal != null) { if (patternVal != null) {
val regexStr = val regexStr =
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) StringEscapeUtils.escapeJava(escape(patternVal.asInstanceOf[UTF8String].toString()))
val pattern = ctx.addMutableState(patternClass, "patternLike", val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern",
v => s"""$v = $patternClass.compile("$regexStr");""") v => s"""$v = $patternClass.compile("$regexStr");""")
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx) val eval = str.genCode(ctx)
ev.copy(code = code""" ev.copy(code = code"""
${eval.code} ${eval.code}
boolean ${ev.isNull} = ${eval.isNull}; boolean ${ev.isNull} = ${eval.isNull};
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
if (!${ev.isNull}) { if (!${ev.isNull}) {
${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); ${ev.value} = $compiledPattern.matcher(${eval.value}.toString()).matches();
} }
""") """)
} else { } else {
@ -164,18 +185,18 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
} else { } else {
escapeChar escapeChar
} }
val rightStr = ctx.freshName("rightStr") val patternStr = ctx.freshName("patternStr")
val pattern = ctx.addMutableState(patternClass, "pattern") val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern")
val lastRightStr = ctx.addMutableState(classOf[String].getName, "lastRightStr") val lastPatternStr = ctx.addMutableState(classOf[String].getName, "lastPatternStr")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => { nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => {
s""" s"""
String $rightStr = $eval2.toString(); String $patternStr = $eval2.toString();
if (!$rightStr.equals($lastRightStr)) { if (!$patternStr.equals($lastPatternStr)) {
$pattern = $patternClass.compile($escapeFunc($rightStr, '$newEscapeChar')); $compiledPattern = $patternClass.compile($escapeFunc($patternStr, '$newEscapeChar'));
$lastRightStr = $rightStr; $lastPatternStr = $patternStr;
} }
${ev.value} = $pattern.matcher($eval1.toString()).matches(); ${ev.value} = $compiledPattern.matcher($eval1.toString()).matches();
""" """
}) })
} }
@ -214,12 +235,20 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
""", """,
since = "1.0.0") since = "1.0.0")
// scalastyle:on line.contains.tab // scalastyle:on line.contains.tab
case class RLike(left: Expression, right: Expression) extends StringRegexExpression { case class RLike(left: Expression, right: Expression)
extends BinaryExpression with StringRegexExpression {
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def str: Expression = left
override def pattern: Expression = right
override def escape(v: String): String = v override def escape(v: String): String = v
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
override def toString: String = s"$left RLIKE $right" override def toString: String = s"$left RLIKE $right"
protected override def nullSafeEval(input1: Any, input2: Any): Any = nullSafeMatch(input1, input2)
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName val patternClass = classOf[Pattern].getName

View file

@ -1392,9 +1392,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
throw new ParseException("Invalid escape string." + throw new ParseException("Invalid escape string." +
"Escape string must contains only one character.", ctx) "Escape string must contains only one character.", ctx)
} }
str.charAt(0) str
}.getOrElse('\\') }.getOrElse('\\')
invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar)) invertIfNotDefined(Like(e, expression(ctx.pattern), Literal(escapeChar)))
case SqlBaseParser.RLIKE => case SqlBaseParser.RLIKE =>
invertIfNotDefined(RLike(e, expression(ctx.pattern))) invertIfNotDefined(RLike(e, expression(ctx.pattern)))
case SqlBaseParser.NULL if ctx.NOT != null => case SqlBaseParser.NULL if ctx.NOT != null =>

View file

@ -3562,6 +3562,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"), checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
Seq(Row(1))) Seq(Row(1)))
} }
test("the like function with the escape parameter") {
val df = Seq(("abc", "a_c", "!")).toDF("str", "pattern", "escape")
checkAnswer(df.selectExpr("like(str, pattern, '@')"), Row(true))
val longEscapeError = intercept[AnalysisException] {
df.selectExpr("like(str, pattern, '@%')").collect()
}.getMessage
assert(longEscapeError.contains("The 'escape' parameter must be a string literal of one char"))
val nonFoldableError = intercept[AnalysisException] {
df.selectExpr("like(str, pattern, escape)").collect()
}.getMessage
assert(nonFoldableError.contains("The 'escape' parameter must be a string literal"))
}
} }
object DataFrameFunctionsSuite { object DataFrameFunctionsSuite {