From 9d7a47406ed538f0005cdc7a62bc6e6f20634815 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Thu, 4 Aug 2016 19:45:47 +0800 Subject: [PATCH] [SPARK-16853][SQL] fixes encoder error in DataSet typed select ## What changes were proposed in this pull request? For DataSet typed select: ``` def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] ``` If type T is a case class or a tuple class that is not atomic, the resulting logical plan's schema will mismatch with `Dataset[T]` encoder's schema, which will cause encoder error and throw AnalysisException. ### Before change: ``` scala> case class A(a: Int, b: Int) scala> Seq((0, A(1,2))).toDS.select($"_2".as[A]) org.apache.spark.sql.AnalysisException: cannot resolve '`a`' given input columns: [_2]; .. ``` ### After change: ``` scala> case class A(a: Int, b: Int) scala> Seq((0, A(1,2))).toDS.select($"_2".as[A]).show +---+---+ | a| b| +---+---+ | 1| 2| +---+---+ ``` ## How was this patch tested? Unit test. Author: Sean Zhong Closes #14474 from clockfly/SPARK-16853. --- project/MimaExcludes.scala | 4 +++- .../catalyst/encoders/ExpressionEncoder.scala | 4 ++++ .../scala/org/apache/spark/sql/Dataset.scala | 20 ++++++++++--------- .../org/apache/spark/sql/DatasetSuite.scala | 11 ++++++++++ 4 files changed, 29 insertions(+), 10 deletions(-) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 56061559fe..a201d7f838 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -38,7 +38,9 @@ object MimaExcludes { lazy val v21excludes = v20excludes ++ { Seq( // [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter - ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references") + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"), + // [SPARK-16853][SQL] Fixes encoder error in DataSet typed select + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select") ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 1fac26c438..b96b744b4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -169,6 +169,10 @@ object ExpressionEncoder { ClassTag(cls)) } + // Tuple1 + def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] = + tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]] + def tuple[T1, T2]( e1: ExpressionEncoder[T1], e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 8b6443c8b9..306ca773d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1061,15 +1061,17 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - new Dataset[U1]( - sparkSession, - Project( - c1.withInputType( - exprEnc.deserializer, - logicalPlan.output).named :: Nil, - logicalPlan), - implicitly[Encoder[U1]]) + def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { + implicit val encoder = c1.encoder + val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil, + logicalPlan) + + if (encoder.flat) { + new Dataset[U1](sparkSession, project, encoder) + } else { + // Flattens inner fields of U1 + new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1) + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7e3b7b63d8..8a756fd474 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -184,6 +184,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 2, 3, 4) } + test("SPARK-16853: select, case class and tuple") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( + ds.select(expr("struct(_2, _2)").as[(Int, Int)]): Dataset[(Int, Int)], + (1, 1), (2, 2), (3, 3)) + + checkDataset( + ds.select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]): Dataset[ClassData], + ClassData("a", 1), ClassData("b", 2), ClassData("c", 3)) + } + test("select 2") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() checkDataset(