[SPARK-16922] [SPARK-17211] [SQL] make the address of values portable in LongToUnsafeRowMap
## What changes were proposed in this pull request? In LongToUnsafeRowMap, we use offset of a value as pointer, stored in a array also in the page for chained values. The offset is not portable, because Platform.LONG_ARRAY_OFFSET will be different with different JVM Heap size, then the deserialized LongToUnsafeRowMap will be corrupt. This PR will change to use portable address (without Platform.LONG_ARRAY_OFFSET). ## How was this patch tested? Added a test case with random generated keys, to improve the coverage. But this test is not a regression test, that could require a Spark cluster that have at least 32G heap in driver or executor. Author: Davies Liu <davies@databricks.com> Closes #14927 from davies/longmap.
This commit is contained in:
parent
bc2767df26
commit
f7e26d7887
|
@ -447,10 +447,20 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
|
||||||
*/
|
*/
|
||||||
private def nextSlot(pos: Int): Int = (pos + 2) & mask
|
private def nextSlot(pos: Int): Int = (pos + 2) & mask
|
||||||
|
|
||||||
|
private[this] def toAddress(offset: Long, size: Int): Long = {
|
||||||
|
((offset - Platform.LONG_ARRAY_OFFSET) << SIZE_BITS) | size
|
||||||
|
}
|
||||||
|
|
||||||
|
private[this] def toOffset(address: Long): Long = {
|
||||||
|
(address >>> SIZE_BITS) + Platform.LONG_ARRAY_OFFSET
|
||||||
|
}
|
||||||
|
|
||||||
|
private[this] def toSize(address: Long): Int = {
|
||||||
|
(address & SIZE_MASK).toInt
|
||||||
|
}
|
||||||
|
|
||||||
private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
|
private def getRow(address: Long, resultRow: UnsafeRow): UnsafeRow = {
|
||||||
val offset = address >>> SIZE_BITS
|
resultRow.pointTo(page, toOffset(address), toSize(address))
|
||||||
val size = address & SIZE_MASK
|
|
||||||
resultRow.pointTo(page, offset, size.toInt)
|
|
||||||
resultRow
|
resultRow
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -485,9 +495,9 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
|
||||||
var addr = address
|
var addr = address
|
||||||
override def hasNext: Boolean = addr != 0
|
override def hasNext: Boolean = addr != 0
|
||||||
override def next(): UnsafeRow = {
|
override def next(): UnsafeRow = {
|
||||||
val offset = addr >>> SIZE_BITS
|
val offset = toOffset(addr)
|
||||||
val size = addr & SIZE_MASK
|
val size = toSize(addr)
|
||||||
resultRow.pointTo(page, offset, size.toInt)
|
resultRow.pointTo(page, offset, size)
|
||||||
addr = Platform.getLong(page, offset + size)
|
addr = Platform.getLong(page, offset + size)
|
||||||
resultRow
|
resultRow
|
||||||
}
|
}
|
||||||
|
@ -554,7 +564,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
|
||||||
Platform.putLong(page, cursor, 0)
|
Platform.putLong(page, cursor, 0)
|
||||||
cursor += 8
|
cursor += 8
|
||||||
numValues += 1
|
numValues += 1
|
||||||
updateIndex(key, (offset.toLong << SIZE_BITS) | row.getSizeInBytes)
|
updateIndex(key, toAddress(offset, row.getSizeInBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -562,6 +572,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
|
||||||
*/
|
*/
|
||||||
private def updateIndex(key: Long, address: Long): Unit = {
|
private def updateIndex(key: Long, address: Long): Unit = {
|
||||||
var pos = firstSlot(key)
|
var pos = firstSlot(key)
|
||||||
|
assert(numKeys < array.length / 2)
|
||||||
while (array(pos) != key && array(pos + 1) != 0) {
|
while (array(pos) != key && array(pos + 1) != 0) {
|
||||||
pos = nextSlot(pos)
|
pos = nextSlot(pos)
|
||||||
}
|
}
|
||||||
|
@ -582,7 +593,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager, cap
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// there are some values for this key, put the address in the front of them.
|
// there are some values for this key, put the address in the front of them.
|
||||||
val pointer = (address >>> SIZE_BITS) + (address & SIZE_MASK)
|
val pointer = toOffset(address) + toSize(address)
|
||||||
Platform.putLong(page, pointer, array(pos + 1))
|
Platform.putLong(page, pointer, array(pos + 1))
|
||||||
array(pos + 1) = address
|
array(pos + 1) = address
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.joins
|
||||||
|
|
||||||
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
|
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}
|
||||||
|
|
||||||
|
import scala.util.Random
|
||||||
|
|
||||||
import org.apache.spark.{SparkConf, SparkFunSuite}
|
import org.apache.spark.{SparkConf, SparkFunSuite}
|
||||||
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
|
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
|
||||||
import org.apache.spark.serializer.KryoSerializer
|
import org.apache.spark.serializer.KryoSerializer
|
||||||
|
@ -197,6 +199,60 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
test("LongToUnsafeRowMap with random keys") {
|
||||||
|
val taskMemoryManager = new TaskMemoryManager(
|
||||||
|
new StaticMemoryManager(
|
||||||
|
new SparkConf().set("spark.memory.offHeap.enabled", "false"),
|
||||||
|
Long.MaxValue,
|
||||||
|
Long.MaxValue,
|
||||||
|
1),
|
||||||
|
0)
|
||||||
|
val unsafeProj = UnsafeProjection.create(Seq(BoundReference(0, LongType, false)))
|
||||||
|
|
||||||
|
val N = 1000000
|
||||||
|
val rand = new Random
|
||||||
|
val keys = (0 to N).map(x => rand.nextLong()).toArray
|
||||||
|
|
||||||
|
val map = new LongToUnsafeRowMap(taskMemoryManager, 10)
|
||||||
|
keys.foreach { k =>
|
||||||
|
map.append(k, unsafeProj(InternalRow(k)))
|
||||||
|
}
|
||||||
|
map.optimize()
|
||||||
|
|
||||||
|
val os = new ByteArrayOutputStream()
|
||||||
|
val out = new ObjectOutputStream(os)
|
||||||
|
map.writeExternal(out)
|
||||||
|
out.flush()
|
||||||
|
val in = new ObjectInputStream(new ByteArrayInputStream(os.toByteArray))
|
||||||
|
val map2 = new LongToUnsafeRowMap(taskMemoryManager, 1)
|
||||||
|
map2.readExternal(in)
|
||||||
|
|
||||||
|
val row = unsafeProj(InternalRow(0L)).copy()
|
||||||
|
keys.foreach { k =>
|
||||||
|
val r = map2.get(k, row)
|
||||||
|
assert(r.hasNext)
|
||||||
|
var c = 0
|
||||||
|
while (r.hasNext) {
|
||||||
|
val rr = r.next()
|
||||||
|
assert(rr.getLong(0) === k)
|
||||||
|
c += 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var i = 0
|
||||||
|
while (i < N * 10) {
|
||||||
|
val k = rand.nextLong()
|
||||||
|
val r = map2.get(k, row)
|
||||||
|
if (r != null) {
|
||||||
|
assert(r.hasNext)
|
||||||
|
while (r.hasNext) {
|
||||||
|
assert(r.next().getLong(0) === k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i += 1
|
||||||
|
}
|
||||||
|
map.free()
|
||||||
|
}
|
||||||
|
|
||||||
test("Spark-14521") {
|
test("Spark-14521") {
|
||||||
val ser = new KryoSerializer(
|
val ser = new KryoSerializer(
|
||||||
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()
|
(new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()
|
||||||
|
|
Loading…
Reference in a new issue