[SPARK-34263][SQL] Simplify the code for treating unicode/octal/escaped characters in string literals
### What changes were proposed in this pull request? In the current master, the code for treating unicode/octal/escaped characters in string literals is a little bit complex so let's simplify it. ### Why are the changes needed? To keep it easy to maintain. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? `ParserUtilsSuite` passes. Closes #31362 from sarutak/refactor-unicode-escapes. Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com> Signed-off-by: Kousuke Saruta <sarutak@oss.nttdata.com>
This commit is contained in:
parent
79515b82f1
commit
d308794adb
|
@ -17,6 +17,7 @@
|
|||
package org.apache.spark.sql.catalyst.parser
|
||||
|
||||
import java.lang.{Long => JLong}
|
||||
import java.nio.CharBuffer
|
||||
import java.util
|
||||
|
||||
import scala.collection.mutable.StringBuilder
|
||||
|
@ -33,6 +34,12 @@ import org.apache.spark.sql.errors.QueryParsingErrors
|
|||
* A collection of utility methods for use during the parsing process.
|
||||
*/
|
||||
object ParserUtils {
|
||||
|
||||
val U16_CHAR_PATTERN = """\\u([a-fA-F0-9]{4})(?s).*""".r
|
||||
val U32_CHAR_PATTERN = """\\U([a-fA-F0-9]{8})(?s).*""".r
|
||||
val OCTAL_CHAR_PATTERN = """\\([01][0-7]{2})(?s).*""".r
|
||||
val ESCAPED_CHAR_PATTERN = """\\((?s).)(?s).*""".r
|
||||
|
||||
/** Get the command which created the token. */
|
||||
def command(ctx: ParserRuleContext): String = {
|
||||
val stream = ctx.getStart.getInputStream
|
||||
|
@ -131,7 +138,6 @@ object ParserUtils {
|
|||
|
||||
/** Unescape backslash-escaped string enclosed by quotes. */
|
||||
def unescapeSQLString(b: String): String = {
|
||||
var enclosure: Character = null
|
||||
val sb = new StringBuilder(b.length())
|
||||
|
||||
def appendEscapedChar(n: Char): Unit = {
|
||||
|
@ -152,34 +158,19 @@ object ParserUtils {
|
|||
}
|
||||
}
|
||||
|
||||
var i = 0
|
||||
val strLength = b.length
|
||||
while (i < strLength) {
|
||||
val currentChar = b.charAt(i)
|
||||
if (enclosure == null) {
|
||||
if (currentChar == '\'' || currentChar == '\"') {
|
||||
enclosure = currentChar
|
||||
}
|
||||
} else if (enclosure == currentChar) {
|
||||
enclosure = null
|
||||
} else if (currentChar == '\\') {
|
||||
// Skip the first and last quotations enclosing the string literal.
|
||||
val charBuffer = CharBuffer.wrap(b, 1, b.length - 1)
|
||||
|
||||
if ((i + 6 < strLength) && b.charAt(i + 1) == 'u') {
|
||||
while (charBuffer.remaining() > 0) {
|
||||
charBuffer match {
|
||||
case U16_CHAR_PATTERN(cp) =>
|
||||
// \u0000 style 16-bit unicode character literals.
|
||||
|
||||
val base = i + 2
|
||||
val code = (0 until 4).foldLeft(0) { (mid, j) =>
|
||||
val digit = Character.digit(b.charAt(j + base), 16)
|
||||
(mid << 4) + digit
|
||||
}
|
||||
sb.append(code.asInstanceOf[Char])
|
||||
i += 5
|
||||
} else if ((i + 10 < strLength) && b.charAt(i + 1) == 'U' &&
|
||||
(2 until 10).forall(j => Character.digit(b.charAt(i + j), 16) != -1)) {
|
||||
sb.append(Integer.parseInt(cp, 16).toChar)
|
||||
charBuffer.position(charBuffer.position() + 6)
|
||||
case U32_CHAR_PATTERN(cp) =>
|
||||
// \U00000000 style 32-bit unicode character literals.
|
||||
|
||||
// Use Long to treat codePoint as unsigned in the range of 32-bit.
|
||||
val codePoint = JLong.parseLong(b.substring(i + 2, i + 10), 16)
|
||||
val codePoint = JLong.parseLong(cp, 16)
|
||||
if (codePoint < 0x10000) {
|
||||
sb.append((codePoint & 0xFFFF).toChar)
|
||||
} else {
|
||||
|
@ -188,33 +179,19 @@ object ParserUtils {
|
|||
sb.append(highSurrogate.toChar)
|
||||
sb.append(lowSurrogate.toChar)
|
||||
}
|
||||
i += 9
|
||||
} else if (i + 4 < strLength) {
|
||||
charBuffer.position(charBuffer.position() + 10)
|
||||
case OCTAL_CHAR_PATTERN(cp) =>
|
||||
// \000 style character literals.
|
||||
|
||||
val i1 = b.charAt(i + 1)
|
||||
val i2 = b.charAt(i + 2)
|
||||
val i3 = b.charAt(i + 3)
|
||||
|
||||
if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) {
|
||||
val tmp = ((i3 - '0') + ((i2 - '0') << 3) + ((i1 - '0') << 6)).asInstanceOf[Char]
|
||||
sb.append(tmp)
|
||||
i += 3
|
||||
} else {
|
||||
appendEscapedChar(i1)
|
||||
i += 1
|
||||
}
|
||||
} else if (i + 2 < strLength) {
|
||||
sb.append(Integer.parseInt(cp, 8).toChar)
|
||||
charBuffer.position(charBuffer.position() + 4)
|
||||
case ESCAPED_CHAR_PATTERN(c) =>
|
||||
// escaped character literals.
|
||||
val n = b.charAt(i + 1)
|
||||
appendEscapedChar(n)
|
||||
i += 1
|
||||
}
|
||||
} else {
|
||||
// non-escaped character literals.
|
||||
sb.append(currentChar)
|
||||
appendEscapedChar(c.charAt(0))
|
||||
charBuffer.position(charBuffer.position() + 2)
|
||||
case _ =>
|
||||
// non-escaped character literals.
|
||||
sb.append(charBuffer.get())
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
sb.toString()
|
||||
}
|
||||
|
|
|
@ -105,6 +105,13 @@ class ParserUtilsSuite extends SparkFunSuite {
|
|||
|
||||
// String including surrogate pair characters (U+1F408 is a cat and U+1F415 is a dog in Emoji).
|
||||
assert(unescapeSQLString("\"\\U0001F408 \\U0001F415\"") == "\uD83D\uDC08 \uD83D\uDC15")
|
||||
|
||||
// String including escaped normal characters.
|
||||
assert(unescapeSQLString(
|
||||
""""ab\
|
||||
|cd\ef"""".stripMargin) ==
|
||||
"""ab
|
||||
|cdef""".stripMargin)
|
||||
// scalastyle:on nonascii
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue