[SPARK-23862][SQL] Support Java enums from Scala Dataset API

### What changes were proposed in this pull request?
Add support for Java Enums (`java.lang.Enum`) from the Scala typed Dataset APIs. This involves adding an implicit for `Encoder` creation in `SQLImplicits`, and updating `ScalaReflection` to handle Java Enums on the serialization and deserialization pathways.

Enums are mapped to a `StringType` which is just the name of the Enum value.

### Why are the changes needed?
In [SPARK-21255](https://issues.apache.org/jira/browse/SPARK-21255), support for (de)serialization of Java Enums was added, but only when called from Java code. It is common for Scala code to rely on Java libraries that are out of control of the Scala developer. Today, if there is a dependency on some Java code which defines an Enum, it would be necessary to define a corresponding Scala class. This change brings closer feature parity between Scala and Java APIs.

### Does this PR introduce _any_ user-facing change?
Yes, previously something like:
```
val ds = Seq(MyJavaEnum.VALUE1, MyJavaEnum.VALUE2).toDS
// or
val ds = Seq(CaseClass(MyJavaEnum.VALUE1), CaseClass(MyJavaEnum.VALUE2)).toDS
```
would fail. Now, it will succeed.

### How was this patch tested?
Additional unit tests are added in `DatasetSuite`. Tests include validating top-level enums, enums inside of case classes, enums inside of arrays, and validating that the Enum is stored as the expected string.

Closes #30877 from xkrogen/xkrogen-SPARK-23862-scalareflection-java-enums.

Lead-authored-by: Erik Krogen <xkrogen@apache.org>
Co-authored-by: Fangshi Li <fli@linkedin.com>
Signed-off-by: Dongjoon Hyun <dhyun@apple.com>
This commit is contained in:
Erik Krogen 2020-12-22 09:55:33 -08:00 committed by Dongjoon Hyun
parent 1d450250eb
commit 303b8c8773
4 changed files with 47 additions and 0 deletions

View file

@ -232,6 +232,11 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.time.Instant]) =>
createDeserializerForInstant(path)
case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
createDeserializerForTypesSupportValueOf(
Invoke(path, "toString", ObjectType(classOf[String]), returnNullable = false),
getClassFromType(t))
case t if isSubtype(t, localTypeOf[java.sql.Timestamp]) =>
createDeserializerForSqlTimestamp(path)
@ -526,6 +531,9 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.math.BigInteger]) =>
createSerializerForJavaBigInteger(inputObject)
case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) =>
createSerializerForJavaEnum(inputObject)
case t if isSubtype(t, localTypeOf[scala.math.BigInt]) =>
createSerializerForScalaBigInt(inputObject)
@ -749,6 +757,7 @@ object ScalaReflection extends ScalaReflection {
case t if isSubtype(t, localTypeOf[java.lang.Short]) => Schema(ShortType, nullable = true)
case t if isSubtype(t, localTypeOf[java.lang.Byte]) => Schema(ByteType, nullable = true)
case t if isSubtype(t, localTypeOf[java.lang.Boolean]) => Schema(BooleanType, nullable = true)
case t if isSubtype(t, localTypeOf[java.lang.Enum[_]]) => Schema(StringType, nullable = true)
case t if isSubtype(t, definitions.IntTpe) => Schema(IntegerType, nullable = false)
case t if isSubtype(t, definitions.LongTpe) => Schema(LongType, nullable = false)
case t if isSubtype(t, definitions.DoubleTpe) => Schema(DoubleType, nullable = false)

View file

@ -74,6 +74,9 @@ object SerializerBuildHelper {
returnNullable = false)
}
def createSerializerForJavaEnum(inputObject: Expression): Expression =
createSerializerForString(Invoke(inputObject, "name", ObjectType(classOf[String])))
def createSerializerForSqlTimestamp(inputObject: Expression): Expression = {
StaticInvoke(
DateTimeUtils.getClass,

View file

@ -88,6 +88,10 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits {
/** @since 3.0.0 */
implicit def newInstantEncoder: Encoder[java.time.Instant] = Encoders.INSTANT
/** @since 3.2.0 */
implicit def newJavaEnumEncoder[A <: java.lang.Enum[_] : TypeTag]: Encoder[A] =
ExpressionEncoder()
// Boxed primitives
/** @since 2.0.0 */

View file

@ -1693,6 +1693,33 @@ class DatasetSuite extends QueryTest
checkDataset(ds1.select("_2._2"), ds2.select("_2._2").collect(): _*)
}
test("SPARK-23862: Spark ExpressionEncoder should support Java Enum type from Scala") {
val saveModeSeq =
Seq(SaveMode.Append, SaveMode.Overwrite, SaveMode.ErrorIfExists, SaveMode.Ignore, null)
assert(saveModeSeq.toDS().collect().toSeq === saveModeSeq)
assert(saveModeSeq.toDS().schema === new StructType().add("value", StringType, nullable = true))
val saveModeCaseSeq = saveModeSeq.map(SaveModeCase.apply)
assert(saveModeCaseSeq.toDS().collect().toSet === saveModeCaseSeq.toSet)
assert(saveModeCaseSeq.toDS().schema ===
new StructType().add("mode", StringType, nullable = true))
val saveModeArrayCaseSeq =
Seq(SaveModeArrayCase(Array()), SaveModeArrayCase(saveModeSeq.toArray))
val collected = saveModeArrayCaseSeq.toDS().collect()
assert(collected.length === 2)
val sortedByLength = collected.sortBy(_.modes.length)
assert(sortedByLength(0).modes === Array())
assert(sortedByLength(1).modes === saveModeSeq.toArray)
assert(saveModeArrayCaseSeq.toDS().schema ===
new StructType().add("modes", ArrayType(StringType, containsNull = true), nullable = true))
// Enum is stored as string, so it is possible to convert to/from string
val stringSeq = saveModeSeq.map(Option.apply).map(_.map(_.toString).orNull)
assert(stringSeq.toDS().as[SaveMode].collect().toSet === saveModeSeq.toSet)
assert(saveModeSeq.toDS().as[String].collect().toSet === stringSeq.toSet)
}
test("SPARK-24571: filtering of string values by char literal") {
val df = Seq("Amsterdam", "San Francisco", "X").toDF("city")
checkAnswer(df.where($"city" === 'X'), Seq(Row("X")))
@ -2053,3 +2080,7 @@ case class CircularReferenceClassD(map: Map[String, CircularReferenceClassE])
case class CircularReferenceClassE(id: String, list: List[CircularReferenceClassD])
case class SpecialCharClass(`field.1`: String, `field 2`: String)
/** Used to test Java Enums from Scala code */
case class SaveModeCase(mode: SaveMode)
case class SaveModeArrayCase(modes: Array[SaveMode])