[SPARK-6986] [SQL] Use Serializer2 in more cases.

With 0a2b15ce43, the serialization stream and deserialization stream has enough information to determine it is handling a key-value pari, a key, or a value. It is safe to use `SparkSqlSerializer2` in more cases.

Author: Yin Huai <yhuai@databricks.com>

Closes #5849 from yhuai/serializer2MoreCases and squashes the following commits:

53a5eaa [Yin Huai] Josh's comments.
487f540 [Yin Huai] Use BufferedOutputStream.
8385f95 [Yin Huai] Always create a new row at the deserialization side to work with sort merge join.
c7e2129 [Yin Huai] Update tests.
4513d13 [Yin Huai] Use Serializer2 in more places.

(cherry picked from commit 3af423c92f)
Signed-off-by: Yin Huai <yhuai@databricks.com>
This commit is contained in:
Yin Huai 2015-05-07 20:59:42 -07:00
parent 28d4238708
commit 9d0d28940f
3 changed files with 69 additions and 58 deletions

View file

@ -84,18 +84,8 @@ case class Exchange(
def serializer( def serializer(
keySchema: Array[DataType], keySchema: Array[DataType],
valueSchema: Array[DataType], valueSchema: Array[DataType],
hasKeyOrdering: Boolean,
numPartitions: Int): Serializer = { numPartitions: Int): Serializer = {
// In ExternalSorter's spillToMergeableFile function, key-value pairs are written out
// through write(key) and then write(value) instead of write((key, value)). Because
// SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use
// it when spillToMergeableFile in ExternalSorter will be used.
// So, we will not use SparkSqlSerializer2 when
// - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater
// then the bypassMergeThreshold; or
// - newOrdering is defined.
val cannotUseSqlSerializer2 =
(sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty
// It is true when there is no field that needs to be write out. // It is true when there is no field that needs to be write out.
// For now, we will not use SparkSqlSerializer2 when noField is true. // For now, we will not use SparkSqlSerializer2 when noField is true.
val noField = val noField =
@ -104,14 +94,13 @@ case class Exchange(
val useSqlSerializer2 = val useSqlSerializer2 =
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled. child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
!cannotUseSqlSerializer2 && // Safe to use Serializer2.
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported. SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported. SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
!noField !noField
val serializer = if (useSqlSerializer2) { val serializer = if (useSqlSerializer2) {
logInfo("Using SparkSqlSerializer2.") logInfo("Using SparkSqlSerializer2.")
new SparkSqlSerializer2(keySchema, valueSchema) new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering)
} else { } else {
logInfo("Using SparkSqlSerializer.") logInfo("Using SparkSqlSerializer.")
new SparkSqlSerializer(sparkConf) new SparkSqlSerializer(sparkConf)
@ -154,7 +143,8 @@ case class Exchange(
} }
val keySchema = expressions.map(_.dataType).toArray val keySchema = expressions.map(_.dataType).toArray
val valueSchema = child.output.map(_.dataType).toArray val valueSchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions)) shuffled.setSerializer(
serializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions))
shuffled.map(_._2) shuffled.map(_._2)
@ -179,7 +169,8 @@ case class Exchange(
new ShuffledRDD[Row, Null, Null](rdd, part) new ShuffledRDD[Row, Null, Null](rdd, part)
} }
val keySchema = child.output.map(_.dataType).toArray val keySchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(keySchema, null, numPartitions)) shuffled.setSerializer(
serializer(keySchema, null, newOrdering.nonEmpty, numPartitions))
shuffled.map(_._1) shuffled.map(_._1)
@ -199,7 +190,7 @@ case class Exchange(
val partitioner = new HashPartitioner(1) val partitioner = new HashPartitioner(1)
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner) val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
val valueSchema = child.output.map(_.dataType).toArray val valueSchema = child.output.map(_.dataType).toArray
shuffled.setSerializer(serializer(null, valueSchema, 1)) shuffled.setSerializer(serializer(null, valueSchema, false, 1))
shuffled.map(_._2) shuffled.map(_._2)
case _ => sys.error(s"Exchange not implemented for $newPartitioning") case _ => sys.error(s"Exchange not implemented for $newPartitioning")

View file

@ -27,7 +27,7 @@ import scala.reflect.ClassTag
import org.apache.spark.serializer._ import org.apache.spark.serializer._
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.sql.Row import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow}
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
/** /**
@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
out: OutputStream) out: OutputStream)
extends SerializationStream with Logging { extends SerializationStream with Logging {
val rowOut = new DataOutputStream(out) private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut) private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut) private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
override def writeObject[T: ClassTag](t: T): SerializationStream = { override def writeObject[T: ClassTag](t: T): SerializationStream = {
val kv = t.asInstanceOf[Product2[Row, Row]] val kv = t.asInstanceOf[Product2[Row, Row]]
@ -86,31 +86,44 @@ private[sql] class Serializer2SerializationStream(
private[sql] class Serializer2DeserializationStream( private[sql] class Serializer2DeserializationStream(
keySchema: Array[DataType], keySchema: Array[DataType],
valueSchema: Array[DataType], valueSchema: Array[DataType],
hasKeyOrdering: Boolean,
in: InputStream) in: InputStream)
extends DeserializationStream with Logging { extends DeserializationStream with Logging {
val rowIn = new DataInputStream(new BufferedInputStream(in)) private val rowIn = new DataInputStream(new BufferedInputStream(in))
val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null if (schema == null) {
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key) () => null
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value) } else {
if (hasKeyOrdering) {
// We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
() => new GenericMutableRow(schema.length)
} else {
// It is safe to reuse the mutable row.
val mutableRow = new SpecificMutableRow(schema)
() => mutableRow
}
}
}
// Functions used to return rows for key and value.
private val getKey = rowGenerator(keySchema)
private val getValue = rowGenerator(valueSchema)
// Functions used to read a serialized row from the InputStream and deserialize it.
private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
override def readObject[T: ClassTag](): T = { override def readObject[T: ClassTag](): T = {
readKeyFunc() (readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
readValueFunc()
(key, value).asInstanceOf[T]
} }
override def readKey[T: ClassTag](): T = { override def readKey[T: ClassTag](): T = {
readKeyFunc() readKeyFunc(getKey()).asInstanceOf[T]
key.asInstanceOf[T]
} }
override def readValue[T: ClassTag](): T = { override def readValue[T: ClassTag](): T = {
readValueFunc() readValueFunc(getValue()).asInstanceOf[T]
value.asInstanceOf[T]
} }
override def close(): Unit = { override def close(): Unit = {
@ -118,9 +131,10 @@ private[sql] class Serializer2DeserializationStream(
} }
} }
private[sql] class ShuffleSerializerInstance( private[sql] class SparkSqlSerializer2Instance(
keySchema: Array[DataType], keySchema: Array[DataType],
valueSchema: Array[DataType]) valueSchema: Array[DataType],
hasKeyOrdering: Boolean)
extends SerializerInstance { extends SerializerInstance {
def serialize[T: ClassTag](t: T): ByteBuffer = def serialize[T: ClassTag](t: T): ByteBuffer =
@ -137,7 +151,7 @@ private[sql] class ShuffleSerializerInstance(
} }
def deserializeStream(s: InputStream): DeserializationStream = { def deserializeStream(s: InputStream): DeserializationStream = {
new Serializer2DeserializationStream(keySchema, valueSchema, s) new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s)
} }
} }
@ -148,12 +162,16 @@ private[sql] class ShuffleSerializerInstance(
* The schema of keys is represented by `keySchema` and that of values is represented by * The schema of keys is represented by `keySchema` and that of values is represented by
* `valueSchema`. * `valueSchema`.
*/ */
private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType]) private[sql] class SparkSqlSerializer2(
keySchema: Array[DataType],
valueSchema: Array[DataType],
hasKeyOrdering: Boolean)
extends Serializer extends Serializer
with Logging with Logging
with Serializable{ with Serializable{
def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema) def newInstance(): SerializerInstance =
new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering)
override def supportsRelocationOfSerializedObjects: Boolean = { override def supportsRelocationOfSerializedObjects: Boolean = {
// SparkSqlSerializer2 is stateless and writes no stream headers // SparkSqlSerializer2 is stateless and writes no stream headers
@ -323,11 +341,11 @@ private[sql] object SparkSqlSerializer2 {
*/ */
def createDeserializationFunction( def createDeserializationFunction(
schema: Array[DataType], schema: Array[DataType],
in: DataInputStream, in: DataInputStream): (MutableRow) => Row = {
mutableRow: SpecificMutableRow): () => Unit = { if (schema == null) {
() => { (mutableRow: MutableRow) => null
// If the schema is null, the returned function does nothing when it get called. } else {
if (schema != null) { (mutableRow: MutableRow) => {
var i = 0 var i = 0
while (i < schema.length) { while (i < schema.length) {
schema(i) match { schema(i) match {
@ -440,6 +458,8 @@ private[sql] object SparkSqlSerializer2 {
} }
i += 1 i += 1
} }
mutableRow
} }
} }
} }

View file

@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
table("shuffle").collect()) table("shuffle").collect())
} }
test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
}
test("value schema is null") { test("value schema is null") {
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0") val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
checkSerializer(df.queryExecution.executedPlan, serializerClass) checkSerializer(df.queryExecution.executedPlan, serializerClass)
@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
// Sort merge will not be triggered. // Sort merge will not be triggered.
sql("set spark.sql.shuffle.partitions = 200") val bypassMergeThreshold =
} sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
test("key schema is null") {
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
val df = sql(s"SELECT $aggregations FROM shuffle")
checkSerializer(df.queryExecution.executedPlan, serializerClass)
checkAnswer(
df,
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
} }
} }
/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */ /** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite { class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
// We are expecting SparkSqlSerializer.
override val serializerClass: Class[Serializer] =
classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]
override def beforeAll(): Unit = { override def beforeAll(): Unit = {
super.beforeAll() super.beforeAll()
// To trigger the sort merge. // To trigger the sort merge.
sql("set spark.sql.shuffle.partitions = 201") val bypassMergeThreshold =
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
} }
} }