[SPARK-11797][SQL] collect, first, and take should use encoders for serialization
They were previously using Spark's default serializer for serialization. Author: Reynold Xin <rxin@databricks.com> Closes #9787 from rxin/SPARK-11797.
This commit is contained in:
parent
98be8169f0
commit
91f4b6f2db
|
@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
|
|||
import org.apache.spark.annotation.Experimental
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.api.java.function._
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
|
||||
import org.apache.spark.sql.catalyst.encoders._
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
|
@ -199,7 +200,6 @@ class Dataset[T] private[sql](
|
|||
* @since 1.6.0
|
||||
*/
|
||||
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
|
||||
encoderFor[T].assertUnresolved()
|
||||
new Dataset[U](
|
||||
sqlContext,
|
||||
MapPartitions[T, U](
|
||||
|
@ -519,7 +519,7 @@ class Dataset[T] private[sql](
|
|||
* Returns the first element in this [[Dataset]].
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def first(): T = rdd.first()
|
||||
def first(): T = take(1).head
|
||||
|
||||
/**
|
||||
* Returns an array that contains all the elements in this [[Dataset]].
|
||||
|
@ -530,7 +530,14 @@ class Dataset[T] private[sql](
|
|||
* For Java API, use [[collectAsList]].
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def collect(): Array[T] = rdd.collect()
|
||||
def collect(): Array[T] = {
|
||||
// This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
|
||||
// to convert the rows into objects of type T.
|
||||
val tEnc = resolvedTEncoder
|
||||
val input = queryExecution.analyzed.output
|
||||
val bound = tEnc.bind(input)
|
||||
queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns an array that contains all the elements in this [[Dataset]].
|
||||
|
@ -541,7 +548,7 @@ class Dataset[T] private[sql](
|
|||
* For Java API, use [[collectAsList]].
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava
|
||||
def collectAsList(): java.util.List[T] = collect().toSeq.asJava
|
||||
|
||||
/**
|
||||
* Returns the first `num` elements of this [[Dataset]] as an array.
|
||||
|
@ -551,7 +558,7 @@ class Dataset[T] private[sql](
|
|||
*
|
||||
* @since 1.6.0
|
||||
*/
|
||||
def take(num: Int): Array[T] = rdd.take(num)
|
||||
def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
|
||||
|
||||
/**
|
||||
* Returns the first `num` elements of this [[Dataset]] as an array.
|
||||
|
|
|
@ -17,6 +17,8 @@
|
|||
|
||||
package org.apache.spark.sql
|
||||
|
||||
import java.io.{ObjectInput, ObjectOutput, Externalizable}
|
||||
|
||||
import scala.language.postfixOps
|
||||
|
||||
import org.apache.spark.sql.functions._
|
||||
|
@ -24,6 +26,20 @@ import org.apache.spark.sql.test.SharedSQLContext
|
|||
|
||||
case class ClassData(a: String, b: Int)
|
||||
|
||||
/**
|
||||
* A class used to test serialization using encoders. This class throws exceptions when using
|
||||
* Java serialization -- so the only way it can be "serialized" is through our encoders.
|
||||
*/
|
||||
case class NonSerializableCaseClass(value: String) extends Externalizable {
|
||||
override def readExternal(in: ObjectInput): Unit = {
|
||||
throw new UnsupportedOperationException
|
||||
}
|
||||
|
||||
override def writeExternal(out: ObjectOutput): Unit = {
|
||||
throw new UnsupportedOperationException
|
||||
}
|
||||
}
|
||||
|
||||
class DatasetSuite extends QueryTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
|
@ -41,6 +57,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
1, 1, 1)
|
||||
}
|
||||
|
||||
test("collect, first, and take should use encoders for serialization") {
|
||||
val item = NonSerializableCaseClass("abcd")
|
||||
val ds = Seq(item).toDS()
|
||||
assert(ds.collect().head == item)
|
||||
assert(ds.collectAsList().get(0) == item)
|
||||
assert(ds.first() == item)
|
||||
assert(ds.take(1).head == item)
|
||||
assert(ds.takeAsList(1).get(0) == item)
|
||||
}
|
||||
|
||||
test("as tuple") {
|
||||
val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
|
||||
checkAnswer(
|
||||
|
@ -75,6 +101,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
|
||||
ignore("Dataset should set the resolved encoders internally for maps") {
|
||||
// TODO: Enable this once we fix SPARK-11793.
|
||||
// We inject a group by here to make sure this test case is future proof
|
||||
// when we implement better pipelining and local execution mode.
|
||||
val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS()
|
||||
.map(c => ClassData(c.a, c.b + 1))
|
||||
.groupBy(p => p).count()
|
||||
|
@ -219,7 +247,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
|
|||
("a", 30), ("b", 3), ("c", 1))
|
||||
}
|
||||
|
||||
test("groupBy function, fatMap") {
|
||||
test("groupBy function, flatMap") {
|
||||
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
|
||||
val grouped = ds.groupBy(v => (v._1, "word"))
|
||||
val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) }
|
||||
|
|
Loading…
Reference in a new issue