[SPARK-16285][SQL] Implement sentences SQL functions

## What changes were proposed in this pull request?

This PR implements `sentences` SQL function.

## How was this patch tested?

Pass the Jenkins tests with a new testcase.

Author: Dongjoon Hyun <dongjoon@apache.org>

Closes #14004 from dongjoon-hyun/SPARK_16285.
This commit is contained in:
Dongjoon Hyun 2016-07-08 17:05:24 +08:00 committed by Wenchen Fan
parent 8228b06303
commit a54438cb23
5 changed files with 111 additions and 3 deletions

View file

@ -296,6 +296,7 @@ object FunctionRegistry {
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
expression[Sentences]("sentences"),
expression[SoundEx]("soundex"),
expression[StringSpace]("space"),
expression[StringSplit]("split"),

View file

@ -17,13 +17,15 @@
package org.apache.spark.sql.catalyst.expressions
import java.text.{DecimalFormat, DecimalFormatSymbols}
import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.util.{HashMap, Locale, Map => JMap}
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@ -1188,3 +1190,65 @@ case class FormatNumber(x: Expression, d: Expression)
override def prettyName: String = "format_number"
}
/**
* Splits a string into arrays of sentences, where each sentence is an array of words.
* The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used.
*/
@ExpressionDescription(
usage = "_FUNC_(str[, lang, country]) - Splits str into an array of array of words.",
extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]")
case class Sentences(
str: Expression,
language: Expression = Literal(""),
country: Expression = Literal(""))
extends Expression with ImplicitCastInputTypes with CodegenFallback {
def this(str: Expression) = this(str, Literal(""), Literal(""))
def this(str: Expression, language: Expression) = this(str, language, Literal(""))
override def nullable: Boolean = true
override def dataType: DataType =
ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
override def children: Seq[Expression] = str :: language :: country :: Nil
override def eval(input: InternalRow): Any = {
val string = str.eval(input)
if (string == null) {
null
} else {
val languageStr = language.eval(input).asInstanceOf[UTF8String]
val countryStr = country.eval(input).asInstanceOf[UTF8String]
val locale = if (languageStr != null && countryStr != null) {
new Locale(languageStr.toString, countryStr.toString)
} else {
Locale.getDefault
}
getSentences(string.asInstanceOf[UTF8String].toString, locale)
}
}
private def getSentences(sentences: String, locale: Locale) = {
val bi = BreakIterator.getSentenceInstance(locale)
bi.setText(sentences)
var idx = 0
val result = new ArrayBuffer[GenericArrayData]
while (bi.next != BreakIterator.DONE) {
val sentence = sentences.substring(idx, bi.current)
idx = bi.current
val wi = BreakIterator.getWordInstance(locale)
var widx = 0
wi.setText(sentence)
val words = new ArrayBuffer[UTF8String]
while (wi.next != BreakIterator.DONE) {
val word = sentence.substring(widx, wi.current)
widx = wi.current
if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word)
}
result += new GenericArrayData(words)
}
new GenericArrayData(result)
}
}

View file

@ -725,4 +725,27 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0)
checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
}
test("Sentences") {
val nullString = Literal.create(null, StringType)
checkEvaluation(Sentences(nullString, nullString, nullString), null)
checkEvaluation(Sentences(nullString, nullString), null)
checkEvaluation(Sentences(nullString), null)
checkEvaluation(Sentences(Literal.create(null, NullType)), null)
checkEvaluation(Sentences("", nullString, nullString), Seq.empty)
checkEvaluation(Sentences("", nullString), Seq.empty)
checkEvaluation(Sentences(""), Seq.empty)
val answer = Seq(
Seq("Hi", "there"),
Seq("The", "price", "was"),
Seq("But", "not", "now"))
checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now."), answer)
checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), answer)
checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"),
answer)
checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXX", "YYY"),
answer)
}
}

View file

@ -349,4 +349,24 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
df2.filter("b>0").selectExpr("format_number(a, b)"),
Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil)
}
test("string sentences function") {
val df = Seq(("Hi there! The price was $1,234.56.... But, not now.", "en", "US"))
.toDF("str", "language", "country")
checkAnswer(
df.selectExpr("sentences(str, language, country)"),
Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))
// Type coercion
checkAnswer(
df.selectExpr("sentences(null)", "sentences(10)", "sentences(3.14)"),
Row(null, Seq(Seq("10")), Seq(Seq("3.14"))))
// Argument number exception
val m = intercept[AnalysisException] {
df.selectExpr("sentences()")
}.getMessage
assert(m.contains("Invalid number of arguments for function sentences"))
}
}

View file

@ -236,7 +236,7 @@ private[sql] class HiveSessionCatalog(
// str_to_map, windowingtablefunction.
private val hiveFunctions = Seq(
"hash", "java_method", "histogram_numeric",
"parse_url", "percentile", "percentile_approx", "reflect", "sentences", "str_to_map",
"parse_url", "percentile", "percentile_approx", "reflect", "str_to_map",
"xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
"xpath_number", "xpath_short", "xpath_string"
)