[SPARK-16287][SQL] Implement str_to_map SQL function

## What changes were proposed in this pull request?
This PR adds `str_to_map` SQL function in order to remove Hive fallback.

## How was this patch tested?
Pass the Jenkins tests with newly added.

Author: Sandeep Singh <sandeep@techaddict.me>

Closes #13990 from techaddict/SPARK-16287.
This commit is contained in:
Sandeep Singh 2016-07-22 10:05:21 +08:00 committed by Wenchen Fan
parent 46f80a3073
commit df2c6d59d0
5 changed files with 112 additions and 3 deletions

View file

@ -228,6 +228,7 @@ object FunctionRegistry {
expression[Signum]("signum"),
expression[Sin]("sin"),
expression[Sinh]("sinh"),
expression[StringToMap]("str_to_map"),
expression[Sqrt]("sqrt"),
expression[Tan]("tan"),
expression[Tanh]("tanh"),

View file

@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@ -393,3 +393,53 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression
override def prettyName: String = "named_struct_unsafe"
}
/**
* Creates a map after splitting the input text into key/value pairs using delimiters
*/
@ExpressionDescription(
usage = "_FUNC_(text[, pairDelim, keyValueDelim]) - Creates a map after splitting the text " +
"into key/value pairs using delimiters. " +
"Default delimiters are ',' for pairDelim and ':' for keyValueDelim.",
extended = """ > SELECT _FUNC_('a:1,b:2,c:3',',',':');\n map("a":"1","b":"2","c":"3") """)
case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression)
extends TernaryExpression with CodegenFallback with ExpectsInputTypes {
def this(child: Expression, pairDelim: Expression) = {
this(child, pairDelim, Literal(":"))
}
def this(child: Expression) = {
this(child, Literal(","), Literal(":"))
}
override def children: Seq[Expression] = Seq(text, pairDelim, keyValueDelim)
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
override def dataType: DataType = MapType(StringType, StringType, valueContainsNull = false)
override def checkInputDataTypes(): TypeCheckResult = {
if (Seq(pairDelim, keyValueDelim).exists(! _.foldable)) {
TypeCheckResult.TypeCheckFailure(s"$prettyName's delimiters must be foldable.")
} else {
super.checkInputDataTypes()
}
}
override def nullSafeEval(str: Any, delim1: Any, delim2: Any): Any = {
val array = str.asInstanceOf[UTF8String]
.split(delim1.asInstanceOf[UTF8String], -1)
.map { kv =>
val arr = kv.split(delim2.asInstanceOf[UTF8String], 2)
if (arr.length < 2) {
Array(arr(0), null)
} else {
arr
}
}
ArrayBasedMapData(array.map(_ (0)), array.map(_ (1)))
}
override def prettyName: String = "str_to_map"
}

View file

@ -246,4 +246,40 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
checkMetadata(CreateStructUnsafe(Seq(a, b)))
checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b)))
}
test("StringToMap") {
val s0 = Literal("a:1,b:2,c:3")
val m0 = Map("a" -> "1", "b" -> "2", "c" -> "3")
checkEvaluation(new StringToMap(s0), m0)
val s1 = Literal("a: ,b:2")
val m1 = Map("a" -> " ", "b" -> "2")
checkEvaluation(new StringToMap(s1), m1)
val s2 = Literal("a=1,b=2,c=3")
val m2 = Map("a" -> "1", "b" -> "2", "c" -> "3")
checkEvaluation(StringToMap(s2, Literal(","), Literal("=")), m2)
val s3 = Literal("")
val m3 = Map[String, String]("" -> null)
checkEvaluation(StringToMap(s3, Literal(","), Literal("=")), m3)
val s4 = Literal("a:1_b:2_c:3")
val m4 = Map("a" -> "1", "b" -> "2", "c" -> "3")
checkEvaluation(new StringToMap(s4, Literal("_")), m4)
// arguments checking
assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess)
assert(new StringToMap(Literal(null)).checkInputDataTypes().isFailure)
assert(new StringToMap(Literal("a:1,b:2,c:3"), Literal(null)).checkInputDataTypes().isFailure)
assert(StringToMap(Literal("a:1,b:2,c:3"), Literal(null), Literal(null))
.checkInputDataTypes().isFailure)
assert(new StringToMap(Literal(null), Literal(null)).checkInputDataTypes().isFailure)
assert(new StringToMap(Literal("a:1_b:2_c:3"), NonFoldableLiteral("_"))
.checkInputDataTypes().isFailure)
assert(
new StringToMap(Literal("a=1_b=2_c=3"), Literal("_"), NonFoldableLiteral("="))
.checkInputDataTypes().isFailure)
}
}

View file

@ -384,4 +384,27 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
}.getMessage
assert(m.contains("Invalid number of arguments for function sentences"))
}
test("str_to_map function") {
val df1 = Seq(
("a=1,b=2", "y"),
("a=1,b=2,c=3", "y")
).toDF("a", "b")
checkAnswer(
df1.selectExpr("str_to_map(a,',','=')"),
Seq(
Row(Map("a" -> "1", "b" -> "2")),
Row(Map("a" -> "1", "b" -> "2", "c" -> "3"))
)
)
val df2 = Seq(("a:1,b:2,c:3", "y")).toDF("a", "b")
checkAnswer(
df2.selectExpr("str_to_map(a)"),
Seq(Row(Map("a" -> "1", "b" -> "2", "c" -> "3")))
)
}
}

View file

@ -238,7 +238,6 @@ private[sql] class HiveSessionCatalog(
"hash",
"histogram_numeric",
"percentile",
"percentile_approx",
"str_to_map"
"percentile_approx"
)
}