[SPARK-19254][SQL] Support Seq, Map, and Struct in functions.lit
## What changes were proposed in this pull request? This pr is to support Seq, Map, and Struct in functions.lit; it adds a new IF named `lit2` with `TypeTag` for avoiding type erasure. ## How was this patch tested? Added tests in `LiteralExpressionSuite` Author: Takeshi Yamamuro <yamamuro@apache.org> Author: Takeshi YAMAMURO <linguin.m.s@gmail.com> Closes #16610 from maropu/SPARK-19254.
This commit is contained in:
parent
f48461ab2b
commit
14bb398fae
|
@ -32,11 +32,13 @@ import java.util.Objects
|
|||
import javax.xml.bind.DatatypeConverter
|
||||
|
||||
import scala.math.{BigDecimal, BigInt}
|
||||
import scala.reflect.runtime.universe.TypeTag
|
||||
import scala.util.Try
|
||||
|
||||
import org.json4s.JsonAST._
|
||||
|
||||
import org.apache.spark.sql.AnalysisException
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.expressions.codegen._
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -153,6 +155,14 @@ object Literal {
|
|||
Literal(CatalystTypeConverters.convertToCatalyst(v), dataType)
|
||||
}
|
||||
|
||||
def create[T : TypeTag](v: T): Literal = Try {
|
||||
val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
|
||||
val convert = CatalystTypeConverters.createToCatalystConverter(dataType)
|
||||
Literal(convert(v), dataType)
|
||||
}.getOrElse {
|
||||
Literal(v)
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a literal with default value for given DataType
|
||||
*/
|
||||
|
|
|
@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions
|
|||
|
||||
import java.nio.charset.StandardCharsets
|
||||
|
||||
import scala.reflect.runtime.universe.{typeTag, TypeTag}
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.catalyst.CatalystTypeConverters
|
||||
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection}
|
||||
import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT
|
||||
import org.apache.spark.sql.catalyst.util.DateTimeUtils
|
||||
import org.apache.spark.sql.types._
|
||||
|
@ -75,6 +77,9 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
test("boolean literals") {
|
||||
checkEvaluation(Literal(true), true)
|
||||
checkEvaluation(Literal(false), false)
|
||||
|
||||
checkEvaluation(Literal.create(true), true)
|
||||
checkEvaluation(Literal.create(false), false)
|
||||
}
|
||||
|
||||
test("int literals") {
|
||||
|
@ -83,36 +88,60 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
checkEvaluation(Literal(d.toLong), d.toLong)
|
||||
checkEvaluation(Literal(d.toShort), d.toShort)
|
||||
checkEvaluation(Literal(d.toByte), d.toByte)
|
||||
|
||||
checkEvaluation(Literal.create(d), d)
|
||||
checkEvaluation(Literal.create(d.toLong), d.toLong)
|
||||
checkEvaluation(Literal.create(d.toShort), d.toShort)
|
||||
checkEvaluation(Literal.create(d.toByte), d.toByte)
|
||||
}
|
||||
checkEvaluation(Literal(Long.MinValue), Long.MinValue)
|
||||
checkEvaluation(Literal(Long.MaxValue), Long.MaxValue)
|
||||
|
||||
checkEvaluation(Literal.create(Long.MinValue), Long.MinValue)
|
||||
checkEvaluation(Literal.create(Long.MaxValue), Long.MaxValue)
|
||||
}
|
||||
|
||||
test("double literals") {
|
||||
List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d =>
|
||||
checkEvaluation(Literal(d), d)
|
||||
checkEvaluation(Literal(d.toFloat), d.toFloat)
|
||||
|
||||
checkEvaluation(Literal.create(d), d)
|
||||
checkEvaluation(Literal.create(d.toFloat), d.toFloat)
|
||||
}
|
||||
checkEvaluation(Literal(Double.MinValue), Double.MinValue)
|
||||
checkEvaluation(Literal(Double.MaxValue), Double.MaxValue)
|
||||
checkEvaluation(Literal(Float.MinValue), Float.MinValue)
|
||||
checkEvaluation(Literal(Float.MaxValue), Float.MaxValue)
|
||||
|
||||
checkEvaluation(Literal.create(Double.MinValue), Double.MinValue)
|
||||
checkEvaluation(Literal.create(Double.MaxValue), Double.MaxValue)
|
||||
checkEvaluation(Literal.create(Float.MinValue), Float.MinValue)
|
||||
checkEvaluation(Literal.create(Float.MaxValue), Float.MaxValue)
|
||||
|
||||
}
|
||||
|
||||
test("string literals") {
|
||||
checkEvaluation(Literal(""), "")
|
||||
checkEvaluation(Literal("test"), "test")
|
||||
checkEvaluation(Literal("\u0000"), "\u0000")
|
||||
|
||||
checkEvaluation(Literal.create(""), "")
|
||||
checkEvaluation(Literal.create("test"), "test")
|
||||
checkEvaluation(Literal.create("\u0000"), "\u0000")
|
||||
}
|
||||
|
||||
test("sum two literals") {
|
||||
checkEvaluation(Add(Literal(1), Literal(1)), 2)
|
||||
checkEvaluation(Add(Literal.create(1), Literal.create(1)), 2)
|
||||
}
|
||||
|
||||
test("binary literals") {
|
||||
checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0))
|
||||
checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2))
|
||||
|
||||
checkEvaluation(Literal.create(new Array[Byte](0)), new Array[Byte](0))
|
||||
checkEvaluation(Literal.create(new Array[Byte](2)), new Array[Byte](2))
|
||||
}
|
||||
|
||||
test("decimal") {
|
||||
|
@ -124,24 +153,63 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
|
|||
Decimal((d * 1000L).toLong, 10, 3))
|
||||
checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d))
|
||||
checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d))
|
||||
|
||||
checkEvaluation(Literal.create(Decimal(d)), Decimal(d))
|
||||
checkEvaluation(Literal.create(Decimal(d.toInt)), Decimal(d.toInt))
|
||||
checkEvaluation(Literal.create(Decimal(d.toLong)), Decimal(d.toLong))
|
||||
checkEvaluation(Literal.create(Decimal((d * 1000L).toLong, 10, 3)),
|
||||
Decimal((d * 1000L).toLong, 10, 3))
|
||||
checkEvaluation(Literal.create(BigDecimal(d.toString)), Decimal(d))
|
||||
checkEvaluation(Literal.create(new java.math.BigDecimal(d.toString)), Decimal(d))
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private def toCatalyst[T: TypeTag](value: T): Any = {
|
||||
val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T]
|
||||
CatalystTypeConverters.createToCatalystConverter(dataType)(value)
|
||||
}
|
||||
|
||||
test("array") {
|
||||
def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = {
|
||||
val toCatalyst = (a: Array[_], elementType: DataType) => {
|
||||
CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a)
|
||||
}
|
||||
checkEvaluation(Literal(a), toCatalyst(a, elementType))
|
||||
def checkArrayLiteral[T: TypeTag](a: Array[T]): Unit = {
|
||||
checkEvaluation(Literal(a), toCatalyst(a))
|
||||
checkEvaluation(Literal.create(a), toCatalyst(a))
|
||||
}
|
||||
checkArrayLiteral(Array(1, 2, 3), IntegerType)
|
||||
checkArrayLiteral(Array("a", "b", "c"), StringType)
|
||||
checkArrayLiteral(Array(1.0, 4.0), DoubleType)
|
||||
checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
|
||||
checkArrayLiteral(Array(1, 2, 3))
|
||||
checkArrayLiteral(Array("a", "b", "c"))
|
||||
checkArrayLiteral(Array(1.0, 4.0))
|
||||
checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR))
|
||||
}
|
||||
|
||||
test("seq") {
|
||||
def checkSeqLiteral[T: TypeTag](a: Seq[T], elementType: DataType): Unit = {
|
||||
checkEvaluation(Literal.create(a), toCatalyst(a))
|
||||
}
|
||||
checkSeqLiteral(Seq(1, 2, 3), IntegerType)
|
||||
checkSeqLiteral(Seq("a", "b", "c"), StringType)
|
||||
checkSeqLiteral(Seq(1.0, 4.0), DoubleType)
|
||||
checkSeqLiteral(Seq(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR),
|
||||
CalendarIntervalType)
|
||||
}
|
||||
|
||||
test("unsupported types (map and struct) in literals") {
|
||||
test("map") {
|
||||
def checkMapLiteral[T: TypeTag](m: T): Unit = {
|
||||
checkEvaluation(Literal.create(m), toCatalyst(m))
|
||||
}
|
||||
checkMapLiteral(Map("a" -> 1, "b" -> 2, "c" -> 3))
|
||||
checkMapLiteral(Map("1" -> 1.0, "2" -> 2.0, "3" -> 3.0))
|
||||
}
|
||||
|
||||
test("struct") {
|
||||
def checkStructLiteral[T: TypeTag](s: T): Unit = {
|
||||
checkEvaluation(Literal.create(s), toCatalyst(s))
|
||||
}
|
||||
checkStructLiteral((1, 3.0, "abcde"))
|
||||
checkStructLiteral(("de", 1, 2.0f))
|
||||
checkStructLiteral((1, ("fgh", 3.0)))
|
||||
}
|
||||
|
||||
test("unsupported types (map and struct) in Literal.apply") {
|
||||
def checkUnsupportedTypeInLiteral(v: Any): Unit = {
|
||||
val errMsgMap = intercept[RuntimeException] {
|
||||
Literal(v)
|
||||
|
|
|
@ -91,15 +91,24 @@ object functions {
|
|||
* @group normal_funcs
|
||||
* @since 1.3.0
|
||||
*/
|
||||
def lit(literal: Any): Column = {
|
||||
literal match {
|
||||
case c: Column => return c
|
||||
case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name)
|
||||
case _ => // continue
|
||||
}
|
||||
def lit(literal: Any): Column = typedLit(literal)
|
||||
|
||||
val literalExpr = Literal(literal)
|
||||
Column(literalExpr)
|
||||
/**
|
||||
* Creates a [[Column]] of literal value.
|
||||
*
|
||||
* The passed in object is returned directly if it is already a [[Column]].
|
||||
* If the object is a Scala Symbol, it is converted into a [[Column]] also.
|
||||
* Otherwise, a new [[Column]] is created to represent the literal value.
|
||||
* The difference between this function and [[lit]] is that this function
|
||||
* can handle parameterized scala types e.g.: List, Seq and Map.
|
||||
*
|
||||
* @group normal_funcs
|
||||
* @since 2.2.0
|
||||
*/
|
||||
def typedLit[T : TypeTag](literal: T): Column = literal match {
|
||||
case c: Column => c
|
||||
case s: Symbol => new ColumnName(s.name)
|
||||
case _ => Column(Literal.create(literal))
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
|
@ -712,4 +712,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
|
|||
testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)),
|
||||
testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39)))
|
||||
}
|
||||
|
||||
test("typedLit") {
|
||||
val df = Seq(Tuple1(0)).toDF("a")
|
||||
// Only check the types `lit` cannot handle
|
||||
checkAnswer(
|
||||
df.select(typedLit(Seq(1, 2, 3))),
|
||||
Row(Seq(1, 2, 3)) :: Nil)
|
||||
checkAnswer(
|
||||
df.select(typedLit(Map("a" -> 1, "b" -> 2))),
|
||||
Row(Map("a" -> 1, "b" -> 2)) :: Nil)
|
||||
checkAnswer(
|
||||
df.select(typedLit(("a", 2, 1.0))),
|
||||
Row(Row("a", 2, 1.0)) :: Nil)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue