[SPARK-9247] [SQL] Use BytesToBytesMap for broadcast join

This PR introduce BytesToBytesMap to UnsafeHashedRelation, use it in executor for better performance.

It serialize all the key and values from java HashMap, put them into a BytesToBytesMap while deserializing. All the values for a same key are stored continuous to have better memory locality.

This PR also address the comments for #7480 , do some clean up.

Author: Davies Liu <davies@databricks.com>

Closes #7592 from davies/unsafe_map2 and squashes the following commits:

42c578a [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_map2
fd09528 [Davies Liu] remove thread local cache and update docs
1c5ad8d [Davies Liu] fix test
5eb1b5a [Davies Liu] address comments in #7480
46f1f22 [Davies Liu] fix style
fc221e0 [Davies Liu] use BytesToBytesMap for broadcast join
This commit is contained in:
Davies Liu 2015-07-28 15:56:19 -07:00 committed by Davies Liu
parent 198d181dfb
commit 21825529ea
12 changed files with 215 additions and 122 deletions

View file

@ -62,7 +62,7 @@ case class BroadcastHashJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = buildHashRelation(input.iterator)
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size)
sparkContext.broadcast(hashed)
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)

View file

@ -61,7 +61,7 @@ case class BroadcastHashOuterJoin(
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
val hashed = buildHashRelation(input.iterator)
val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
sparkContext.broadcast(hashed)
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)

View file

@ -37,17 +37,17 @@ case class BroadcastLeftSemiJoinHash(
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
protected override def doExecute(): RDD[InternalRow] = {
val buildIter = right.execute().map(_.copy()).collect().toIterator
val input = right.execute().map(_.copy()).collect()
if (condition.isEmpty) {
val hashSet = buildKeyHashSet(buildIter)
val hashSet = buildKeyHashSet(input.toIterator)
val broadcastedRelation = sparkContext.broadcast(hashSet)
left.execute().mapPartitions { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value)
}
} else {
val hashRelation = buildHashRelation(buildIter)
val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size)
val broadcastedRelation = sparkContext.broadcast(hashRelation)
left.execute().mapPartitions { streamIter =>

View file

@ -47,13 +47,11 @@ case class BroadcastNestedLoopJoin(
override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
override def canProcessUnsafeRows: Boolean = true
@transient private[this] lazy val resultProjection: Projection = {
@transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
if (outputsUnsafeRows) {
UnsafeProjection.create(schema)
} else {
new Projection {
override def apply(r: InternalRow): InternalRow = r
}
identity[InternalRow]
}
}
@ -96,7 +94,6 @@ case class BroadcastNestedLoopJoin(
var streamRowMatched = false
while (i < broadcastedRelation.value.size) {
// TODO: One bitset per partition instead of per row.
val broadcastedRow = broadcastedRelation.value(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
@ -135,17 +132,26 @@ case class BroadcastNestedLoopJoin(
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
var i = 0
val rel = broadcastedRelation.value
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
(joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) =>
buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
case (LeftOuter | FullOuter, BuildLeft) =>
buf += resultProjection(new JoinedRow(rel(i), rightNulls))
case _ =>
(joinType, buildSide) match {
case (RightOuter | FullOuter, BuildRight) =>
val joinedRow = new JoinedRow
joinedRow.withLeft(leftNulls)
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
buf += resultProjection(joinedRow.withRight(rel(i))).copy()
}
i += 1
}
}
i += 1
case (LeftOuter | FullOuter, BuildLeft) =>
val joinedRow = new JoinedRow
joinedRow.withRight(rightNulls)
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
buf += resultProjection(joinedRow.withLeft(rel(i))).copy()
}
i += 1
}
case _ =>
}
buf.toSeq
}

View file

@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.collection.CompactBuffer
trait HashJoin {
@ -44,16 +43,24 @@ trait HashJoin {
override def output: Seq[Attribute] = left.output ++ right.output
protected[this] def supportUnsafe: Boolean = {
protected[this] def isUnsafeMode: Boolean = {
(self.codegenEnabled && UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(self.schema))
}
override def outputsUnsafeRows: Boolean = supportUnsafe
override def canProcessUnsafeRows: Boolean = supportUnsafe
override def outputsUnsafeRows: Boolean = isUnsafeMode
override def canProcessUnsafeRows: Boolean = isUnsafeMode
override def canProcessSafeRows: Boolean = !isUnsafeMode
@transient protected lazy val buildSideKeyGenerator: Projection =
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildPlan.output)
} else {
newMutableProjection(buildKeys, buildPlan.output)()
}
@transient protected lazy val streamSideKeyGenerator: Projection =
if (supportUnsafe) {
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
newMutableProjection(streamedKeys, streamedPlan.output)()
@ -65,18 +72,16 @@ trait HashJoin {
{
new Iterator[InternalRow] {
private[this] var currentStreamedRow: InternalRow = _
private[this] var currentHashMatches: CompactBuffer[InternalRow] = _
private[this] var currentHashMatches: Seq[InternalRow] = _
private[this] var currentMatchPosition: Int = -1
// Mutable per row objects.
private[this] val joinRow = new JoinedRow
private[this] val resultProjection: Projection = {
if (supportUnsafe) {
private[this] val resultProjection: (InternalRow) => InternalRow = {
if (isUnsafeMode) {
UnsafeProjection.create(self.schema)
} else {
new Projection {
override def apply(r: InternalRow): InternalRow = r
}
identity[InternalRow]
}
}
@ -122,12 +127,4 @@ trait HashJoin {
}
}
}
protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (supportUnsafe) {
UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
} else {
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
}
}
}

View file

@ -75,30 +75,36 @@ trait HashOuterJoin {
s"HashOuterJoin should not take $x as the JoinType")
}
protected[this] def supportUnsafe: Boolean = {
protected[this] def isUnsafeMode: Boolean = {
(self.codegenEnabled && joinType != FullOuter
&& UnsafeProjection.canSupport(buildKeys)
&& UnsafeProjection.canSupport(self.schema))
}
override def outputsUnsafeRows: Boolean = supportUnsafe
override def canProcessUnsafeRows: Boolean = supportUnsafe
override def outputsUnsafeRows: Boolean = isUnsafeMode
override def canProcessUnsafeRows: Boolean = isUnsafeMode
override def canProcessSafeRows: Boolean = !isUnsafeMode
protected[this] def streamedKeyGenerator(): Projection = {
if (supportUnsafe) {
@transient protected lazy val buildKeyGenerator: Projection =
if (isUnsafeMode) {
UnsafeProjection.create(buildKeys, buildPlan.output)
} else {
newMutableProjection(buildKeys, buildPlan.output)()
}
@transient protected[this] lazy val streamedKeyGenerator: Projection = {
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedPlan.output)
} else {
newProjection(streamedKeys, streamedPlan.output)
}
}
@transient private[this] lazy val resultProjection: Projection = {
if (supportUnsafe) {
@transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
if (isUnsafeMode) {
UnsafeProjection.create(self.schema)
} else {
new Projection {
override def apply(r: InternalRow): InternalRow = r
}
identity[InternalRow]
}
}
@ -230,12 +236,4 @@ trait HashOuterJoin {
hashTable
}
protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (supportUnsafe) {
UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
} else {
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
}
}
}

View file

@ -35,11 +35,13 @@ trait HashSemiJoin {
protected[this] def supportUnsafe: Boolean = {
(self.codegenEnabled && UnsafeProjection.canSupport(leftKeys)
&& UnsafeProjection.canSupport(rightKeys)
&& UnsafeProjection.canSupport(left.schema))
&& UnsafeProjection.canSupport(left.schema)
&& UnsafeProjection.canSupport(right.schema))
}
override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows
override def outputsUnsafeRows: Boolean = supportUnsafe
override def canProcessUnsafeRows: Boolean = supportUnsafe
override def canProcessSafeRows: Boolean = !supportUnsafe
@transient protected lazy val leftKeyGenerator: Projection =
if (supportUnsafe) {
@ -87,14 +89,6 @@ trait HashSemiJoin {
})
}
protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
if (supportUnsafe) {
UnsafeHashedRelation(buildIter, rightKeys, right)
} else {
HashedRelation(buildIter, newProjection(rightKeys, right.output))
}
}
protected def hashSemiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation): Iterator[InternalRow] = {

View file

@ -18,12 +18,15 @@
package org.apache.spark.sql.execution.joins
import java.io.{Externalizable, ObjectInput, ObjectOutput}
import java.nio.ByteOrder
import java.util.{HashMap => JavaHashMap}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.util.collection.CompactBuffer
@ -32,7 +35,7 @@ import org.apache.spark.util.collection.CompactBuffer
* object.
*/
private[joins] sealed trait HashedRelation {
def get(key: InternalRow): CompactBuffer[InternalRow]
def get(key: InternalRow): Seq[InternalRow]
// This is a helper method to implement Externalizable, and is used by
// GeneralHashedRelation and UniqueKeyHashedRelation
@ -59,9 +62,9 @@ private[joins] final class GeneralHashedRelation(
private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
extends HashedRelation with Externalizable {
def this() = this(null) // Needed for serialization
private def this() = this(null) // Needed for serialization
override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key)
override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key)
override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
@ -81,9 +84,9 @@ private[joins]
final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow])
extends HashedRelation with Externalizable {
def this() = this(null) // Needed for serialization
private def this() = this(null) // Needed for serialization
override def get(key: InternalRow): CompactBuffer[InternalRow] = {
override def get(key: InternalRow): Seq[InternalRow] = {
val v = hashTable.get(key)
if (v eq null) null else CompactBuffer(v)
}
@ -109,6 +112,10 @@ private[joins] object HashedRelation {
keyGenerator: Projection,
sizeEstimate: Int = 64): HashedRelation = {
if (keyGenerator.isInstanceOf[UnsafeProjection]) {
return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
}
// TODO: Use Spark's HashMap implementation.
val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate)
var currentRow: InternalRow = null
@ -149,31 +156,133 @@ private[joins] object HashedRelation {
}
}
/**
* A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a
* sequence of values.
* A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
* into a sequence of values.
*
* TODO(davies): use BytesToBytesMap
* When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
* BytesToBytesMap for better memory performance (multiple values for the same are stored as a
* continuous byte array.
*
* It's serialized in the following format:
* [number of keys]
* [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
* ...
*
* All the values are serialized as following:
* [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
* ...
*/
private[joins] final class UnsafeHashedRelation(
private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
extends HashedRelation with Externalizable {
def this() = this(null) // Needed for serialization
private[joins] def this() = this(null) // Needed for serialization
override def get(key: InternalRow): CompactBuffer[InternalRow] = {
// Use BytesToBytesMap in executor for better performance (it's created when deserialization)
@transient private[this] var binaryMap: BytesToBytesMap = _
override def get(key: InternalRow): Seq[InternalRow] = {
val unsafeKey = key.asInstanceOf[UnsafeRow]
// Thanks to type eraser
hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
if (binaryMap != null) {
// Used in Broadcast join
val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
unsafeKey.getSizeInBytes)
if (loc.isDefined) {
val buffer = CompactBuffer[UnsafeRow]()
val base = loc.getValueAddress.getBaseObject
var offset = loc.getValueAddress.getBaseOffset
val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
while (offset < last) {
val numFields = PlatformDependent.UNSAFE.getInt(base, offset)
val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4)
offset += 8
val row = new UnsafeRow
row.pointTo(base, offset, numFields, sizeInBytes)
buffer += row
offset += sizeInBytes
}
buffer
} else {
null
}
} else {
// Use the JavaHashMap in Local mode or ShuffleHashJoin
hashTable.get(unsafeKey)
}
}
override def writeExternal(out: ObjectOutput): Unit = {
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
out.writeInt(hashTable.size())
val iter = hashTable.entrySet().iterator()
while (iter.hasNext) {
val entry = iter.next()
val key = entry.getKey
val values = entry.getValue
// write all the values as single byte array
var totalSize = 0L
var i = 0
while (i < values.size) {
totalSize += values(i).getSizeInBytes + 4 + 4
i += 1
}
assert(totalSize < Integer.MAX_VALUE, "values are too big")
// [key size] [values size] [key bytes] [values bytes]
out.writeInt(key.getSizeInBytes)
out.writeInt(totalSize.toInt)
out.write(key.getBytes)
i = 0
while (i < values.size) {
// [num of fields] [num of bytes] [row bytes]
// write the integer in native order, so they can be read by UNSAFE.getInt()
if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
out.writeInt(values(i).numFields())
out.writeInt(values(i).getSizeInBytes)
} else {
out.writeInt(Integer.reverseBytes(values(i).numFields()))
out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
}
out.write(values(i).getBytes)
i += 1
}
}
}
override def readExternal(in: ObjectInput): Unit = {
hashTable = SparkSqlSerializer.deserialize(readBytes(in))
val nKeys = in.readInt()
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
binaryMap = new BytesToBytesMap(memoryManager, nKeys * 2) // reduce hash collision
var i = 0
var keyBuffer = new Array[Byte](1024)
var valuesBuffer = new Array[Byte](1024)
while (i < nKeys) {
val keySize = in.readInt()
val valuesSize = in.readInt()
if (keySize > keyBuffer.size) {
keyBuffer = new Array[Byte](keySize)
}
in.readFully(keyBuffer, 0, keySize)
if (valuesSize > valuesBuffer.size) {
valuesBuffer = new Array[Byte](valuesSize)
}
in.readFully(valuesBuffer, 0, valuesSize)
// put it into binary map
val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
assert(!loc.isDefined, "Duplicated key found!")
loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
i += 1
}
}
}
@ -181,33 +290,14 @@ private[joins] object UnsafeHashedRelation {
def apply(
input: Iterator[InternalRow],
buildKeys: Seq[Expression],
buildPlan: SparkPlan,
sizeEstimate: Int = 64): HashedRelation = {
val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output))
apply(input, boundedKeys, buildPlan.schema, sizeEstimate)
}
// Used for tests
def apply(
input: Iterator[InternalRow],
buildKeys: Seq[Expression],
rowSchema: StructType,
keyGenerator: UnsafeProjection,
sizeEstimate: Int): HashedRelation = {
// TODO: Use BytesToBytesMap.
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
val toUnsafe = UnsafeProjection.create(rowSchema)
val keyGenerator = UnsafeProjection.create(buildKeys)
// Create a mapping of buildKeys -> rows
while (input.hasNext) {
val currentRow = input.next()
val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) {
currentRow.asInstanceOf[UnsafeRow]
} else {
toUnsafe(currentRow)
}
val unsafeRow = input.next().asInstanceOf[UnsafeRow]
val rowKey = keyGenerator(unsafeRow)
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)

View file

@ -46,7 +46,7 @@ case class LeftSemiJoinHash(
val hashSet = buildKeyHashSet(buildIter)
hashSemiJoin(streamIter, hashSet)
} else {
val hashRelation = buildHashRelation(buildIter)
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
hashSemiJoin(streamIter, hashRelation)
}
}

View file

@ -45,7 +45,7 @@ case class ShuffledHashJoin(
protected override def doExecute(): RDD[InternalRow] = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashed = buildHashRelation(buildIter)
val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
hashJoin(streamIter, hashed)
}
}

View file

@ -50,8 +50,8 @@ case class ShuffledHashOuterJoin(
// TODO this probably can be replaced by external sort (sort merged join?)
joinType match {
case LeftOuter =>
val hashed = buildHashRelation(rightIter)
val keyGenerator = streamedKeyGenerator()
val hashed = HashedRelation(rightIter, buildKeyGenerator)
val keyGenerator = streamedKeyGenerator
leftIter.flatMap( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withLeft(currentRow)
@ -59,8 +59,8 @@ case class ShuffledHashOuterJoin(
})
case RightOuter =>
val hashed = buildHashRelation(leftIter)
val keyGenerator = streamedKeyGenerator()
val hashed = HashedRelation(leftIter, buildKeyGenerator)
val keyGenerator = streamedKeyGenerator
rightIter.flatMap ( currentRow => {
val rowKey = keyGenerator(currentRow)
joinedRow.withRight(currentRow)

View file

@ -17,11 +17,12 @@
package org.apache.spark.sql.execution.joins
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.types.{StructField, StructType, IntegerType}
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.collection.CompactBuffer
@ -64,27 +65,34 @@ class HashedRelationSuite extends SparkFunSuite {
}
test("UnsafeHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val buildKey = Seq(BoundReference(0, IntegerType, false))
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
val toUnsafe = UnsafeProjection.create(schema)
val unsafeData = data.map(toUnsafe(_).copy()).toArray
val buildKey = Seq(BoundReference(0, IntegerType, false))
val keyGenerator = UnsafeProjection.create(buildKey)
val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
assert(hashed.isInstanceOf[UnsafeHashedRelation])
val toUnsafeKey = UnsafeProjection.create(schema)
val unsafeData = data.map(toUnsafeKey(_).copy()).toArray
assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
assert(hashed.get(toUnsafeKey(InternalRow(10))) === null)
assert(hashed.get(toUnsafe(InternalRow(10))) === null)
val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
data2 += unsafeData(2).copy()
assert(hashed.get(unsafeData(2)) === data2)
val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed))
.asInstanceOf[UnsafeHashedRelation]
val os = new ByteArrayOutputStream()
val out = new ObjectOutputStream(os)
hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
out.flush()
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
val hashed2 = new UnsafeHashedRelation()
hashed2.readExternal(in)
assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null)
assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
assert(hashed2.get(unsafeData(2)) === data2)
}
}