[SPARK-16853][SQL] fixes encoder error in DataSet typed select
## What changes were proposed in this pull request? For DataSet typed select: ``` def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] ``` If type T is a case class or a tuple class that is not atomic, the resulting logical plan's schema will mismatch with `Dataset[T]` encoder's schema, which will cause encoder error and throw AnalysisException. ### Before change: ``` scala> case class A(a: Int, b: Int) scala> Seq((0, A(1,2))).toDS.select($"_2".as[A]) org.apache.spark.sql.AnalysisException: cannot resolve '`a`' given input columns: [_2]; .. ``` ### After change: ``` scala> case class A(a: Int, b: Int) scala> Seq((0, A(1,2))).toDS.select($"_2".as[A]).show +---+---+ | a| b| +---+---+ | 1| 2| +---+---+ ``` ## How was this patch tested? Unit test. Author: Sean Zhong <seanzhong@databricks.com> Closes #14474 from clockfly/SPARK-16853.
This commit is contained in:
parent
43f4fd6f9b
commit
9d7a47406e
|
@ -38,7 +38,9 @@ object MimaExcludes {
|
|||
lazy val v21excludes = v20excludes ++ {
|
||||
Seq(
|
||||
// [SPARK-16199][SQL] Add a method to list the referenced columns in data source Filter
|
||||
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references")
|
||||
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.Filter.references"),
|
||||
// [SPARK-16853][SQL] Fixes encoder error in DataSet typed select
|
||||
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.Dataset.select")
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -169,6 +169,10 @@ object ExpressionEncoder {
|
|||
ClassTag(cls))
|
||||
}
|
||||
|
||||
// Tuple1
|
||||
def tuple[T](e: ExpressionEncoder[T]): ExpressionEncoder[Tuple1[T]] =
|
||||
tuple(Seq(e)).asInstanceOf[ExpressionEncoder[Tuple1[T]]]
|
||||
|
||||
def tuple[T1, T2](
|
||||
e1: ExpressionEncoder[T1],
|
||||
e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
|
||||
|
|
|
@ -1061,15 +1061,17 @@ class Dataset[T] private[sql](
|
|||
* @since 1.6.0
|
||||
*/
|
||||
@Experimental
|
||||
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
|
||||
new Dataset[U1](
|
||||
sparkSession,
|
||||
Project(
|
||||
c1.withInputType(
|
||||
exprEnc.deserializer,
|
||||
logicalPlan.output).named :: Nil,
|
||||
logicalPlan),
|
||||
implicitly[Encoder[U1]])
|
||||
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
|
||||
implicit val encoder = c1.encoder
|
||||
val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
|
||||
logicalPlan)
|
||||
|
||||
if (encoder.flat) {
|
||||
new Dataset[U1](sparkSession, project, encoder)
|
||||
} else {
|
||||
// Flattens inner fields of U1
|
||||
new Dataset[Tuple1[U1]](sparkSession, project, ExpressionEncoder.tuple(encoder)).map(_._1)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
@ -184,6 +184,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
2, 3, 4)
|
||||
}
|
||||
|
||||
test("SPARK-16853: select, case class and tuple") {
|
||||
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
|
||||
checkDataset(
|
||||
ds.select(expr("struct(_2, _2)").as[(Int, Int)]): Dataset[(Int, Int)],
|
||||
(1, 1), (2, 2), (3, 3))
|
||||
|
||||
checkDataset(
|
||||
ds.select(expr("named_struct('a', _1, 'b', _2)").as[ClassData]): Dataset[ClassData],
|
||||
ClassData("a", 1), ClassData("b", 2), ClassData("c", 3))
|
||||
}
|
||||
|
||||
test("select 2") {
|
||||
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
|
||||
checkDataset(
|
||||
|
|
Loading…
Reference in a new issue