[SPARK-10065] [SQL] avoid the extra copy when generate unsafe array

The reason for this extra copy is that we iterate the array twice: calculate elements data size and copy elements to array buffer.

A simple solution is to follow `createCodeForStruct`, we can dynamically grow the buffer when needed and thus don't need to know the data size ahead.

This PR also include some typo and style fixes, and did some minor refactor to make sure `input.primitive` is always variable name not code when generate unsafe code.

Author: Wenchen Fan <cloud0fan@outlook.com>

Closes #8496 from cloud-fan/avoid-copy.
This commit is contained in:
Wenchen Fan 2015-09-10 10:04:38 -07:00 committed by Davies Liu
parent 48817cc111
commit 4f1daa1ef6

View file

@ -206,11 +206,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
ctx.addMutableState("UnsafeArrayData", output, s"$output = new UnsafeArrayData();")
val buffer = ctx.freshName("buffer")
ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];")
val tmpBuffer = ctx.freshName("tmpBuffer")
val outputIsNull = ctx.freshName("isNull")
val numElements = ctx.freshName("numElements")
val fixedSize = ctx.freshName("fixedSize")
val numBytes = ctx.freshName("numBytes")
val elements = ctx.freshName("elements")
val cursor = ctx.freshName("cursor")
val index = ctx.freshName("index")
val elementName = ctx.freshName("elementName")
@ -224,57 +224,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
val convertedElement = createConvertCode(ctx, element, elementType)
// go through the input array to calculate how many bytes we need.
val calculateNumBytes = elementType match {
case _ if ctx.isPrimitiveType(elementType) =>
// Should we do word align?
val elementSize = elementType.defaultSize
s"""
$numBytes += $elementSize * $numElements;
"""
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
$numBytes += 8 * $numElements;
"""
case _ =>
val writer = getWriter(elementType)
val elementSize = s"$writer.getSize($elements[$index])"
// TODO(davies): avoid the copy
val unsafeType = elementType match {
case _: StructType => "UnsafeRow"
case _: ArrayType => "UnsafeArrayData"
case _: MapType => "UnsafeMapData"
case _ => ctx.javaType(elementType)
}
val copy = elementType match {
// We reuse the buffer during conversion, need copy it before process next element.
case _: StructType | _: ArrayType | _: MapType => ".copy()"
case _ => ""
}
val newElements = if (elementType == BinaryType) {
s"new byte[$numElements][]"
} else {
s"new $unsafeType[$numElements]"
}
s"""
final $unsafeType[] $elements = $newElements;
for (int $index = 0; $index < $numElements; $index++) {
${convertedElement.code}
if (!${convertedElement.isNull}) {
$elements[$index] = ${convertedElement.primitive}$copy;
$numBytes += $elementSize;
}
}
"""
}
val writeElement = elementType match {
case _ if ctx.isPrimitiveType(elementType) =>
// Should we do word align?
val elementSize = elementType.defaultSize
s"""
${convertedElement.code}
Platform.put${ctx.primitiveTypeName(elementType)}(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
@ -283,7 +237,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
"""
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS =>
s"""
${convertedElement.code}
Platform.putLong(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
@ -296,15 +249,23 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
$cursor += $writer.write(
$buffer,
Platform.BYTE_ARRAY_OFFSET + $cursor,
$elements[$index]);
${convertedElement.primitive});
"""
}
val checkNull = elementType match {
case _ if ctx.isPrimitiveType(elementType) => s"${convertedElement.isNull}"
case t: DecimalType => s"$elements[$index] == null" +
s" || !$elements[$index].changePrecision(${t.precision}, ${t.scale})"
case _ => s"$elements[$index] == null"
val checkNull = convertedElement.isNull + (elementType match {
case t: DecimalType =>
s" || !${convertedElement.primitive}.changePrecision(${t.precision}, ${t.scale})"
case _ => ""
})
val elementSize = elementType match {
// Should we do word align for primitive types?
case _ if ctx.isPrimitiveType(elementType) => elementType.defaultSize.toString
case t: DecimalType if t.precision <= Decimal.MAX_LONG_DIGITS => "8"
case _ =>
val writer = getWriter(elementType)
s"$writer.getSize(${convertedElement.primitive})"
}
val code = s"""
@ -318,18 +279,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
final int $fixedSize = 4 * $numElements;
int $numBytes = $fixedSize;
$calculateNumBytes
if ($numBytes > $buffer.length) {
$buffer = new byte[$numBytes];
}
int $cursor = $fixedSize;
for (int $index = 0; $index < $numElements; $index++) {
${convertedElement.code}
if ($checkNull) {
// If element is null, write the negative value address into offset region.
Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, -$cursor);
} else {
$numBytes += $elementSize;
if ($buffer.length < $numBytes) {
// This will not happen frequently, because the buffer is re-used.
byte[] $tmpBuffer = new byte[$numBytes * 2];
Platform.copyMemory($buffer, Platform.BYTE_ARRAY_OFFSET,
$tmpBuffer, Platform.BYTE_ARRAY_OFFSET, $buffer.length);
$buffer = $tmpBuffer;
}
Platform.putInt($buffer, Platform.BYTE_ARRAY_OFFSET + 4 * $index, $cursor);
$writeElement
}