[SPARK-6986] [SQL] Use Serializer2 in more cases.
With0a2b15ce43
, the serialization stream and deserialization stream has enough information to determine it is handling a key-value pari, a key, or a value. It is safe to use `SparkSqlSerializer2` in more cases. Author: Yin Huai <yhuai@databricks.com> Closes #5849 from yhuai/serializer2MoreCases and squashes the following commits: 53a5eaa [Yin Huai] Josh's comments. 487f540 [Yin Huai] Use BufferedOutputStream. 8385f95 [Yin Huai] Always create a new row at the deserialization side to work with sort merge join. c7e2129 [Yin Huai] Update tests. 4513d13 [Yin Huai] Use Serializer2 in more places. (cherry picked from commit3af423c92f
) Signed-off-by: Yin Huai <yhuai@databricks.com>
This commit is contained in:
parent
28d4238708
commit
9d0d28940f
|
@ -84,18 +84,8 @@ case class Exchange(
|
||||||
def serializer(
|
def serializer(
|
||||||
keySchema: Array[DataType],
|
keySchema: Array[DataType],
|
||||||
valueSchema: Array[DataType],
|
valueSchema: Array[DataType],
|
||||||
|
hasKeyOrdering: Boolean,
|
||||||
numPartitions: Int): Serializer = {
|
numPartitions: Int): Serializer = {
|
||||||
// In ExternalSorter's spillToMergeableFile function, key-value pairs are written out
|
|
||||||
// through write(key) and then write(value) instead of write((key, value)). Because
|
|
||||||
// SparkSqlSerializer2 assumes that objects passed in are Product2, we cannot safely use
|
|
||||||
// it when spillToMergeableFile in ExternalSorter will be used.
|
|
||||||
// So, we will not use SparkSqlSerializer2 when
|
|
||||||
// - Sort-based shuffle is enabled and the number of reducers (numPartitions) is greater
|
|
||||||
// then the bypassMergeThreshold; or
|
|
||||||
// - newOrdering is defined.
|
|
||||||
val cannotUseSqlSerializer2 =
|
|
||||||
(sortBasedShuffleOn && numPartitions > bypassMergeThreshold) || newOrdering.nonEmpty
|
|
||||||
|
|
||||||
// It is true when there is no field that needs to be write out.
|
// It is true when there is no field that needs to be write out.
|
||||||
// For now, we will not use SparkSqlSerializer2 when noField is true.
|
// For now, we will not use SparkSqlSerializer2 when noField is true.
|
||||||
val noField =
|
val noField =
|
||||||
|
@ -104,14 +94,13 @@ case class Exchange(
|
||||||
|
|
||||||
val useSqlSerializer2 =
|
val useSqlSerializer2 =
|
||||||
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
|
child.sqlContext.conf.useSqlSerializer2 && // SparkSqlSerializer2 is enabled.
|
||||||
!cannotUseSqlSerializer2 && // Safe to use Serializer2.
|
|
||||||
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
|
SparkSqlSerializer2.support(keySchema) && // The schema of key is supported.
|
||||||
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
|
SparkSqlSerializer2.support(valueSchema) && // The schema of value is supported.
|
||||||
!noField
|
!noField
|
||||||
|
|
||||||
val serializer = if (useSqlSerializer2) {
|
val serializer = if (useSqlSerializer2) {
|
||||||
logInfo("Using SparkSqlSerializer2.")
|
logInfo("Using SparkSqlSerializer2.")
|
||||||
new SparkSqlSerializer2(keySchema, valueSchema)
|
new SparkSqlSerializer2(keySchema, valueSchema, hasKeyOrdering)
|
||||||
} else {
|
} else {
|
||||||
logInfo("Using SparkSqlSerializer.")
|
logInfo("Using SparkSqlSerializer.")
|
||||||
new SparkSqlSerializer(sparkConf)
|
new SparkSqlSerializer(sparkConf)
|
||||||
|
@ -154,7 +143,8 @@ case class Exchange(
|
||||||
}
|
}
|
||||||
val keySchema = expressions.map(_.dataType).toArray
|
val keySchema = expressions.map(_.dataType).toArray
|
||||||
val valueSchema = child.output.map(_.dataType).toArray
|
val valueSchema = child.output.map(_.dataType).toArray
|
||||||
shuffled.setSerializer(serializer(keySchema, valueSchema, numPartitions))
|
shuffled.setSerializer(
|
||||||
|
serializer(keySchema, valueSchema, newOrdering.nonEmpty, numPartitions))
|
||||||
|
|
||||||
shuffled.map(_._2)
|
shuffled.map(_._2)
|
||||||
|
|
||||||
|
@ -179,7 +169,8 @@ case class Exchange(
|
||||||
new ShuffledRDD[Row, Null, Null](rdd, part)
|
new ShuffledRDD[Row, Null, Null](rdd, part)
|
||||||
}
|
}
|
||||||
val keySchema = child.output.map(_.dataType).toArray
|
val keySchema = child.output.map(_.dataType).toArray
|
||||||
shuffled.setSerializer(serializer(keySchema, null, numPartitions))
|
shuffled.setSerializer(
|
||||||
|
serializer(keySchema, null, newOrdering.nonEmpty, numPartitions))
|
||||||
|
|
||||||
shuffled.map(_._1)
|
shuffled.map(_._1)
|
||||||
|
|
||||||
|
@ -199,7 +190,7 @@ case class Exchange(
|
||||||
val partitioner = new HashPartitioner(1)
|
val partitioner = new HashPartitioner(1)
|
||||||
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
|
val shuffled = new ShuffledRDD[Null, Row, Row](rdd, partitioner)
|
||||||
val valueSchema = child.output.map(_.dataType).toArray
|
val valueSchema = child.output.map(_.dataType).toArray
|
||||||
shuffled.setSerializer(serializer(null, valueSchema, 1))
|
shuffled.setSerializer(serializer(null, valueSchema, false, 1))
|
||||||
shuffled.map(_._2)
|
shuffled.map(_._2)
|
||||||
|
|
||||||
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
|
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
|
||||||
|
|
|
@ -27,7 +27,7 @@ import scala.reflect.ClassTag
|
||||||
import org.apache.spark.serializer._
|
import org.apache.spark.serializer._
|
||||||
import org.apache.spark.Logging
|
import org.apache.spark.Logging
|
||||||
import org.apache.spark.sql.Row
|
import org.apache.spark.sql.Row
|
||||||
import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
|
import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, GenericMutableRow}
|
||||||
import org.apache.spark.sql.types._
|
import org.apache.spark.sql.types._
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -49,9 +49,9 @@ private[sql] class Serializer2SerializationStream(
|
||||||
out: OutputStream)
|
out: OutputStream)
|
||||||
extends SerializationStream with Logging {
|
extends SerializationStream with Logging {
|
||||||
|
|
||||||
val rowOut = new DataOutputStream(out)
|
private val rowOut = new DataOutputStream(new BufferedOutputStream(out))
|
||||||
val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
|
private val writeKeyFunc = SparkSqlSerializer2.createSerializationFunction(keySchema, rowOut)
|
||||||
val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
|
private val writeValueFunc = SparkSqlSerializer2.createSerializationFunction(valueSchema, rowOut)
|
||||||
|
|
||||||
override def writeObject[T: ClassTag](t: T): SerializationStream = {
|
override def writeObject[T: ClassTag](t: T): SerializationStream = {
|
||||||
val kv = t.asInstanceOf[Product2[Row, Row]]
|
val kv = t.asInstanceOf[Product2[Row, Row]]
|
||||||
|
@ -86,31 +86,44 @@ private[sql] class Serializer2SerializationStream(
|
||||||
private[sql] class Serializer2DeserializationStream(
|
private[sql] class Serializer2DeserializationStream(
|
||||||
keySchema: Array[DataType],
|
keySchema: Array[DataType],
|
||||||
valueSchema: Array[DataType],
|
valueSchema: Array[DataType],
|
||||||
|
hasKeyOrdering: Boolean,
|
||||||
in: InputStream)
|
in: InputStream)
|
||||||
extends DeserializationStream with Logging {
|
extends DeserializationStream with Logging {
|
||||||
|
|
||||||
val rowIn = new DataInputStream(new BufferedInputStream(in))
|
private val rowIn = new DataInputStream(new BufferedInputStream(in))
|
||||||
|
|
||||||
val key = if (keySchema != null) new SpecificMutableRow(keySchema) else null
|
private def rowGenerator(schema: Array[DataType]): () => (MutableRow) = {
|
||||||
val value = if (valueSchema != null) new SpecificMutableRow(valueSchema) else null
|
if (schema == null) {
|
||||||
val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn, key)
|
() => null
|
||||||
val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn, value)
|
} else {
|
||||||
|
if (hasKeyOrdering) {
|
||||||
|
// We have key ordering specified in a ShuffledRDD, it is not safe to reuse a mutable row.
|
||||||
|
() => new GenericMutableRow(schema.length)
|
||||||
|
} else {
|
||||||
|
// It is safe to reuse the mutable row.
|
||||||
|
val mutableRow = new SpecificMutableRow(schema)
|
||||||
|
() => mutableRow
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Functions used to return rows for key and value.
|
||||||
|
private val getKey = rowGenerator(keySchema)
|
||||||
|
private val getValue = rowGenerator(valueSchema)
|
||||||
|
// Functions used to read a serialized row from the InputStream and deserialize it.
|
||||||
|
private val readKeyFunc = SparkSqlSerializer2.createDeserializationFunction(keySchema, rowIn)
|
||||||
|
private val readValueFunc = SparkSqlSerializer2.createDeserializationFunction(valueSchema, rowIn)
|
||||||
|
|
||||||
override def readObject[T: ClassTag](): T = {
|
override def readObject[T: ClassTag](): T = {
|
||||||
readKeyFunc()
|
(readKeyFunc(getKey()), readValueFunc(getValue())).asInstanceOf[T]
|
||||||
readValueFunc()
|
|
||||||
|
|
||||||
(key, value).asInstanceOf[T]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override def readKey[T: ClassTag](): T = {
|
override def readKey[T: ClassTag](): T = {
|
||||||
readKeyFunc()
|
readKeyFunc(getKey()).asInstanceOf[T]
|
||||||
key.asInstanceOf[T]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override def readValue[T: ClassTag](): T = {
|
override def readValue[T: ClassTag](): T = {
|
||||||
readValueFunc()
|
readValueFunc(getValue()).asInstanceOf[T]
|
||||||
value.asInstanceOf[T]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override def close(): Unit = {
|
override def close(): Unit = {
|
||||||
|
@ -118,9 +131,10 @@ private[sql] class Serializer2DeserializationStream(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private[sql] class ShuffleSerializerInstance(
|
private[sql] class SparkSqlSerializer2Instance(
|
||||||
keySchema: Array[DataType],
|
keySchema: Array[DataType],
|
||||||
valueSchema: Array[DataType])
|
valueSchema: Array[DataType],
|
||||||
|
hasKeyOrdering: Boolean)
|
||||||
extends SerializerInstance {
|
extends SerializerInstance {
|
||||||
|
|
||||||
def serialize[T: ClassTag](t: T): ByteBuffer =
|
def serialize[T: ClassTag](t: T): ByteBuffer =
|
||||||
|
@ -137,7 +151,7 @@ private[sql] class ShuffleSerializerInstance(
|
||||||
}
|
}
|
||||||
|
|
||||||
def deserializeStream(s: InputStream): DeserializationStream = {
|
def deserializeStream(s: InputStream): DeserializationStream = {
|
||||||
new Serializer2DeserializationStream(keySchema, valueSchema, s)
|
new Serializer2DeserializationStream(keySchema, valueSchema, hasKeyOrdering, s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,12 +162,16 @@ private[sql] class ShuffleSerializerInstance(
|
||||||
* The schema of keys is represented by `keySchema` and that of values is represented by
|
* The schema of keys is represented by `keySchema` and that of values is represented by
|
||||||
* `valueSchema`.
|
* `valueSchema`.
|
||||||
*/
|
*/
|
||||||
private[sql] class SparkSqlSerializer2(keySchema: Array[DataType], valueSchema: Array[DataType])
|
private[sql] class SparkSqlSerializer2(
|
||||||
|
keySchema: Array[DataType],
|
||||||
|
valueSchema: Array[DataType],
|
||||||
|
hasKeyOrdering: Boolean)
|
||||||
extends Serializer
|
extends Serializer
|
||||||
with Logging
|
with Logging
|
||||||
with Serializable{
|
with Serializable{
|
||||||
|
|
||||||
def newInstance(): SerializerInstance = new ShuffleSerializerInstance(keySchema, valueSchema)
|
def newInstance(): SerializerInstance =
|
||||||
|
new SparkSqlSerializer2Instance(keySchema, valueSchema, hasKeyOrdering)
|
||||||
|
|
||||||
override def supportsRelocationOfSerializedObjects: Boolean = {
|
override def supportsRelocationOfSerializedObjects: Boolean = {
|
||||||
// SparkSqlSerializer2 is stateless and writes no stream headers
|
// SparkSqlSerializer2 is stateless and writes no stream headers
|
||||||
|
@ -323,11 +341,11 @@ private[sql] object SparkSqlSerializer2 {
|
||||||
*/
|
*/
|
||||||
def createDeserializationFunction(
|
def createDeserializationFunction(
|
||||||
schema: Array[DataType],
|
schema: Array[DataType],
|
||||||
in: DataInputStream,
|
in: DataInputStream): (MutableRow) => Row = {
|
||||||
mutableRow: SpecificMutableRow): () => Unit = {
|
if (schema == null) {
|
||||||
() => {
|
(mutableRow: MutableRow) => null
|
||||||
// If the schema is null, the returned function does nothing when it get called.
|
} else {
|
||||||
if (schema != null) {
|
(mutableRow: MutableRow) => {
|
||||||
var i = 0
|
var i = 0
|
||||||
while (i < schema.length) {
|
while (i < schema.length) {
|
||||||
schema(i) match {
|
schema(i) match {
|
||||||
|
@ -440,6 +458,8 @@ private[sql] object SparkSqlSerializer2 {
|
||||||
}
|
}
|
||||||
i += 1
|
i += 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mutableRow
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -148,6 +148,15 @@ abstract class SparkSqlSerializer2Suite extends QueryTest with BeforeAndAfterAll
|
||||||
table("shuffle").collect())
|
table("shuffle").collect())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("key schema is null") {
|
||||||
|
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
|
||||||
|
val df = sql(s"SELECT $aggregations FROM shuffle")
|
||||||
|
checkSerializer(df.queryExecution.executedPlan, serializerClass)
|
||||||
|
checkAnswer(
|
||||||
|
df,
|
||||||
|
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
|
||||||
|
}
|
||||||
|
|
||||||
test("value schema is null") {
|
test("value schema is null") {
|
||||||
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
|
val df = sql(s"SELECT col0 FROM shuffle ORDER BY col0")
|
||||||
checkSerializer(df.queryExecution.executedPlan, serializerClass)
|
checkSerializer(df.queryExecution.executedPlan, serializerClass)
|
||||||
|
@ -167,29 +176,20 @@ class SparkSqlSerializer2SortShuffleSuite extends SparkSqlSerializer2Suite {
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
// Sort merge will not be triggered.
|
// Sort merge will not be triggered.
|
||||||
sql("set spark.sql.shuffle.partitions = 200")
|
val bypassMergeThreshold =
|
||||||
}
|
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
|
||||||
|
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold-1}")
|
||||||
test("key schema is null") {
|
|
||||||
val aggregations = allColumns.split(",").map(c => s"COUNT($c)").mkString(",")
|
|
||||||
val df = sql(s"SELECT $aggregations FROM shuffle")
|
|
||||||
checkSerializer(df.queryExecution.executedPlan, serializerClass)
|
|
||||||
checkAnswer(
|
|
||||||
df,
|
|
||||||
Row(1000, 1000, 0, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000, 1000))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
|
/** For now, we will use SparkSqlSerializer for sort based shuffle with sort merge. */
|
||||||
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
|
class SparkSqlSerializer2SortMergeShuffleSuite extends SparkSqlSerializer2Suite {
|
||||||
|
|
||||||
// We are expecting SparkSqlSerializer.
|
|
||||||
override val serializerClass: Class[Serializer] =
|
|
||||||
classOf[SparkSqlSerializer].asInstanceOf[Class[Serializer]]
|
|
||||||
|
|
||||||
override def beforeAll(): Unit = {
|
override def beforeAll(): Unit = {
|
||||||
super.beforeAll()
|
super.beforeAll()
|
||||||
// To trigger the sort merge.
|
// To trigger the sort merge.
|
||||||
sql("set spark.sql.shuffle.partitions = 201")
|
val bypassMergeThreshold =
|
||||||
|
sparkContext.conf.getInt("spark.shuffle.sort.bypassMergeThreshold", 200)
|
||||||
|
sql(s"set spark.sql.shuffle.partitions=${bypassMergeThreshold + 1}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue