[SPARK-34094][SQL] Extends StringTranslate to support unicode characters whose code point >= U+10000
### What changes were proposed in this pull request? This PR extends `StringTranslate` to support unicode characters whose code point >= `U+10000`. ### Why are the changes needed? To make it work with wide variety of characters. ### Does this PR introduce _any_ user-facing change? Yes. Users can use `StringTranslate` with unicode characters whose code point >= `U+10000`. ### How was this patch tested? New assertion added to the existing test. Closes #31164 from sarutak/extends-translate. Authored-by: Kousuke Saruta <sarutak@oss.nttdata.com> Signed-off-by: Sean Owen <srowen@gmail.com>
This commit is contained in:
parent
28131a7794
commit
116f4cab6b
|
@ -1075,16 +1075,20 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
|
|||
return buf.build();
|
||||
}
|
||||
|
||||
// TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes
|
||||
public UTF8String translate(Map<Character, Character> dict) {
|
||||
public UTF8String translate(Map<String, String> dict) {
|
||||
String srcStr = this.toString();
|
||||
|
||||
StringBuilder sb = new StringBuilder();
|
||||
for(int k = 0; k< srcStr.length(); k++) {
|
||||
if (null == dict.get(srcStr.charAt(k))) {
|
||||
sb.append(srcStr.charAt(k));
|
||||
} else if ('\0' != dict.get(srcStr.charAt(k))){
|
||||
sb.append(dict.get(srcStr.charAt(k)));
|
||||
int charCount = 0;
|
||||
for (int k = 0; k < srcStr.length(); k += charCount) {
|
||||
int codePoint = srcStr.codePointAt(k);
|
||||
charCount = Character.charCount(codePoint);
|
||||
String subStr = srcStr.substring(k, k + charCount);
|
||||
String translated = dict.get(subStr);
|
||||
if (null == translated) {
|
||||
sb.append(subStr);
|
||||
} else if (!"\0".equals(translated)) {
|
||||
sb.append(translated);
|
||||
}
|
||||
}
|
||||
return fromString(sb.toString());
|
||||
|
|
|
@ -465,10 +465,10 @@ public class UTF8StringSuite {
|
|||
assertEquals(
|
||||
fromString("1a2s3ae"),
|
||||
fromString("translate").translate(ImmutableMap.of(
|
||||
'r', '1',
|
||||
'n', '2',
|
||||
'l', '3',
|
||||
't', '\0'
|
||||
"r", "1",
|
||||
"n", "2",
|
||||
"l", "3",
|
||||
"t", "\0"
|
||||
)));
|
||||
assertEquals(
|
||||
fromString("translate"),
|
||||
|
@ -476,16 +476,16 @@ public class UTF8StringSuite {
|
|||
assertEquals(
|
||||
fromString("asae"),
|
||||
fromString("translate").translate(ImmutableMap.of(
|
||||
'r', '\0',
|
||||
'n', '\0',
|
||||
'l', '\0',
|
||||
't', '\0'
|
||||
"r", "\0",
|
||||
"n", "\0",
|
||||
"l", "\0",
|
||||
"t", "\0"
|
||||
)));
|
||||
assertEquals(
|
||||
fromString("aa世b"),
|
||||
fromString("花花世界").translate(ImmutableMap.of(
|
||||
'花', 'a',
|
||||
'界', 'b'
|
||||
"花", "a",
|
||||
"界", "b"
|
||||
)));
|
||||
}
|
||||
|
||||
|
|
|
@ -633,17 +633,29 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len:
|
|||
object StringTranslate {
|
||||
|
||||
def buildDict(matchingString: UTF8String, replaceString: UTF8String)
|
||||
: JMap[Character, Character] = {
|
||||
: JMap[String, String] = {
|
||||
val matching = matchingString.toString()
|
||||
val replace = replaceString.toString()
|
||||
val dict = new HashMap[Character, Character]()
|
||||
val dict = new HashMap[String, String]()
|
||||
var i = 0
|
||||
while (i < matching.length()) {
|
||||
val rep = if (i < replace.length()) replace.charAt(i) else '\u0000'
|
||||
if (null == dict.get(matching.charAt(i))) {
|
||||
dict.put(matching.charAt(i), rep)
|
||||
var j = 0
|
||||
|
||||
while (i < matching.length) {
|
||||
val rep = if (j < replace.length) {
|
||||
val repCharCount = Character.charCount(replace.codePointAt(j))
|
||||
val repStr = replace.substring(j, j + repCharCount)
|
||||
j += repCharCount
|
||||
repStr
|
||||
} else {
|
||||
"\u0000"
|
||||
}
|
||||
i += 1
|
||||
|
||||
val matchCharCount = Character.charCount(matching.codePointAt(i))
|
||||
val matchStr = matching.substring(i, i + matchCharCount)
|
||||
if (null == dict.get(matchStr)) {
|
||||
dict.put(matchStr, rep)
|
||||
}
|
||||
i += matchCharCount
|
||||
}
|
||||
dict
|
||||
}
|
||||
|
@ -671,7 +683,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac
|
|||
|
||||
@transient private var lastMatching: UTF8String = _
|
||||
@transient private var lastReplace: UTF8String = _
|
||||
@transient private var dict: JMap[Character, Character] = _
|
||||
@transient private var dict: JMap[String, String] = _
|
||||
|
||||
override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = {
|
||||
if (matchingEval != lastMatching || replaceEval != lastReplace) {
|
||||
|
|
|
@ -550,6 +550,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
// scalastyle:off
|
||||
// non ascii characters are not allowed in the source code, so we disable the scalastyle.
|
||||
checkEvaluation(StringTranslate(Literal("花花世界"), Literal("花界"), Literal("ab")), "aa世b")
|
||||
// test for unicode characters whose code point >= 0x10000
|
||||
checkEvaluation(
|
||||
StringTranslate(
|
||||
Literal("\uD840\uDC0Bxyza\uD867\uDE49c123b\uD842\uDFB7\uD867\uDE3D"),
|
||||
Literal("\uD867\uDE3Da\uD842\uDFB7b\uD840\uDC0Bc\uD867\uDE49c"),
|
||||
Literal("1\uD83C\uDF3B2\uD83C\uDF37\uD83D\uDC15\uD83D\uDC08\uD83C\uDF38")),
|
||||
"\uD83D\uDC15xyz\uD83C\uDF3B\uD83C\uDF38\uD83D\uDC08123\uD83C\uDF3721")
|
||||
// scalastyle:on
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in a new issue