[SPARK-16853][SQL] fixes encoder error in DataSet typed select

## What changes were proposed in this pull request?

For DataSet typed select:
```
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1]
```
If type T is a case class or a tuple class that is not atomic, the resulting logical plan's schema will mismatch with `Dataset[T]` encoder's schema, which will cause encoder error and throw AnalysisException.

### Before change:
```
scala> case class A(a: Int, b: Int)
scala> Seq((0, A(1,2))).toDS.select($"_2".as[A])
org.apache.spark.sql.AnalysisException: cannot resolve '`a`' given input columns: [_2];
..
```

### After change:
```
scala> case class A(a: Int, b: Int)
scala> Seq((0, A(1,2))).toDS.select($"_2".as[A]).show
+---+---+
|  a|  b|
+---+---+
|  1|  2|
+---+---+
```

## How was this patch tested?

Unit test.

Author: Sean Zhong <seanzhong@databricks.com>

Closes #14474 from clockfly/SPARK-16853.
This commit is contained in:
Sean Zhong 2016-08-04 19:45:47 +08:00 committed by Wenchen Fan
parent 43f4fd6f9b
commit 9d7a47406e
4 changed files with 29 additions and 10 deletions

View file

@ -38,7 +38,9 @@ object MimaExcludes {
lazy val v21excludes = v20excludes ++ {
Seq(
// [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"),
// [SPARK-16853][SQL] Fixes encoder error in DataSet typed select
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select")
)
}

View file

@ -169,6 +169,10 @@ object ExpressionEncoder {
ClassTag(cls))
}
// Tuple1
def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] =
tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]]
def tuple[T1, T2](
e1: ExpressionEncoder[T1],
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =

View file

@ -1061,15 +1061,17 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
new Dataset[U1](
sparkSession,
Project(
c1.withInputType(
exprEnc.deserializer,
logicalPlan.output).named :: Nil,
logicalPlan),
implicitly[Encoder[U1]])
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
implicit val encoder = c1.encoder
val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
logicalPlan)
if (encoder.flat) {
new Dataset[U1](sparkSession, project, encoder)
} else {
// Flattens inner fields of U1
new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1)
}
}
/**

View file

@ -184,6 +184,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
2, 3, 4)
}
test("SPARK-16853: select, case class and tuple") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(
ds.select(expr("struct(_2, _2)").as[(Int, Int)]): Dataset[(Int, Int)],
(1, 1), (2, 2), (3, 3))
checkDataset(
ds.select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]): Dataset[ClassData],
ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))
}
test("select 2") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
checkDataset(