[SPARK-28077][SQL] Support ANSI SQL OVERLAY function.
## What changes were proposed in this pull request? The `OVERLAY` function is a `ANSI` `SQL`. For example: ``` SELECT OVERLAY('abcdef' PLACING '45' FROM 4); SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5); SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5 FOR 0); SELECT OVERLAY('babosa' PLACING 'ubb' FROM 2 FOR 4); ``` The results of the above four `SQL` are: ``` abc45f yabadaba yabadabadoo bubba ``` Note: If the input string is null, then the result is null too. There are some mainstream database support the syntax. **PostgreSQL:** https://www.postgresql.org/docs/11/functions-string.html **Vertica:** https://www.vertica.com/docs/9.2.x/HTML/Content/Authoring/SQLReferenceManual/Functions/String/OVERLAY.htm?zoom_highlight=overlay **Oracle:** https://docs.oracle.com/en/database/oracle/oracle-database/19/arpls/UTL_RAW.html#GUID-342E37E7-FE43-4CE1-A0E9-7DAABD000369 **DB2:** https://www.ibm.com/support/knowledgecenter/SSGMCP_5.3.0/com.ibm.cics.rexx.doc/rexx/overlay.html There are some show of the PR on my production environment. ``` spark-sql> SELECT OVERLAY('abcdef' PLACING '45' FROM 4); abc45f Time taken: 6.385 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5); yabadaba Time taken: 0.191 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY('yabadoo' PLACING 'daba' FROM 5 FOR 0); yabadabadoo Time taken: 0.186 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY('babosa' PLACING 'ubb' FROM 2 FOR 4); bubba Time taken: 0.151 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY(null PLACING '45' FROM 4); NULL Time taken: 0.22 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY(null PLACING 'daba' FROM 5); NULL Time taken: 0.157 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY(null PLACING 'daba' FROM 5 FOR 0); NULL Time taken: 0.254 seconds, Fetched 1 row(s) spark-sql> SELECT OVERLAY(null PLACING 'ubb' FROM 2 FOR 4); NULL Time taken: 0.159 seconds, Fetched 1 row(s) ``` ## How was this patch tested? Exists UT and new UT. Closes #24918 from beliefer/ansi-sql-overlay. Lead-authored-by: gengjiaan <gengjiaan@360.cn> Co-authored-by: Jiaan Geng <beliefer@163.com> Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
This commit is contained in:
parent
31e7c37354
commit
832ff87918
|
@ -194,12 +194,14 @@ Below is a list of all the keywords in Spark SQL.
|
|||
<tr><td>OUTPUTFORMAT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
|
||||
<tr><td>OVER</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
|
||||
<tr><td>OVERLAPS</td><td>reserved</td><td>non-reserved</td><td>reserved</td></tr>
|
||||
<tr><td>OVERLAY</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
|
||||
<tr><td>OVERWRITE</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
|
||||
<tr><td>PARTITION</td><td>non-reserved</td><td>non-reserved</td><td>reserved</td></tr>
|
||||
<tr><td>PARTITIONED</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
|
||||
<tr><td>PARTITIONS</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
|
||||
<tr><td>PERCENT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
|
||||
<tr><td>PIVOT</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
|
||||
<tr><td>PLACING</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
|
||||
<tr><td>POSITION</td><td>non-reserved</td><td>non-reserved</td><td>reserved</td></tr>
|
||||
<tr><td>PRECEDING</td><td>non-reserved</td><td>non-reserved</td><td>non-reserved</td></tr>
|
||||
<tr><td>PRIMARY</td><td>reserved</td><td>non-reserved</td><td>reserved</td></tr>
|
||||
|
|
|
@ -705,6 +705,8 @@ primaryExpression
|
|||
((FOR | ',') len=valueExpression)? ')' #substring
|
||||
| TRIM '(' trimOption=(BOTH | LEADING | TRAILING)? (trimStr=valueExpression)?
|
||||
FROM srcStr=valueExpression ')' #trim
|
||||
| OVERLAY '(' input=valueExpression PLACING replace=valueExpression
|
||||
FROM position=valueExpression (FOR length=valueExpression)? ')' #overlay
|
||||
;
|
||||
|
||||
constant
|
||||
|
@ -1002,6 +1004,7 @@ ansiNonReserved
|
|||
| OUT
|
||||
| OUTPUTFORMAT
|
||||
| OVER
|
||||
| OVERLAY
|
||||
| OVERWRITE
|
||||
| PARTITION
|
||||
| PARTITIONED
|
||||
|
@ -1253,12 +1256,14 @@ nonReserved
|
|||
| OUTPUTFORMAT
|
||||
| OVER
|
||||
| OVERLAPS
|
||||
| OVERLAY
|
||||
| OVERWRITE
|
||||
| PARTITION
|
||||
| PARTITIONED
|
||||
| PARTITIONS
|
||||
| PERCENTLIT
|
||||
| PIVOT
|
||||
| PLACING
|
||||
| POSITION
|
||||
| PRECEDING
|
||||
| PRIMARY
|
||||
|
@ -1509,12 +1514,14 @@ OUTER: 'OUTER';
|
|||
OUTPUTFORMAT: 'OUTPUTFORMAT';
|
||||
OVER: 'OVER';
|
||||
OVERLAPS: 'OVERLAPS';
|
||||
OVERLAY: 'OVERLAY';
|
||||
OVERWRITE: 'OVERWRITE';
|
||||
PARTITION: 'PARTITION';
|
||||
PARTITIONED: 'PARTITIONED';
|
||||
PARTITIONS: 'PARTITIONS';
|
||||
PERCENTLIT: 'PERCENT';
|
||||
PIVOT: 'PIVOT';
|
||||
PLACING: 'PLACING';
|
||||
POSITION: 'POSITION';
|
||||
PRECEDING: 'PRECEDING';
|
||||
PRIMARY: 'PRIMARY';
|
||||
|
|
|
@ -348,6 +348,7 @@ object FunctionRegistry {
|
|||
expression[RegExpReplace]("regexp_replace"),
|
||||
expression[StringRepeat]("repeat"),
|
||||
expression[StringReplace]("replace"),
|
||||
expression[Overlay]("overlay"),
|
||||
expression[RLike]("rlike"),
|
||||
expression[StringRPad]("rpad"),
|
||||
expression[StringTrimRight]("rtrim"),
|
||||
|
|
|
@ -68,6 +68,7 @@ import org.apache.spark.sql.types._
|
|||
* - [[UnaryExpression]]: an expression that has one child.
|
||||
* - [[BinaryExpression]]: an expression that has two children.
|
||||
* - [[TernaryExpression]]: an expression that has three children.
|
||||
* - [[QuaternaryExpression]]: an expression that has four children.
|
||||
* - [[BinaryOperator]]: a special case of [[BinaryExpression]] that requires two children to have
|
||||
* the same output data type.
|
||||
*
|
||||
|
@ -757,6 +758,111 @@ abstract class TernaryExpression extends Expression {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An expression with four inputs and one output. The output is by default evaluated to null
|
||||
* if any input is evaluated to null.
|
||||
*/
|
||||
abstract class QuaternaryExpression extends Expression {
|
||||
|
||||
override def foldable: Boolean = children.forall(_.foldable)
|
||||
|
||||
override def nullable: Boolean = children.exists(_.nullable)
|
||||
|
||||
/**
|
||||
* Default behavior of evaluation according to the default nullability of QuaternaryExpression.
|
||||
* If subclass of QuaternaryExpression override nullable, probably should also override this.
|
||||
*/
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val exprs = children
|
||||
val value1 = exprs(0).eval(input)
|
||||
if (value1 != null) {
|
||||
val value2 = exprs(1).eval(input)
|
||||
if (value2 != null) {
|
||||
val value3 = exprs(2).eval(input)
|
||||
if (value3 != null) {
|
||||
val value4 = exprs(3).eval(input)
|
||||
if (value4 != null) {
|
||||
return nullSafeEval(value1, value2, value3, value4)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
null
|
||||
}
|
||||
|
||||
/**
|
||||
* Called by default [[eval]] implementation. If subclass of QuaternaryExpression keep the
|
||||
* default nullability, they can override this method to save null-check code. If we need
|
||||
* full control of evaluation process, we should override [[eval]].
|
||||
*/
|
||||
protected def nullSafeEval(input1: Any, input2: Any, input3: Any, input4: Any): Any =
|
||||
sys.error(s"QuaternaryExpressions must override either eval or nullSafeEval")
|
||||
|
||||
/**
|
||||
* Short hand for generating quaternary evaluation code.
|
||||
* If either of the sub-expressions is null, the result of this computation
|
||||
* is assumed to be null.
|
||||
*
|
||||
* @param f accepts four variable names and returns Java code to compute the output.
|
||||
*/
|
||||
protected def defineCodeGen(
|
||||
ctx: CodegenContext,
|
||||
ev: ExprCode,
|
||||
f: (String, String, String, String) => String): ExprCode = {
|
||||
nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3, eval4) => {
|
||||
s"${ev.value} = ${f(eval1, eval2, eval3, eval4)};"
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Short hand for generating quaternary evaluation code.
|
||||
* If either of the sub-expressions is null, the result of this computation
|
||||
* is assumed to be null.
|
||||
*
|
||||
* @param f function that accepts the 4 non-null evaluation result names of children
|
||||
* and returns Java code to compute the output.
|
||||
*/
|
||||
protected def nullSafeCodeGen(
|
||||
ctx: CodegenContext,
|
||||
ev: ExprCode,
|
||||
f: (String, String, String, String) => String): ExprCode = {
|
||||
val firstGen = children(0).genCode(ctx)
|
||||
val secondGen = children(1).genCode(ctx)
|
||||
val thridGen = children(2).genCode(ctx)
|
||||
val fourthGen = children(3).genCode(ctx)
|
||||
val resultCode = f(firstGen.value, secondGen.value, thridGen.value, fourthGen.value)
|
||||
|
||||
if (nullable) {
|
||||
val nullSafeEval =
|
||||
firstGen.code + ctx.nullSafeExec(children(0).nullable, firstGen.isNull) {
|
||||
secondGen.code + ctx.nullSafeExec(children(1).nullable, secondGen.isNull) {
|
||||
thridGen.code + ctx.nullSafeExec(children(2).nullable, thridGen.isNull) {
|
||||
fourthGen.code + ctx.nullSafeExec(children(3).nullable, fourthGen.isNull) {
|
||||
s"""
|
||||
${ev.isNull} = false; // resultCode could change nullability.
|
||||
$resultCode
|
||||
"""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ev.copy(code = code"""
|
||||
boolean ${ev.isNull} = true;
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$nullSafeEval""")
|
||||
} else {
|
||||
ev.copy(code = code"""
|
||||
${firstGen.code}
|
||||
${secondGen.code}
|
||||
${thridGen.code}
|
||||
${fourthGen.code}
|
||||
${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|
||||
$resultCode""", isNull = FalseLiteral)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A trait used for resolving nullable flags, including `nullable`, `containsNull` of [[ArrayType]]
|
||||
* and `valueContainsNull` of [[MapType]], containsNull, valueContainsNull flags of the output date
|
||||
|
|
|
@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
|
|||
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.UTF8StringBuilder
|
||||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -454,6 +455,69 @@ case class StringReplace(srcExpr: Expression, searchExpr: Expression, replaceExp
|
|||
override def prettyName: String = "replace"
|
||||
}
|
||||
|
||||
object Overlay {
|
||||
|
||||
def calculate(input: UTF8String, replace: UTF8String, pos: Int, len: Int): UTF8String = {
|
||||
val builder = new UTF8StringBuilder
|
||||
builder.append(input.substringSQL(1, pos - 1))
|
||||
builder.append(replace)
|
||||
// If you specify length, it must be a positive whole number or zero.
|
||||
// Otherwise it will be ignored.
|
||||
// The default value for length is the length of replace.
|
||||
val length = if (len >= 0) {
|
||||
len
|
||||
} else {
|
||||
replace.numChars
|
||||
}
|
||||
builder.append(input.substringSQL(pos + length, Int.MaxValue))
|
||||
builder.build()
|
||||
}
|
||||
}
|
||||
|
||||
// scalastyle:off line.size.limit
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(input, replace, pos[, len]) - Replace `input` with `replace` that starts at `pos` and is of length `len`.",
|
||||
examples = """
|
||||
Examples:
|
||||
> SELECT _FUNC_('Spark SQL' PLACING '_' FROM 6);
|
||||
Spark_SQL
|
||||
> SELECT _FUNC_('Spark SQL' PLACING 'CORE' FROM 7);
|
||||
Spark CORE
|
||||
> SELECT _FUNC_('Spark SQL' PLACING 'ANSI ' FROM 7 FOR 0);
|
||||
Spark ANSI SQL
|
||||
> SELECT _FUNC_('Spark SQL' PLACING 'tructured' FROM 2 FOR 4);
|
||||
Structured SQL
|
||||
""")
|
||||
// scalastyle:on line.size.limit
|
||||
case class Overlay(input: Expression, replace: Expression, pos: Expression, len: Expression)
|
||||
extends QuaternaryExpression with ImplicitCastInputTypes with NullIntolerant {
|
||||
|
||||
def this(str: Expression, replace: Expression, pos: Expression) = {
|
||||
this(str, replace, pos, Literal.create(-1, IntegerType))
|
||||
}
|
||||
|
||||
override def dataType: DataType = StringType
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] =
|
||||
Seq(StringType, StringType, IntegerType, IntegerType)
|
||||
|
||||
override def children: Seq[Expression] = input :: replace :: pos :: len :: Nil
|
||||
|
||||
override def nullSafeEval(inputEval: Any, replaceEval: Any, posEval: Any, lenEval: Any): Any = {
|
||||
val inputStr = inputEval.asInstanceOf[UTF8String]
|
||||
val replaceStr = replaceEval.asInstanceOf[UTF8String]
|
||||
val position = posEval.asInstanceOf[Int]
|
||||
val length = lenEval.asInstanceOf[Int]
|
||||
Overlay.calculate(inputStr, replaceStr, position, length)
|
||||
}
|
||||
|
||||
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
|
||||
defineCodeGen(ctx, ev, (input, replace, pos, len) =>
|
||||
"org.apache.spark.sql.catalyst.expressions.Overlay" +
|
||||
s".calculate($input, $replace, $pos, $len);")
|
||||
}
|
||||
}
|
||||
|
||||
object StringTranslate {
|
||||
|
||||
def buildDict(matchingString: UTF8String, replaceString: UTF8String)
|
||||
|
|
|
@ -1421,6 +1421,20 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a Overlay expression.
|
||||
*/
|
||||
override def visitOverlay(ctx: OverlayContext): Expression = withOrigin(ctx) {
|
||||
val input = expression(ctx.input)
|
||||
val replace = expression(ctx.replace)
|
||||
val position = expression(ctx.position)
|
||||
val lengthOpt = Option(ctx.length).map(expression)
|
||||
lengthOpt match {
|
||||
case Some(length) => Overlay(input, replace, position, length)
|
||||
case None => new Overlay(input, replace, position)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a (windowed) Function expression.
|
||||
*/
|
||||
|
|
|
@ -428,6 +428,30 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
// scalastyle:on
|
||||
}
|
||||
|
||||
test("overlay") {
|
||||
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("_"),
|
||||
Literal.create(6, IntegerType)), "Spark_SQL")
|
||||
checkEvaluation(new Overlay(Literal("Spark SQL"), Literal("CORE"),
|
||||
Literal.create(7, IntegerType)), "Spark CORE")
|
||||
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("ANSI "),
|
||||
Literal.create(7, IntegerType), Literal.create(0, IntegerType)), "Spark ANSI SQL")
|
||||
checkEvaluation(Overlay(Literal("Spark SQL"), Literal("tructured"),
|
||||
Literal.create(2, IntegerType), Literal.create(4, IntegerType)), "Structured SQL")
|
||||
checkEvaluation(new Overlay(Literal.create(null, StringType), Literal("_"),
|
||||
Literal.create(6, IntegerType)), null)
|
||||
checkEvaluation(new Overlay(Literal.create(null, StringType), Literal("CORE"),
|
||||
Literal.create(7, IntegerType)), null)
|
||||
checkEvaluation(Overlay(Literal.create(null, StringType), Literal("ANSI "),
|
||||
Literal.create(7, IntegerType), Literal.create(0, IntegerType)), null)
|
||||
checkEvaluation(Overlay(Literal.create(null, StringType), Literal("tructured"),
|
||||
Literal.create(2, IntegerType), Literal.create(4, IntegerType)), null)
|
||||
// scalastyle:off
|
||||
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
|
||||
checkEvaluation(new Overlay(Literal("Spark的SQL"), Literal("_"),
|
||||
Literal.create(6, IntegerType)), "Spark_SQL")
|
||||
// scalastyle:on
|
||||
}
|
||||
|
||||
test("translate") {
|
||||
checkEvaluation(
|
||||
StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae")
|
||||
|
|
|
@ -744,6 +744,35 @@ class PlanParserSuite extends AnalysisTest {
|
|||
)
|
||||
}
|
||||
|
||||
test("OVERLAY function") {
|
||||
def assertOverlayPlans(inputSQL: String, expectedExpression: Expression): Unit = {
|
||||
comparePlans(
|
||||
parsePlan(inputSQL),
|
||||
Project(Seq(UnresolvedAlias(expectedExpression)), OneRowRelation())
|
||||
)
|
||||
}
|
||||
|
||||
assertOverlayPlans(
|
||||
"SELECT OVERLAY('Spark SQL' PLACING '_' FROM 6)",
|
||||
new Overlay(Literal("Spark SQL"), Literal("_"), Literal(6))
|
||||
)
|
||||
|
||||
assertOverlayPlans(
|
||||
"SELECT OVERLAY('Spark SQL' PLACING 'CORE' FROM 7)",
|
||||
new Overlay(Literal("Spark SQL"), Literal("CORE"), Literal(7))
|
||||
)
|
||||
|
||||
assertOverlayPlans(
|
||||
"SELECT OVERLAY('Spark SQL' PLACING 'ANSI ' FROM 7 FOR 0)",
|
||||
Overlay(Literal("Spark SQL"), Literal("ANSI "), Literal(7), Literal(0))
|
||||
)
|
||||
|
||||
assertOverlayPlans(
|
||||
"SELECT OVERLAY('Spark SQL' PLACING 'tructured' FROM 2 FOR 4)",
|
||||
Overlay(Literal("Spark SQL"), Literal("tructured"), Literal(2), Literal(4))
|
||||
)
|
||||
}
|
||||
|
||||
test("precedence of set operations") {
|
||||
val a = table("a").select(star())
|
||||
val b = table("b").select(star())
|
||||
|
|
|
@ -2516,6 +2516,28 @@ object functions {
|
|||
SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Overlay the specified portion of `src` with `replaceString`,
|
||||
* starting from byte position `pos` of `inputString` and proceeding for `len` bytes.
|
||||
*
|
||||
* @group string_funcs
|
||||
* @since 3.0.0
|
||||
*/
|
||||
def overlay(src: Column, replaceString: String, pos: Int, len: Int): Column = withExpr {
|
||||
Overlay(src.expr, lit(replaceString).expr, lit(pos).expr, lit(len).expr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Overlay the specified portion of `src` with `replaceString`,
|
||||
* starting from byte position `pos` of `inputString`.
|
||||
*
|
||||
* @group string_funcs
|
||||
* @since 3.0.0
|
||||
*/
|
||||
def overlay(src: Column, replaceString: String, pos: Int): Column = withExpr {
|
||||
new Overlay(src.expr, lit(replaceString).expr, lit(pos).expr)
|
||||
}
|
||||
|
||||
/**
|
||||
* Translate any character in the src by a character in replaceString.
|
||||
* The characters in replaceString correspond to the characters in matchingString.
|
||||
|
|
|
@ -129,6 +129,18 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
Row("AQIDBA==", bytes))
|
||||
}
|
||||
|
||||
test("overlay function") {
|
||||
// scalastyle:off
|
||||
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
|
||||
val df = Seq(("Spark SQL", "Spark的SQL")).toDF("a", "b")
|
||||
checkAnswer(df.select(overlay($"a", "_", 6)), Row("Spark_SQL"))
|
||||
checkAnswer(df.select(overlay($"a", "CORE", 7)), Row("Spark CORE"))
|
||||
checkAnswer(df.select(overlay($"a", "ANSI ", 7, 0)), Row("Spark ANSI SQL"))
|
||||
checkAnswer(df.select(overlay($"a", "tructured", 2, 4)), Row("Structured SQL"))
|
||||
checkAnswer(df.select(overlay($"b", "_", 6)), Row("Spark_SQL"))
|
||||
// scalastyle:on
|
||||
}
|
||||
|
||||
test("string / binary substring function") {
|
||||
// scalastyle:off
|
||||
// non ascii characters are not allowed in the code, so we disable the scalastyle here.
|
||||
|
|
Loading…
Reference in a new issue