[SPARK-28077][SQL] Support ANSI SQL OVERLAY function.

## What changes were proposed in this pull request?

The `OVERLAY` function is a `ANSI` `SQL`.
For example:
```
SELECT OVERLAY('abcdef' PLACING '45' FROM 4);

SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5);

SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5 FOR 0);

SELECT OVERLAY('babosa' PLACING 'ubb' FROM 2 FOR 4);
```
The results of the above four `SQL` are:
```
abc45f
yabadaba
yabadabadoo
bubba
```

Note: If the input string is null, then the result is null too.

There are some mainstream database support the syntax.
**PostgreSQL:**
https://www.postgresql.org/docs/11/functions-string.html

**Vertica:** https://www.vertica.com/docs/9.2.x/HTML/Content/Authoring/SQLReferenceManual/Functions/String/OVERLAY.htm?zoom_highlight=overlay

**Oracle:**
https://docs.oracle.com/en/database/oracle/oracle-database/19/arpls/UTL_RAW.html#GUID-342E37E7-FE43-4CE1-A0E9-7DAABD000369

**DB2:**
https://www.ibm.com/support/knowledgecenter/SSGMCP_5.3.0/com.ibm.cics.rexx.doc/rexx/overlay.html

There are some show of the PR on my production environment.
```
spark-sql> SELECT OVERLAY('abcdef' PLACING '45' FROM 4);
abc45f
Time taken: 6.385 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5);
yabadaba
Time taken: 0.191 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5 FOR 0);
yabadabadoo
Time taken: 0.186 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY('babosa' PLACING 'ubb' FROM 2 FOR 4);
bubba
Time taken: 0.151 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY(null PLACING '45' FROM 4);
NULL
Time taken: 0.22 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY(null PLACING 'daba' FROM 5);
NULL
Time taken: 0.157 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY(null PLACING 'daba' FROM 5 FOR 0);
NULL
Time taken: 0.254 seconds, Fetched 1 row(s)
spark-sql> SELECT OVERLAY(null PLACING 'ubb' FROM 2 FOR 4);
NULL
Time taken: 0.159 seconds, Fetched 1 row(s)
```

## How was this patch tested?

Exists UT and new UT.

Closes #24918 from beliefer/ansi-sql-overlay.

Lead-authored-by: gengjiaan <gengjiaan@360.cn>
Co-authored-by: Jiaan Geng <beliefer@163.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
gengjiaan 2019-06-28 19:13:08 +09:00 committed by Takuya UESHIN
parent 31e7c37354
commit 832ff87918
10 changed files with 281 additions and 0 deletions

View file

@ -194,12 +194,14 @@ Below is a list of all the keywords in Spark SQL.
<tr><td>OUTPUTFORMAT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>OVER</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>OVERLAPS</td><td>reserved</td><td>non-reserved</td><td>reserved</td></tr>
<tr><td>OVERLAY</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>OVERWRITE</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PARTITION</td><td>non-reserved</td><td>non-reserved</td><td>reserved</td></tr>
<tr><td>PARTITIONED</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PARTITIONS</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PERCENT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PIVOT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PLACING</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>POSITION</td><td>non-reserved</td><td>non-reserved</td><td>reserved</td></tr>
<tr><td>PRECEDING</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
<tr><td>PRIMARY</td><td>reserved</td><td>non-reserved</td><td>reserved</td></tr>

View file

@ -705,6 +705,8 @@ primaryExpression
((FOR | ',') len=valueExpression)? ')' #substring
| TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)?
FROM srcStr=valueExpression ')' #trim
| OVERLAY '(' input=valueExpression PLACING replace=valueExpression
FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay
;
constant
@ -1002,6 +1004,7 @@ ansiNonReserved
| OUT
| OUTPUTFORMAT
| OVER
| OVERLAY
| OVERWRITE
| PARTITION
| PARTITIONED
@ -1253,12 +1256,14 @@ nonReserved
| OUTPUTFORMAT
| OVER
| OVERLAPS
| OVERLAY
| OVERWRITE
| PARTITION
| PARTITIONED
| PARTITIONS
| PERCENTLIT
| PIVOT
| PLACING
| POSITION
| PRECEDING
| PRIMARY
@ -1509,12 +1514,14 @@ OUTER: 'OUTER';
OUTPUTFORMAT: 'OUTPUTFORMAT';
OVER: 'OVER';
OVERLAPS: 'OVERLAPS';
OVERLAY: 'OVERLAY';
OVERWRITE: 'OVERWRITE';
PARTITION: 'PARTITION';
PARTITIONED: 'PARTITIONED';
PARTITIONS: 'PARTITIONS';
PERCENTLIT: 'PERCENT';
PIVOT: 'PIVOT';
PLACING: 'PLACING';
POSITION: 'POSITION';
PRECEDING: 'PRECEDING';
PRIMARY: 'PRIMARY';

View file

@ -348,6 +348,7 @@ object FunctionRegistry {
expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReplace]("replace"),
expression[Overlay]("overlay"),
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),

View file

@ -68,6 +68,7 @@ import org.apache.spark.sql.types._
* - [[UnaryExpression]]: an expression that has one child.
* - [[BinaryExpression]]: an expression that has two children.
* - [[TernaryExpression]]: an expression that has three children.
* - [[QuaternaryExpression]]: an expression that has four children.
* - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have
* the same output data type.
*
@ -757,6 +758,111 @@ abstract class TernaryExpression extends Expression {
}
}
/**
* An expression with four inputs and one output. The output is by default evaluated to null
* if any input is evaluated to null.
*/
abstract class QuaternaryExpression extends Expression {
override def foldable: Boolean = children.forall(_.foldable)
override def nullable: Boolean = children.exists(_.nullable)
/**
* Default behavior of evaluation according to the default nullability of QuaternaryExpression.
* If subclass of QuaternaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
val value1 = exprs(0).eval(input)
if (value1 != null) {
val value2 = exprs(1).eval(input)
if (value2 != null) {
val value3 = exprs(2).eval(input)
if (value3 != null) {
val value4 = exprs(3).eval(input)
if (value4 != null) {
return nullSafeEval(value1, value2, value3, value4)
}
}
}
}
null
}
/**
* Called by default [[eval]] implementation. If subclass of QuaternaryExpression keep the
* default nullability, they can override this method to save null-check code. If we need
* full control of evaluation process, we should override [[eval]].
*/
protected def nullSafeEval(input1: Any, input2: Any, input3: Any, input4: Any): Any =
sys.error(s"QuaternaryExpressions must override either eval or nullSafeEval")
/**
* Short hand for generating quaternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f accepts four variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String, String) => String): ExprCode = {
nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3, eval4) => {
s"${ev.value} = ${f(eval1, eval2, eval3, eval4)};"
})
}
/**
* Short hand for generating quaternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
* @param f function that accepts the 4 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String, String) => String): ExprCode = {
val firstGen = children(0).genCode(ctx)
val secondGen = children(1).genCode(ctx)
val thridGen = children(2).genCode(ctx)
val fourthGen = children(3).genCode(ctx)
val resultCode = f(firstGen.value, secondGen.value, thridGen.value, fourthGen.value)
if (nullable) {
val nullSafeEval =
firstGen.code + ctx.nullSafeExec(children(0).nullable, firstGen.isNull) {
secondGen.code + ctx.nullSafeExec(children(1).nullable, secondGen.isNull) {
thridGen.code + ctx.nullSafeExec(children(2).nullable, thridGen.isNull) {
fourthGen.code + ctx.nullSafeExec(children(3).nullable, fourthGen.isNull) {
s"""
${ev.isNull} = false; // resultCode could change nullability.
$resultCode
"""
}
}
}
}
ev.copy(code = code"""
boolean ${ev.isNull} = true;
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$nullSafeEval""")
} else {
ev.copy(code = code"""
${firstGen.code}
${secondGen.code}
${thridGen.code}
${fourthGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
$resultCode""", isNull = FalseLiteral)
}
}
}
/**
* A trait used for resolving nullable flags, including `nullable`, `containsNull` of [[ArrayType]]
* and `valueContainsNull` of [[MapType]], containsNull, valueContainsNull flags of the output date

View file

@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -454,6 +455,69 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp
override def prettyName: String = "replace"
}
object Overlay {
def calculate(input: UTF8String, replace: UTF8String, pos: Int, len: Int): UTF8String = {
val builder = new UTF8StringBuilder
builder.append(input.substringSQL(1, pos - 1))
builder.append(replace)
// If you specify length, it must be a positive whole number or zero.
// Otherwise it will be ignored.
// The default value for length is the length of replace.
val length = if (len >= 0) {
len
} else {
replace.numChars
}
builder.append(input.substringSQL(pos + length, Int.MaxValue))
builder.build()
}
}
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(input, replace, pos[, len]) - Replace `input` with `replace` that starts at `pos` and is of length `len`.",
examples = """
Examples:
> SELECT _FUNC_('Spark SQL' PLACING '_' FROM 6);
Spark_SQL
> SELECT _FUNC_('Spark SQL' PLACING 'CORE' FROM 7);
Spark CORE
> SELECT _FUNC_('Spark SQL' PLACING 'ANSI ' FROM 7 FOR 0);
Spark ANSI SQL
> SELECT _FUNC_('Spark SQL' PLACING 'tructured' FROM 2 FOR 4);
Structured SQL
""")
// scalastyle:on line.size.limit
case class Overlay(input: Expression, replace: Expression, pos: Expression, len: Expression)
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
def this(str: Expression, replace: Expression, pos: Expression) = {
this(str, replace, pos, Literal.create(-1, IntegerType))
}
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] =
Seq(StringType, StringType, IntegerType, IntegerType)
override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil
override def nullSafeEval(inputEval: Any, replaceEval: Any, posEval: Any, lenEval: Any): Any = {
val inputStr = inputEval.asInstanceOf[UTF8String]
val replaceStr = replaceEval.asInstanceOf[UTF8String]
val position = posEval.asInstanceOf[Int]
val length = lenEval.asInstanceOf[Int]
Overlay.calculate(inputStr, replaceStr, position, length)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (input, replace, pos, len) =>
"org.apache.spark.sql.catalyst.expressions.Overlay" +
s".calculate($input, $replace, $pos, $len);")
}
}
object StringTranslate {
def buildDict(matchingString: UTF8String, replaceString: UTF8String)

View file

@ -1421,6 +1421,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
}
}
/**
* Create a Overlay expression.
*/
override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) {
val input = expression(ctx.input)
val replace = expression(ctx.replace)
val position = expression(ctx.position)
val lengthOpt = Option(ctx.length).map(expression)
lengthOpt match {
case Some(length) => Overlay(input, replace, position, length)
case None => new Overlay(input, replace, position)
}
}
/**
* Create a (windowed) Function expression.
*/

View file

@ -428,6 +428,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}
test("overlay") {
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("_"),
Literal.create(6, IntegerType)), "Spark_SQL")
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("CORE"),
Literal.create(7, IntegerType)), "Spark CORE")
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("ANSI "),
Literal.create(7, IntegerType), Literal.create(0, IntegerType)), "Spark ANSI SQL")
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("tructured"),
Literal.create(2, IntegerType), Literal.create(4, IntegerType)), "Structured SQL")
checkEvaluation(new Overlay(Literal.create(null, StringType), Literal("_"),
Literal.create(6, IntegerType)), null)
checkEvaluation(new Overlay(Literal.create(null, StringType), Literal("CORE"),
Literal.create(7, IntegerType)), null)
checkEvaluation(Overlay(Literal.create(null, StringType), Literal("ANSI "),
Literal.create(7, IntegerType), Literal.create(0, IntegerType)), null)
checkEvaluation(Overlay(Literal.create(null, StringType), Literal("tructured"),
Literal.create(2, IntegerType), Literal.create(4, IntegerType)), null)
// scalastyle:off
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
checkEvaluation(new Overlay(Literal("Spark的SQL"), Literal("_"),
Literal.create(6, IntegerType)), "Spark_SQL")
// scalastyle:on
}
test("translate") {
checkEvaluation(
StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae")

View file

@ -744,6 +744,35 @@ class PlanParserSuite extends AnalysisTest {
)
}
test("OVERLAY function") {
def assertOverlayPlans(inputSQL: String, expectedExpression: Expression): Unit = {
comparePlans(
parsePlan(inputSQL),
Project(Seq(UnresolvedAlias(expectedExpression)), OneRowRelation())
)
}
assertOverlayPlans(
"SELECT OVERLAY('Spark SQL' PLACING '_' FROM 6)",
new Overlay(Literal("Spark SQL"), Literal("_"), Literal(6))
)
assertOverlayPlans(
"SELECT OVERLAY('Spark SQL' PLACING 'CORE' FROM 7)",
new Overlay(Literal("Spark SQL"), Literal("CORE"), Literal(7))
)
assertOverlayPlans(
"SELECT OVERLAY('Spark SQL' PLACING 'ANSI ' FROM 7 FOR 0)",
Overlay(Literal("Spark SQL"), Literal("ANSI "), Literal(7), Literal(0))
)
assertOverlayPlans(
"SELECT OVERLAY('Spark SQL' PLACING 'tructured' FROM 2 FOR 4)",
Overlay(Literal("Spark SQL"), Literal("tructured"), Literal(2), Literal(4))
)
}
test("precedence of set operations") {
val a = table("a").select(star())
val b = table("b").select(star())

View file

@ -2516,6 +2516,28 @@ object functions {
SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
}
/**
* Overlay the specified portion of `src` with `replaceString`,
* starting from byte position `pos` of `inputString` and proceeding for `len` bytes.
*
* @group string_funcs
* @since 3.0.0
*/
def overlay(src: Column, replaceString: String, pos: Int, len: Int): Column = withExpr {
Overlay(src.expr, lit(replaceString).expr, lit(pos).expr, lit(len).expr)
}
/**
* Overlay the specified portion of `src` with `replaceString`,
* starting from byte position `pos` of `inputString`.
*
* @group string_funcs
* @since 3.0.0
*/
def overlay(src: Column, replaceString: String, pos: Int): Column = withExpr {
new Overlay(src.expr, lit(replaceString).expr, lit(pos).expr)
}
/**
* Translate any character in the src by a character in replaceString.
* The characters in replaceString correspond to the characters in matchingString.

View file

@ -129,6 +129,18 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row("AQIDBA==", bytes))
}
test("overlay function") {
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
val df = Seq(("Spark SQL", "Spark的SQL")).toDF("a", "b")
checkAnswer(df.select(overlay($"a", "_", 6)), Row("Spark_SQL"))
checkAnswer(df.select(overlay($"a", "CORE", 7)), Row("Spark CORE"))
checkAnswer(df.select(overlay($"a", "ANSI ", 7, 0)), Row("Spark ANSI SQL"))
checkAnswer(df.select(overlay($"a", "tructured", 2, 4)), Row("Structured SQL"))
checkAnswer(df.select(overlay($"b", "_", 6)), Row("Spark_SQL"))
// scalastyle:on
}
test("string / binary substring function") {
// scalastyle:off
// non ascii characters are not allowed in the code, so we disable the scalastyle here.