[SPARK-16285][SQL] Implement sentences SQL functions
## What changes were proposed in this pull request? This PR implements `sentences` SQL function. ## How was this patch tested? Pass the Jenkins tests with a new testcase. Author: Dongjoon Hyun <dongjoon@apache.org> Closes #14004 from dongjoon-hyun/SPARK_16285.
This commit is contained in:
parent
8228b06303
commit
a54438cb23
|
@ -296,6 +296,7 @@ object FunctionRegistry {
|
|||
expression[RLike]("rlike"),
|
||||
expression[StringRPad]("rpad"),
|
||||
expression[StringTrimRight]("rtrim"),
|
||||
expression[Sentences]("sentences"),
|
||||
expression[SoundEx]("soundex"),
|
||||
expression[StringSpace]("space"),
|
||||
expression[StringSplit]("split"),
|
||||
|
|
|
@ -17,13 +17,15 @@
|
|||
|
||||
package org.apache.spark.sql.catalyst.expressions
|
||||
|
||||
import java.text.{DecimalFormat, DecimalFormatSymbols}
|
||||
import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
|
||||
import java.util.{HashMap, Locale, Map => JMap}
|
||||
|
||||
import scala.collection.mutable.ArrayBuffer
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.ArrayData
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
|
||||
|
||||
|
@ -1188,3 +1190,65 @@ case class FormatNumber(x: Expression, d: Expression)
|
|||
|
||||
override def prettyName: String = "format_number"
|
||||
}
|
||||
|
||||
/**
|
||||
* Splits a string into arrays of sentences, where each sentence is an array of words.
|
||||
* The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used.
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(str[, lang, country]) - Splits str into an array of array of words.",
|
||||
extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]")
|
||||
case class Sentences(
|
||||
str: Expression,
|
||||
language: Expression = Literal(""),
|
||||
country: Expression = Literal(""))
|
||||
extends Expression with ImplicitCastInputTypes with CodegenFallback {
|
||||
|
||||
def this(str: Expression) = this(str, Literal(""), Literal(""))
|
||||
def this(str: Expression, language: Expression) = this(str, language, Literal(""))
|
||||
|
||||
override def nullable: Boolean = true
|
||||
override def dataType: DataType =
|
||||
ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
|
||||
override def children: Seq[Expression] = str :: language :: country :: Nil
|
||||
|
||||
override def eval(input: InternalRow): Any = {
|
||||
val string = str.eval(input)
|
||||
if (string == null) {
|
||||
null
|
||||
} else {
|
||||
val languageStr = language.eval(input).asInstanceOf[UTF8String]
|
||||
val countryStr = country.eval(input).asInstanceOf[UTF8String]
|
||||
val locale = if (languageStr != null && countryStr != null) {
|
||||
new Locale(languageStr.toString, countryStr.toString)
|
||||
} else {
|
||||
Locale.getDefault
|
||||
}
|
||||
getSentences(string.asInstanceOf[UTF8String].toString, locale)
|
||||
}
|
||||
}
|
||||
|
||||
private def getSentences(sentences: String, locale: Locale) = {
|
||||
val bi = BreakIterator.getSentenceInstance(locale)
|
||||
bi.setText(sentences)
|
||||
var idx = 0
|
||||
val result = new ArrayBuffer[GenericArrayData]
|
||||
while (bi.next != BreakIterator.DONE) {
|
||||
val sentence = sentences.substring(idx, bi.current)
|
||||
idx = bi.current
|
||||
|
||||
val wi = BreakIterator.getWordInstance(locale)
|
||||
var widx = 0
|
||||
wi.setText(sentence)
|
||||
val words = new ArrayBuffer[UTF8String]
|
||||
while (wi.next != BreakIterator.DONE) {
|
||||
val word = sentence.substring(widx, wi.current)
|
||||
widx = wi.current
|
||||
if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word)
|
||||
}
|
||||
result += new GenericArrayData(words)
|
||||
}
|
||||
new GenericArrayData(result)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -725,4 +725,27 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0)
|
||||
checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
|
||||
}
|
||||
|
||||
test("Sentences") {
|
||||
val nullString = Literal.create(null, StringType)
|
||||
checkEvaluation(Sentences(nullString, nullString, nullString), null)
|
||||
checkEvaluation(Sentences(nullString, nullString), null)
|
||||
checkEvaluation(Sentences(nullString), null)
|
||||
checkEvaluation(Sentences(Literal.create(null, NullType)), null)
|
||||
checkEvaluation(Sentences("", nullString, nullString), Seq.empty)
|
||||
checkEvaluation(Sentences("", nullString), Seq.empty)
|
||||
checkEvaluation(Sentences(""), Seq.empty)
|
||||
|
||||
val answer = Seq(
|
||||
Seq("Hi", "there"),
|
||||
Seq("The", "price", "was"),
|
||||
Seq("But", "not", "now"))
|
||||
|
||||
checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now."), answer)
|
||||
checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), answer)
|
||||
checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"),
|
||||
answer)
|
||||
checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXX", "YYY"),
|
||||
answer)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -349,4 +349,24 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
df2.filter("b>0").selectExpr("format_number(a, b)"),
|
||||
Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil)
|
||||
}
|
||||
|
||||
test("string sentences function") {
|
||||
val df = Seq(("Hi there! The price was $1,234.56.... But, not now.", "en", "US"))
|
||||
.toDF("str", "language", "country")
|
||||
|
||||
checkAnswer(
|
||||
df.selectExpr("sentences(str, language, country)"),
|
||||
Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))
|
||||
|
||||
// Type coercion
|
||||
checkAnswer(
|
||||
df.selectExpr("sentences(null)", "sentences(10)", "sentences(3.14)"),
|
||||
Row(null, Seq(Seq("10")), Seq(Seq("3.14"))))
|
||||
|
||||
// Argument number exception
|
||||
val m = intercept[AnalysisException] {
|
||||
df.selectExpr("sentences()")
|
||||
}.getMessage
|
||||
assert(m.contains("Invalid number of arguments for function sentences"))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog(
|
|||
// str_to_map, windowingtablefunction.
|
||||
private val hiveFunctions = Seq(
|
||||
"hash", "java_method", "histogram_numeric",
|
||||
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "str_to_map",
|
||||
"parse_url", "percentile", "percentile_approx", "reflect", "str_to_map",
|
||||
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
|
||||
"xpath_number", "xpath_short", "xpath_string"
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue