[SPARK-34263][SQL] Simplify the code for treating unicode/octal/escaped characters in string literals

### What changes were proposed in this pull request?

In the current master, the code for treating unicode/octal/escaped characters in string literals is a little bit complex so let's simplify it.

### Why are the changes needed?

To keep it easy to maintain.

### Does this PR introduce _any_ user-facing change?

No.

### How was this patch tested?

`ParserUtilsSuite` passes.

Closes #31362 from sarutak/refactor-unicode-escapes.

Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com>
Signed-off-by: Kousuke Saruta <sarutak@oss.nttdata.com>
This commit is contained in:
Kousuke Saruta 2021-02-03 01:07:12 +09:00
parent 79515b82f1
commit d308794adb
2 changed files with 33 additions and 49 deletions

View file

@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.parser
import java.lang.{Long => JLong}
import java.nio.CharBuffer
import java.util
import scala.collection.mutable.StringBuilder
@ -33,6 +34,12 @@ import org.apache.spark.sql.errors.QueryParsingErrors
* A collection of utility methods for use during the parsing process.
*/
object ParserUtils {
val U16_CHAR_PATTERN = """\\u([a-fA-F0-9]{4})(?s).*""".r
val U32_CHAR_PATTERN = """\\U([a-fA-F0-9]{8})(?s).*""".r
val OCTAL_CHAR_PATTERN = """\\([01][0-7]{2})(?s).*""".r
val ESCAPED_CHAR_PATTERN = """\\((?s).)(?s).*""".r
/** Get the command which created the token. */
def command(ctx: ParserRuleContext): String = {
val stream = ctx.getStart.getInputStream
@ -131,7 +138,6 @@ object ParserUtils {
/** Unescape backslash-escaped string enclosed by quotes. */
def unescapeSQLString(b: String): String = {
var enclosure: Character = null
val sb = new StringBuilder(b.length())
def appendEscapedChar(n: Char): Unit = {
@ -152,34 +158,19 @@ object ParserUtils {
}
}
var i = 0
val strLength = b.length
while (i < strLength) {
val currentChar = b.charAt(i)
if (enclosure == null) {
if (currentChar == '\'' || currentChar == '\"') {
enclosure = currentChar
}
} else if (enclosure == currentChar) {
enclosure = null
} else if (currentChar == '\\') {
// Skip the first and last quotations enclosing the string literal.
val charBuffer = CharBuffer.wrap(b, 1, b.length - 1)
if ((i + 6 < strLength) && b.charAt(i + 1) == 'u') {
while (charBuffer.remaining() > 0) {
charBuffer match {
case U16_CHAR_PATTERN(cp) =>
// \u0000 style 16-bit unicode character literals.
val base = i + 2
val code = (0 until 4).foldLeft(0) { (mid, j) =>
val digit = Character.digit(b.charAt(j + base), 16)
(mid << 4) + digit
}
sb.append(code.asInstanceOf[Char])
i += 5
} else if ((i + 10 < strLength) && b.charAt(i + 1) == 'U' &&
(2 until 10).forall(j => Character.digit(b.charAt(i + j), 16) != -1)) {
sb.append(Integer.parseInt(cp, 16).toChar)
charBuffer.position(charBuffer.position() + 6)
case U32_CHAR_PATTERN(cp) =>
// \U00000000 style 32-bit unicode character literals.
// Use Long to treat codePoint as unsigned in the range of 32-bit.
val codePoint = JLong.parseLong(b.substring(i + 2, i + 10), 16)
val codePoint = JLong.parseLong(cp, 16)
if (codePoint < 0x10000) {
sb.append((codePoint & 0xFFFF).toChar)
} else {
@ -188,33 +179,19 @@ object ParserUtils {
sb.append(highSurrogate.toChar)
sb.append(lowSurrogate.toChar)
}
i += 9
} else if (i + 4 < strLength) {
charBuffer.position(charBuffer.position() + 10)
case OCTAL_CHAR_PATTERN(cp) =>
// \000 style character literals.
val i1 = b.charAt(i + 1)
val i2 = b.charAt(i + 2)
val i3 = b.charAt(i + 3)
if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) {
val tmp = ((i3 - '0') + ((i2 - '0') << 3) + ((i1 - '0') << 6)).asInstanceOf[Char]
sb.append(tmp)
i += 3
} else {
appendEscapedChar(i1)
i += 1
}
} else if (i + 2 < strLength) {
sb.append(Integer.parseInt(cp, 8).toChar)
charBuffer.position(charBuffer.position() + 4)
case ESCAPED_CHAR_PATTERN(c) =>
// escaped character literals.
val n = b.charAt(i + 1)
appendEscapedChar(n)
i += 1
}
} else {
// non-escaped character literals.
sb.append(currentChar)
appendEscapedChar(c.charAt(0))
charBuffer.position(charBuffer.position() + 2)
case _ =>
// non-escaped character literals.
sb.append(charBuffer.get())
}
i += 1
}
sb.toString()
}

View file

@ -105,6 +105,13 @@ class ParserUtilsSuite extends SparkFunSuite {
// String including surrogate pair characters (U+1F408 is a cat and U+1F415 is a dog in Emoji).
assert(unescapeSQLString("\"\\U0001F408 \\U0001F415\"") == "\uD83D\uDC08 \uD83D\uDC15")
// String including escaped normal characters.
assert(unescapeSQLString(
""""ab\
|cd\ef"""".stripMargin) ==
"""ab
|cdef""".stripMargin)
// scalastyle:on nonascii
}