[SPARK-32585][SQL] Support scala enumeration in ScalaReflection
### What changes were proposed in this pull request? Add code in `ScalaReflection` to support scala enumeration and make enumeration type as string type in Spark. ### Why are the changes needed? We support java enum but failed with scala enum, it's better to keep the same behavior. Here is a example. ``` package test object TestEnum extends Enumeration { type TestEnum = Value val E1, E2, E3 = Value } import TestEnum._ case class TestClass(i: Int, e: TestEnum) { } import test._ Seq(TestClass(1, TestEnum.E1)).toDS ``` Before this PR ``` Exception in thread "main" java.lang.UnsupportedOperationException: No Encoder found for test.TestEnum.TestEnum - field (class: "scala.Enumeration.Value", name: "e") - root class: "test.TestClass" at org.apache.spark.sql.catalyst.ScalaReflection$.$anonfun$serializerFor$1(ScalaReflection.scala:567) at scala.reflect.internal.tpe.TypeConstraints$UndoLog.undo(TypeConstraints.scala:69) at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects(ScalaReflection.scala:882) at org.apache.spark.sql.catalyst.ScalaReflection.cleanUpReflectionObjects$(ScalaReflection.scala:881) ``` After this PR `org.apache.spark.sql.Dataset[test.TestClass] = [i: int, e: string]` ### Does this PR introduce _any_ user-facing change? Yes, user can make case class which include scala enumeration field as dataset. ### How was this patch tested? Add test. Closes #29403 from ulysses-you/SPARK-32585. Authored-by: ulysses <youxiduo@weidian.com> Signed-off-by: Tathagata Das <tathagata.das1565@gmail.com>
This commit is contained in:
parent
9c618b3308
commit
e62d24717e
|
@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._
|
||||||
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
|
||||||
|
import org.apache.spark.util.Utils
|
||||||
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -377,6 +378,23 @@ object ScalaReflection extends ScalaReflection {
|
||||||
expressions.Literal.create(null, ObjectType(cls)),
|
expressions.Literal.create(null, ObjectType(cls)),
|
||||||
newInstance
|
newInstance
|
||||||
)
|
)
|
||||||
|
|
||||||
|
case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
|
||||||
|
// package example
|
||||||
|
// object Foo extends Enumeration {
|
||||||
|
// type Foo = Value
|
||||||
|
// val E1, E2 = Value
|
||||||
|
// }
|
||||||
|
// the fullName of tpe is example.Foo.Foo, but we need example.Foo so that
|
||||||
|
// we can call example.Foo.withName to deserialize string to enumeration.
|
||||||
|
val parent = t.asInstanceOf[TypeRef].pre.typeSymbol.asClass
|
||||||
|
val cls = mirror.runtimeClass(parent)
|
||||||
|
StaticInvoke(
|
||||||
|
cls,
|
||||||
|
ObjectType(getClassFromType(t)),
|
||||||
|
"withName",
|
||||||
|
createDeserializerForString(path, false) :: Nil,
|
||||||
|
returnNullable = false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -561,6 +579,14 @@ object ScalaReflection extends ScalaReflection {
|
||||||
}
|
}
|
||||||
createSerializerForObject(inputObject, fields)
|
createSerializerForObject(inputObject, fields)
|
||||||
|
|
||||||
|
case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
|
||||||
|
createSerializerForString(
|
||||||
|
Invoke(
|
||||||
|
inputObject,
|
||||||
|
"toString",
|
||||||
|
ObjectType(classOf[java.lang.String]),
|
||||||
|
returnNullable = false))
|
||||||
|
|
||||||
case _ =>
|
case _ =>
|
||||||
throw new UnsupportedOperationException(
|
throw new UnsupportedOperationException(
|
||||||
s"No Encoder found for $tpe\n" + walkedTypePath)
|
s"No Encoder found for $tpe\n" + walkedTypePath)
|
||||||
|
@ -738,6 +764,8 @@ object ScalaReflection extends ScalaReflection {
|
||||||
val Schema(dataType, nullable) = schemaFor(fieldType)
|
val Schema(dataType, nullable) = schemaFor(fieldType)
|
||||||
StructField(fieldName, dataType, nullable)
|
StructField(fieldName, dataType, nullable)
|
||||||
}), nullable = true)
|
}), nullable = true)
|
||||||
|
case t if isSubtype(t, localTypeOf[Enumeration#Value]) =>
|
||||||
|
Schema(StringType, nullable = true)
|
||||||
case other =>
|
case other =>
|
||||||
throw new UnsupportedOperationException(s"Schema for type $other is not supported")
|
throw new UnsupportedOperationException(s"Schema for type $other is not supported")
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
|
||||||
import scala.reflect.runtime.universe.TypeTag
|
import scala.reflect.runtime.universe.TypeTag
|
||||||
|
|
||||||
import org.apache.spark.SparkFunSuite
|
import org.apache.spark.SparkFunSuite
|
||||||
|
import org.apache.spark.sql.catalyst.FooEnum.FooEnum
|
||||||
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
|
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
|
||||||
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast}
|
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, Expression, If, SpecificInternalRow, UpCast}
|
||||||
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
|
import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, NewInstance}
|
||||||
|
@ -90,6 +91,13 @@ case class FooWithAnnotation(f1: String @FooAnnotation, f2: Option[String] @FooA
|
||||||
|
|
||||||
case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String)
|
case class SpecialCharAsFieldData(`field.1`: String, `field 2`: String)
|
||||||
|
|
||||||
|
object FooEnum extends Enumeration {
|
||||||
|
type FooEnum = Value
|
||||||
|
val E1, E2 = Value
|
||||||
|
}
|
||||||
|
|
||||||
|
case class FooClassWithEnum(i: Int, e: FooEnum)
|
||||||
|
|
||||||
object TestingUDT {
|
object TestingUDT {
|
||||||
@SQLUserDefinedType(udt = classOf[NestedStructUDT])
|
@SQLUserDefinedType(udt = classOf[NestedStructUDT])
|
||||||
class NestedStruct(val a: Integer, val b: Long, val c: Double)
|
class NestedStruct(val a: Integer, val b: Long, val c: Double)
|
||||||
|
@ -437,4 +445,11 @@ class ScalaReflectionSuite extends SparkFunSuite {
|
||||||
StructField("f2", StringType))))
|
StructField("f2", StringType))))
|
||||||
assert(deserializerFor[FooWithAnnotation].dataType == ObjectType(classOf[FooWithAnnotation]))
|
assert(deserializerFor[FooWithAnnotation].dataType == ObjectType(classOf[FooWithAnnotation]))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-32585: Support scala enumeration in ScalaReflection") {
|
||||||
|
assert(serializerFor[FooClassWithEnum].dataType == StructType(Seq(
|
||||||
|
StructField("i", IntegerType, false),
|
||||||
|
StructField("e", StringType, true))))
|
||||||
|
assert(deserializerFor[FooClassWithEnum].dataType == ObjectType(classOf[FooClassWithEnum]))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
|
||||||
import scala.reflect.runtime.universe.TypeTag
|
import scala.reflect.runtime.universe.TypeTag
|
||||||
|
|
||||||
import org.apache.spark.sql.{Encoder, Encoders}
|
import org.apache.spark.sql.{Encoder, Encoders}
|
||||||
import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
|
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, OptionalData, PrimitiveData}
|
||||||
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
|
import org.apache.spark.sql.catalyst.analysis.AnalysisTest
|
||||||
import org.apache.spark.sql.catalyst.dsl.plans._
|
import org.apache.spark.sql.catalyst.dsl.plans._
|
||||||
import org.apache.spark.sql.catalyst.expressions.AttributeReference
|
import org.apache.spark.sql.catalyst.expressions.AttributeReference
|
||||||
|
@ -389,6 +389,14 @@ class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with AnalysisTes
|
||||||
assert(e.getMessage.contains("tuple with more than 22 elements are not supported"))
|
assert(e.getMessage.contains("tuple with more than 22 elements are not supported"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encodeDecodeTest((1, FooEnum.E1), "Tuple with Int and scala Enum")
|
||||||
|
encodeDecodeTest((null, FooEnum.E1, FooEnum.E2), "Tuple with Null and scala Enum")
|
||||||
|
encodeDecodeTest(Seq(FooEnum.E1, null), "Seq with scala Enum")
|
||||||
|
encodeDecodeTest(Map("key" -> FooEnum.E1), "Map with String key and scala Enum")
|
||||||
|
encodeDecodeTest(Map(FooEnum.E1 -> "value"), "Map with scala Enum key and String value")
|
||||||
|
encodeDecodeTest(FooClassWithEnum(1, FooEnum.E1), "case class with Int and scala Enum")
|
||||||
|
encodeDecodeTest(FooEnum.E1, "scala Enum")
|
||||||
|
|
||||||
// Scala / Java big decimals ----------------------------------------------------------
|
// Scala / Java big decimals ----------------------------------------------------------
|
||||||
|
|
||||||
encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18),
|
encodeDecodeTest(BigDecimal(("9" * 20) + "." + "9" * 18),
|
||||||
|
|
|
@ -25,7 +25,7 @@ import org.scalatest.exceptions.TestFailedException
|
||||||
import org.scalatest.prop.TableDrivenPropertyChecks._
|
import org.scalatest.prop.TableDrivenPropertyChecks._
|
||||||
|
|
||||||
import org.apache.spark.{SparkException, TaskContext}
|
import org.apache.spark.{SparkException, TaskContext}
|
||||||
import org.apache.spark.sql.catalyst.ScroogeLikeExample
|
import org.apache.spark.sql.catalyst.{FooClassWithEnum, FooEnum, ScroogeLikeExample}
|
||||||
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
|
import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder}
|
||||||
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
|
import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi}
|
||||||
import org.apache.spark.sql.catalyst.util.sideBySide
|
import org.apache.spark.sql.catalyst.util.sideBySide
|
||||||
|
@ -1926,6 +1926,19 @@ class DatasetSuite extends QueryTest
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("SPARK-32585: Support scala enumeration in ScalaReflection") {
|
||||||
|
checkDataset(
|
||||||
|
Seq(FooClassWithEnum(1, FooEnum.E1), FooClassWithEnum(2, FooEnum.E2)).toDS(),
|
||||||
|
Seq(FooClassWithEnum(1, FooEnum.E1), FooClassWithEnum(2, FooEnum.E2)): _*
|
||||||
|
)
|
||||||
|
|
||||||
|
// test null
|
||||||
|
checkDataset(
|
||||||
|
Seq(FooClassWithEnum(1, null), FooClassWithEnum(2, FooEnum.E2)).toDS(),
|
||||||
|
Seq(FooClassWithEnum(1, null), FooClassWithEnum(2, FooEnum.E2)): _*
|
||||||
|
)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
object AssertExecutionId {
|
object AssertExecutionId {
|
||||||
|
|
Loading…
Reference in a new issue