[SPARK-30625][SQL] Support escape
as third parameter of the like
function
### What changes were proposed in this pull request?
In the PR, I propose to transform the `Like` expression to `TernaryExpression`, and add third parameter `escape`. So, the `like` function will have feature parity with `LIKE ... ESCAPE` syntax supported by 187f3c1773
.
### Why are the changes needed?
The `like` functions can be called with 2 or 3 parameters, and functionally equivalent to `LIKE` and `LIKE ... ESCAPE` SQL expressions.
### Does this PR introduce any user-facing change?
Yes, before `like` fails with the exception:
```sql
spark-sql> SELECT like('_Apache Spark_', '__%Spark__', '_');
Error in query: Invalid number of arguments for function like. Expected: 2; Found: 3; line 1 pos 7
```
After:
```sql
spark-sql> SELECT like('_Apache Spark_', '__%Spark__', '_');
true
```
### How was this patch tested?
- Add new example for the `like` function which is checked by `SQLQuerySuite`
- Run `RegexpExpressionsSuite` and `ExpressionParserSuite`.
Closes #27355 from MaxGekk/like-3-args.
Authored-by: Maxim Gekk <max.gekk@gmail.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
parent
c5c580ba0d
commit
8aebc80e0e
|
@ -99,7 +99,7 @@ package object dsl {
|
||||||
}
|
}
|
||||||
|
|
||||||
def like(other: Expression, escapeChar: Char = '\\'): Expression =
|
def like(other: Expression, escapeChar: Char = '\\'): Expression =
|
||||||
Like(expr, other, escapeChar)
|
Like(expr, other, Literal(escapeChar.toString))
|
||||||
def rlike(other: Expression): Expression = RLike(expr, other)
|
def rlike(other: Expression): Expression = RLike(expr, other)
|
||||||
def contains(other: Expression): Expression = Contains(expr, other)
|
def contains(other: Expression): Expression = Contains(expr, other)
|
||||||
def startsWith(other: Expression): Expression = StartsWith(expr, other)
|
def startsWith(other: Expression): Expression = StartsWith(expr, other)
|
||||||
|
|
|
@ -22,6 +22,7 @@ import java.util.regex.{MatchResult, Pattern}
|
||||||
|
|
||||||
import org.apache.commons.text.StringEscapeUtils
|
import org.apache.commons.text.StringEscapeUtils
|
||||||
|
|
||||||
|
import org.apache.spark.sql.AnalysisException
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||||
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
|
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
|
||||||
|
@ -29,18 +30,20 @@ import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.unsafe.types.UTF8String
|
import org.apache.spark.unsafe.types.UTF8String
|
||||||
|
|
||||||
|
|
||||||
abstract class StringRegexExpression extends BinaryExpression
|
trait StringRegexExpression extends Expression
|
||||||
with ImplicitCastInputTypes with NullIntolerant {
|
with ImplicitCastInputTypes with NullIntolerant {
|
||||||
|
|
||||||
|
def str: Expression
|
||||||
|
def pattern: Expression
|
||||||
|
|
||||||
def escape(v: String): String
|
def escape(v: String): String
|
||||||
def matches(regex: Pattern, str: String): Boolean
|
def matches(regex: Pattern, str: String): Boolean
|
||||||
|
|
||||||
override def dataType: DataType = BooleanType
|
override def dataType: DataType = BooleanType
|
||||||
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
|
|
||||||
|
|
||||||
// try cache the pattern for Literal
|
// try cache the pattern for Literal
|
||||||
private lazy val cache: Pattern = right match {
|
private lazy val cache: Pattern = pattern match {
|
||||||
case x @ Literal(value: String, StringType) => compile(value)
|
case Literal(value: String, StringType) => compile(value)
|
||||||
case _ => null
|
case _ => null
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -51,10 +54,9 @@ abstract class StringRegexExpression extends BinaryExpression
|
||||||
Pattern.compile(escape(str))
|
Pattern.compile(escape(str))
|
||||||
}
|
}
|
||||||
|
|
||||||
protected def pattern(str: String) = if (cache == null) compile(str) else cache
|
def nullSafeMatch(input1: Any, input2: Any): Any = {
|
||||||
|
val s = input2.asInstanceOf[UTF8String].toString
|
||||||
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
|
val regex = if (cache == null) compile(s) else cache
|
||||||
val regex = pattern(input2.asInstanceOf[UTF8String].toString)
|
|
||||||
if(regex == null) {
|
if(regex == null) {
|
||||||
null
|
null
|
||||||
} else {
|
} else {
|
||||||
|
@ -62,7 +64,7 @@ abstract class StringRegexExpression extends BinaryExpression
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}"
|
override def sql: String = s"${str.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${pattern.sql}"
|
||||||
}
|
}
|
||||||
|
|
||||||
// scalastyle:off line.contains.tab
|
// scalastyle:off line.contains.tab
|
||||||
|
@ -107,46 +109,65 @@ abstract class StringRegexExpression extends BinaryExpression
|
||||||
true
|
true
|
||||||
> SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/';
|
> SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/';
|
||||||
true
|
true
|
||||||
|
> SELECT _FUNC_('_Apache Spark_', '__%Spark__', '_');
|
||||||
|
true
|
||||||
""",
|
""",
|
||||||
note = """
|
note = """
|
||||||
Use RLIKE to match with standard regular expressions.
|
Use RLIKE to match with standard regular expressions.
|
||||||
""",
|
""",
|
||||||
since = "1.0.0")
|
since = "1.0.0")
|
||||||
// scalastyle:on line.contains.tab
|
// scalastyle:on line.contains.tab
|
||||||
case class Like(left: Expression, right: Expression, escapeChar: Char)
|
case class Like(str: Expression, pattern: Expression, escape: Expression)
|
||||||
extends StringRegexExpression {
|
extends TernaryExpression with StringRegexExpression {
|
||||||
|
|
||||||
def this(left: Expression, right: Expression) = this(left, right, '\\')
|
def this(str: Expression, pattern: Expression) = this(str, pattern, Literal("\\"))
|
||||||
|
|
||||||
|
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
|
||||||
|
override def children: Seq[Expression] = Seq(str, pattern, escape)
|
||||||
|
|
||||||
|
private lazy val escapeChar: Char = if (escape.foldable) {
|
||||||
|
escape.eval() match {
|
||||||
|
case s: UTF8String if s != null && s.numChars() == 1 => s.toString.charAt(0)
|
||||||
|
case s => throw new AnalysisException(
|
||||||
|
s"The 'escape' parameter must be a string literal of one char but it is $s.")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new AnalysisException("The 'escape' parameter must be a string literal.")
|
||||||
|
}
|
||||||
|
|
||||||
override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar)
|
override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar)
|
||||||
|
|
||||||
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
|
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
|
||||||
|
|
||||||
override def toString: String = escapeChar match {
|
override def toString: String = escapeChar match {
|
||||||
case '\\' => s"$left LIKE $right"
|
case '\\' => s"$str LIKE $pattern"
|
||||||
case c => s"$left LIKE $right ESCAPE '$c'"
|
case c => s"$str LIKE $pattern ESCAPE '$c'"
|
||||||
|
}
|
||||||
|
|
||||||
|
protected override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
|
||||||
|
nullSafeMatch(input1, input2)
|
||||||
}
|
}
|
||||||
|
|
||||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
val patternClass = classOf[Pattern].getName
|
val patternClass = classOf[Pattern].getName
|
||||||
val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex"
|
val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex"
|
||||||
|
|
||||||
if (right.foldable) {
|
if (pattern.foldable) {
|
||||||
val rVal = right.eval()
|
val patternVal = pattern.eval()
|
||||||
if (rVal != null) {
|
if (patternVal != null) {
|
||||||
val regexStr =
|
val regexStr =
|
||||||
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
|
StringEscapeUtils.escapeJava(escape(patternVal.asInstanceOf[UTF8String].toString()))
|
||||||
val pattern = ctx.addMutableState(patternClass, "patternLike",
|
val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern",
|
||||||
v => s"""$v = $patternClass.compile("$regexStr");""")
|
v => s"""$v = $patternClass.compile("$regexStr");""")
|
||||||
|
|
||||||
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
|
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
|
||||||
val eval = left.genCode(ctx)
|
val eval = str.genCode(ctx)
|
||||||
ev.copy(code = code"""
|
ev.copy(code = code"""
|
||||||
${eval.code}
|
${eval.code}
|
||||||
boolean ${ev.isNull} = ${eval.isNull};
|
boolean ${ev.isNull} = ${eval.isNull};
|
||||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||||
if (!${ev.isNull}) {
|
if (!${ev.isNull}) {
|
||||||
${ev.value} = $pattern.matcher(${eval.value}.toString()).matches();
|
${ev.value} = $compiledPattern.matcher(${eval.value}.toString()).matches();
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
} else {
|
} else {
|
||||||
|
@ -164,18 +185,18 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
|
||||||
} else {
|
} else {
|
||||||
escapeChar
|
escapeChar
|
||||||
}
|
}
|
||||||
val rightStr = ctx.freshName("rightStr")
|
val patternStr = ctx.freshName("patternStr")
|
||||||
val pattern = ctx.addMutableState(patternClass, "pattern")
|
val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern")
|
||||||
val lastRightStr = ctx.addMutableState(classOf[String].getName, "lastRightStr")
|
val lastPatternStr = ctx.addMutableState(classOf[String].getName, "lastPatternStr")
|
||||||
|
|
||||||
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
|
nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => {
|
||||||
s"""
|
s"""
|
||||||
String $rightStr = $eval2.toString();
|
String $patternStr = $eval2.toString();
|
||||||
if (!$rightStr.equals($lastRightStr)) {
|
if (!$patternStr.equals($lastPatternStr)) {
|
||||||
$pattern = $patternClass.compile($escapeFunc($rightStr, '$newEscapeChar'));
|
$compiledPattern = $patternClass.compile($escapeFunc($patternStr, '$newEscapeChar'));
|
||||||
$lastRightStr = $rightStr;
|
$lastPatternStr = $patternStr;
|
||||||
}
|
}
|
||||||
${ev.value} = $pattern.matcher($eval1.toString()).matches();
|
${ev.value} = $compiledPattern.matcher($eval1.toString()).matches();
|
||||||
"""
|
"""
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -214,12 +235,20 @@ case class Like(left: Expression, right: Expression, escapeChar: Char)
|
||||||
""",
|
""",
|
||||||
since = "1.0.0")
|
since = "1.0.0")
|
||||||
// scalastyle:on line.contains.tab
|
// scalastyle:on line.contains.tab
|
||||||
case class RLike(left: Expression, right: Expression) extends StringRegexExpression {
|
case class RLike(left: Expression, right: Expression)
|
||||||
|
extends BinaryExpression with StringRegexExpression {
|
||||||
|
|
||||||
|
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
|
||||||
|
|
||||||
|
override def str: Expression = left
|
||||||
|
override def pattern: Expression = right
|
||||||
|
|
||||||
override def escape(v: String): String = v
|
override def escape(v: String): String = v
|
||||||
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
|
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
|
||||||
override def toString: String = s"$left RLIKE $right"
|
override def toString: String = s"$left RLIKE $right"
|
||||||
|
|
||||||
|
protected override def nullSafeEval(input1: Any, input2: Any): Any = nullSafeMatch(input1, input2)
|
||||||
|
|
||||||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||||
val patternClass = classOf[Pattern].getName
|
val patternClass = classOf[Pattern].getName
|
||||||
|
|
||||||
|
|
|
@ -1392,9 +1392,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
|
||||||
throw new ParseException("Invalid escape string." +
|
throw new ParseException("Invalid escape string." +
|
||||||
"Escape string must contains only one character.", ctx)
|
"Escape string must contains only one character.", ctx)
|
||||||
}
|
}
|
||||||
str.charAt(0)
|
str
|
||||||
}.getOrElse('\\')
|
}.getOrElse('\\')
|
||||||
invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar))
|
invertIfNotDefined(Like(e, expression(ctx.pattern), Literal(escapeChar)))
|
||||||
case SqlBaseParser.RLIKE =>
|
case SqlBaseParser.RLIKE =>
|
||||||
invertIfNotDefined(RLike(e, expression(ctx.pattern)))
|
invertIfNotDefined(RLike(e, expression(ctx.pattern)))
|
||||||
case SqlBaseParser.NULL if ctx.NOT != null =>
|
case SqlBaseParser.NULL if ctx.NOT != null =>
|
||||||
|
|
|
@ -3562,6 +3562,21 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
|
||||||
checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
|
checkAnswer(df.select("x").filter("exists(i, x -> x % d == 0)"),
|
||||||
Seq(Row(1)))
|
Seq(Row(1)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("the like function with the escape parameter") {
|
||||||
|
val df = Seq(("abc", "a_c", "!")).toDF("str", "pattern", "escape")
|
||||||
|
checkAnswer(df.selectExpr("like(str, pattern, '@')"), Row(true))
|
||||||
|
|
||||||
|
val longEscapeError = intercept[AnalysisException] {
|
||||||
|
df.selectExpr("like(str, pattern, '@%')").collect()
|
||||||
|
}.getMessage
|
||||||
|
assert(longEscapeError.contains("The 'escape' parameter must be a string literal of one char"))
|
||||||
|
|
||||||
|
val nonFoldableError = intercept[AnalysisException] {
|
||||||
|
df.selectExpr("like(str, pattern, escape)").collect()
|
||||||
|
}.getMessage
|
||||||
|
assert(nonFoldableError.contains("The 'escape' parameter must be a string literal"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object DataFrameFunctionsSuite {
|
object DataFrameFunctionsSuite {
|
||||||
|
|
Loading…
Reference in a new issue