diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 40998080bc..b4a8bafe22 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -99,7 +99,7 @@ package object dsl { } def like(other: Expression, escapeChar: Char = '\\'): Expression = - Like(expr, other, Literal(escapeChar.toString)) + Like(expr, other, escapeChar) def rlike(other: Expression): Expression = RLike(expr, other) def contains(other: Expression): Expression = Contains(expr, other) def startsWith(other: Expression): Expression = StartsWith(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index f8d328bf60..e5ee0edfcf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -22,7 +22,6 @@ import java.util.regex.{MatchResult, Pattern} import org.apache.commons.text.StringEscapeUtils -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils} @@ -30,19 +29,17 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -trait StringRegexExpression extends Expression +abstract class StringRegexExpression extends BinaryExpression with ImplicitCastInputTypes with NullIntolerant { - def str: Expression - def pattern: Expression - def escape(v: String): String def matches(regex: Pattern, str: String): Boolean override def dataType: DataType = BooleanType + override def inputTypes: Seq[DataType] = Seq(StringType, StringType) // try cache foldable pattern - private lazy val cache: Pattern = pattern match { + private lazy val cache: Pattern = right match { case p: Expression if p.foldable => compile(p.eval().asInstanceOf[UTF8String].toString) case _ => null @@ -55,9 +52,10 @@ trait StringRegexExpression extends Expression Pattern.compile(escape(str)) } - def nullSafeMatch(input1: Any, input2: Any): Any = { - val s = input2.asInstanceOf[UTF8String].toString - val regex = if (cache == null) compile(s) else cache + protected def pattern(str: String) = if (cache == null) compile(str) else cache + + protected override def nullSafeEval(input1: Any, input2: Any): Any = { + val regex = pattern(input2.asInstanceOf[UTF8String].toString) if(regex == null) { null } else { @@ -65,7 +63,7 @@ trait StringRegexExpression extends Expression } } - override def sql: String = s"${str.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${pattern.sql}" + override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}" } // scalastyle:off line.contains.tab @@ -110,65 +108,46 @@ trait StringRegexExpression extends Expression true > SELECT '%SystemDrive%/Users/John' _FUNC_ '/%SystemDrive/%//Users%' ESCAPE '/'; true - > SELECT _FUNC_('_Apache Spark_', '__%Spark__', '_'); - true """, note = """ Use RLIKE to match with standard regular expressions. """, since = "1.0.0") // scalastyle:on line.contains.tab -case class Like(str: Expression, pattern: Expression, escape: Expression) - extends TernaryExpression with StringRegexExpression { +case class Like(left: Expression, right: Expression, escapeChar: Char) + extends StringRegexExpression { - def this(str: Expression, pattern: Expression) = this(str, pattern, Literal("\\")) - - override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType) - override def children: Seq[Expression] = Seq(str, pattern, escape) - - private lazy val escapeChar: Char = if (escape.foldable) { - escape.eval() match { - case s: UTF8String if s != null && s.numChars() == 1 => s.toString.charAt(0) - case s => throw new AnalysisException( - s"The 'escape' parameter must be a string literal of one char but it is $s.") - } - } else { - throw new AnalysisException("The 'escape' parameter must be a string literal.") - } + def this(left: Expression, right: Expression) = this(left, right, '\\') override def escape(v: String): String = StringUtils.escapeLikeRegex(v, escapeChar) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches() override def toString: String = escapeChar match { - case '\\' => s"$str LIKE $pattern" - case c => s"$str LIKE $pattern ESCAPE '$c'" - } - - protected override def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = { - nullSafeMatch(input1, input2) + case '\\' => s"$left LIKE $right" + case c => s"$left LIKE $right ESCAPE '$c'" } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" - if (pattern.foldable) { - val patternVal = pattern.eval() - if (patternVal != null) { + if (right.foldable) { + val rVal = right.eval() + if (rVal != null) { val regexStr = - StringEscapeUtils.escapeJava(escape(patternVal.asInstanceOf[UTF8String].toString())) - val compiledPattern = ctx.addMutableState(patternClass, "compiledPattern", + StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString())) + val pattern = ctx.addMutableState(patternClass, "patternLike", v => s"""$v = $patternClass.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = str.genCode(ctx) + val eval = left.genCode(ctx) ev.copy(code = code""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; if (!${ev.isNull}) { - ${ev.value} = $compiledPattern.matcher(${eval.value}.toString()).matches(); + ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } """) } else { @@ -178,8 +157,8 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) """) } } else { - val patternStr = ctx.freshName("patternStr") - val compiledPattern = ctx.freshName("compiledPattern") + val pattern = ctx.freshName("pattern") + val rightStr = ctx.freshName("rightStr") // We need double escape to avoid org.codehaus.commons.compiler.CompileException. // '\\' will cause exception 'Single quote must be backslash-escaped in character literal'. // '\"' will cause exception 'Line break in literal not allowed'. @@ -188,12 +167,12 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) } else { escapeChar } - nullSafeCodeGen(ctx, ev, (eval1, eval2, _) => { + nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - String $patternStr = $eval2.toString(); - $patternClass $compiledPattern = $patternClass.compile( - $escapeFunc($patternStr, '$newEscapeChar')); - ${ev.value} = $compiledPattern.matcher($eval1.toString()).matches(); + String $rightStr = $eval2.toString(); + $patternClass $pattern = $patternClass.compile( + $escapeFunc($rightStr, '$newEscapeChar')); + ${ev.value} = $pattern.matcher($eval1.toString()).matches(); """ }) } @@ -232,20 +211,12 @@ case class Like(str: Expression, pattern: Expression, escape: Expression) """, since = "1.0.0") // scalastyle:on line.contains.tab -case class RLike(left: Expression, right: Expression) - extends BinaryExpression with StringRegexExpression { - - override def inputTypes: Seq[DataType] = Seq(StringType, StringType) - - override def str: Expression = left - override def pattern: Expression = right +case class RLike(left: Expression, right: Expression) extends StringRegexExpression { override def escape(v: String): String = v override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" - protected override def nullSafeEval(input1: Any, input2: Any): Any = nullSafeMatch(input1, input2) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 6fc65e1486..62e568587f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1392,9 +1392,9 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging throw new ParseException("Invalid escape string." + "Escape string must contains only one character.", ctx) } - str + str.charAt(0) }.getOrElse('\\') - invertIfNotDefined(Like(e, expression(ctx.pattern), Literal(escapeChar))) + invertIfNotDefined(Like(e, expression(ctx.pattern), escapeChar)) case SqlBaseParser.RLIKE => invertIfNotDefined(RLike(e, expression(ctx.pattern))) case SqlBaseParser.NULL if ctx.NOT != null => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 9e9d8c3e9a..6012678341 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3560,21 +3560,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Seq(Row(1))) } - test("the like function with the escape parameter") { - val df = Seq(("abc", "a_c", "!")).toDF("str", "pattern", "escape") - checkAnswer(df.selectExpr("like(str, pattern, '@')"), Row(true)) - - val longEscapeError = intercept[AnalysisException] { - df.selectExpr("like(str, pattern, '@%')").collect() - }.getMessage - assert(longEscapeError.contains("The 'escape' parameter must be a string literal of one char")) - - val nonFoldableError = intercept[AnalysisException] { - df.selectExpr("like(str, pattern, escape)").collect() - }.getMessage - assert(nonFoldableError.contains("The 'escape' parameter must be a string literal")) - } - test("SPARK-29462: Empty array of NullType for array function with no arguments") { Seq((true, StringType), (false, NullType)).foreach { case (arrayDefaultToString, expectedType) =>