From 116f4cab6b05f4286a08b9f03b8ddfb48ec464cf Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Thu, 21 Jan 2021 08:15:55 -0600 Subject: [PATCH] [SPARK-34094][SQL] Extends StringTranslate to support unicode characters whose code point >= U+10000 ### What changes were proposed in this pull request? This PR extends `StringTranslate` to support unicode characters whose code point >= `U+10000`. ### Why are the changes needed? To make it work with wide variety of characters. ### Does this PR introduce _any_ user-facing change? Yes. Users can use `StringTranslate` with unicode characters whose code point >= `U+10000`. ### How was this patch tested? New assertion added to the existing test. Closes #31164 from sarutak/extends-translate. Authored-by: Kousuke Saruta Signed-off-by: Sean Owen --- .../apache/spark/unsafe/types/UTF8String.java | 18 +++++++----- .../spark/unsafe/types/UTF8StringSuite.java | 20 ++++++------- .../expressions/stringExpressions.scala | 28 +++++++++++++------ .../expressions/StringExpressionsSuite.scala | 7 +++++ 4 files changed, 48 insertions(+), 25 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index cdc6e23f7c..8f4ccb9055 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -1075,16 +1075,20 @@ public final class UTF8String implements Comparable, Externalizable, return buf.build(); } - // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes - public UTF8String translate(Map dict) { + public UTF8String translate(Map dict) { String srcStr = this.toString(); StringBuilder sb = new StringBuilder(); - for(int k = 0; k< srcStr.length(); k++) { - if (null == dict.get(srcStr.charAt(k))) { - sb.append(srcStr.charAt(k)); - } else if ('\0' != dict.get(srcStr.charAt(k))){ - sb.append(dict.get(srcStr.charAt(k))); + int charCount = 0; + for (int k = 0; k < srcStr.length(); k += charCount) { + int codePoint = srcStr.codePointAt(k); + charCount = Character.charCount(codePoint); + String subStr = srcStr.substring(k, k + charCount); + String translated = dict.get(subStr); + if (null == translated) { + sb.append(subStr); + } else if (!"\0".equals(translated)) { + sb.append(translated); } } return fromString(sb.toString()); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 70e276f7e5..ba3e4269e9 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -465,10 +465,10 @@ public class UTF8StringSuite { assertEquals( fromString("1a2s3ae"), fromString("translate").translate(ImmutableMap.of( - 'r', '1', - 'n', '2', - 'l', '3', - 't', '\0' + "r", "1", + "n", "2", + "l", "3", + "t", "\0" ))); assertEquals( fromString("translate"), @@ -476,16 +476,16 @@ public class UTF8StringSuite { assertEquals( fromString("asae"), fromString("translate").translate(ImmutableMap.of( - 'r', '\0', - 'n', '\0', - 'l', '\0', - 't', '\0' + "r", "\0", + "n", "\0", + "l", "\0", + "t", "\0" ))); assertEquals( fromString("aa世b"), fromString("花花世界").translate(ImmutableMap.of( - '花', 'a', - '界', 'b' + "花", "a", + "界", "b" ))); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 9317684d03..974dfc743c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -633,17 +633,29 @@ case class Overlay(input: Expression, replace: Expression, pos: Expression, len: object StringTranslate { def buildDict(matchingString: UTF8String, replaceString: UTF8String) - : JMap[Character, Character] = { + : JMap[String, String] = { val matching = matchingString.toString() val replace = replaceString.toString() - val dict = new HashMap[Character, Character]() + val dict = new HashMap[String, String]() var i = 0 - while (i < matching.length()) { - val rep = if (i < replace.length()) replace.charAt(i) else '\u0000' - if (null == dict.get(matching.charAt(i))) { - dict.put(matching.charAt(i), rep) + var j = 0 + + while (i < matching.length) { + val rep = if (j < replace.length) { + val repCharCount = Character.charCount(replace.codePointAt(j)) + val repStr = replace.substring(j, j + repCharCount) + j += repCharCount + repStr + } else { + "\u0000" } - i += 1 + + val matchCharCount = Character.charCount(matching.codePointAt(i)) + val matchStr = matching.substring(i, i + matchCharCount) + if (null == dict.get(matchStr)) { + dict.put(matchStr, rep) + } + i += matchCharCount } dict } @@ -671,7 +683,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac @transient private var lastMatching: UTF8String = _ @transient private var lastReplace: UTF8String = _ - @transient private var dict: JMap[Character, Character] = _ + @transient private var dict: JMap[String, String] = _ override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = { if (matchingEval != lastMatching || replaceEval != lastReplace) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 11ef1e98c8..bd06c1ca31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -550,6 +550,13 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { // scalastyle:off // non ascii characters are not allowed in the source code, so we disable the scalastyle. checkEvaluation(StringTranslate(Literal("花花世界"), Literal("花界"), Literal("ab")), "aa世b") + // test for unicode characters whose code point >= 0x10000 + checkEvaluation( + StringTranslate( + Literal("\uD840\uDC0Bxyza\uD867\uDE49c123b\uD842\uDFB7\uD867\uDE3D"), + Literal("\uD867\uDE3Da\uD842\uDFB7b\uD840\uDC0Bc\uD867\uDE49c"), + Literal("1\uD83C\uDF3B2\uD83C\uDF37\uD83D\uDC15\uD83D\uDC08\uD83C\uDF38")), + "\uD83D\uDC15xyz\uD83C\uDF3B\uD83C\uDF38\uD83D\uDC08123\uD83C\uDF3721") // scalastyle:on }