[SPARK-9247] [SQL] Use BytesToBytesMap for broadcast join
This PR introduce BytesToBytesMap to UnsafeHashedRelation, use it in executor for better performance. It serialize all the key and values from java HashMap, put them into a BytesToBytesMap while deserializing. All the values for a same key are stored continuous to have better memory locality. This PR also address the comments for #7480 , do some clean up. Author: Davies Liu <davies@databricks.com> Closes #7592 from davies/unsafe_map2 and squashes the following commits: 42c578a [Davies Liu] Merge branch 'master' of github.com:apache/spark into unsafe_map2 fd09528 [Davies Liu] remove thread local cache and update docs 1c5ad8d [Davies Liu] fix test 5eb1b5a [Davies Liu] address comments in #7480 46f1f22 [Davies Liu] fix style fc221e0 [Davies Liu] use BytesToBytesMap for broadcast join
This commit is contained in:
parent
198d181dfb
commit
21825529ea
|
@ -62,7 +62,7 @@ case class BroadcastHashJoin(
|
|||
private val broadcastFuture = future {
|
||||
// Note that we use .execute().collect() because we don't want to convert data to Scala types
|
||||
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
|
||||
val hashed = buildHashRelation(input.iterator)
|
||||
val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.size)
|
||||
sparkContext.broadcast(hashed)
|
||||
}(BroadcastHashJoin.broadcastHashJoinExecutionContext)
|
||||
|
||||
|
|
|
@ -61,7 +61,7 @@ case class BroadcastHashOuterJoin(
|
|||
private val broadcastFuture = future {
|
||||
// Note that we use .execute().collect() because we don't want to convert data to Scala types
|
||||
val input: Array[InternalRow] = buildPlan.execute().map(_.copy()).collect()
|
||||
val hashed = buildHashRelation(input.iterator)
|
||||
val hashed = HashedRelation(input.iterator, buildKeyGenerator, input.size)
|
||||
sparkContext.broadcast(hashed)
|
||||
}(BroadcastHashOuterJoin.broadcastHashOuterJoinExecutionContext)
|
||||
|
||||
|
|
|
@ -37,17 +37,17 @@ case class BroadcastLeftSemiJoinHash(
|
|||
condition: Option[Expression]) extends BinaryNode with HashSemiJoin {
|
||||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
val buildIter = right.execute().map(_.copy()).collect().toIterator
|
||||
val input = right.execute().map(_.copy()).collect()
|
||||
|
||||
if (condition.isEmpty) {
|
||||
val hashSet = buildKeyHashSet(buildIter)
|
||||
val hashSet = buildKeyHashSet(input.toIterator)
|
||||
val broadcastedRelation = sparkContext.broadcast(hashSet)
|
||||
|
||||
left.execute().mapPartitions { streamIter =>
|
||||
hashSemiJoin(streamIter, broadcastedRelation.value)
|
||||
}
|
||||
} else {
|
||||
val hashRelation = buildHashRelation(buildIter)
|
||||
val hashRelation = HashedRelation(input.toIterator, rightKeyGenerator, input.size)
|
||||
val broadcastedRelation = sparkContext.broadcast(hashRelation)
|
||||
|
||||
left.execute().mapPartitions { streamIter =>
|
||||
|
|
|
@ -47,13 +47,11 @@ case class BroadcastNestedLoopJoin(
|
|||
override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows
|
||||
override def canProcessUnsafeRows: Boolean = true
|
||||
|
||||
@transient private[this] lazy val resultProjection: Projection = {
|
||||
@transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
|
||||
if (outputsUnsafeRows) {
|
||||
UnsafeProjection.create(schema)
|
||||
} else {
|
||||
new Projection {
|
||||
override def apply(r: InternalRow): InternalRow = r
|
||||
}
|
||||
identity[InternalRow]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,7 +94,6 @@ case class BroadcastNestedLoopJoin(
|
|||
var streamRowMatched = false
|
||||
|
||||
while (i < broadcastedRelation.value.size) {
|
||||
// TODO: One bitset per partition instead of per row.
|
||||
val broadcastedRow = broadcastedRelation.value(i)
|
||||
buildSide match {
|
||||
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
|
||||
|
@ -135,17 +132,26 @@ case class BroadcastNestedLoopJoin(
|
|||
val buf: CompactBuffer[InternalRow] = new CompactBuffer()
|
||||
var i = 0
|
||||
val rel = broadcastedRelation.value
|
||||
while (i < rel.length) {
|
||||
if (!allIncludedBroadcastTuples.contains(i)) {
|
||||
(joinType, buildSide) match {
|
||||
case (RightOuter | FullOuter, BuildRight) =>
|
||||
buf += resultProjection(new JoinedRow(leftNulls, rel(i)))
|
||||
case (LeftOuter | FullOuter, BuildLeft) =>
|
||||
buf += resultProjection(new JoinedRow(rel(i), rightNulls))
|
||||
case _ =>
|
||||
(joinType, buildSide) match {
|
||||
case (RightOuter | FullOuter, BuildRight) =>
|
||||
val joinedRow = new JoinedRow
|
||||
joinedRow.withLeft(leftNulls)
|
||||
while (i < rel.length) {
|
||||
if (!allIncludedBroadcastTuples.contains(i)) {
|
||||
buf += resultProjection(joinedRow.withRight(rel(i))).copy()
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
i += 1
|
||||
case (LeftOuter | FullOuter, BuildLeft) =>
|
||||
val joinedRow = new JoinedRow
|
||||
joinedRow.withRight(rightNulls)
|
||||
while (i < rel.length) {
|
||||
if (!allIncludedBroadcastTuples.contains(i)) {
|
||||
buf += resultProjection(joinedRow.withLeft(rel(i))).copy()
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
case _ =>
|
||||
}
|
||||
buf.toSeq
|
||||
}
|
||||
|
|
|
@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.joins
|
|||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.execution.SparkPlan
|
||||
import org.apache.spark.util.collection.CompactBuffer
|
||||
|
||||
|
||||
trait HashJoin {
|
||||
|
@ -44,16 +43,24 @@ trait HashJoin {
|
|||
|
||||
override def output: Seq[Attribute] = left.output ++ right.output
|
||||
|
||||
protected[this] def supportUnsafe: Boolean = {
|
||||
protected[this] def isUnsafeMode: Boolean = {
|
||||
(self.codegenEnabled && UnsafeProjection.canSupport(buildKeys)
|
||||
&& UnsafeProjection.canSupport(self.schema))
|
||||
}
|
||||
|
||||
override def outputsUnsafeRows: Boolean = supportUnsafe
|
||||
override def canProcessUnsafeRows: Boolean = supportUnsafe
|
||||
override def outputsUnsafeRows: Boolean = isUnsafeMode
|
||||
override def canProcessUnsafeRows: Boolean = isUnsafeMode
|
||||
override def canProcessSafeRows: Boolean = !isUnsafeMode
|
||||
|
||||
@transient protected lazy val buildSideKeyGenerator: Projection =
|
||||
if (isUnsafeMode) {
|
||||
UnsafeProjection.create(buildKeys, buildPlan.output)
|
||||
} else {
|
||||
newMutableProjection(buildKeys, buildPlan.output)()
|
||||
}
|
||||
|
||||
@transient protected lazy val streamSideKeyGenerator: Projection =
|
||||
if (supportUnsafe) {
|
||||
if (isUnsafeMode) {
|
||||
UnsafeProjection.create(streamedKeys, streamedPlan.output)
|
||||
} else {
|
||||
newMutableProjection(streamedKeys, streamedPlan.output)()
|
||||
|
@ -65,18 +72,16 @@ trait HashJoin {
|
|||
{
|
||||
new Iterator[InternalRow] {
|
||||
private[this] var currentStreamedRow: InternalRow = _
|
||||
private[this] var currentHashMatches: CompactBuffer[InternalRow] = _
|
||||
private[this] var currentHashMatches: Seq[InternalRow] = _
|
||||
private[this] var currentMatchPosition: Int = -1
|
||||
|
||||
// Mutable per row objects.
|
||||
private[this] val joinRow = new JoinedRow
|
||||
private[this] val resultProjection: Projection = {
|
||||
if (supportUnsafe) {
|
||||
private[this] val resultProjection: (InternalRow) => InternalRow = {
|
||||
if (isUnsafeMode) {
|
||||
UnsafeProjection.create(self.schema)
|
||||
} else {
|
||||
new Projection {
|
||||
override def apply(r: InternalRow): InternalRow = r
|
||||
}
|
||||
identity[InternalRow]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -122,12 +127,4 @@ trait HashJoin {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
|
||||
if (supportUnsafe) {
|
||||
UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
|
||||
} else {
|
||||
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -75,30 +75,36 @@ trait HashOuterJoin {
|
|||
s"HashOuterJoin should not take $x as the JoinType")
|
||||
}
|
||||
|
||||
protected[this] def supportUnsafe: Boolean = {
|
||||
protected[this] def isUnsafeMode: Boolean = {
|
||||
(self.codegenEnabled && joinType != FullOuter
|
||||
&& UnsafeProjection.canSupport(buildKeys)
|
||||
&& UnsafeProjection.canSupport(self.schema))
|
||||
}
|
||||
|
||||
override def outputsUnsafeRows: Boolean = supportUnsafe
|
||||
override def canProcessUnsafeRows: Boolean = supportUnsafe
|
||||
override def outputsUnsafeRows: Boolean = isUnsafeMode
|
||||
override def canProcessUnsafeRows: Boolean = isUnsafeMode
|
||||
override def canProcessSafeRows: Boolean = !isUnsafeMode
|
||||
|
||||
protected[this] def streamedKeyGenerator(): Projection = {
|
||||
if (supportUnsafe) {
|
||||
@transient protected lazy val buildKeyGenerator: Projection =
|
||||
if (isUnsafeMode) {
|
||||
UnsafeProjection.create(buildKeys, buildPlan.output)
|
||||
} else {
|
||||
newMutableProjection(buildKeys, buildPlan.output)()
|
||||
}
|
||||
|
||||
@transient protected[this] lazy val streamedKeyGenerator: Projection = {
|
||||
if (isUnsafeMode) {
|
||||
UnsafeProjection.create(streamedKeys, streamedPlan.output)
|
||||
} else {
|
||||
newProjection(streamedKeys, streamedPlan.output)
|
||||
}
|
||||
}
|
||||
|
||||
@transient private[this] lazy val resultProjection: Projection = {
|
||||
if (supportUnsafe) {
|
||||
@transient private[this] lazy val resultProjection: InternalRow => InternalRow = {
|
||||
if (isUnsafeMode) {
|
||||
UnsafeProjection.create(self.schema)
|
||||
} else {
|
||||
new Projection {
|
||||
override def apply(r: InternalRow): InternalRow = r
|
||||
}
|
||||
identity[InternalRow]
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -230,12 +236,4 @@ trait HashOuterJoin {
|
|||
|
||||
hashTable
|
||||
}
|
||||
|
||||
protected[this] def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
|
||||
if (supportUnsafe) {
|
||||
UnsafeHashedRelation(buildIter, buildKeys, buildPlan)
|
||||
} else {
|
||||
HashedRelation(buildIter, newProjection(buildKeys, buildPlan.output))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -35,11 +35,13 @@ trait HashSemiJoin {
|
|||
protected[this] def supportUnsafe: Boolean = {
|
||||
(self.codegenEnabled && UnsafeProjection.canSupport(leftKeys)
|
||||
&& UnsafeProjection.canSupport(rightKeys)
|
||||
&& UnsafeProjection.canSupport(left.schema))
|
||||
&& UnsafeProjection.canSupport(left.schema)
|
||||
&& UnsafeProjection.canSupport(right.schema))
|
||||
}
|
||||
|
||||
override def outputsUnsafeRows: Boolean = right.outputsUnsafeRows
|
||||
override def outputsUnsafeRows: Boolean = supportUnsafe
|
||||
override def canProcessUnsafeRows: Boolean = supportUnsafe
|
||||
override def canProcessSafeRows: Boolean = !supportUnsafe
|
||||
|
||||
@transient protected lazy val leftKeyGenerator: Projection =
|
||||
if (supportUnsafe) {
|
||||
|
@ -87,14 +89,6 @@ trait HashSemiJoin {
|
|||
})
|
||||
}
|
||||
|
||||
protected def buildHashRelation(buildIter: Iterator[InternalRow]): HashedRelation = {
|
||||
if (supportUnsafe) {
|
||||
UnsafeHashedRelation(buildIter, rightKeys, right)
|
||||
} else {
|
||||
HashedRelation(buildIter, newProjection(rightKeys, right.output))
|
||||
}
|
||||
}
|
||||
|
||||
protected def hashSemiJoin(
|
||||
streamIter: Iterator[InternalRow],
|
||||
hashedRelation: HashedRelation): Iterator[InternalRow] = {
|
||||
|
|
|
@ -18,12 +18,15 @@
|
|||
package org.apache.spark.sql.execution.joins
|
||||
|
||||
import java.io.{Externalizable, ObjectInput, ObjectOutput}
|
||||
import java.nio.ByteOrder
|
||||
import java.util.{HashMap => JavaHashMap}
|
||||
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.sql.execution.SparkSqlSerializer
|
||||
import org.apache.spark.unsafe.PlatformDependent
|
||||
import org.apache.spark.unsafe.map.BytesToBytesMap
|
||||
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
|
||||
import org.apache.spark.util.collection.CompactBuffer
|
||||
|
||||
|
||||
|
@ -32,7 +35,7 @@ import org.apache.spark.util.collection.CompactBuffer
|
|||
* object.
|
||||
*/
|
||||
private[joins] sealed trait HashedRelation {
|
||||
def get(key: InternalRow): CompactBuffer[InternalRow]
|
||||
def get(key: InternalRow): Seq[InternalRow]
|
||||
|
||||
// This is a helper method to implement Externalizable, and is used by
|
||||
// GeneralHashedRelation and UniqueKeyHashedRelation
|
||||
|
@ -59,9 +62,9 @@ private[joins] final class GeneralHashedRelation(
|
|||
private var hashTable: JavaHashMap[InternalRow, CompactBuffer[InternalRow]])
|
||||
extends HashedRelation with Externalizable {
|
||||
|
||||
def this() = this(null) // Needed for serialization
|
||||
private def this() = this(null) // Needed for serialization
|
||||
|
||||
override def get(key: InternalRow): CompactBuffer[InternalRow] = hashTable.get(key)
|
||||
override def get(key: InternalRow): Seq[InternalRow] = hashTable.get(key)
|
||||
|
||||
override def writeExternal(out: ObjectOutput): Unit = {
|
||||
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
|
||||
|
@ -81,9 +84,9 @@ private[joins]
|
|||
final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalRow, InternalRow])
|
||||
extends HashedRelation with Externalizable {
|
||||
|
||||
def this() = this(null) // Needed for serialization
|
||||
private def this() = this(null) // Needed for serialization
|
||||
|
||||
override def get(key: InternalRow): CompactBuffer[InternalRow] = {
|
||||
override def get(key: InternalRow): Seq[InternalRow] = {
|
||||
val v = hashTable.get(key)
|
||||
if (v eq null) null else CompactBuffer(v)
|
||||
}
|
||||
|
@ -109,6 +112,10 @@ private[joins] object HashedRelation {
|
|||
keyGenerator: Projection,
|
||||
sizeEstimate: Int = 64): HashedRelation = {
|
||||
|
||||
if (keyGenerator.isInstanceOf[UnsafeProjection]) {
|
||||
return UnsafeHashedRelation(input, keyGenerator.asInstanceOf[UnsafeProjection], sizeEstimate)
|
||||
}
|
||||
|
||||
// TODO: Use Spark's HashMap implementation.
|
||||
val hashTable = new JavaHashMap[InternalRow, CompactBuffer[InternalRow]](sizeEstimate)
|
||||
var currentRow: InternalRow = null
|
||||
|
@ -149,31 +156,133 @@ private[joins] object HashedRelation {
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* A HashedRelation for UnsafeRow, which is backed by BytesToBytesMap that maps the key into a
|
||||
* sequence of values.
|
||||
* A HashedRelation for UnsafeRow, which is backed by HashMap or BytesToBytesMap that maps the key
|
||||
* into a sequence of values.
|
||||
*
|
||||
* TODO(davies): use BytesToBytesMap
|
||||
* When it's created, it uses HashMap. After it's serialized and deserialized, it switch to use
|
||||
* BytesToBytesMap for better memory performance (multiple values for the same are stored as a
|
||||
* continuous byte array.
|
||||
*
|
||||
* It's serialized in the following format:
|
||||
* [number of keys]
|
||||
* [size of key] [size of all values in bytes] [key bytes] [bytes for all values]
|
||||
* ...
|
||||
*
|
||||
* All the values are serialized as following:
|
||||
* [number of fields] [number of bytes] [underlying bytes of UnsafeRow]
|
||||
* ...
|
||||
*/
|
||||
private[joins] final class UnsafeHashedRelation(
|
||||
private var hashTable: JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]])
|
||||
extends HashedRelation with Externalizable {
|
||||
|
||||
def this() = this(null) // Needed for serialization
|
||||
private[joins] def this() = this(null) // Needed for serialization
|
||||
|
||||
override def get(key: InternalRow): CompactBuffer[InternalRow] = {
|
||||
// Use BytesToBytesMap in executor for better performance (it's created when deserialization)
|
||||
@transient private[this] var binaryMap: BytesToBytesMap = _
|
||||
|
||||
override def get(key: InternalRow): Seq[InternalRow] = {
|
||||
val unsafeKey = key.asInstanceOf[UnsafeRow]
|
||||
// Thanks to type eraser
|
||||
hashTable.get(unsafeKey).asInstanceOf[CompactBuffer[InternalRow]]
|
||||
|
||||
if (binaryMap != null) {
|
||||
// Used in Broadcast join
|
||||
val loc = binaryMap.lookup(unsafeKey.getBaseObject, unsafeKey.getBaseOffset,
|
||||
unsafeKey.getSizeInBytes)
|
||||
if (loc.isDefined) {
|
||||
val buffer = CompactBuffer[UnsafeRow]()
|
||||
|
||||
val base = loc.getValueAddress.getBaseObject
|
||||
var offset = loc.getValueAddress.getBaseOffset
|
||||
val last = loc.getValueAddress.getBaseOffset + loc.getValueLength
|
||||
while (offset < last) {
|
||||
val numFields = PlatformDependent.UNSAFE.getInt(base, offset)
|
||||
val sizeInBytes = PlatformDependent.UNSAFE.getInt(base, offset + 4)
|
||||
offset += 8
|
||||
|
||||
val row = new UnsafeRow
|
||||
row.pointTo(base, offset, numFields, sizeInBytes)
|
||||
buffer += row
|
||||
offset += sizeInBytes
|
||||
}
|
||||
buffer
|
||||
} else {
|
||||
null
|
||||
}
|
||||
|
||||
} else {
|
||||
// Use the JavaHashMap in Local mode or ShuffleHashJoin
|
||||
hashTable.get(unsafeKey)
|
||||
}
|
||||
}
|
||||
|
||||
override def writeExternal(out: ObjectOutput): Unit = {
|
||||
writeBytes(out, SparkSqlSerializer.serialize(hashTable))
|
||||
out.writeInt(hashTable.size())
|
||||
|
||||
val iter = hashTable.entrySet().iterator()
|
||||
while (iter.hasNext) {
|
||||
val entry = iter.next()
|
||||
val key = entry.getKey
|
||||
val values = entry.getValue
|
||||
|
||||
// write all the values as single byte array
|
||||
var totalSize = 0L
|
||||
var i = 0
|
||||
while (i < values.size) {
|
||||
totalSize += values(i).getSizeInBytes + 4 + 4
|
||||
i += 1
|
||||
}
|
||||
assert(totalSize < Integer.MAX_VALUE, "values are too big")
|
||||
|
||||
// [key size] [values size] [key bytes] [values bytes]
|
||||
out.writeInt(key.getSizeInBytes)
|
||||
out.writeInt(totalSize.toInt)
|
||||
out.write(key.getBytes)
|
||||
i = 0
|
||||
while (i < values.size) {
|
||||
// [num of fields] [num of bytes] [row bytes]
|
||||
// write the integer in native order, so they can be read by UNSAFE.getInt()
|
||||
if (ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN) {
|
||||
out.writeInt(values(i).numFields())
|
||||
out.writeInt(values(i).getSizeInBytes)
|
||||
} else {
|
||||
out.writeInt(Integer.reverseBytes(values(i).numFields()))
|
||||
out.writeInt(Integer.reverseBytes(values(i).getSizeInBytes))
|
||||
}
|
||||
out.write(values(i).getBytes)
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override def readExternal(in: ObjectInput): Unit = {
|
||||
hashTable = SparkSqlSerializer.deserialize(readBytes(in))
|
||||
val nKeys = in.readInt()
|
||||
// This is used in Broadcast, shared by multiple tasks, so we use on-heap memory
|
||||
val memoryManager = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
|
||||
binaryMap = new BytesToBytesMap(memoryManager, nKeys * 2) // reduce hash collision
|
||||
|
||||
var i = 0
|
||||
var keyBuffer = new Array[Byte](1024)
|
||||
var valuesBuffer = new Array[Byte](1024)
|
||||
while (i < nKeys) {
|
||||
val keySize = in.readInt()
|
||||
val valuesSize = in.readInt()
|
||||
if (keySize > keyBuffer.size) {
|
||||
keyBuffer = new Array[Byte](keySize)
|
||||
}
|
||||
in.readFully(keyBuffer, 0, keySize)
|
||||
if (valuesSize > valuesBuffer.size) {
|
||||
valuesBuffer = new Array[Byte](valuesSize)
|
||||
}
|
||||
in.readFully(valuesBuffer, 0, valuesSize)
|
||||
|
||||
// put it into binary map
|
||||
val loc = binaryMap.lookup(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize)
|
||||
assert(!loc.isDefined, "Duplicated key found!")
|
||||
loc.putNewKey(keyBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, keySize,
|
||||
valuesBuffer, PlatformDependent.BYTE_ARRAY_OFFSET, valuesSize)
|
||||
i += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -181,33 +290,14 @@ private[joins] object UnsafeHashedRelation {
|
|||
|
||||
def apply(
|
||||
input: Iterator[InternalRow],
|
||||
buildKeys: Seq[Expression],
|
||||
buildPlan: SparkPlan,
|
||||
sizeEstimate: Int = 64): HashedRelation = {
|
||||
val boundedKeys = buildKeys.map(BindReferences.bindReference(_, buildPlan.output))
|
||||
apply(input, boundedKeys, buildPlan.schema, sizeEstimate)
|
||||
}
|
||||
|
||||
// Used for tests
|
||||
def apply(
|
||||
input: Iterator[InternalRow],
|
||||
buildKeys: Seq[Expression],
|
||||
rowSchema: StructType,
|
||||
keyGenerator: UnsafeProjection,
|
||||
sizeEstimate: Int): HashedRelation = {
|
||||
|
||||
// TODO: Use BytesToBytesMap.
|
||||
val hashTable = new JavaHashMap[UnsafeRow, CompactBuffer[UnsafeRow]](sizeEstimate)
|
||||
val toUnsafe = UnsafeProjection.create(rowSchema)
|
||||
val keyGenerator = UnsafeProjection.create(buildKeys)
|
||||
|
||||
// Create a mapping of buildKeys -> rows
|
||||
while (input.hasNext) {
|
||||
val currentRow = input.next()
|
||||
val unsafeRow = if (currentRow.isInstanceOf[UnsafeRow]) {
|
||||
currentRow.asInstanceOf[UnsafeRow]
|
||||
} else {
|
||||
toUnsafe(currentRow)
|
||||
}
|
||||
val unsafeRow = input.next().asInstanceOf[UnsafeRow]
|
||||
val rowKey = keyGenerator(unsafeRow)
|
||||
if (!rowKey.anyNull) {
|
||||
val existingMatchList = hashTable.get(rowKey)
|
||||
|
|
|
@ -46,7 +46,7 @@ case class LeftSemiJoinHash(
|
|||
val hashSet = buildKeyHashSet(buildIter)
|
||||
hashSemiJoin(streamIter, hashSet)
|
||||
} else {
|
||||
val hashRelation = buildHashRelation(buildIter)
|
||||
val hashRelation = HashedRelation(buildIter, rightKeyGenerator)
|
||||
hashSemiJoin(streamIter, hashRelation)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ case class ShuffledHashJoin(
|
|||
|
||||
protected override def doExecute(): RDD[InternalRow] = {
|
||||
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
|
||||
val hashed = buildHashRelation(buildIter)
|
||||
val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
|
||||
hashJoin(streamIter, hashed)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -50,8 +50,8 @@ case class ShuffledHashOuterJoin(
|
|||
// TODO this probably can be replaced by external sort (sort merged join?)
|
||||
joinType match {
|
||||
case LeftOuter =>
|
||||
val hashed = buildHashRelation(rightIter)
|
||||
val keyGenerator = streamedKeyGenerator()
|
||||
val hashed = HashedRelation(rightIter, buildKeyGenerator)
|
||||
val keyGenerator = streamedKeyGenerator
|
||||
leftIter.flatMap( currentRow => {
|
||||
val rowKey = keyGenerator(currentRow)
|
||||
joinedRow.withLeft(currentRow)
|
||||
|
@ -59,8 +59,8 @@ case class ShuffledHashOuterJoin(
|
|||
})
|
||||
|
||||
case RightOuter =>
|
||||
val hashed = buildHashRelation(leftIter)
|
||||
val keyGenerator = streamedKeyGenerator()
|
||||
val hashed = HashedRelation(leftIter, buildKeyGenerator)
|
||||
val keyGenerator = streamedKeyGenerator
|
||||
rightIter.flatMap ( currentRow => {
|
||||
val rowKey = keyGenerator(currentRow)
|
||||
joinedRow.withRight(currentRow)
|
||||
|
|
|
@ -17,11 +17,12 @@
|
|||
|
||||
package org.apache.spark.sql.execution.joins
|
||||
|
||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
|
||||
|
||||
import org.apache.spark.SparkFunSuite
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import org.apache.spark.sql.catalyst.expressions._
|
||||
import org.apache.spark.sql.execution.SparkSqlSerializer
|
||||
import org.apache.spark.sql.types.{StructField, StructType, IntegerType}
|
||||
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
|
||||
import org.apache.spark.util.collection.CompactBuffer
|
||||
|
||||
|
||||
|
@ -64,27 +65,34 @@ class HashedRelationSuite extends SparkFunSuite {
|
|||
}
|
||||
|
||||
test("UnsafeHashedRelation") {
|
||||
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
|
||||
val buildKey = Seq(BoundReference(0, IntegerType, false))
|
||||
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
|
||||
val hashed = UnsafeHashedRelation(data.iterator, buildKey, schema, 1)
|
||||
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
|
||||
val toUnsafe = UnsafeProjection.create(schema)
|
||||
val unsafeData = data.map(toUnsafe(_).copy()).toArray
|
||||
|
||||
val buildKey = Seq(BoundReference(0, IntegerType, false))
|
||||
val keyGenerator = UnsafeProjection.create(buildKey)
|
||||
val hashed = UnsafeHashedRelation(unsafeData.iterator, keyGenerator, 1)
|
||||
assert(hashed.isInstanceOf[UnsafeHashedRelation])
|
||||
|
||||
val toUnsafeKey = UnsafeProjection.create(schema)
|
||||
val unsafeData = data.map(toUnsafeKey(_).copy()).toArray
|
||||
assert(hashed.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
|
||||
assert(hashed.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
|
||||
assert(hashed.get(toUnsafeKey(InternalRow(10))) === null)
|
||||
assert(hashed.get(toUnsafe(InternalRow(10))) === null)
|
||||
|
||||
val data2 = CompactBuffer[InternalRow](unsafeData(2).copy())
|
||||
data2 += unsafeData(2).copy()
|
||||
assert(hashed.get(unsafeData(2)) === data2)
|
||||
|
||||
val hashed2 = SparkSqlSerializer.deserialize(SparkSqlSerializer.serialize(hashed))
|
||||
.asInstanceOf[UnsafeHashedRelation]
|
||||
val os = new ByteArrayOutputStream()
|
||||
val out = new ObjectOutputStream(os)
|
||||
hashed.asInstanceOf[UnsafeHashedRelation].writeExternal(out)
|
||||
out.flush()
|
||||
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
|
||||
val hashed2 = new UnsafeHashedRelation()
|
||||
hashed2.readExternal(in)
|
||||
assert(hashed2.get(unsafeData(0)) === CompactBuffer[InternalRow](unsafeData(0)))
|
||||
assert(hashed2.get(unsafeData(1)) === CompactBuffer[InternalRow](unsafeData(1)))
|
||||
assert(hashed2.get(toUnsafeKey(InternalRow(10))) === null)
|
||||
assert(hashed2.get(toUnsafe(InternalRow(10))) === null)
|
||||
assert(hashed2.get(unsafeData(2)) === data2)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue