[SPARK-16287][SQL] Implement str_to_map SQL function
## What changes were proposed in this pull request? This PR adds `str_to_map` SQL function in order to remove Hive fallback. ## How was this patch tested? Pass the Jenkins tests with newly added. Author: Sandeep Singh <sandeep@techaddict.me> Closes #13990 from techaddict/SPARK-16287.
This commit is contained in:
parent
46f80a3073
commit
df2c6d59d0
|
@ -228,6 +228,7 @@ object FunctionRegistry {
|
|||
expression[Signum]("signum"),
|
||||
expression[Sin]("sin"),
|
||||
expression[Sinh]("sinh"),
|
||||
expression[StringToMap]("str_to_map"),
|
||||
expression[Sqrt]("sqrt"),
|
||||
expression[Tan]("tan"),
|
||||
expression[Tanh]("tanh"),
|
||||
|
|
|
@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, TypeUtils}
|
||||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData, TypeUtils}
|
||||
import org.apache.spark.sql.types._
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
|
@ -393,3 +393,53 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression
|
|||
|
||||
override def prettyName: String = "named_struct_unsafe"
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a map after splitting the input text into key/value pairs using delimiters
|
||||
*/
|
||||
@ExpressionDescription(
|
||||
usage = "_FUNC_(text[, pairDelim, keyValueDelim]) - Creates a map after splitting the text " +
|
||||
"into key/value pairs using delimiters. " +
|
||||
"Default delimiters are ',' for pairDelim and ':' for keyValueDelim.",
|
||||
extended = """ > SELECT _FUNC_('a:1,b:2,c:3',',',':');\n map("a":"1","b":"2","c":"3") """)
|
||||
case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: Expression)
|
||||
extends TernaryExpression with CodegenFallback with ExpectsInputTypes {
|
||||
|
||||
def this(child: Expression, pairDelim: Expression) = {
|
||||
this(child, pairDelim, Literal(":"))
|
||||
}
|
||||
|
||||
def this(child: Expression) = {
|
||||
this(child, Literal(","), Literal(":"))
|
||||
}
|
||||
|
||||
override def children: Seq[Expression] = Seq(text, pairDelim, keyValueDelim)
|
||||
|
||||
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
|
||||
|
||||
override def dataType: DataType = MapType(StringType, StringType, valueContainsNull = false)
|
||||
|
||||
override def checkInputDataTypes(): TypeCheckResult = {
|
||||
if (Seq(pairDelim, keyValueDelim).exists(! _.foldable)) {
|
||||
TypeCheckResult.TypeCheckFailure(s"$prettyName's delimiters must be foldable.")
|
||||
} else {
|
||||
super.checkInputDataTypes()
|
||||
}
|
||||
}
|
||||
|
||||
override def nullSafeEval(str: Any, delim1: Any, delim2: Any): Any = {
|
||||
val array = str.asInstanceOf[UTF8String]
|
||||
.split(delim1.asInstanceOf[UTF8String], -1)
|
||||
.map { kv =>
|
||||
val arr = kv.split(delim2.asInstanceOf[UTF8String], 2)
|
||||
if (arr.length < 2) {
|
||||
Array(arr(0), null)
|
||||
} else {
|
||||
arr
|
||||
}
|
||||
}
|
||||
ArrayBasedMapData(array.map(_ (0)), array.map(_ (1)))
|
||||
}
|
||||
|
||||
override def prettyName: String = "str_to_map"
|
||||
}
|
||||
|
|
|
@ -246,4 +246,40 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkMetadata(CreateStructUnsafe(Seq(a, b)))
|
||||
checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b)))
|
||||
}
|
||||
|
||||
test("StringToMap") {
|
||||
val s0 = Literal("a:1,b:2,c:3")
|
||||
val m0 = Map("a" -> "1", "b" -> "2", "c" -> "3")
|
||||
checkEvaluation(new StringToMap(s0), m0)
|
||||
|
||||
val s1 = Literal("a: ,b:2")
|
||||
val m1 = Map("a" -> " ", "b" -> "2")
|
||||
checkEvaluation(new StringToMap(s1), m1)
|
||||
|
||||
val s2 = Literal("a=1,b=2,c=3")
|
||||
val m2 = Map("a" -> "1", "b" -> "2", "c" -> "3")
|
||||
checkEvaluation(StringToMap(s2, Literal(","), Literal("=")), m2)
|
||||
|
||||
val s3 = Literal("")
|
||||
val m3 = Map[String, String]("" -> null)
|
||||
checkEvaluation(StringToMap(s3, Literal(","), Literal("=")), m3)
|
||||
|
||||
val s4 = Literal("a:1_b:2_c:3")
|
||||
val m4 = Map("a" -> "1", "b" -> "2", "c" -> "3")
|
||||
checkEvaluation(new StringToMap(s4, Literal("_")), m4)
|
||||
|
||||
// arguments checking
|
||||
assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess)
|
||||
assert(new StringToMap(Literal(null)).checkInputDataTypes().isFailure)
|
||||
assert(new StringToMap(Literal("a:1,b:2,c:3"), Literal(null)).checkInputDataTypes().isFailure)
|
||||
assert(StringToMap(Literal("a:1,b:2,c:3"), Literal(null), Literal(null))
|
||||
.checkInputDataTypes().isFailure)
|
||||
assert(new StringToMap(Literal(null), Literal(null)).checkInputDataTypes().isFailure)
|
||||
|
||||
assert(new StringToMap(Literal("a:1_b:2_c:3"), NonFoldableLiteral("_"))
|
||||
.checkInputDataTypes().isFailure)
|
||||
assert(
|
||||
new StringToMap(Literal("a=1_b=2_c=3"), Literal("_"), NonFoldableLiteral("="))
|
||||
.checkInputDataTypes().isFailure)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -384,4 +384,27 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
|
|||
}.getMessage
|
||||
assert(m.contains("Invalid number of arguments for function sentences"))
|
||||
}
|
||||
|
||||
test("str_to_map function") {
|
||||
val df1 = Seq(
|
||||
("a=1,b=2", "y"),
|
||||
("a=1,b=2,c=3", "y")
|
||||
).toDF("a", "b")
|
||||
|
||||
checkAnswer(
|
||||
df1.selectExpr("str_to_map(a,',','=')"),
|
||||
Seq(
|
||||
Row(Map("a" -> "1", "b" -> "2")),
|
||||
Row(Map("a" -> "1", "b" -> "2", "c" -> "3"))
|
||||
)
|
||||
)
|
||||
|
||||
val df2 = Seq(("a:1,b:2,c:3", "y")).toDF("a", "b")
|
||||
|
||||
checkAnswer(
|
||||
df2.selectExpr("str_to_map(a)"),
|
||||
Seq(Row(Map("a" -> "1", "b" -> "2", "c" -> "3")))
|
||||
)
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
@ -238,7 +238,6 @@ private[sql] class HiveSessionCatalog(
|
|||
"hash",
|
||||
"histogram_numeric",
|
||||
"percentile",
|
||||
"percentile_approx",
|
||||
"str_to_map"
|
||||
"percentile_approx"
|
||||
)
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue