[SPARK-17528][SQL] data should be copied properly before saving into InternalRow

## What changes were proposed in this pull request?

For performance reasons, `UnsafeRow.getString`, `getStruct`, etc. return a "pointer" that points to a memory region of this unsafe row. This makes the unsafe projection a little dangerous, because all of its output rows share one instance.

When we implement SQL operators, we should be careful to not cache the input rows because they may be produced by unsafe projection from child operator and thus its content may change overtime.

However, when we updating values of InternalRow(e.g. in mutable projection and safe projection), we only copy UTF8String, we should also copy InternalRow, ArrayData and MapData. This PR fixes this, and also fixes the copy of vairous InternalRow, ArrayData and MapData implementations.

## How was this patch tested?

new regression tests

Author: Wenchen Fan <wenchen@databricks.com>

Closes #18483 from cloud-fan/fix-copy.
This commit is contained in:
Wenchen Fan 2017-07-01 09:25:29 +08:00
parent fd13255225
commit 4eb41879ce
18 changed files with 212 additions and 113 deletions

View file

@ -1088,6 +1088,12 @@ public final class UTF8String implements Comparable<UTF8String>, Externalizable,
return fromBytes(getBytes()); return fromBytes(getBytes());
} }
public UTF8String copy() {
byte[] bytes = new byte[numBytes];
copyMemory(base, offset, bytes, BYTE_ARRAY_OFFSET, numBytes);
return fromBytes(bytes);
}
@Override @Override
public int compareTo(@Nonnull final UTF8String other) { public int compareTo(@Nonnull final UTF8String other) {
int len = Math.min(numBytes, other.numBytes); int len = Math.min(numBytes, other.numBytes);

View file

@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst package org.apache.spark.sql.catalyst
import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types.{DataType, Decimal, StructType} import org.apache.spark.sql.types.{DataType, Decimal, StructType}
import org.apache.spark.unsafe.types.UTF8String
/** /**
* An abstract class for row used internally in Spark SQL, which only contains the columns as * An abstract class for row used internally in Spark SQL, which only contains the columns as
@ -33,6 +35,10 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
def setNullAt(i: Int): Unit def setNullAt(i: Int): Unit
/**
* Updates the value at column `i`. Note that after updating, the given value will be kept in this
* row, and the caller side should guarantee that this value won't be changed afterwards.
*/
def update(i: Int, value: Any): Unit def update(i: Int, value: Any): Unit
// default implementation (slow) // default implementation (slow)
@ -58,7 +64,15 @@ abstract class InternalRow extends SpecializedGetters with Serializable {
def copy(): InternalRow def copy(): InternalRow
/** Returns true if there are any NULL values in this row. */ /** Returns true if there are any NULL values in this row. */
def anyNull: Boolean def anyNull: Boolean = {
val len = numFields
var i = 0
while (i < len) {
if (isNullAt(i)) { return true }
i += 1
}
false
}
/* ---------------------- utility methods for Scala ---------------------- */ /* ---------------------- utility methods for Scala ---------------------- */
@ -94,4 +108,15 @@ object InternalRow {
/** Returns an empty [[InternalRow]]. */ /** Returns an empty [[InternalRow]]. */
val empty = apply() val empty = apply()
/**
* Copies the given value if it's string/struct/array/map type.
*/
def copyValue(value: Any): Any = value match {
case v: UTF8String => v.copy()
case v: InternalRow => v.copy()
case v: ArrayData => v.copy()
case v: MapData => v.copy()
case _ => value
}
} }

View file

@ -1047,7 +1047,7 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
final $rowClass $result = new $rowClass(${fieldsCasts.length}); final $rowClass $result = new $rowClass(${fieldsCasts.length});
final InternalRow $tmpRow = $c; final InternalRow $tmpRow = $c;
$fieldsEvalCode $fieldsEvalCode
$evPrim = $result.copy(); $evPrim = $result;
""" """
} }

View file

@ -17,7 +17,6 @@
package org.apache.spark.sql.catalyst.expressions package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.types._ import org.apache.spark.sql.types._
/** /**
@ -220,17 +219,6 @@ final class SpecificInternalRow(val values: Array[MutableValue]) extends BaseGen
override def isNullAt(i: Int): Boolean = values(i).isNull override def isNullAt(i: Int): Boolean = values(i).isNull
override def copy(): InternalRow = {
val newValues = new Array[Any](values.length)
var i = 0
while (i < values.length) {
newValues(i) = values(i).boxed
i += 1
}
new GenericInternalRow(newValues)
}
override protected def genericGet(i: Int): Any = values(i).boxed override protected def genericGet(i: Int): Any = values(i).boxed
override def update(ordinal: Int, value: Any) { override def update(ordinal: Int, value: Any) {

View file

@ -52,7 +52,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper
// Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here. // Do not allow null values. We follow the semantics of Hive's collect_list/collect_set here.
// See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator // See: org.apache.hadoop.hive.ql.udf.generic.GenericUDAFMkCollectionEvaluator
if (value != null) { if (value != null) {
buffer += value buffer += InternalRow.copyValue(value)
} }
buffer buffer
} }

View file

@ -317,6 +317,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
* Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`. * Updates its aggregation buffer, located in `mutableAggBuffer`, based on the given `inputRow`.
* *
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
*
* Note that, the input row may be produced by unsafe projection and it may not be safe to cache
* some fields of the input row, as the values can be changed unexpectedly.
*/ */
def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit def update(mutableAggBuffer: InternalRow, inputRow: InternalRow): Unit
@ -326,6 +329,9 @@ abstract class ImperativeAggregate extends AggregateFunction with CodegenFallbac
* *
* Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`. * Use `fieldNumber + mutableAggBufferOffset` to access fields of `mutableAggBuffer`.
* Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`.
*
* Note that, the input row may be produced by unsafe projection and it may not be safe to cache
* some fields of the input row, as the values can be changed unexpectedly.
*/ */
def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit def merge(mutableAggBuffer: InternalRow, inputAggBuffer: InternalRow): Unit
} }

View file

@ -408,9 +408,11 @@ class CodegenContext {
dataType match { dataType match {
case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)" case _ if isPrimitiveType(jt) => s"$row.set${primitiveTypeName(jt)}($ordinal, $value)"
case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})" case t: DecimalType => s"$row.setDecimal($ordinal, $value, ${t.precision})"
// The UTF8String may came from UnsafeRow, otherwise clone is cheap (re-use the bytes)
case StringType => s"$row.update($ordinal, $value.clone())"
case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value) case udt: UserDefinedType[_] => setColumn(row, udt.sqlType, ordinal, value)
// The UTF8String, InternalRow, ArrayData and MapData may came from UnsafeRow, we should copy
// it to avoid keeping a "pointer" to a memory region which may get updated afterwards.
case StringType | _: StructType | _: ArrayType | _: MapType =>
s"$row.update($ordinal, $value.copy())"
case _ => s"$row.update($ordinal, $value)" case _ => s"$row.update($ordinal, $value)"
} }
} }

View file

@ -131,8 +131,6 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection]
case s: StructType => createCodeForStruct(ctx, input, s) case s: StructType => createCodeForStruct(ctx, input, s)
case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType) case ArrayType(elementType, _) => createCodeForArray(ctx, input, elementType)
case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType) case MapType(keyType, valueType, _) => createCodeForMap(ctx, input, keyType, valueType)
// UTF8String act as a pointer if it's inside UnsafeRow, so copy it to make it safe.
case StringType => ExprCode("", "false", s"$input.clone()")
case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType) case udt: UserDefinedType[_] => convertToSafe(ctx, input, udt.sqlType)
case _ => ExprCode("", "false", input) case _ => ExprCode("", "false", input)
} }

View file

@ -50,16 +50,6 @@ trait BaseGenericInternalRow extends InternalRow {
override def getMap(ordinal: Int): MapData = getAs(ordinal) override def getMap(ordinal: Int): MapData = getAs(ordinal)
override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal)
override def anyNull: Boolean = {
val len = numFields
var i = 0
while (i < len) {
if (isNullAt(i)) { return true }
i += 1
}
false
}
override def toString: String = { override def toString: String = {
if (numFields == 0) { if (numFields == 0) {
"[empty row]" "[empty row]"
@ -79,6 +69,17 @@ trait BaseGenericInternalRow extends InternalRow {
} }
} }
override def copy(): GenericInternalRow = {
val len = numFields
val newValues = new Array[Any](len)
var i = 0
while (i < len) {
newValues(i) = InternalRow.copyValue(genericGet(i))
i += 1
}
new GenericInternalRow(newValues)
}
override def equals(o: Any): Boolean = { override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[BaseGenericInternalRow]) { if (!o.isInstanceOf[BaseGenericInternalRow]) {
return false return false
@ -206,6 +207,4 @@ class GenericInternalRow(val values: Array[Any]) extends BaseGenericInternalRow
override def setNullAt(i: Int): Unit = { values(i) = null} override def setNullAt(i: Int): Unit = { values(i) = null}
override def update(i: Int, value: Any): Unit = { values(i) = value } override def update(i: Int, value: Any): Unit = { values(i) = value }
override def copy(): GenericInternalRow = this
} }

View file

@ -49,7 +49,15 @@ class GenericArrayData(val array: Array[Any]) extends ArrayData {
def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray)) def this(seqOrArray: Any) = this(GenericArrayData.anyToSeq(seqOrArray))
override def copy(): ArrayData = new GenericArrayData(array.clone()) override def copy(): ArrayData = {
val newValues = new Array[Any](array.length)
var i = 0
while (i < array.length) {
newValues(i) = InternalRow.copyValue(array(i))
i += 1
}
new GenericArrayData(newValues)
}
override def numElements(): Int = array.length override def numElements(): Int = array.length

View file

@ -121,10 +121,6 @@ class RowTest extends FunSpec with Matchers {
externalRow should be theSameInstanceAs externalRow.copy() externalRow should be theSameInstanceAs externalRow.copy()
} }
it("copy should return same ref for internal rows") {
internalRow should be theSameInstanceAs internalRow.copy()
}
it("toSeq should not expose internal state for external rows") { it("toSeq should not expose internal state for external rows") {
val modifiedValues = modifyValues(externalRow.toSeq) val modifiedValues = modifyValues(externalRow.toSeq)
externalRow.toSeq should not equal modifiedValues externalRow.toSeq should not equal modifiedValues

View file

@ -1,57 +0,0 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import scala.collection._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String
class MapDataSuite extends SparkFunSuite {
test("inequality tests") {
def u(str: String): UTF8String = UTF8String.fromString(str)
// test data
val testMap1 = Map(u("key1") -> 1)
val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
val testMap3 = Map(u("key1") -> 1)
val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)
// ArrayBasedMapData
val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
assert(testArrayMap1 !== testArrayMap3)
assert(testArrayMap2 !== testArrayMap4)
// UnsafeMapData
val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
val row = new GenericInternalRow(1)
def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
row.update(0, map)
val unsafeRow = unsafeConverter.apply(row)
unsafeRow.getMap(0).copy
}
assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
}
}

View file

@ -172,4 +172,40 @@ class GeneratedProjectionSuite extends SparkFunSuite {
assert(unsafe1 === unsafe3) assert(unsafe1 === unsafe3)
assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7)) assert(unsafe1.getStruct(1, 7) === unsafe3.getStruct(1, 7))
} }
test("MutableProjection should not cache content from the input row") {
val mutableProj = GenerateMutableProjection.generate(
Seq(BoundReference(0, new StructType().add("i", StringType), true)))
val row = new GenericInternalRow(1)
mutableProj.target(row)
val unsafeProj = GenerateUnsafeProjection.generate(
Seq(BoundReference(0, new StructType().add("i", StringType), true)))
val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a"))))
mutableProj.apply(unsafeRow)
assert(row.getStruct(0, 1).getString(0) == "a")
// Even if the input row of the mutable projection has been changed, the target mutable row
// should keep same.
unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b"))))
assert(row.getStruct(0, 1).getString(0).toString == "a")
}
test("SafeProjection should not cache content from the input row") {
val safeProj = GenerateSafeProjection.generate(
Seq(BoundReference(0, new StructType().add("i", StringType), true)))
val unsafeProj = GenerateUnsafeProjection.generate(
Seq(BoundReference(0, new StructType().add("i", StringType), true)))
val unsafeRow = unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("a"))))
val row = safeProj.apply(unsafeRow)
assert(row.getStruct(0, 1).getString(0) == "a")
// Even if the input row of the mutable projection has been changed, the target mutable row
// should keep same.
unsafeProj.apply(InternalRow(InternalRow(UTF8String.fromString("b"))))
assert(row.getStruct(0, 1).getString(0).toString == "a")
}
} }

View file

@ -0,0 +1,107 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.util
import scala.collection._
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{BoundReference, GenericInternalRow, SpecificInternalRow, UnsafeMapData, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String
class ComplexDataSuite extends SparkFunSuite {
def utf8(str: String): UTF8String = UTF8String.fromString(str)
test("inequality tests for MapData") {
// test data
val testMap1 = Map(utf8("key1") -> 1)
val testMap2 = Map(utf8("key1") -> 1, utf8("key2") -> 2)
val testMap3 = Map(utf8("key1") -> 1)
val testMap4 = Map(utf8("key1") -> 1, utf8("key2") -> 2)
// ArrayBasedMapData
val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
assert(testArrayMap1 !== testArrayMap3)
assert(testArrayMap2 !== testArrayMap4)
// UnsafeMapData
val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
val row = new GenericInternalRow(1)
def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
row.update(0, map)
val unsafeRow = unsafeConverter.apply(row)
unsafeRow.getMap(0).copy
}
assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
}
test("GenericInternalRow.copy return a new instance that is independent from the old one") {
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
val unsafeRow = project.apply(InternalRow(utf8("a")))
val genericRow = new GenericInternalRow(Array[Any](unsafeRow.getUTF8String(0)))
val copiedGenericRow = genericRow.copy()
assert(copiedGenericRow.getString(0) == "a")
project.apply(InternalRow(UTF8String.fromString("b")))
// The copied internal row should not be changed externally.
assert(copiedGenericRow.getString(0) == "a")
}
test("SpecificMutableRow.copy return a new instance that is independent from the old one") {
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
val unsafeRow = project.apply(InternalRow(utf8("a")))
val mutableRow = new SpecificInternalRow(Seq(StringType))
mutableRow(0) = unsafeRow.getUTF8String(0)
val copiedMutableRow = mutableRow.copy()
assert(copiedMutableRow.getString(0) == "a")
project.apply(InternalRow(UTF8String.fromString("b")))
// The copied internal row should not be changed externally.
assert(copiedMutableRow.getString(0) == "a")
}
test("GenericArrayData.copy return a new instance that is independent from the old one") {
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
val unsafeRow = project.apply(InternalRow(utf8("a")))
val genericArray = new GenericArrayData(Array[Any](unsafeRow.getUTF8String(0)))
val copiedGenericArray = genericArray.copy()
assert(copiedGenericArray.getUTF8String(0).toString == "a")
project.apply(InternalRow(UTF8String.fromString("b")))
// The copied array data should not be changed externally.
assert(copiedGenericArray.getUTF8String(0).toString == "a")
}
test("copy on nested complex type") {
val project = GenerateUnsafeProjection.generate(Seq(BoundReference(0, StringType, true)))
val unsafeRow = project.apply(InternalRow(utf8("a")))
val arrayOfRow = new GenericArrayData(Array[Any](InternalRow(unsafeRow.getUTF8String(0))))
val copied = arrayOfRow.copy()
assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
project.apply(InternalRow(UTF8String.fromString("b")))
// The copied data should not be changed externally.
assert(copied.getStruct(0, 1).getUTF8String(0).toString == "a")
}
}

View file

@ -149,7 +149,7 @@ public final class ColumnarBatch {
} else if (dt instanceof DoubleType) { } else if (dt instanceof DoubleType) {
row.setDouble(i, getDouble(i)); row.setDouble(i, getDouble(i));
} else if (dt instanceof StringType) { } else if (dt instanceof StringType) {
row.update(i, getUTF8String(i)); row.update(i, getUTF8String(i).copy());
} else if (dt instanceof BinaryType) { } else if (dt instanceof BinaryType) {
row.update(i, getBinary(i)); row.update(i, getBinary(i));
} else if (dt instanceof DecimalType) { } else if (dt instanceof DecimalType) {

View file

@ -86,17 +86,6 @@ class SortBasedAggregationIterator(
// The aggregation buffer used by the sort-based aggregation. // The aggregation buffer used by the sort-based aggregation.
private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer private[this] val sortBasedAggregationBuffer: InternalRow = newBuffer
// This safe projection is used to turn the input row into safe row. This is necessary
// because the input row may be produced by unsafe projection in child operator and all the
// produced rows share one byte array. However, when we update the aggregate buffer according to
// the input row, we may cache some values from input row, e.g. `Max` will keep the max value from
// input row via MutableProjection, `CollectList` will keep all values in an array via
// ImperativeAggregate framework. These values may get changed unexpectedly if the underlying
// unsafe projection update the shared byte array. By applying a safe projection to the input row,
// we can cut down the connection from input row to the shared byte array, and thus it's safe to
// cache values from input row while updating the aggregation buffer.
private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType))
protected def initialize(): Unit = { protected def initialize(): Unit = {
if (inputIterator.hasNext) { if (inputIterator.hasNext) {
initializeBuffer(sortBasedAggregationBuffer) initializeBuffer(sortBasedAggregationBuffer)
@ -119,7 +108,7 @@ class SortBasedAggregationIterator(
// We create a variable to track if we see the next group. // We create a variable to track if we see the next group.
var findNextPartition = false var findNextPartition = false
// firstRowInNextGroup is the first row of this group. We first process it. // firstRowInNextGroup is the first row of this group. We first process it.
processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) processRow(sortBasedAggregationBuffer, firstRowInNextGroup)
// The search will stop when we see the next group or there is no // The search will stop when we see the next group or there is no
// input row left in the iter. // input row left in the iter.
@ -130,7 +119,7 @@ class SortBasedAggregationIterator(
// Check if the current row belongs the current input row. // Check if the current row belongs the current input row.
if (currentGroupingKey == groupingKey) { if (currentGroupingKey == groupingKey) {
processRow(sortBasedAggregationBuffer, safeProj(currentRow)) processRow(sortBasedAggregationBuffer, currentRow)
} else { } else {
// We find a new group. // We find a new group.
findNextPartition = true findNextPartition = true

View file

@ -56,7 +56,6 @@ class MutableUnsafeRow(val writer: UnsafeRowWriter) extends BaseGenericInternalR
// all other methods inherited from GenericMutableRow are not need // all other methods inherited from GenericMutableRow are not need
override protected def genericGet(ordinal: Int): Any = throw new UnsupportedOperationException override protected def genericGet(ordinal: Int): Any = throw new UnsupportedOperationException
override def numFields: Int = throw new UnsupportedOperationException override def numFields: Int = throw new UnsupportedOperationException
override def copy(): InternalRow = throw new UnsupportedOperationException
} }
/** /**

View file

@ -145,13 +145,10 @@ private[window] final class AggregateProcessor(
/** Update the buffer. */ /** Update the buffer. */
def update(input: InternalRow): Unit = { def update(input: InternalRow): Unit = {
// TODO(hvanhovell) this sacrifices performance for correctness. We should make sure that updateProjection(join(buffer, input))
// MutableProjection makes copies of the complex input objects it buffer.
val copy = input.copy()
updateProjection(join(buffer, copy))
var i = 0 var i = 0
while (i < numImperatives) { while (i < numImperatives) {
imperatives(i).update(buffer, copy) imperatives(i).update(buffer, input)
i += 1 i += 1
} }
} }