[SPARK-11797][SQL] collect, first, and take should use encoders for serialization

They were previously using Spark's default serializer for serialization.

Author: Reynold Xin <rxin@databricks.com>

Closes #9787 from rxin/SPARK-11797.
This commit is contained in:
Reynold Xin 2015-11-17 21:40:58 -08:00
parent 98be8169f0
commit 91f4b6f2db
2 changed files with 41 additions and 6 deletions

View file

@ -22,6 +22,7 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
@ -199,7 +200,6 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def mapPartitions[U : Encoder](func: Iterator[T] => Iterator[U]): Dataset[U] = {
encoderFor[T].assertUnresolved()
new Dataset[U](
sqlContext,
MapPartitions[T, U](
@ -519,7 +519,7 @@ class Dataset[T] private[sql](
* Returns the first element in this [[Dataset]].
* @since 1.6.0
*/
def first(): T = rdd.first()
def first(): T = take(1).head
/**
* Returns an array that contains all the elements in this [[Dataset]].
@ -530,7 +530,14 @@ class Dataset[T] private[sql](
* For Java API, use [[collectAsList]].
* @since 1.6.0
*/
def collect(): Array[T] = rdd.collect()
def collect(): Array[T] = {
// This is different from Dataset.rdd in that it collects Rows, and then runs the encoders
// to convert the rows into objects of type T.
val tEnc = resolvedTEncoder
val input = queryExecution.analyzed.output
val bound = tEnc.bind(input)
queryExecution.toRdd.map(_.copy()).collect().map(bound.fromRow)
}
/**
* Returns an array that contains all the elements in this [[Dataset]].
@ -541,7 +548,7 @@ class Dataset[T] private[sql](
* For Java API, use [[collectAsList]].
* @since 1.6.0
*/
def collectAsList(): java.util.List[T] = rdd.collect().toSeq.asJava
def collectAsList(): java.util.List[T] = collect().toSeq.asJava
/**
* Returns the first `num` elements of this [[Dataset]] as an array.
@ -551,7 +558,7 @@ class Dataset[T] private[sql](
*
* @since 1.6.0
*/
def take(num: Int): Array[T] = rdd.take(num)
def take(num: Int): Array[T] = withPlan(Limit(Literal(num), _)).collect()
/**
* Returns the first `num` elements of this [[Dataset]] as an array.

View file

@ -17,6 +17,8 @@
package org.apache.spark.sql
import java.io.{ObjectInput, ObjectOutput, Externalizable}
import scala.language.postfixOps
import org.apache.spark.sql.functions._
@ -24,6 +26,20 @@ import org.apache.spark.sql.test.SharedSQLContext
case class ClassData(a: String, b: Int)
/**
* A class used to test serialization using encoders. This class throws exceptions when using
* Java serialization -- so the only way it can be "serialized" is through our encoders.
*/
case class NonSerializableCaseClass(value: String) extends Externalizable {
override def readExternal(in: ObjectInput): Unit = {
throw new UnsupportedOperationException
}
override def writeExternal(out: ObjectOutput): Unit = {
throw new UnsupportedOperationException
}
}
class DatasetSuite extends QueryTest with SharedSQLContext {
import testImplicits._
@ -41,6 +57,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1, 1, 1)
}
test("collect, first, and take should use encoders for serialization") {
val item = NonSerializableCaseClass("abcd")
val ds = Seq(item).toDS()
assert(ds.collect().head == item)
assert(ds.collectAsList().get(0) == item)
assert(ds.first() == item)
assert(ds.take(1).head == item)
assert(ds.takeAsList(1).get(0) == item)
}
test("as tuple") {
val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
checkAnswer(
@ -75,6 +101,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
ignore("Dataset should set the resolved encoders internally for maps") {
// TODO: Enable this once we fix SPARK-11793.
// We inject a group by here to make sure this test case is future proof
// when we implement better pipelining and local execution mode.
val ds: Dataset[(ClassData, Long)] = Seq(ClassData("one", 1), ClassData("two", 2)).toDS()
.map(c => ClassData(c.a, c.b + 1))
.groupBy(p => p).count()
@ -219,7 +247,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
("a", 30), ("b", 3), ("c", 1))
}
test("groupBy function, fatMap") {
test("groupBy function, flatMap") {
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy(v => (v._1, "word"))
val agged = grouped.flatMap { case (g, iter) => Iterator(g._1, iter.map(_._2).sum.toString) }