[SPARK-1845] [SQL] Use AllScalaRegistrar for SparkSqlSerializer to register serializers of ...

...Scala collections.

When I execute `orderBy` or `limit` for `SchemaRDD` including `ArrayType` or `MapType`, `SparkSqlSerializer` throws the following exception:

```
com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.$colon$colon
```

or

```
com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.Vector
```

or

```
com.esotericsoftware.kryo.KryoException: Class cannot be created (missing no-arg constructor): scala.collection.immutable.HashMap$HashTrieMap
```

and so on.

This is because registrations of serializers for each concrete collections are missing in `SparkSqlSerializer`.
I believe it should use `AllScalaRegistrar`.
`AllScalaRegistrar` covers a lot of serializers for concrete classes of `Seq`, `Map` for `ArrayType`, `MapType`.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #790 from ueshin/issues/SPARK-1845 and squashes the following commits:

d1ed992 [Takuya UESHIN] Use AllScalaRegistrar for SparkSqlSerializer to register serializers of Scala collections.
This commit is contained in:
Takuya UESHIN 2014-05-15 11:20:21 -07:00 committed by Reynold Xin
parent 3abe2b734a
commit db8cc6f28a
4 changed files with 66 additions and 26 deletions

View file

@ -24,6 +24,7 @@ import scala.reflect.ClassTag
import com.clearspring.analytics.stream.cardinality.HyperLogLog
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Serializer, Kryo}
import com.twitter.chill.AllScalaRegistrar
import org.apache.spark.{SparkEnv, SparkConf}
import org.apache.spark.serializer.KryoSerializer
@ -35,22 +36,14 @@ private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(co
val kryo = new Kryo()
kryo.setRegistrationRequired(false)
kryo.register(classOf[MutablePair[_, _]])
kryo.register(classOf[Array[Any]])
// This is kinda hacky...
kryo.register(classOf[scala.collection.immutable.Map$Map1], new MapSerializer)
kryo.register(classOf[scala.collection.immutable.Map$Map2], new MapSerializer)
kryo.register(classOf[scala.collection.immutable.Map$Map3], new MapSerializer)
kryo.register(classOf[scala.collection.immutable.Map$Map4], new MapSerializer)
kryo.register(classOf[scala.collection.immutable.Map[_,_]], new MapSerializer)
kryo.register(classOf[scala.collection.Map[_,_]], new MapSerializer)
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericRow])
kryo.register(classOf[org.apache.spark.sql.catalyst.expressions.GenericMutableRow])
kryo.register(classOf[com.clearspring.analytics.stream.cardinality.HyperLogLog],
new HyperLogLogSerializer)
kryo.register(classOf[scala.collection.mutable.ArrayBuffer[_]])
kryo.register(classOf[scala.math.BigDecimal], new BigDecimalSerializer)
kryo.setReferences(false)
kryo.setClassLoader(Utils.getSparkClassLoader)
new AllScalaRegistrar().apply(kryo)
kryo
}
}
@ -97,20 +90,3 @@ private[sql] class HyperLogLogSerializer extends Serializer[HyperLogLog] {
HyperLogLog.Builder.build(bytes)
}
}
/**
* Maps do not have a no arg constructor and so cannot be serialized by default. So, we serialize
* them as `Array[(k,v)]`.
*/
private[sql] class MapSerializer extends Serializer[Map[_,_]] {
def write(kryo: Kryo, output: Output, map: Map[_,_]) {
kryo.writeObject(output, map.flatMap(e => Seq(e._1, e._2)).toArray)
}
def read(kryo: Kryo, input: Input, tpe: Class[Map[_,_]]): Map[_,_] = {
kryo.readObject(input, classOf[Array[Any]])
.sliding(2,2)
.map { case Array(k,v) => (k,v) }
.toMap
}
}

View file

@ -69,12 +69,36 @@ class DslQuerySuite extends QueryTest {
checkAnswer(
testData2.orderBy('a.desc, 'b.asc),
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
checkAnswer(
arrayData.orderBy(GetItem('data, 0).asc),
arrayData.collect().sortBy(_.data(0)).toSeq)
checkAnswer(
arrayData.orderBy(GetItem('data, 0).desc),
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
checkAnswer(
mapData.orderBy(GetItem('data, 1).asc),
mapData.collect().sortBy(_.data(1)).toSeq)
checkAnswer(
mapData.orderBy(GetItem('data, 1).desc),
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
}
test("limit") {
checkAnswer(
testData.limit(10),
testData.take(10).toSeq)
checkAnswer(
arrayData.limit(1),
arrayData.take(1).toSeq)
checkAnswer(
mapData.limit(1),
mapData.take(1).toSeq)
}
test("average") {

View file

@ -85,6 +85,36 @@ class SQLQuerySuite extends QueryTest {
checkAnswer(
sql("SELECT * FROM testData2 ORDER BY a DESC, b ASC"),
Seq((3,1), (3,2), (2,1), (2,2), (1,1), (1,2)))
checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] ASC"),
arrayData.collect().sortBy(_.data(0)).toSeq)
checkAnswer(
sql("SELECT * FROM arrayData ORDER BY data[0] DESC"),
arrayData.collect().sortBy(_.data(0)).reverse.toSeq)
checkAnswer(
sql("SELECT * FROM mapData ORDER BY data[1] ASC"),
mapData.collect().sortBy(_.data(1)).toSeq)
checkAnswer(
sql("SELECT * FROM mapData ORDER BY data[1] DESC"),
mapData.collect().sortBy(_.data(1)).reverse.toSeq)
}
test("limit") {
checkAnswer(
sql("SELECT * FROM testData LIMIT 10"),
testData.take(10).toSeq)
checkAnswer(
sql("SELECT * FROM arrayData LIMIT 1"),
arrayData.collect().take(1).toSeq)
checkAnswer(
sql("SELECT * FROM mapData LIMIT 1"),
mapData.collect().take(1).toSeq)
}
test("average") {

View file

@ -74,6 +74,16 @@ object TestData {
ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
arrayData.registerAsTable("arrayData")
case class MapData(data: Map[Int, String])
val mapData =
TestSQLContext.sparkContext.parallelize(
MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
MapData(Map(1 -> "a4", 2 -> "b4")) ::
MapData(Map(1 -> "a5")) :: Nil)
mapData.registerAsTable("mapData")
case class StringData(s: String)
val repeatedData =
TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))